├── .gitignore ├── LICENSE ├── README.md ├── conda_env.yml ├── docker-gym-mujocov2 ├── Dockerfile └── README.md ├── docker-gym-mujocov4 ├── Dockerfile └── README.md ├── experiments ├── grid_sample_hpc_script.sh ├── grid_train_redq_sac.py ├── grid_utils.py ├── sample_hpc_script.sh └── train_redq_sac.py ├── mujoco_download.sh ├── plot_utils ├── plot_REDQ.py └── redq_plot_helper.py ├── redq ├── __init__.py ├── algos │ ├── __init__.py │ ├── core.py │ └── redq_sac.py ├── user_config.py └── utils │ ├── __init__.py │ ├── bias_utils.py │ ├── logx.py │ ├── run_utils.py │ └── serialization_utils.py └── setup.py /.gitignore: -------------------------------------------------------------------------------- 1 | # ignore data and extra python libraries 2 | data/ 3 | extra-python-libs/ 4 | 5 | # ignore experiment files that are not useful to track 6 | hpc_experiments/ 7 | 8 | .idea/ 9 | 10 | # Byte-compiled / optimized / DLL files 11 | __pycache__/ 12 | *.py[cod] 13 | *$py.class 14 | 15 | # C extensions 16 | *.so 17 | 18 | # Distribution / packaging 19 | .Python 20 | build/ 21 | develop-eggs/ 22 | dist/ 23 | downloads/ 24 | eggs/ 25 | .eggs/ 26 | lib/ 27 | lib64/ 28 | parts/ 29 | sdist/ 30 | var/ 31 | wheels/ 32 | pip-wheel-metadata/ 33 | share/python-wheels/ 34 | *.egg-info/ 35 | .installed.cfg 36 | *.egg 37 | MANIFEST 38 | 39 | # PyInstaller 40 | # Usually these files are written by a python script from a template 41 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 42 | *.manifest 43 | *.spec 44 | 45 | # Installer logs 46 | pip-log.txt 47 | pip-delete-this-directory.txt 48 | 49 | # Unit test / coverage reports 50 | htmlcov/ 51 | .tox/ 52 | .nox/ 53 | .coverage 54 | .coverage.* 55 | .cache 56 | nosetests.xml 57 | coverage.xml 58 | *.cover 59 | *.py,cover 60 | .hypothesis/ 61 | .pytest_cache/ 62 | 63 | # Translations 64 | *.mo 65 | *.pot 66 | 67 | # Django stuff: 68 | *.log 69 | local_settings.py 70 | db.sqlite3 71 | db.sqlite3-journal 72 | 73 | # Flask stuff: 74 | instance/ 75 | .webassets-cache 76 | 77 | # Scrapy stuff: 78 | .scrapy 79 | 80 | # Sphinx documentation 81 | docs/_build/ 82 | 83 | # PyBuilder 84 | target/ 85 | 86 | # Jupyter Notebook 87 | .ipynb_checkpoints 88 | 89 | # IPython 90 | profile_default/ 91 | ipython_config.py 92 | 93 | # pyenv 94 | .python-version 95 | 96 | # pipenv 97 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 98 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 99 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 100 | # install all needed dependencies. 101 | #Pipfile.lock 102 | 103 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 104 | __pypackages__/ 105 | 106 | # Celery stuff 107 | celerybeat-schedule 108 | celerybeat.pid 109 | 110 | # SageMath parsed files 111 | *.sage.py 112 | 113 | # Environments 114 | .env 115 | .venv 116 | env/ 117 | venv/ 118 | ENV/ 119 | env.bak/ 120 | venv.bak/ 121 | 122 | # Spyder project settings 123 | .spyderproject 124 | .spyproject 125 | 126 | # Rope project settings 127 | .ropeproject 128 | 129 | # mkdocs documentation 130 | /site 131 | 132 | # mypy 133 | .mypy_cache/ 134 | .dmypy.json 135 | dmypy.json 136 | 137 | # Pyre type checker 138 | .pyre/ 139 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 watchernyu 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # REDQ source code 2 | Author's PyTorch implementation of Randomized Ensembled Double Q-Learning (REDQ) algorithm. Paper link: https://arxiv.org/abs/2101.05982 3 | 4 | 5 | 6 | ## Table of Contents 7 | - [Table of contents](#table-of-contents) 8 | - [Code structure explained](#code-structure) 9 | - [Implementation video tutorial](#video-tutorial) 10 | - [Data and reproducing figures in REDQ](#reproduce-figures) 11 | - [Train an REDQ agent](#train-redq) 12 | - [Implement REDQ](#implement-redq) 13 | - [Reproduce the results](#reproduce-results) 14 | - [Newest Docker + Singularity setup for Gym + MuJoCo v2 and v4](#setup-dockersing) 15 | - [(outdated) Environment setup MuJoCo 2.1, v4 tasks, NYU HPC 18.04](#setup-nyuhpc-new) 16 | - [(outdated) Environment setup MuJoCo 2.1, Ubuntu 18.04](#setup-ubuntu) 17 | - [(outdated) Environment setup MuJoCo 2.1, NYU Shanghai HPC](#setup-nyuhpc) 18 | - [(outdated) Environment setup](#setup-old) 19 | - [Acknowledgement](#acknowledgement) 20 | - [Previous updates](#prevupdates) 21 | 22 | Feb 20, 2023: An updated [docker + singularity setup](#setup-dockersing) is now available. This is probably the easiest set up ever and allows you to use docker to start running your DRL experiments with just 3 commands. We have also released new dockerfiles for gym + mujoco v2 and v4 environments (in the newest version of the repo, you will see 2 folders (`docker-gym-mujocov2`, `docker-gym-mujocov4`) each containing a dockerfile). 23 | 24 | 25 | 26 | ## Code structure explained 27 | The code structure is pretty simple and should be easy to follow. 28 | 29 | In `experiments/train_redq_sac.py` you will find the main training loop. Here we set up the environment, initialize an instance of the `REDQSACAgent` class, specifying all the hyperparameters and train the agent. You can run this file to train a REDQ agent. 30 | 31 | In `redq/algos/redq_sac.py` we provide code for the `REDQSACAgent` class. If you are trying to take a look at how the core components of REDQ are implemented, the most important function is the `train()` function. 32 | 33 | In `redq/algos/core.py` we provide code for some basic classes (Q network, policy network, replay buffer) and some helper functions. These classes and functions are used by the REDQ agent class. 34 | 35 | In `redq/utils` there are some utility classes (such as a logger) and helper functions that largely have nothing to do with REDQ's core components. In `redq/utils/bias_utils.py` you can find utility functions to get bias estimation (bias estimate is computed roughly as: Monte Carlo return - current Q estimate). In `experiments/train_redq_sac.py` you can decide whether you want bias evaluation when running the experiment by setting the `evaluate_bias` flag (this will lead to some minor computation overhead). 36 | 37 | In `plot_utils` there are some utility functions to reproduce the figures we presented in the paper. (See the section on "Data and reproducing figures in REDQ") 38 | 39 | 40 | 41 | 42 | ## Implementation video tutorial 43 | Here is the link to a video tutorial we created that explains the REDQ implementation in detail: 44 | 45 | [REDQ code explained video tutorial (Google Drive Link)](https://drive.google.com/file/d/1QabNOz8VDyKhI0K7tPbIkmx4LKw6N6Te/view?usp=sharing) 46 | 47 | 48 | 49 | ## Data and reproducing figures in REDQ 50 | The data used to produce the figures in the REDQ paper can be downloaded here: 51 | [REDQ DATA download link](https://drive.google.com/file/d/11mjDYCzp3T1MaICGruySrWd-akr-SVJp/view?usp=sharing) (Google Drive Link, ~80 MB) 52 | 53 | To reproduce the figures, first download the data, and then extract the zip file to `REDQ/data`. So now a folder called `REDQ_ICLR21` should be at this path: `REDQ/data/REDQ_ICLR21`. 54 | 55 | Then you can go into the `plot_utils` folder, and run the `plot_REDQ.py` program there. You will need `seaborn==0.8.1` to run it correctly. We might update the code later so that it works for newer versions but currently seaborn newer than 0.8.1 is not supported. If you don't want to mess up existing conda or python virtual environments, you can create a new environment and simply install seaborn 0.8.1 there and use it to run the program. 56 | 57 | If you encounter any problem or cannot access the data (can't use google or can't download), please open an issue to let us know! Thanks! 58 | 59 | 60 | 61 | ## Environment setup (old guide, for the newest guide, see end of this page) 62 | 63 | **VERY IMPORTANT**: because MuJoCo is now free, the setup guide here is slightly outdated (this is the setup we used when we run our experiments for the REDQ paper), we now provide a newer updated setup guide that uses the newest MuJoCo, please see the end of the this page. 64 | 65 | Note: you don't need to exactly follow the tutorial here if you know well about how to install python packages. 66 | 67 | First create a conda environment and activate it: 68 | ``` 69 | conda create -n redq python=3.6 70 | conda activate redq 71 | ``` 72 | 73 | Install PyTorch (or you can follow the tutorial on PyTorch official website). 74 | On Ubuntu (might also work on Windows but is not fully tested): 75 | ``` 76 | conda install pytorch==1.3.1 torchvision==0.4.2 cudatoolkit=10.1 -c pytorch 77 | ``` 78 | On OSX: 79 | ``` 80 | conda install pytorch==1.3.1 torchvision==0.4.2 -c pytorch 81 | ``` 82 | 83 | Install gym (0.17.2): 84 | ``` 85 | git clone https://github.com/openai/gym.git 86 | cd gym 87 | git checkout b2727d6 88 | pip install -e . 89 | cd .. 90 | ``` 91 | 92 | Install mujoco_py (2.0.2.1): 93 | ``` 94 | git clone https://github.com/openai/mujoco-py 95 | cd mujoco-py 96 | git checkout 379bb19 97 | pip install -e . --no-cache 98 | cd .. 99 | ``` 100 | 101 | For gym and mujoco_py, depending on your system, you might need to install some other packages, if you run into such problems, please refer to their official sites for guidance. 102 | If you want to test on Mujoco environments, you will also need to get Mujoco files and license from Mujoco website. Please refer to the Mujoco website for how to do this correctly. 103 | 104 | Clone and install this repository (Although even if you don't install it you might still be able to use the code): 105 | ``` 106 | git clone https://github.com/watchernyu/REDQ.git 107 | cd REDQ 108 | pip install -e . 109 | ``` 110 | 111 | 112 | 113 | ## Train an REDQ agent 114 | To train an REDQ agent, run: 115 | ``` 116 | python experiments/train_redq_sac.py 117 | ``` 118 | On a 2080Ti GPU, running Hopper to 125K will approximately take 10-12 hours. Running Humanoid to 300K will approximately take 26 hours. 119 | 120 | 121 | 122 | ## Implement REDQ 123 | As discussed in the paper, we obtain REDQ by making minimal changes to a Soft Actor-Critic (SAC) baseline. You can easily modify your SAC code to get REDQ: (a) use an update-to-data (UTD) ratio > 1, (b) have > 2 Q networks, (c) when computing the Q target, randomly select a subset of Q target networks, take their min. 124 | 125 | If you intend to implement REDQ on your codebase, please refer to the paper and the [video tutorial](#video-tutorial) for guidance. In particular, in Appendix B of the paper, we discussed hyperparameters and some additional implementation details. One important detail is in the beginning of the training, for the first 5000 data points, we sample random action from the action space and do not perform any updates. If you perform a large number of updates with a very small amount of data, it can lead to severe bias accumulation and can negatively affect the performance. 126 | 127 | For REDQ-OFE, as mentioned in the paper, for some reason adding PyTorch batch norm to OFENet will lead to divergence. So in the end we did not use batch norm in our code. 128 | 129 | 130 | 131 | ## Reproduce the results 132 | If you use a different PyTorch version, it might still work, however, it might be better if your version is close to the ones we used. We have found that for example, on Ant environment, PyTorch 1.3 and 1.2 give quite different results. The reason is not entirely clear. 133 | 134 | Other factors such as versions of other packages (for example numpy) or environment (mujoco/gym) or even types of hardware (cpu/gpu) can also affect the final results. Thus reproducing exactly the same results can be difficult. However, if the package versions are the same, when averaged over a large number of random seeds, the overall performance should be similar to those reported in the paper. 135 | 136 | As of Mar. 29, 2021, we have used the installation guide on this page to re-setup a conda environment and run the code hosted on this repo and the reproduced results are similar to what we have in the paper (though not exactly the same, in some environments, performance are a bit stronger and others a bit weaker). 137 | 138 | Please open an issue if you find any problems in the code, thanks! 139 | 140 | 141 | 142 | ## Environment setup with MuJoCo and OpenAI Gym v2/V4 tasks, with Docker or Singularity 143 | This is a new 2023 Guide that is based on Docker and Singularity. (currently under more testing) 144 | 145 | Local setup: simply build a docker container with the dockerfile (either the v2 or the v4 version, depending on your need) provided in this repo (it basically specifies what you need to do to install all dependencies starting with a ubuntu18 system. You can also easily modify it to your needs). 146 | 147 | To get things to run very quick (with v2 gym-mujoco environments), simply pull this from my dockerhub: `docker pull cwatcherw/gym-mujocov2:1.0` 148 | 149 | After you pull the docker container, you can quickly test it: 150 | ``` 151 | docker run -it --rm cwatcherw/gym-mujocov2:1.0 152 | ``` 153 | Once you are inside the container, run: 154 | 155 | ``` 156 | cd /workspace/REDQ/experiments/ 157 | python train_redq_sac.py 158 | ``` 159 | 160 | (Alternatively, remove `--rm` flag so the container is kept after shutting down, or add `--gpus all` to use GPU. ) 161 | 162 | If you want to modify the REDQ codebase to test new ideas, you can clone (a fork of) REDQ repo to a local directory, and then mount it to `/workspace/REDQ`. For example: 163 | 164 | ``` 165 | docker run -it --rm --mount type=bind,source=$(pwd)/REDQ,target=/workspace/REDQ cwatcherw/gym-mujocov2:1.0 166 | ``` 167 | 168 | Example setup if you want to run on a Slurm HPC with singularity (you might need to make changes, depending on your HPC settings): 169 | 170 | First time setup: 171 | ``` 172 | mkdir /scratch/$USER/.sing_cache 173 | export SINGULARITY_CACHEDIR=/scratch/$USER/.sing_cache 174 | echo "export SINGULARITY_CACHEDIR=/scratch/$USER/.sing_cache" >> ~/.bashrc 175 | mkdir /scratch/$USER/sing 176 | cd /scratch/$USER/sing 177 | git clone https://github.com/watchernyu/REDQ.git 178 | ``` 179 | 180 | Build singularity and run singularity: 181 | ``` 182 | module load singularity 183 | cd /scratch/$USER/sing/ 184 | singularity build --sandbox mujoco-sandbox docker://cwatcherw/gym-mujocov2:1.0 185 | singularity exec -B /scratch/$USER/sing/REDQ:/workspace/REDQ -B /scratch/$USER/sing/mujoco-sandbox/opt/conda/lib/python3.8/site-packages/mujoco_py/:/opt/conda/lib/python3.8/site-packages/mujoco_py/ /scratch/$USER/sing/mujoco-sandbox bash 186 | 187 | singularity exec -B REDQ/:/workspace/REDQ/ mujoco-sandbox bash 188 | ``` 189 | 190 | Don't forget to use following command to tell your mujoco to use egl headless rendering if you need to do rendering or use visual input. (this is required when you run on the hpc or other headless machines.) 191 | ``` 192 | export MUJOCO_GL=egl 193 | ``` 194 | 195 | Sample command to open an interactive session for debugging: 196 | ``` 197 | srun -p parallel --pty --mem 12000 -t 0-05:00 bash 198 | ``` 199 | 200 | 201 | 202 | 203 | ## Environment setup with newest MuJoCo and OpenAI Gym V4 tasks, on the NYU Shanghai hpc cluster (system is CentOS Linux release 7.4.1708, hpc management is Slurm) 204 | This one is the newest guide (2022 Summer) that helps you set up for Gym V4 MuJoCo tasks. And this newer version of Gym MuJoCo tasks is much easier to set up compared to previous versions. If you have limited CS background, when following these steps, make sure you don't perform extra commands in between steps. 205 | 206 | 1. to avoid storage space issues, we will work under the scratch partition (change the `netid` to your netid), we first clone the REDQ repo and `cd` into the REDQ folder. 207 | ``` 208 | cd /scratch/NETID/ 209 | git clone https://github.com/watchernyu/REDQ.git 210 | cd REDQ 211 | ``` 212 | 213 | 2. download MuJoCo files by running the provided script. 214 | ``` 215 | bash mujoco_download.sh 216 | ``` 217 | 218 | 3. load the anaconda module (since we are on the HPC, we need to use `module load` to get certain software, instead of installing them ourselves), create and activate a conda environment using the yaml file provided. (Again don't forget to change the `netid` part, and the second line is to avoid overly long text on the terminal) 219 | ``` 220 | module load anaconda3 221 | conda config --set env_prompt '({name})' 222 | conda env create -f conda_env.yml --prefix /scratch/NETID/redq_env 223 | conda activate /scratch/NETID/redq_env 224 | ``` 225 | 226 | 4. install redq 227 | ``` 228 | pip install -e . 229 | ``` 230 | 231 | 5. run a test script, but make sure you filled in your actual netid in `experiments/sample_hpc_script.sh` so that the script can correctly locate your conda environment. 232 | ``` 233 | cd experiments 234 | sbatch sample_hpc_script.sh 235 | ``` 236 | 237 | 238 | 239 | ## Environment setup with newest MuJoCo 2.1, on a Ubuntu 18.04 local machine 240 | First download MuJoCo files, on a linux machine, we put them under ~/.mujoco: 241 | ``` 242 | cd ~ 243 | mkdir ~/.mujoco 244 | cd ~/.mujoco 245 | curl -O https://mujoco.org/download/mujoco210-linux-x86_64.tar.gz 246 | tar -xf mujoco210-linux-x86_64.tar.gz 247 | ``` 248 | 249 | Now we create a conda environment (you will need anaconda), and install pytorch (if you just want mujoco+gym and don't want pytorch, then skip this step) 250 | ``` 251 | conda create -y -n redq python=3.8 252 | conda activate redq 253 | conda install -y pytorch torchvision torchaudio cudatoolkit=11.3 -c pytorch 254 | ``` 255 | 256 | Now install mujoco_py: 257 | ``` 258 | cd ~ 259 | mkdir rl_course 260 | cd rl_course 261 | git clone https://github.com/openai/mujoco-py.git 262 | cd mujoco-py/ 263 | pip install -e . --no-cache 264 | ``` 265 | A list of packages that need to be installed for linux is here https://github.com/openai/mujoco-py/blob/master/Dockerfile 266 | 267 | Now test by running python and then `import mujoco_py`, typically you will run into some error message, check that Dockerfile to see if you are missing any of the required packages (either python package or system package). 268 | 269 | If mujoco works, then install REDQ: 270 | 271 | ``` 272 | cd ~ 273 | cd rl_course 274 | git clone https://github.com/watchernyu/REDQ.git 275 | cd REDQ/ 276 | pip install -e . 277 | ``` 278 | 279 | Now test REDQ by running: 280 | ``` 281 | python experiments/train_redq_sac.py --debug 282 | ``` 283 | If you see training logs, then the environment should be setup correctly! 284 | 285 | 286 | 287 | 288 | ## Environment setup with newest MuJoCo 2.1, on the NYU Shanghai hpc cluster (system is Linux, hpc management is Slurm) 289 | This guide helps you set up MuJoCo and then OpenAI Gym, and then REDQ. (You can also follow the guide if you just want OpenAI Gym + MuJoCo and not REDQ, REDQ is only the last step). This likely also works for NYU NY hpc cluster, and might also works for hpc cluster in other schools, assuming your hpc is linux and is using Slurm. 290 | 291 | ### conda init 292 | First we need to login to the hpc. 293 | ``` 294 | ssh netid@hpc.shanghai.nyu.edu 295 | ``` 296 | After this you should be on the login node of the hpc. Note the login node is different from a compute node, we will set up environment on the login node, then when we submit actual jobs, they are in fact run on the compute node. 297 | 298 | Note that on the hpc, students typically don't have admin privileges (which means you cannot install things that require `sudo`), so for some of the required system packages, we will not use the typical `sudo apt install` command, instead, we will use `module avail` to check if they are available, and then use `module load` to load them. If on your hpc cluster, a system package is not there, check with your hpc admin and ask them to help you. 299 | 300 | On the NYU Shanghai hpc (after you ssh, you get to the login node), first we want to set up conda correctly (typically need to do this for new accounts): 301 | ``` 302 | module load anaconda3 303 | conda init bash 304 | ``` 305 | Now use Ctrl + D to logout, and then login again. 306 | 307 | ### set up MuJoCo 308 | 309 | Now we are again on the hpc login node simply run this to load all required packages: 310 | ``` 311 | module load anaconda3 cuda/11.3.1 312 | ``` 313 | 314 | Now download MuJoCo files, on a linux machine, we put them under ~/.mujoco: 315 | ``` 316 | cd ~ 317 | mkdir ~/.mujoco 318 | cd ~/.mujoco 319 | curl -O https://mujoco.org/download/mujoco210-linux-x86_64.tar.gz 320 | tar -xf mujoco210-linux-x86_64.tar.gz 321 | ``` 322 | 323 | ### set up conda environment 324 | Then set up a conda virtualenv, and activate it (you can give it a different name, the env name does not matter) 325 | ``` 326 | conda create -y -n redq python=3.8 327 | conda activate redq 328 | ``` 329 | 330 | Install Pytorch (skip this step if you don't need pytorch) 331 | ``` 332 | conda install -y pytorch torchvision torchaudio cudatoolkit=11.3 -c pytorch 333 | ``` 334 | 335 | After you installed Pytorch, check if it works by running `python`, then `import torch` in the python interpreter. 336 | If Pytorch works, then either run `quit()` or Ctrl + D to exit the python interpreter. 337 | 338 | ### set up `mujoco_py` 339 | 340 | The next is to install MuJoCo, let's first run these 341 | ``` 342 | echo 'export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$HOME/.mujoco/mujoco210/bin' >> ~/.bashrc 343 | echo 'export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/lib/nvidia' >> ~/.bashrc 344 | ``` 345 | (it might be easier to do this step here. However, you can also set this up later when you test `import mujoco_py`) 346 | 347 | Now we want the `.bashrc` file to take effect, we need to ctrl+D to logout, and then login again, after we logout and login, we need to reload all the modules , and then also activate the conda env again. (each time you login, the loaded modules and activated environment will be reset, but files on disk will persist) After you login to the hpc again: 348 | ``` 349 | module load anaconda3 cuda/11.3.1 350 | conda activate redq 351 | ``` 352 | 353 | Now we can install `mujoco_py`. (You might want to know why do we need this? We already have the MuJoCo files, but they do not work directly with python, so `mujoco_py` is needed for us to use MuJoCo in python): 354 | ``` 355 | cd ~ 356 | mkdir rl_course 357 | cd rl_course 358 | git clone https://github.com/openai/mujoco-py.git 359 | cd mujoco-py/ 360 | pip install -e . --no-cache 361 | ``` 362 | 363 | Somewhere during installation, you might find that some libraries are missing (an error message might show up saying sth is missing, for example `GL/glew.h: No such file or directory #include `), on the HPC, since we are not super user, we cannot install system libraries, but can use `module load` to load them. This command works on some machines: `module load glfw/3.3 gcc/7.3 mesa/19.0.5 llvm/7.0.1`. You can try this and then see if the error goes away. 364 | 365 | Now we want to test whether it works, run `python`, and then in the python interpreter, run `import mujoco_py`. The first time you run it, it will give a ton of log, if you can run it again and get no log, then it should be working. (Summary: try run `import mujoco_py` twice, if the second time you do it, you get no log and no error message, then it should be working). After this testing is complete, quit the python interpreter with either `quit()` or Ctrl + D. 366 | 367 | Note: if you got an error when import `mujoco_py`, sometimes the error will tell you to add some `export` text to a file (typically `~/.bashrc`) on your system, if you see that, then likely this is because you are installing on a system that configured things slightly differently from NYU Shanghai hpc cluster, in that case, just follow the error message to do whatever it tells you, then logout and login, and test it again. Check https://github.com/openai/mujoco-py for more info. 368 | 369 | ### set up gym 370 | 371 | Now we install OpenAI gym: 372 | `pip install gym` 373 | 374 | After this step, test gym by again run `python`, in the python interpreter, run: 375 | ``` 376 | import gym 377 | e = gym.make('Ant-v2') 378 | e.reset() 379 | ``` 380 | If you see a large numpy array (which is the initial state, or initial observation for the Ant-v2 environment), then gym is working. 381 | 382 | ### set up REDQ 383 | After this step, you can install REDQ. 384 | ``` 385 | cd ~ 386 | cd rl_course 387 | git clone https://github.com/watchernyu/REDQ.git 388 | cd REDQ/ 389 | pip install -e . 390 | ``` 391 | 392 | ### test on login node (sometimes things work on login but not on compute, we will first test login) 393 | ``` 394 | cd ~/rl_course/REDQ 395 | python experiments/train_redq_sac.py --debug 396 | ``` 397 | 398 | ### test on compute node (will be updated soon) 399 | Now we test whether REDQ runs. We will first login to an interactive compute node (note a login node is not a compute node, don't do intensive computation on the login node.): 400 | ``` 401 | srun -p aquila --pty --mem 5000 -t 0-05:00 bash 402 | ``` 403 | And now don't forget we are in a new node and need to load modules and activate conda env: 404 | ``` 405 | module load anaconda3 cuda/11.3.1 406 | conda deactivate 407 | conda activate redq 408 | ``` 409 | Now test redq algorithm: 410 | ``` 411 | cd ~/rl_course/REDQ 412 | python experiments/train_redq_sac.py --debug 413 | ``` 414 | 415 | 416 | ## other HPC issues 417 | ### missing patchelf 418 | Simply install patchelf. Make sure your conda environment is activated, try this command (make sure you are inside your conda env): `conda install -c anaconda patchelf`. 419 | 420 | If you see warnings, or a message telling you to update conda, ignore it, if it asks you whether you want to install, choose yes and wait for the installation to finish. 421 | 422 | ### quota exceeded 423 | If you home quota is exceeded, you can contact the current HPC admin to extend your quota. Alternatively, you can install all require packages under `\scratch`, which has plenty of space (but your data under scratch will be removed if you don't use your account for too long). But you might need more python skills to do this correctly. 424 | 425 | 426 | 427 | ## Previous updates 428 | 429 | June 23, 2022: added guide for setting up with OpenAI MuJoCo v4 tasks on Slurm HPCs (not fully tested yet). Currently it seems this newer version of MuJoCo is much easier to set up compared to previous ones. 430 | 431 | Nov 14, 2021: **MuJoCo** is now free (thanks DeepMind!) and we now have a guide on setting up with MuJoCo 2.1 + OpenAI Gym + REDQ on a linux machine (see end of this page for newest setup guide). 432 | 433 | Aug 18, 2021: **VERY IMPORTANT BUG FIX** in `experiments/train_redq_sac.py`, the done signal is not being correctly used, the done signal value should be `False` when the episode terminates due to environment timelimit, but in the earlier version of the code, 434 | the agent puts the transition in buffer before this value is corrected. This can affect performance especially for environments where termination due to bad action is rare. This is now fixed and we might do some more testing. If you use this file to run experiments **please check immediately or pull the latest version** of the code. 435 | Sorry for the bug! Please don't hesitate to open an issue if you have any questions. 436 | 437 | July, 2021: data and the function to reproduce all figures in the paper are now available, see the `Data and reproducing figures in REDQ` section for details. 438 | 439 | Mar 23, 2021: We have reorganized the code to make it cleaner and more readable and the first version is now released! 440 | 441 | Mar 29, 2021: We tested the installation process and run the code, and everything seems to be working correctly. We are now working on the implementation video tutorial, which will be released soon. 442 | 443 | May 3, 2021: We uploaded a video tutorial (shared via google drive), please see link below. Hope it helps! 444 | 445 | Code for REDQ-OFE is still being cleaned up and will be released soon (essentially the same code but with additional input from a OFENet). 446 | 447 | 448 | 449 | ## Acknowledgement 450 | 451 | Our code for REDQ-SAC is partly based on the SAC implementation in OpenAI Spinup (https://github.com/openai/spinningup). The current code structure is inspired by the super clean TD3 source code by Scott Fujimoto (https://github.com/sfujim/TD3). 452 | 453 | -------------------------------------------------------------------------------- /conda_env.yml: -------------------------------------------------------------------------------- 1 | name: redq 2 | channels: 3 | - defaults 4 | dependencies: 5 | - python=3.8 6 | - pip=21.2.4 7 | - pytorch::pytorch=1.11.0 8 | - pytorch::torchvision=0.12.0 9 | - nvidia::cudatoolkit=11.3.1 10 | - joblib=1.1.0 11 | - pip: 12 | - gym[mujoco,robotics]==0.24.1 13 | -------------------------------------------------------------------------------- /docker-gym-mujocov2/Dockerfile: -------------------------------------------------------------------------------- 1 | # cudagl with miniconda and python 3.8, pytorch, mujoco and gym v2 envs, REDQ 2 | 3 | FROM nvidia/cudagl:11.0-base-ubuntu18.04 4 | WORKDIR /workspace 5 | ENV HOME=/workspace 6 | ENV LANG=C.UTF-8 LC_ALL=C.UTF-8 7 | ENV PATH /opt/conda/bin:$PATH 8 | 9 | # idea: start with a nvidia docker with gl support (guess this one also has cuda?) 10 | # then install miniconda, borrowing docker command from miniconda's Dockerfile (https://hub.docker.com/r/continuumio/anaconda/dockerfile/) 11 | # need to make sure the miniconda python version is what we need (https://docs.conda.io/en/latest/miniconda.html for the right version) 12 | # then install other dependencies we need 13 | 14 | # nvidia GPG key alternative fix (https://forums.developer.nvidia.com/t/notice-cuda-linux-repository-key-rotation/212772) 15 | # sudo apt-key del 7fa2af80 16 | # wget https://developer.download.nvidia.com/compute/cuda/repos/$distro/$arch/cuda-keyring_1.0-1_all.deb 17 | # sudo dpkg -i cuda-keyring_1.0-1_all.deb 18 | 19 | RUN \ 20 | # Update nvidia GPG key 21 | rm /etc/apt/sources.list.d/cuda.list \ 22 | && rm /etc/apt/sources.list.d/nvidia-ml.list \ 23 | && apt-key del 7fa2af80 \ 24 | && apt-get update && apt-get install -y --no-install-recommends wget \ 25 | && wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2004/x86_64/cuda-keyring_1.0-1_all.deb \ 26 | && dpkg -i cuda-keyring_1.0-1_all.deb \ 27 | && apt-get update 28 | 29 | RUN apt-get update --fix-missing && apt-get install -y wget bzip2 ca-certificates \ 30 | libglib2.0-0 libxext6 libsm6 libxrender1 \ 31 | git mercurial subversion 32 | 33 | RUN wget --quiet https://repo.anaconda.com/miniconda/Miniconda3-py38_4.12.0-Linux-x86_64.sh -O ~/miniconda.sh \ 34 | && /bin/bash ~/miniconda.sh -b -p /opt/conda \ 35 | && rm ~/miniconda.sh \ 36 | && ln -s /opt/conda/etc/profile.d/conda.sh /etc/profile.d/conda.sh \ 37 | && echo ". /opt/conda/etc/profile.d/conda.sh" >> ~/.bashrc \ 38 | && echo "conda activate base" >> ~/.bashrc 39 | 40 | RUN apt-get install -y curl grep sed dpkg \ 41 | # TINI_VERSION=`curl https://github.com/krallin/tini/releases/latest | grep -o "/v.*\"" | sed 's:^..\(.*\).$:\1:'` && \ 42 | # curl -L "https://github.com/krallin/tini/releases/download/v${TINI_VERSION}/tini_${TINI_VERSION}.deb" > tini.deb && \ 43 | # dpkg -i tini.deb && \ 44 | # rm tini.deb && \ 45 | && apt-get clean 46 | 47 | # Install some basic utilities 48 | RUN apt-get update && apt-get install -y software-properties-common \ 49 | && add-apt-repository -y ppa:redislabs/redis && apt-get update \ 50 | && apt-get install -y sudo ssh libx11-6 gcc iputils-ping \ 51 | libxrender-dev graphviz tmux htop build-essential wget cmake libgl1-mesa-glx redis \ 52 | && rm -rf /var/lib/apt/lists/* 53 | 54 | RUN apt-get update && DEBIAN_FRONTEND=noninteractive \ 55 | && apt-get install -y zlib1g zlib1g-dev libosmesa6-dev libgl1-mesa-glx libglfw3 libglew2.0 56 | # && ln -s /usr/lib/x86_64-linux-gnu/libGL.so.1 /usr/lib/x86_64-linux-gnu/libGL.so 57 | # ---------- now we should have all major dependencies ------------ 58 | 59 | # --------- now we have cudagl + python38 --------- 60 | RUN pip install --no-cache-dir torch==1.13.1 torchvision==0.14.1 torchaudio==0.13.1 --extra-index-url https://download.pytorch.org/whl/cu116 61 | # pip install torch==1.8.0+cu111 torchvision==0.9.0+cu111 torchaudio==0.8.0 -f https://download.pytorch.org/whl/torch_stable.html 62 | 63 | RUN pip install --no-cache-dir scikit-learn pandas imageio 64 | 65 | # adroit env does not work with newer version of mujoco, so we have to use the old version... 66 | # setting MUJOCO_GL to egl is needed if run on headless machine 67 | ENV MUJOCO_GL=egl 68 | ENV LD_LIBRARY_PATH=${LD_LIBRARY_PATH}:/workspace/.mujoco/mujoco210/bin 69 | RUN mkdir -p /workspace/.mujoco \ 70 | && wget https://github.com/deepmind/mujoco/releases/download/2.1.0/mujoco210-linux-x86_64.tar.gz -O mujoco210.tar.gz \ 71 | && tar -xvzf mujoco210.tar.gz -C /workspace/.mujoco \ 72 | && rm mujoco210.tar.gz 73 | 74 | # RL env: get mujoco and gym (v2 version environments) 75 | RUN pip install --no-cache-dir patchelf==0.17.2.0 mujoco-py==2.1.2.14 gym==0.21.0 76 | 77 | # first time import mujoco_py will take some time to run, do it here to avoid the wait later 78 | RUN echo "import mujoco_py" >> /workspace/import_mujoco.py \ 79 | && python /workspace/import_mujoco.py 80 | 81 | # RL algorithm: install REDQ. If you don't want to use REDQ, simply remove these lines 82 | RUN cd /workspace/ \ 83 | && git clone https://github.com/watchernyu/REDQ.git \ 84 | && cd REDQ \ 85 | && git checkout ac840198f143d10bb22425ed2105a49d01b383fa \ 86 | && pip install -e . 87 | 88 | CMD [ "/bin/bash" ] 89 | 90 | # # build docker container: 91 | # docker build . -t name:tag 92 | 93 | # # example docker command to run interactive container, enable gpu, and remove it when shutdown: 94 | # docker run -it --rm --gpus all name:tag 95 | -------------------------------------------------------------------------------- /docker-gym-mujocov2/README.md: -------------------------------------------------------------------------------- 1 | This dockerfile contains commands to setup the older gym mujoco v2 environments. The results in the REDQ paper were run with the v2 environments. If you use this dockerfile, you will be able to run REDQ code right away. -------------------------------------------------------------------------------- /docker-gym-mujocov4/Dockerfile: -------------------------------------------------------------------------------- 1 | # cudagl with miniconda and python 3.8, pytorch, mujoco and gym v4 envs 2 | 3 | FROM nvidia/cudagl:11.0-base-ubuntu18.04 4 | WORKDIR /workspace 5 | ENV HOME=/workspace 6 | ENV LANG=C.UTF-8 LC_ALL=C.UTF-8 7 | ENV PATH /opt/conda/bin:$PATH 8 | 9 | # idea: start with a nvidia docker with gl support (guess this one also has cuda?) 10 | # then install miniconda, borrowing docker command from miniconda's Dockerfile (https://hub.docker.com/r/continuumio/anaconda/dockerfile/) 11 | # need to make sure the miniconda python version is what we need (https://docs.conda.io/en/latest/miniconda.html for the right version) 12 | # then install other dependencies we need 13 | 14 | # nvidia GPG key alternative fix (https://forums.developer.nvidia.com/t/notice-cuda-linux-repository-key-rotation/212772) 15 | # sudo apt-key del 7fa2af80 16 | # wget https://developer.download.nvidia.com/compute/cuda/repos/$distro/$arch/cuda-keyring_1.0-1_all.deb 17 | # sudo dpkg -i cuda-keyring_1.0-1_all.deb 18 | 19 | RUN \ 20 | # Update nvidia GPG key 21 | rm /etc/apt/sources.list.d/cuda.list && \ 22 | rm /etc/apt/sources.list.d/nvidia-ml.list && \ 23 | apt-key del 7fa2af80 && \ 24 | apt-get update && apt-get install -y --no-install-recommends wget && \ 25 | wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2004/x86_64/cuda-keyring_1.0-1_all.deb && \ 26 | dpkg -i cuda-keyring_1.0-1_all.deb && \ 27 | apt-get update 28 | 29 | RUN apt-get update --fix-missing && apt-get install -y wget bzip2 ca-certificates \ 30 | libglib2.0-0 libxext6 libsm6 libxrender1 \ 31 | git mercurial subversion 32 | 33 | RUN wget --quiet https://repo.anaconda.com/miniconda/Miniconda3-py38_4.12.0-Linux-x86_64.sh -O ~/miniconda.sh && \ 34 | /bin/bash ~/miniconda.sh -b -p /opt/conda && \ 35 | rm ~/miniconda.sh && \ 36 | ln -s /opt/conda/etc/profile.d/conda.sh /etc/profile.d/conda.sh && \ 37 | echo ". /opt/conda/etc/profile.d/conda.sh" >> ~/.bashrc && \ 38 | echo "conda activate base" >> ~/.bashrc 39 | 40 | RUN apt-get install -y curl grep sed dpkg && \ 41 | # TINI_VERSION=`curl https://github.com/krallin/tini/releases/latest | grep -o "/v.*\"" | sed 's:^..\(.*\).$:\1:'` && \ 42 | # curl -L "https://github.com/krallin/tini/releases/download/v${TINI_VERSION}/tini_${TINI_VERSION}.deb" > tini.deb && \ 43 | # dpkg -i tini.deb && \ 44 | # rm tini.deb && \ 45 | apt-get clean 46 | 47 | # Install some basic utilities 48 | RUN apt-get update && apt-get install -y software-properties-common && \ 49 | add-apt-repository -y ppa:redislabs/redis && apt-get update && \ 50 | apt-get install -y sudo ssh libx11-6 gcc iputils-ping \ 51 | libxrender-dev graphviz tmux htop build-essential wget cmake libgl1-mesa-glx redis && \ 52 | rm -rf /var/lib/apt/lists/* 53 | 54 | RUN apt-get update && DEBIAN_FRONTEND=noninteractive \ 55 | && apt-get install -y zlib1g zlib1g-dev libosmesa6-dev libgl1-mesa-glx libglfw3 libglew2.0 56 | # && ln -s /usr/lib/x86_64-linux-gnu/libGL.so.1 /usr/lib/x86_64-linux-gnu/libGL.so 57 | # ---------- now we should have all major dependencies ------------ 58 | 59 | # --------- now we have cudagl + python38 --------- 60 | RUN pip install --no-cache-dir torch==1.13.1 torchvision==0.14.1 torchaudio==0.13.1 --extra-index-url https://download.pytorch.org/whl/cu116 61 | # pip install torch==1.8.0+cu111 torchvision==0.9.0+cu111 torchaudio==0.8.0 -f https://download.pytorch.org/whl/torch_stable.html 62 | 63 | RUN pip install --no-cache-dir scikit-learn pandas imageio 64 | 65 | # RL env: get mujoco and gym 66 | RUN pip install --no-cache-dir mujoco==2.2.2 gym==0.26.2 67 | # change to headless rendering (if you want rendering on a machine's monitor see https://github.com/deepmind/dm_control) 68 | ENV MUJOCO_GL=egl 69 | # import mujoco now to avoid long waiting time at runtime 70 | RUN echo "import mujoco_py" >> /workspace/import_mujoco.py \ 71 | && python /workspace/import_mujoco.py 72 | 73 | # RL algorithm: install REDQ 74 | RUN cd /workspace/ \ 75 | && git clone https://github.com/watchernyu/REDQ.git \ 76 | && cd REDQ \ 77 | # && git checkout ac840198f143d10bb22425ed2105a49d01b383fa \ 78 | && pip install -e . 79 | 80 | CMD [ "/bin/bash" ] 81 | 82 | # example docker command to run interactive container, enable gpu, remove when shutdown: 83 | # docker run -it --rm --gpus all name:tag 84 | 85 | -------------------------------------------------------------------------------- /docker-gym-mujocov4/README.md: -------------------------------------------------------------------------------- 1 | This dockerfile contains commands to setup a new version of mujoco and gym (gym mujoco v4 environments). However, many things in gym have changed from v2 to v4 (e.g. there are now more output from the environment, and seeding has been significantly modified). The v4 version environments are easier to setup and work with, but if you are using a codebase designed for v2 environments, you will need to modify the code to work with the new v4 environments. 2 | -------------------------------------------------------------------------------- /experiments/grid_sample_hpc_script.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | #SBATCH --verbose 3 | #SBATCH -p aquila,parallel 4 | #SBATCH --time=168:00:00 5 | #SBATCH --nodes=1 6 | #SBATCH --mem=12GB 7 | #SBATCH --mail-type=ALL # select which email types will be sent 8 | #SBATCH --mail-user=NETID@nyu.edu # NOTE: put your netid here if you want emails 9 | 10 | #SBATCH --array=0-5 # here the number depends on number of tasks in the array, e.g. 0-59 will create 60 tasks 11 | #SBATCH --output=r_%A_%a.out # %A is SLURM_ARRAY_JOB_ID, %a is SLURM_ARRAY_TASK_ID 12 | #SBATCH --error=r_%A_%a.err 13 | 14 | # ##################################################### 15 | # #SBATCH --gres=gpu:1 # uncomment this line to request a gpu 16 | #SBATCH --constraint=cpu # specify constraint features ('cpu' means only use nodes that have the 'cpu' feature) check features with the showcluster command 17 | 18 | sleep $(( (RANDOM%10) + 1 )) # to avoid issues when submitting large amounts of jobs 19 | 20 | echo "SLURM_JOBID: " $SLURM_JOBID 21 | echo "SLURM_ARRAY_JOB_ID: " $SLURM_ARRAY_JOB_ID 22 | echo "SLURM_ARRAY_TASK_ID: " $SLURM_ARRAY_TASK_ID 23 | 24 | module load anaconda3 cuda/9.0 glfw/3.3 gcc/7.3 mesa/19.0.5 llvm/7.0.1 # load modules that we might use 25 | conda init bash # this and the following line are important to avoid hpc issues 26 | source ~/.bashrc 27 | conda activate /scratch/NETID/redq_env # NOTE: remember to change to your actual netid 28 | 29 | echo ${SLURM_ARRAY_TASK_ID} 30 | python grid_train_redq_sac.py --setting ${SLURM_ARRAY_TASK_ID} -------------------------------------------------------------------------------- /experiments/grid_train_redq_sac.py: -------------------------------------------------------------------------------- 1 | from train_redq_sac import redq_sac as function_to_run ## here make sure you import correct function 2 | import time 3 | from redq.utils.run_utils import setup_logger_kwargs 4 | from grid_utils import get_setting_and_exp_name 5 | 6 | if __name__ == '__main__': 7 | import argparse 8 | start_time = time.time() 9 | parser = argparse.ArgumentParser() 10 | parser.add_argument('--setting', type=int, default=0) 11 | args = parser.parse_args() 12 | data_dir = '../data' 13 | 14 | exp_prefix = 'trial' 15 | settings = ['env_name','',['Hopper-v4', 'Ant-v4'], 16 | 'seed','',[0, 1, 2], 17 | 'epochs','e',[20], 18 | 'num_Q','q',[2], 19 | 'utd_ratio','uf',[1], 20 | 'policy_update_delay','pd',[20], 21 | ] 22 | 23 | indexes, actual_setting, total, exp_name_full = get_setting_and_exp_name(settings, args.setting, exp_prefix) 24 | print("##### TOTAL NUMBER OF VARIANTS: %d #####" % total) 25 | 26 | logger_kwargs = setup_logger_kwargs(exp_name_full, actual_setting['seed'], data_dir) 27 | function_to_run(logger_kwargs=logger_kwargs, **actual_setting) 28 | print("Total time used: %.3f hours." % ((time.time() - start_time)/3600)) 29 | -------------------------------------------------------------------------------- /experiments/grid_utils.py: -------------------------------------------------------------------------------- 1 | from train_redq_sac import redq_sac as function_to_run ## here make sure you import correct function 2 | import time 3 | import numpy as np 4 | from redq.utils.run_utils import setup_logger_kwargs 5 | 6 | def get_setting_and_exp_name(settings, setting_number, exp_prefix, random_setting_seed=0, random_order=True): 7 | np.random.seed(random_setting_seed) 8 | hypers, lognames, values_list = [], [], [] 9 | hyper2logname = {} 10 | n_settings = int(len(settings)/3) 11 | for i in range(n_settings): 12 | hypers.append(settings[i*3]) 13 | lognames.append(settings[i*3+1]) 14 | values_list.append(settings[i*3+2]) 15 | hyper2logname[hypers[-1]] = lognames[-1] 16 | 17 | total = 1 18 | for values in values_list: 19 | total *= len(values) 20 | max_job = total 21 | 22 | new_indexes = np.random.choice(total, total, replace=False) if random_order else np.arange(total) 23 | new_index = new_indexes[setting_number] 24 | 25 | indexes = [] ## this says which hyperparameter we use 26 | remainder = new_index 27 | for values in values_list: 28 | division = int(total / len(values)) 29 | index = int(remainder / division) 30 | remainder = remainder % division 31 | indexes.append(index) 32 | total = division 33 | actual_setting = {} 34 | for j in range(len(indexes)): 35 | actual_setting[hypers[j]] = values_list[j][indexes[j]] 36 | 37 | exp_name_full = exp_prefix 38 | for hyper, value in actual_setting.items(): 39 | if hyper not in ['env_name', 'seed']: 40 | exp_name_full = exp_name_full + '_%s' % (hyper2logname[hyper] + str(value)) 41 | exp_name_full = exp_name_full + '_%s' % actual_setting['env_name'] 42 | 43 | return indexes, actual_setting, max_job, exp_name_full 44 | 45 | -------------------------------------------------------------------------------- /experiments/sample_hpc_script.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | #SBATCH --verbose 3 | #SBATCH -p aquila,parallel 4 | #SBATCH --time=168:00:00 5 | #SBATCH --nodes=1 6 | #SBATCH --mem=12GB 7 | #SBATCH --mail-type=ALL # select which email types will be sent 8 | #SBATCH --mail-user=NETID@nyu.edu # NOTE: put your netid here if you want emails 9 | 10 | #SBATCH --array=0-0 # here the number depends on number of tasks in the array, e.g. 0-59 will create 60 tasks 11 | #SBATCH --output=r_%A_%a.out # %A is SLURM_ARRAY_JOB_ID, %a is SLURM_ARRAY_TASK_ID 12 | #SBATCH --error=r_%A_%a.err 13 | 14 | # ##################################################### 15 | # #SBATCH --gres=gpu:1 # uncomment this line to request a gpu 16 | #SBATCH --constraint=cpu # specify constraint features ('cpu' means only use nodes that have the 'cpu' feature) check features with the showcluster command 17 | 18 | sleep $(( (RANDOM%10) + 1 )) # to avoid issues when submitting large amounts of jobs 19 | 20 | echo "SLURM_JOBID: " $SLURM_JOBID 21 | echo "SLURM_ARRAY_JOB_ID: " $SLURM_ARRAY_JOB_ID 22 | echo "SLURM_ARRAY_TASK_ID: " $SLURM_ARRAY_TASK_ID 23 | 24 | module load anaconda3 cuda/9.0 glfw/3.3 gcc/7.3 mesa/19.0.5 llvm/7.0.1 # load modules that we might use 25 | conda init bash # this and the following line are important to avoid hpc issues 26 | source ~/.bashrc 27 | conda activate /scratch/NETID/redq_env # NOTE: remember to change to your actual netid 28 | 29 | echo ${SLURM_ARRAY_TASK_ID} 30 | python train_redq_sac.py --debug --epochs 20 --env Hopper-v4 # use v4 version tasks if you followed the newest Gym+MuJoCo installation guide -------------------------------------------------------------------------------- /experiments/train_redq_sac.py: -------------------------------------------------------------------------------- 1 | import gym 2 | import numpy as np 3 | import torch 4 | import time 5 | import sys 6 | from redq.algos.redq_sac import REDQSACAgent 7 | from redq.algos.core import mbpo_epoches, test_agent 8 | from redq.utils.run_utils import setup_logger_kwargs 9 | from redq.utils.bias_utils import log_bias_evaluation 10 | from redq.utils.logx import EpochLogger 11 | 12 | def redq_sac(env_name, seed=0, epochs='mbpo', steps_per_epoch=1000, 13 | max_ep_len=1000, n_evals_per_epoch=1, 14 | logger_kwargs=dict(), debug=False, 15 | # following are agent related hyperparameters 16 | hidden_sizes=(256, 256), replay_size=int(1e6), batch_size=256, 17 | lr=3e-4, gamma=0.99, polyak=0.995, 18 | alpha=0.2, auto_alpha=True, target_entropy='mbpo', 19 | start_steps=5000, delay_update_steps='auto', 20 | utd_ratio=20, num_Q=10, num_min=2, q_target_mode='min', 21 | policy_update_delay=20, 22 | # following are bias evaluation related 23 | evaluate_bias=True, n_mc_eval=1000, n_mc_cutoff=350, reseed_each_epoch=True 24 | ): 25 | """ 26 | :param env_name: name of the gym environment 27 | :param seed: random seed 28 | :param epochs: number of epochs to run 29 | :param steps_per_epoch: number of timestep (datapoints) for each epoch 30 | :param max_ep_len: max timestep until an episode terminates 31 | :param n_evals_per_epoch: number of evaluation runs for each epoch 32 | :param logger_kwargs: arguments for logger 33 | :param debug: whether to run in debug mode 34 | :param hidden_sizes: hidden layer sizes 35 | :param replay_size: replay buffer size 36 | :param batch_size: mini-batch size 37 | :param lr: learning rate for all networks 38 | :param gamma: discount factor 39 | :param polyak: hyperparameter for polyak averaged target networks 40 | :param alpha: SAC entropy hyperparameter 41 | :param auto_alpha: whether to use adaptive SAC 42 | :param target_entropy: used for adaptive SAC 43 | :param start_steps: the number of random data collected in the beginning of training 44 | :param delay_update_steps: after how many data collected should we start updates 45 | :param utd_ratio: the update-to-data ratio 46 | :param num_Q: number of Q networks in the Q ensemble 47 | :param num_min: number of sampled Q values to take minimal from 48 | :param q_target_mode: 'min' for minimal, 'ave' for average, 'rem' for random ensemble mixture 49 | :param policy_update_delay: how many updates until we update policy network 50 | """ 51 | if debug: # use --debug for very quick debugging 52 | hidden_sizes = [2,2] 53 | batch_size = 2 54 | utd_ratio = 2 55 | num_Q = 3 56 | max_ep_len = 100 57 | start_steps = 100 58 | steps_per_epoch = 100 59 | 60 | # use gpu if available 61 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 62 | # set number of epoch 63 | if epochs == 'mbpo' or epochs < 0: 64 | epochs = mbpo_epoches[env_name] 65 | total_steps = steps_per_epoch * epochs + 1 66 | 67 | """set up logger""" 68 | logger = EpochLogger(**logger_kwargs) 69 | logger.save_config(locals()) 70 | 71 | """set up environment and seeding""" 72 | env_fn = lambda: gym.make(env_name) 73 | env, test_env, bias_eval_env = env_fn(), env_fn(), env_fn() 74 | # seed torch and numpy 75 | torch.manual_seed(seed) 76 | np.random.seed(seed) 77 | 78 | # seed environment along with env action space so that everything is properly seeded for reproducibility 79 | def seed_all(epoch): 80 | seed_shift = epoch * 9999 81 | mod_value = 999999 82 | env_seed = (seed + seed_shift) % mod_value 83 | test_env_seed = (seed + 10000 + seed_shift) % mod_value 84 | bias_eval_env_seed = (seed + 20000 + seed_shift) % mod_value 85 | torch.manual_seed(env_seed) 86 | np.random.seed(env_seed) 87 | env.seed(env_seed) 88 | env.action_space.np_random.seed(env_seed) 89 | test_env.seed(test_env_seed) 90 | test_env.action_space.np_random.seed(test_env_seed) 91 | bias_eval_env.seed(bias_eval_env_seed) 92 | bias_eval_env.action_space.np_random.seed(bias_eval_env_seed) 93 | seed_all(epoch=0) 94 | 95 | """prepare to init agent""" 96 | # get obs and action dimensions 97 | obs_dim = env.observation_space.shape[0] 98 | act_dim = env.action_space.shape[0] 99 | # if environment has a smaller max episode length, then use the environment's max episode length 100 | max_ep_len = env._max_episode_steps if max_ep_len > env._max_episode_steps else max_ep_len 101 | # Action limit for clamping: critically, assumes all dimensions share the same bound! 102 | # we need .item() to convert it from numpy float to python float 103 | act_limit = env.action_space.high[0].item() 104 | # keep track of run time 105 | start_time = time.time() 106 | # flush logger (optional) 107 | sys.stdout.flush() 108 | ################################################################################################# 109 | 110 | """init agent and start training""" 111 | agent = REDQSACAgent(env_name, obs_dim, act_dim, act_limit, device, 112 | hidden_sizes, replay_size, batch_size, 113 | lr, gamma, polyak, 114 | alpha, auto_alpha, target_entropy, 115 | start_steps, delay_update_steps, 116 | utd_ratio, num_Q, num_min, q_target_mode, 117 | policy_update_delay) 118 | 119 | o, r, d, ep_ret, ep_len = env.reset(), 0, False, 0, 0 120 | 121 | for t in range(total_steps): 122 | # get action from agent 123 | a = agent.get_exploration_action(o, env) 124 | # Step the env, get next observation, reward and done signal 125 | o2, r, d, _ = env.step(a) 126 | 127 | # Very important: before we let agent store this transition, 128 | # Ignore the "done" signal if it comes from hitting the time 129 | # horizon (that is, when it's an artificial terminal signal 130 | # that isn't based on the agent's state) 131 | ep_len += 1 132 | d = False if ep_len == max_ep_len else d 133 | 134 | # give new data to agent 135 | agent.store_data(o, a, r, o2, d) 136 | # let agent update 137 | agent.train(logger) 138 | # set obs to next obs 139 | o = o2 140 | ep_ret += r 141 | 142 | 143 | if d or (ep_len == max_ep_len): 144 | # store episode return and length to logger 145 | logger.store(EpRet=ep_ret, EpLen=ep_len) 146 | # reset environment 147 | o, r, d, ep_ret, ep_len = env.reset(), 0, False, 0, 0 148 | 149 | # End of epoch wrap-up 150 | if (t+1) % steps_per_epoch == 0: 151 | epoch = t // steps_per_epoch 152 | 153 | # Test the performance of the deterministic version of the agent. 154 | test_agent(agent, test_env, max_ep_len, logger) # add logging here 155 | if evaluate_bias: 156 | log_bias_evaluation(bias_eval_env, agent, logger, max_ep_len, alpha, gamma, n_mc_eval, n_mc_cutoff) 157 | 158 | # reseed should improve reproducibility (should make results the same whether bias evaluation is on or not) 159 | if reseed_each_epoch: 160 | seed_all(epoch) 161 | 162 | """logging""" 163 | # Log info about epoch 164 | logger.log_tabular('Epoch', epoch) 165 | logger.log_tabular('TotalEnvInteracts', t) 166 | logger.log_tabular('Time', time.time()-start_time) 167 | logger.log_tabular('EpRet', with_min_and_max=True) 168 | logger.log_tabular('EpLen', average_only=True) 169 | logger.log_tabular('TestEpRet', with_min_and_max=True) 170 | logger.log_tabular('TestEpLen', average_only=True) 171 | logger.log_tabular('Q1Vals', with_min_and_max=True) 172 | logger.log_tabular('LossQ1', average_only=True) 173 | logger.log_tabular('LogPi', with_min_and_max=True) 174 | logger.log_tabular('LossPi', average_only=True) 175 | logger.log_tabular('Alpha', with_min_and_max=True) 176 | logger.log_tabular('LossAlpha', average_only=True) 177 | logger.log_tabular('PreTanh', with_min_and_max=True) 178 | 179 | if evaluate_bias: 180 | logger.log_tabular("MCDisRet", with_min_and_max=True) 181 | logger.log_tabular("MCDisRetEnt", with_min_and_max=True) 182 | logger.log_tabular("QPred", with_min_and_max=True) 183 | logger.log_tabular("QBias", with_min_and_max=True) 184 | logger.log_tabular("QBiasAbs", with_min_and_max=True) 185 | logger.log_tabular("NormQBias", with_min_and_max=True) 186 | logger.log_tabular("QBiasSqr", with_min_and_max=True) 187 | logger.log_tabular("NormQBiasSqr", with_min_and_max=True) 188 | logger.dump_tabular() 189 | 190 | # flush logged information to disk 191 | sys.stdout.flush() 192 | 193 | if __name__ == '__main__': 194 | import argparse 195 | parser = argparse.ArgumentParser() 196 | parser.add_argument('--env', type=str, default='Hopper-v2') 197 | parser.add_argument('--seed', '-s', type=int, default=0) 198 | parser.add_argument('--epochs', type=int, default=-1) # -1 means use mbpo epochs 199 | parser.add_argument('--exp_name', type=str, default='redq_sac') 200 | parser.add_argument('--data_dir', type=str, default='../data/') 201 | parser.add_argument('--debug', action='store_true') 202 | args = parser.parse_args() 203 | 204 | # modify the code here if you want to use a different naming scheme 205 | exp_name_full = args.exp_name + '_%s' % args.env 206 | 207 | # specify experiment name, seed and data_dir. 208 | # for example, for seed 0, the progress.txt will be saved under data_dir/exp_name/exp_name_s0 209 | logger_kwargs = setup_logger_kwargs(exp_name_full, args.seed, args.data_dir) 210 | 211 | redq_sac(args.env, seed=args.seed, epochs=args.epochs, 212 | logger_kwargs=logger_kwargs, debug=args.debug) 213 | -------------------------------------------------------------------------------- /mujoco_download.sh: -------------------------------------------------------------------------------- 1 | cd ~ 2 | mkdir ~/.mujoco 3 | cd ~/.mujoco 4 | curl -O https://mujoco.org/download/mujoco210-linux-x86_64.tar.gz 5 | tar -xf mujoco210-linux-x86_64.tar.gz -------------------------------------------------------------------------------- /plot_utils/plot_REDQ.py: -------------------------------------------------------------------------------- 1 | """ 2 | this one program will be used to basically generate all REDQ related figures for ICLR 2021. 3 | NOTE: currently only works with seaborn 0.8.1 the tsplot function is deprecated in the newer version 4 | the plotting function is originally based on the plot function in OpenAI spinningup 5 | """ 6 | import os 7 | import matplotlib.pyplot as plt 8 | import numpy as np 9 | import pandas as pd 10 | import seaborn as sns 11 | from redq_plot_helper import * 12 | 13 | # the path leading to where the experiment file are located 14 | base_path = '../data/REDQ_ICLR21' 15 | 16 | # map experiment name to folder's name 17 | exp2path_main = { 18 | 'MBPO-ant': 'MBPO-Ant', 19 | 'MBPO-hopper': 'MBPO-Hopper', 20 | 'MBPO-walker2d': 'MBPO-Walker2d', 21 | 'MBPO-humanoid': 'MBPO-Humanoid', 22 | 'REDQ-n10-ant':'REDQ_embpo_qmin_piave_n10_m2_uf20_pd20_ant-v2', 23 | 'REDQ-n10-humanoid':'REDQ_embpo_qmin_piave_n10_m2_uf20_pd20_humanoid-v2', 24 | 'REDQ-n10-hopper':'REDQ_embpo_qmin_piave_n10_m2_uf20_pd20_hopper-v2', 25 | 'REDQ-n10-walker2d':'REDQ_embpo_qmin_piave_n10_m2_uf20_pd20_walker2d-v2', 26 | 'SAC-1-ant': 'REDQ_embpo_qmin_piave_n2_m2_uf1_pd20_ant-v2', 27 | 'SAC-1-humanoid': 'REDQ_embpo_qmin_piave_n2_m2_uf1_pd20_humanoid-v2', 28 | 'SAC-1-hopper': 'REDQ_embpo_qmin_piave_n2_m2_uf1_pd20_hopper-v2', 29 | 'SAC-1-walker2d': 'REDQ_embpo_qmin_piave_n2_m2_uf1_pd20_walker2d-v2', 30 | 'SAC-20-ant': 'REDQ_embpo_qmin_piave_n2_m2_uf20_pd20_ant-v2', 31 | 'SAC-20-walker2d': 'REDQ_embpo_qmin_piave_n2_m2_uf20_pd20_walker2d-v2', 32 | 'SAC-20-hopper': 'REDQ_embpo_qmin_piave_n2_m2_uf20_pd20_hopper-v2', 33 | 'SAC-20-humanoid': 'REDQ_embpo_qmin_piave_n2_m2_uf20_pd20_humanoid-v2', 34 | } 35 | exp2path_N_ablation = { 36 | 'REDQ-n15-ant':'REDQ_embpo_qmin_piave_n15_m2_uf20_pd20_ant-v2', 37 | 'REDQ-n15-humanoid':'REDQ_embpo_qmin_piave_n15_m2_uf20_pd20_humanoid-v2', 38 | 'REDQ-n15-hopper':'REDQ_embpo_qmin_piave_n15_m2_uf20_pd20_hopper-v2', 39 | 'REDQ-n15-walker2d':'REDQ_embpo_qmin_piave_n15_m2_uf20_pd20_walker2d-v2', 40 | 'REDQ-n5-ant': 'REDQ_embpo_qmin_piave_n5_m2_uf20_pd20_ant-v2', 41 | 'REDQ-n5-humanoid': 'REDQ_embpo_qmin_piave_n5_m2_uf20_pd20_humanoid-v2', 42 | 'REDQ-n5-hopper': 'REDQ_embpo_qmin_piave_n5_m2_uf20_pd20_hopper-v2', 43 | 'REDQ-n5-walker2d': 'REDQ_embpo_qmin_piave_n5_m2_uf20_pd20_walker2d-v2', 44 | 'REDQ-n3-ant': 'REDQ_embpo_qmin_piave_n3_m2_uf20_pd20_ant-v2', 45 | 'REDQ-n3-humanoid': 'REDQ_embpo_qmin_piave_n3_m2_uf20_pd20_humanoid-v2', 46 | 'REDQ-n3-hopper': 'REDQ_embpo_qmin_piave_n3_m2_uf20_pd20_hopper-v2', 47 | 'REDQ-n3-walker2d': 'REDQ_embpo_qmin_piave_n3_m2_uf20_pd20_walker2d-v2', 48 | 'REDQ-n2-ant': 'REDQ_embpo_qmin_piave_n2_m2_uf20_pd20_ant-v2', 49 | 'REDQ-n2-humanoid': 'REDQ_embpo_qmin_piave_n2_m2_uf20_pd20_humanoid-v2', 50 | 'REDQ-n2-hopper': 'REDQ_embpo_qmin_piave_n2_m2_uf20_pd20_hopper-v2', 51 | 'REDQ-n2-walker2d': 'REDQ_embpo_qmin_piave_n2_m2_uf20_pd20_walker2d-v2', 52 | } 53 | exp2path_redq_variants = { 54 | 'REDQ-weighted-hopper': 'REDQ_embpo_qweighted_piave_n10_m2_uf20_pd20_hopper-v2', 55 | 'REDQ-minpair-hopper': 'REDQ_embpo_qminpair_piave_n10_m2_uf20_pd20_hopper-v2', 56 | 'REDQ-weighted-walker2d': 'REDQ_embpo_qweighted_piave_n10_m2_uf20_pd20_walker2d-v2', 57 | 'REDQ-minpair-walker2d': 'REDQ_embpo_qminpair_piave_n10_m2_uf20_pd20_walker2d-v2', 58 | 'REDQ-weighted-ant': 'REDQ_embpo_qweighted_piave_n10_m2_uf20_pd20_ant-v2', 59 | 'REDQ-minpair-ant': 'REDQ_embpo_qminpair_piave_n10_m2_uf20_pd20_ant-v2', 60 | 'REDQ-rem-hopper': 'REDQ_embpo_qrem_piave_n10_m2_uf20_pd20_hopper-v2', 61 | 'REDQ-ave-hopper': 'REDQ_embpo_qave_piave_n10_m2_uf20_pd20_hopper-v2', 62 | 'REDQ-min-hopper': 'REDQ_embpo_qmin_piave_n3_msame_uf20_pd20_hopper-v2', 63 | 'REDQ-rem-walker2d': 'REDQ_embpo_qrem_piave_n10_m2_uf20_pd20_walker2d-v2', 64 | 'REDQ-ave-walker2d': 'REDQ_embpo_qave_piave_n10_m2_uf20_pd20_walker2d-v2', 65 | 'REDQ-min-walker2d': 'REDQ_embpo_qmin_piave_n3_msame_uf20_pd20_walker2d-v2', 66 | 'REDQ-rem-ant': 'REDQ_embpo_qrem_piave_n10_m2_uf20_pd20_ant-v2', 67 | 'REDQ-ave-ant': 'REDQ_embpo_qave_piave_n10_m2_uf20_pd20_ant-v2', 68 | 'REDQ-min-ant': 'REDQ_embpo_qmin_piave_n3_msame_uf20_pd20_ant-v2', 69 | 'REDQ-ave-humanoid': 'REDQ_embpo_qave_piave_n10_m2_uf20_pd20_humanoid-v2', 70 | 'REDQ-weighted-humanoid': 'REDQ_embpo_qweighted_piave_n10_m2_uf20_pd20_humanoid-v2', 71 | 'REDQ-m1-hopper':'REDQ_embpo_qmin_piave_n10_m1_uf20_pd20_hopper-v2', 72 | 'REDQ-m1-walker2d':'REDQ_embpo_qmin_piave_n10_m1_uf20_pd20_walker2d-v2', 73 | 'REDQ-m1-ant':'REDQ_embpo_qmin_piave_n10_m1_uf20_pd20_ant-v2', 74 | 'REDQ-m1-humanoid':'REDQ_embpo_qmin_piave_n10_m1_uf20_pd20_humanoid-v2', 75 | } 76 | 77 | exp2path_M_ablation = { 78 | 'REDQ-n10-m1-5-ant': 'REDQ_embpo_qmin_piave_n10_m1-5_uf20_pd20_ant-v2', 79 | 'REDQ-n10-m2-5-ant': 'REDQ_embpo_qmin_piave_n10_m2-5_uf20_pd20_ant-v2', 80 | 'REDQ-n10-m3-ant': 'REDQ_embpo_qmin_piave_n10_m3_uf20_pd20_ant-v2', 81 | 'REDQ-n10-m5-ant': 'REDQ_embpo_qmin_piave_n10_m5_uf20_pd20_ant-v2', 82 | 'REDQ-n10-m1-5-walker2d': 'REDQ_embpo_qmin_piave_n10_m1-5_uf20_pd20_walker2d-v2', 83 | 'REDQ-n10-m2-5-walker2d': 'REDQ_embpo_qmin_piave_n10_m2-5_uf20_pd20_walker2d-v2', 84 | 'REDQ-n10-m3-walker2d': 'REDQ_embpo_qmin_piave_n10_m3_uf20_pd20_walker2d-v2', 85 | 'REDQ-n10-m5-walker2d': 'REDQ_embpo_qmin_piave_n10_m5_uf20_pd20_walker2d-v2', 86 | 'REDQ-n10-m1-5-hopper': 'REDQ_embpo_qmin_piave_n10_m1-5_uf20_pd20_hopper-v2', 87 | 'REDQ-n10-m2-5-hopper': 'REDQ_embpo_qmin_piave_n10_m2-5_uf20_pd20_hopper-v2', 88 | 'REDQ-n10-m3-hopper': 'REDQ_embpo_qmin_piave_n10_m3_uf20_pd20_hopper-v2', 89 | 'REDQ-n10-m5-hopper': 'REDQ_embpo_qmin_piave_n10_m5_uf20_pd20_hopper-v2', 90 | } 91 | 92 | exp2path_ablations_app = { 93 | 'REDQ-ave-n15-ant': 'REDQ_embpo_qave_piave_n15_m2_uf20_pd20_ant-v2', 94 | 'REDQ-ave-n5-ant': 'REDQ_embpo_qave_piave_n5_m2_uf20_pd20_ant-v2', 95 | 'REDQ-ave-n3-ant': 'REDQ_embpo_qave_piave_n3_m2_uf20_pd20_ant-v2', 96 | 'REDQ-ave-n2-ant': 'REDQ_embpo_qave_piave_n2_m2_uf20_pd20_ant-v2', 97 | 'REDQ-ave-n15-walker2d': 'REDQ_embpo_qave_piave_n15_m2_uf20_pd20_walker2d-v2', 98 | 'REDQ-ave-n5-walker2d': 'REDQ_embpo_qave_piave_n5_m2_uf20_pd20_walker2d-v2', 99 | 'REDQ-ave-n3-walker2d': 'REDQ_embpo_qave_piave_n3_m2_uf20_pd20_walker2d-v2', 100 | 'REDQ-ave-n2-walker2d': 'REDQ_embpo_qave_piave_n2_m2_uf20_pd20_walker2d-v2', 101 | } 102 | 103 | exp2path_utd = { 104 | 'REDQ-n10-utd1-ant':'REDQ_embpo_qmin_piave_n10_m2_uf1_pd20_ant-v2', 105 | 'REDQ-n10-utd5-ant':'REDQ_embpo_qmin_piave_n10_m2_uf5_pd20_ant-v2', 106 | 'REDQ-n10-utd10-ant':'REDQ_embpo_qmin_piave_n10_m2_uf10_pd20_ant-v2', 107 | 'REDQ-n10-utd1-walker2d':'REDQ_embpo_qmin_piave_n10_m2_uf1_pd20_walker2d-v2', 108 | 'REDQ-n10-utd5-walker2d':'REDQ_embpo_qmin_piave_n10_m2_uf5_pd20_walker2d-v2', 109 | 'REDQ-n10-utd10-walker2d':'REDQ_embpo_qmin_piave_n10_m2_uf10_pd20_walker2d-v2', 110 | 'SAC-5-ant': 'REDQ_embpo_qmin_piave_n2_m2_uf5_pd20_ant-v2', # SAC with policy delay 111 | 'SAC-5-walker2d': 'REDQ_embpo_qmin_piave_n2_m2_uf5_pd20_walker2d-v2', 112 | 'SAC-10-ant': 'REDQ_embpo_qmin_piave_n2_m2_uf10_pd20_ant-v2', 113 | 'SAC-10-walker2d': 'REDQ_embpo_qmin_piave_n2_m2_uf10_pd20_walker2d-v2', 114 | 'REDQ-min-n3-utd1-ant':'REDQ_embpo_qmin_piave_n3_msame_uf1_pd20_ant-v2', 115 | 'REDQ-min-n3-utd5-ant':'REDQ_embpo_qmin_piave_n3_msame_uf5_pd20_ant-v2', 116 | 'REDQ-min-n3-utd10-ant':'REDQ_embpo_qmin_piave_n3_msame_uf10_pd20_ant-v2', 117 | 'REDQ-min-n3-utd20-ant':'REDQ_embpo_qmin_piave_n3_msame_uf20_pd20_ant-v2', 118 | 'REDQ-min-n3-utd1-walker2d':'REDQ_embpo_qmin_piave_n3_msame_uf1_pd20_walker2d-v2', 119 | 'REDQ-min-n3-utd5-walker2d':'REDQ_embpo_qmin_piave_n3_msame_uf5_pd20_walker2d-v2', 120 | 'REDQ-min-n3-utd10-walker2d':'REDQ_embpo_qmin_piave_n3_msame_uf10_pd20_walker2d-v2', 121 | 'REDQ-min-n3-utd20-walker2d':'REDQ_embpo_qmin_piave_n3_msame_uf20_pd20_walker2d-v2', 122 | 'REDQ-min-n3-utd1-hopper':'REDQ_embpo_qmin_piave_n3_msame_uf1_pd20_hopper-v2', 123 | 'REDQ-min-n3-utd5-hopper':'REDQ_embpo_qmin_piave_n3_msame_uf5_pd20_hopper-v2', 124 | 'REDQ-min-n3-utd10-hopper':'REDQ_embpo_qmin_piave_n3_msame_uf10_pd20_hopper-v2', 125 | 'REDQ-min-n3-utd20-hopper':'REDQ_embpo_qmin_piave_n3_msame_uf20_pd20_hopper-v2', 126 | 'REDQ-min-n3-utd1-humanoid':'REDQ_embpo_qmin_piave_n3_msame_uf1_pd20_humanoid-v2', 127 | 'REDQ-min-n3-utd5-humanoid':'REDQ_embpo_qmin_piave_n3_msame_uf5_pd20_humanoid-v2', 128 | 'REDQ-min-n3-utd10-humanoid':'REDQ_embpo_qmin_piave_n3_msame_uf10_pd20_humanoid-v2', 129 | 'REDQ-min-n3-utd20-humanoid':'REDQ_embpo_qmin_piave_n3_msame_uf20_pd20_humanoid-v2', 130 | 'REDQ-ave-utd1-ant':'REDQ_embpo_qave_piave_n10_m2_uf1_pd20_ant-v2', 131 | 'REDQ-ave-utd5-ant':'REDQ_embpo_qave_piave_n10_m2_uf5_pd20_ant-v2', 132 | 'REDQ-ave-utd1-walker2d':'REDQ_embpo_qave_piave_n10_m2_uf1_pd20_walker2d-v2', 133 | 'REDQ-ave-utd5-walker2d':'REDQ_embpo_qave_piave_n10_m2_uf5_pd20_walker2d-v2', 134 | } 135 | 136 | exp2path_pd = { 137 | 'SAC-pd1-5-ant': 'REDQ_embpo_qmin_piave_n2_m2_uf5_pd1_ant-v2', # no policy delay SAC 138 | 'SAC-pd1-5-walker2d': 'REDQ_embpo_qmin_piave_n2_m2_uf5_pd1_walker2d-v2', 139 | 'SAC-pd1-10-ant': 'REDQ_embpo_qmin_piave_n2_m2_uf10_pd1_ant-v2', 140 | 'SAC-pd1-10-walker2d': 'REDQ_embpo_qmin_piave_n2_m2_uf10_pd1_walker2d-v2', 141 | 'SAC-pd1-20-ant': 'REDQ_embpo_qmin_piave_n2_m2_uf20_pd1_ant-v2', 142 | 'SAC-pd1-20-walker2d': 'REDQ_embpo_qmin_piave_n2_m2_uf20_pd1_walker2d-v2', 143 | 'SAC-pd1-20-hopper': 'REDQ_embpo_qmin_piave_n2_m2_uf20_pd1_hopper-v2', 144 | 'SAC-pd1-20-humanoid': 'REDQ_embpo_qmin_piave_n2_m2_uf20_pd1_humanoid-v2', 145 | 'REDQ-n10-pd1-ant':'REDQ_embpo_qmin_piave_n10_m2_uf20_pd1_ant-v2', 146 | 'REDQ-n10-pd1-walker2d':'REDQ_embpo_qmin_piave_n10_m2_uf20_pd1_walker2d-v2', 147 | 'REDQ-n10-pd1-hopper':'REDQ_embpo_qmin_piave_n10_m2_uf20_pd1_hopper-v2', 148 | 'REDQ-n10-pd1-humanoid':'REDQ_embpo_qmin_piave_n10_m2_uf20_pd1_humanoid-v2', 149 | } 150 | 151 | exp2path_ofe = { 152 | 'REDQ-ofe-humanoid': 'REDQ_e20_uf20_pd_ofe20000_p100000_d5_lr0-0003_wd5e-05_w1_humanoid-v2', 153 | 'REDQ-ofe-ant': 'REDQ_e20_uf20_pd_ofe20000_p100000_d5_lr0-0003_wd5e-05_w1_ant-v2', 154 | } 155 | 156 | def generate_exp2path(name_prefix, list_values, list_name_suffix, pre_string, env_names): 157 | exp2path_dict = {} 158 | for e in env_names: 159 | for i in range(len(list_values)): 160 | new_name = '%s-%s-%s' % (name_prefix, list_name_suffix[i], e) 161 | new_string = pre_string % (list_values[i], e) 162 | exp2path_dict[new_name] = new_string 163 | return exp2path_dict 164 | 165 | figsize=(10, 7) 166 | exp2dataset = {} 167 | ############# finalizing plots #################################### 168 | default_smooth = 10 169 | all_4_envs = ['ant', 'humanoid', 'hopper', 'walker2d'] 170 | just_2_envs = ['ant', 'walker2d'] 171 | y_types_all = ['Performance', 'AverageQBias', 'StdQBias', 'AllNormalizedAverageQBias', 'AllNormalizedStdQBias', 172 | 'LossQ1', 'NormLossQ1', 'AverageQBiasSqr', 'MaxPreTanh', 'AveragePreTanh', 'AverageQ1Vals', 173 | 'MaxAlpha', 'AverageAlpha', 'AverageQPred'] 174 | y_types_appendix_6 = ['AverageQBias', 'StdQBias', 'AverageQBiasSqr', 'AllNormalizedAverageQBiasSqr', 175 | 'LossQ1', 'AverageQPred'] 176 | y_types_main_paper = ['Performance', 'AllNormalizedAverageQBias', 'AllNormalizedStdQBias'] 177 | y_types_appendix_9 = y_types_main_paper + y_types_appendix_6 178 | 179 | # main paper 180 | plot_result_sec = True 181 | plot_analysis_sec = True 182 | plot_ablations = True 183 | plot_variants = True 184 | plot_ablations_revision = True 185 | # appendix ones 186 | plot_analysis_app = True 187 | plot_ablations_app = True 188 | plot_variants_app = True 189 | plot_utd_app = True 190 | plot_pd_app = True 191 | plot_ofe = True 192 | 193 | if plot_ablations_app: 194 | exp2dataset.update(get_exp2dataset(exp2path_ablations_app, base_path)) 195 | if plot_utd_app: 196 | exp2dataset.update(get_exp2dataset(exp2path_utd, base_path)) 197 | if plot_pd_app: 198 | exp2dataset.update(get_exp2dataset(exp2path_pd, base_path)) 199 | if plot_ofe: 200 | exp2dataset.update(get_exp2dataset(exp2path_ofe, base_path)) 201 | 202 | exp2dataset.update(get_exp2dataset(exp2path_main, base_path)) 203 | exp2dataset.update(get_exp2dataset(exp2path_N_ablation, base_path)) 204 | exp2dataset.update(get_exp2dataset(exp2path_M_ablation, base_path)) 205 | exp2dataset.update(get_exp2dataset(exp2path_redq_variants, base_path)) 206 | 207 | if plot_result_sec: 208 | # 1. redq-mbpo figure in results section, score only 209 | exp_base_to_plot = ['REDQ-n10', 'SAC-1', 'MBPO'] 210 | save_path, prefix = 'results', 'redq-mbpo' 211 | label_list = ['REDQ', 'SAC', 'MBPO'] 212 | plot_grid(save_path, prefix, exp2dataset, exp_base_to_plot, 213 | all_4_envs, ['Performance'], default_smooth, figsize, label_list=label_list, legend_y_types=['Performance'], 214 | legend_es=['humanoid']) 215 | 216 | if plot_analysis_sec: 217 | # 2. analysis sec, REDQ, SAC-20, ave comparison, argue naive methods don't work 218 | exp_base_to_plot = ['REDQ-n10', 'SAC-20', 'REDQ-ave'] 219 | save_path, prefix = 'analysis', 'redq-sac' 220 | label_list = ['REDQ', 'SAC-20', 'AVG'] 221 | plot_grid(save_path, prefix, exp2dataset, exp_base_to_plot, 222 | all_4_envs, y_types_all, default_smooth, figsize, label_list=label_list, 223 | legend_y_types=['AllNormalizedStdQBias'], legend_es='all', legend_loc='upper right') 224 | 225 | if plot_ablations: 226 | # 3. abaltion on N 227 | exp_base_to_plot = ['REDQ-n15', 'REDQ-n10', 'REDQ-n5', 'REDQ-n3', 'REDQ-n2'] 228 | save_path, prefix = 'ablations', 'redq-N' 229 | label_list = ['REDQ-N15', 'REDQ-N10', 'REDQ-N5', 'REDQ-N3', 'REDQ-N2'] 230 | overriding_ylimit_dict = { 231 | ('ant', 'AllNormalizedAverageQBias'): (-1, 1), 232 | ('ant', 'AllNormalizedStdQBias'): (0, 1), 233 | ('walker2d', 'AllNormalizedAverageQBias'): (-0.5, 0.75), 234 | ('walker2d', 'AllNormalizedStdQBias'):(0, 0.75), 235 | } 236 | plot_grid(save_path, prefix, exp2dataset, exp_base_to_plot, 237 | just_2_envs, y_types_main_paper, default_smooth, figsize, 238 | label_list=label_list, legend_y_types=['AllNormalizedStdQBias'], legend_es=['ant'], legend_loc='upper right', 239 | overriding_ylimit_dict=overriding_ylimit_dict) 240 | 241 | # 4. abaltion on M 242 | exp_base_to_plot = ['REDQ-n10-m1-5', 'REDQ-n10', 'REDQ-n10-m2-5', 'REDQ-n10-m3', 'REDQ-n10-m5'] 243 | save_path, prefix = 'ablations', 'redq-M' 244 | label_list = ['REDQ-M1.5', 'REDQ-M2', 'REDQ-M2.5', 'REDQ-M3', 'REDQ-M5'] 245 | overriding_ylimit_dict = { 246 | ('ant', 'AllNormalizedAverageQBias'): (-4, 4), 247 | ('walker2d', 'AllNormalizedAverageQBias'): (-1, 2.5), 248 | } 249 | plot_grid(save_path, prefix, exp2dataset, exp_base_to_plot, 250 | just_2_envs, y_types_main_paper, default_smooth, figsize, 251 | label_list=label_list, legend_y_types=['AllNormalizedStdQBias'], legend_es=['ant'], legend_loc='upper right', 252 | overriding_ylimit_dict=overriding_ylimit_dict) 253 | 254 | if plot_ablations_revision: 255 | # 5. different variant of Q target computation comparison (here maybe we should do... ) 256 | exp_base_to_plot = ['REDQ-n10', 'REDQ-weighted'] 257 | save_path, prefix = 'revision', 'redq-weighted' 258 | label_list = ['REDQ', 'Weighted'] 259 | overriding_ylimit_dict = { 260 | ('ant', 'AllNormalizedAverageQBias'): (-1, 1), 261 | # ('hopper', 'AllNormalizedAverageQBias'): (-0.5, 1.5), 262 | # ('humanoid', 'AllNormalizedAverageQBias'): (-0.5, 0), 263 | # ('walker2d', 'AllNormalizedAverageQBias'): (-0.5, 0.5), 264 | } 265 | plot_grid(save_path, prefix, exp2dataset, exp_base_to_plot, 266 | all_4_envs, y_types_main_paper, default_smooth, figsize, 267 | label_list=label_list, legend_y_types=['AllNormalizedStdQBias'], legend_es='all', legend_loc='upper right', 268 | overriding_ylimit_dict=overriding_ylimit_dict) 269 | 270 | if plot_variants: 271 | # 5. different variant of Q target computation comparison (here maybe we should do... ) 272 | exp_base_to_plot = ['REDQ-n10', 'REDQ-rem', 'REDQ-min', 'REDQ-weighted', 'REDQ-minpair'] 273 | save_path, prefix = 'variants', 'redq-var' 274 | label_list = ['REDQ', 'REM', 'Maxmin', 'Weighted', 'MinPair'] 275 | overriding_ylimit_dict = { 276 | ('ant', 'AllNormalizedAverageQBias'): (-2, 4), 277 | } 278 | plot_grid(save_path, prefix, exp2dataset, exp_base_to_plot, 279 | just_2_envs, y_types_main_paper, default_smooth, figsize, 280 | label_list=label_list, legend_y_types=['AllNormalizedStdQBias'], legend_es=['ant'], legend_loc='upper right', 281 | overriding_ylimit_dict=overriding_ylimit_dict) 282 | 283 | def plot_grid(save_path, save_name_prefix, exp2dataset, exp_base_to_plot, envs, y_types, smooth, figsize, label_list=None, 284 | legend_y_types='all', legend_es='all', overriding_ylimit_dict=None, legend_loc='best', longxaxis=False, 285 | linestyle_list=None, color_list=None, override_xlimit=None): 286 | for y_i, y_type in enumerate(y_types): 287 | y_save_name = y2savename_dict[y_type] 288 | for e in envs: 289 | no_legend = decide_no_legend(legend_y_types, legend_es, y_type, e) 290 | exp_to_plot = [] 291 | for exp_base in exp_base_to_plot: 292 | exp_to_plot.append(exp_base + '-' + e) 293 | if override_xlimit is None: 294 | xlimit = 125000 if e == 'hopper' else int(3e5) 295 | else: 296 | xlimit = override_xlimit 297 | if xlimit is None and longxaxis: 298 | xlimit = int(2e6) 299 | ylimit = get_ylimit_from_env_ytype(e, y_type, overriding_ylimit_dict) 300 | assert len(label_list) == len(exp_base_to_plot) 301 | if save_path: 302 | save_name = save_name_prefix + '-%s-%s' %(e, y_save_name) 303 | else: 304 | save_name = None 305 | if color_list is None: 306 | color_list = get_default_colors(exp_base_to_plot) 307 | plot_from_data(None, exp2dataset, exp_to_plot, figsize, value_list=[y_type,], 308 | ylabel='auto', save_path=save_path, label_list=label_list, 309 | save_name=save_name, smooth=smooth, xlimit=xlimit, y_limit = ylimit, 310 | color_list=color_list, no_legend=no_legend, legend_loc=legend_loc, 311 | linestyle_list=linestyle_list) 312 | 313 | if plot_ofe: 314 | exp_base_to_plot = ['REDQ-n10', 'SAC-1', 'MBPO', 'REDQ-ofe'] 315 | save_path, prefix = 'extra', 'redq-ofe' 316 | label_list = ['REDQ', 'SAC', 'MBPO', 'REDQ-OFE'] 317 | plot_grid(save_path, prefix, exp2dataset, exp_base_to_plot, 318 | ['humanoid', 'ant'], ['Performance'], default_smooth, figsize, 319 | label_list=label_list, legend_y_types='all', legend_es=['humanoid'], legend_loc='lower right') 320 | 321 | """#################################### REST IS APPENDIX #################################### """ 322 | if plot_analysis_app: 323 | # appendix: extra figures for 2. analysis 324 | exp_base_to_plot = ['REDQ-n10', 'SAC-20', 'REDQ-ave'] 325 | save_path, prefix = 'analysis_app', 'redq-sac' 326 | label_list = ['REDQ', 'SAC-20', 'AVG'] 327 | plot_grid(save_path, prefix, exp2dataset, exp_base_to_plot, 328 | just_2_envs, y_types_all, default_smooth, figsize, label_list=label_list, 329 | legend_y_types=['AverageQPred'], legend_es='all', legend_loc='upper left') 330 | 331 | if plot_ablations_app: 332 | # appendix: extra figures for N, M comparisons 333 | # app abaltion on N 334 | exp_base_to_plot = ['REDQ-n15', 'REDQ-n10', 'REDQ-n5', 'REDQ-n3', 'REDQ-n2'] 335 | save_path, prefix = 'ablations_app', 'redq-N' 336 | label_list = ['REDQ-N15', 'REDQ-N10', 'REDQ-N5', 'REDQ-N3', 'REDQ-N2'] 337 | plot_grid(save_path, prefix, exp2dataset, exp_base_to_plot, 338 | just_2_envs, y_types_appendix_6, default_smooth, figsize, 339 | label_list=label_list, legend_y_types=['AverageQPred'], legend_es='all', legend_loc='upper left',) 340 | 341 | # app abaltion on M 342 | exp_base_to_plot = ['REDQ-n10-m1-5', 'REDQ-n10', 'REDQ-n10-m2-5', 'REDQ-n10-m3', 'REDQ-n10-m5'] 343 | save_path, prefix = 'ablations_app', 'redq-M' 344 | label_list = ['REDQ-M1.5', 'REDQ-M2', 'REDQ-M2.5', 'REDQ-M3', 'REDQ-M5'] 345 | plot_grid(save_path, prefix, exp2dataset, exp_base_to_plot, 346 | just_2_envs, y_types_appendix_6, default_smooth, figsize, 347 | label_list=label_list, legend_y_types=['AverageQPred'], legend_es='all', legend_loc='upper left',) 348 | 349 | if plot_variants_app: 350 | # appendix: extra figures for variants 351 | exp_base_to_plot = ['REDQ-n10', 'REDQ-rem', 'REDQ-min', 'REDQ-weighted', 'REDQ-minpair'] 352 | save_path, prefix = 'variants_app', 'redq-var' 353 | label_list = ['REDQ', 'REM', 'MIN', 'Weighted', 'MinPair'] 354 | overriding_ylimit_dict = { 355 | ('ant', 'AllNormalizedAverageQBias'): (-2, 4), 356 | } 357 | plot_grid(save_path, prefix, exp2dataset, exp_base_to_plot, 358 | just_2_envs, y_types_appendix_6, default_smooth, figsize, 359 | label_list=label_list, legend_y_types=['AverageQPred'], legend_es='all', legend_loc='upper left', 360 | overriding_ylimit_dict=overriding_ylimit_dict) 361 | 362 | if plot_utd_app: 363 | # REDQ different utd 364 | exp_base_to_plot = [ 'REDQ-n10', 'REDQ-n10-utd10', 'REDQ-n10-utd5', 'REDQ-n10-utd1', ] 365 | save_path, prefix = 'utd_app', 'utd-redq' 366 | label_list = [ 'REDQ-UTD20','REDQ-UTD10','REDQ-UTD5' , 'REDQ-UTD1',] 367 | plot_grid(save_path, prefix, exp2dataset, exp_base_to_plot, 368 | just_2_envs, y_types_all, default_smooth, figsize, 369 | label_list=label_list, legend_y_types=['Performance'], legend_es='all', legend_loc='upper left',) 370 | # SAC different utd 371 | exp_base_to_plot = ['SAC-20', 'SAC-10', 'SAC-5', 'SAC-1'] 372 | save_path, prefix = 'utd_app', 'utd-sac' 373 | label_list = ['SAC-UTD20', 'SAC-UTD10', 'SAC-UTD5', 'SAC-UTD1',] 374 | plot_grid(save_path, prefix, exp2dataset, exp_base_to_plot, 375 | just_2_envs, y_types_all, default_smooth, figsize, 376 | label_list=label_list, legend_y_types=['Performance'], legend_es='all', legend_loc='upper left',) 377 | 378 | if plot_pd_app: 379 | exp_base_to_plot = ['REDQ-n10', 'REDQ-n10-pd1', 'SAC-20', 'SAC-pd1-20'] 380 | save_path, prefix = 'pd_app', 'pd-redq' 381 | label_list = ['REDQ','REDQ-NPD','SAC-20', 'SAC-20-NPD',] 382 | linestyle_list = ['solid', 'dashed', 'solid', 'dashed'] 383 | plot_grid(save_path, prefix, exp2dataset, exp_base_to_plot, 384 | all_4_envs, y_types_all, default_smooth, figsize, 385 | label_list=label_list, legend_y_types=['Performance'], legend_es='all', legend_loc='upper left', 386 | linestyle_list=linestyle_list) 387 | -------------------------------------------------------------------------------- /plot_utils/redq_plot_helper.py: -------------------------------------------------------------------------------- 1 | """ 2 | NOTE: currently only works with seaborn 0.8.1 the tsplot function is deprecated in the newer version 3 | the plotting function is originally based on the plot function in OpenAI spinningup 4 | """ 5 | import os 6 | import matplotlib.pyplot as plt 7 | import numpy as np 8 | import pandas as pd 9 | import seaborn as sns 10 | import json 11 | from packaging import version 12 | 13 | DIV_LINE_WIDTH = 50 14 | 15 | # Global vars for tracking and labeling data at load time. 16 | exp_idx = 0 17 | units = dict() 18 | 19 | def get_datasets(logdir, condition=None): 20 | """ 21 | Recursively look through logdir for output files produced by 22 | spinup.logx.Logger. 23 | 24 | Assumes that any file "progress.txt" is a valid hit. 25 | the "condition" here can be a string, when plotting, can be used as label on the legend 26 | """ 27 | global exp_idx 28 | global units 29 | datasets = [] 30 | for root, _, files in os.walk(logdir): 31 | if 'progress.txt' in files: 32 | try: 33 | exp_data = pd.read_table(os.path.join(root, 'progress.txt')) 34 | exp_name = None 35 | try: 36 | config_path = open(os.path.join(root, 'config.json')) 37 | config = json.load(config_path) 38 | if 'exp_name' in config: 39 | exp_name = config['exp_name'] 40 | except: 41 | print('No file named config.json') 42 | condition1 = condition or exp_name or 'exp' 43 | condition2 = condition1 + '-' + str(exp_idx) 44 | exp_idx += 1 45 | if condition1 not in units: 46 | units[condition1] = 0 47 | unit = units[condition1] 48 | units[condition1] += 1 49 | 50 | print(os.path.join(root, 'progress.txt')) 51 | 52 | # exp_data = pd.read_table(os.path.join(root, 'progress.txt')) 53 | performance = 'AverageTestEpRet' if 'AverageTestEpRet' in exp_data else 'AverageEpRet' 54 | exp_data.insert(len(exp_data.columns), 'Unit', unit) 55 | exp_data.insert(len(exp_data.columns), 'Condition1', condition1) 56 | exp_data.insert(len(exp_data.columns), 'Condition2', condition2) 57 | if performance in exp_data: 58 | exp_data.insert(len(exp_data.columns), 'Performance', exp_data[performance]) 59 | datasets.append(exp_data) 60 | except Exception as e: 61 | print(e) 62 | 63 | return datasets 64 | 65 | 66 | y2savename_dict = { 67 | 'Performance':'score', 68 | 'AverageNormQBias':'bias-ave-n', 69 | 'StdNormQBias':'bias-std-n', 70 | 'LossQ1':'qloss', 71 | 'NormLossQ1':'qloss-n', 72 | 'MaxPreTanh':'pretanh-max', 73 | 'AveragePreTanh':'pretanh-ave', 74 | 'AverageQBias':'bias-ave', 75 | 'StdQBias':'bias-std', 76 | 'AverageQPred':'qpred', 77 | 'AverageQBiasSqr':'biassqr-ave', 78 | 'AverageNormQBiasSqr':'biassqr-ave-n', 79 | 'AverageAlpha':'alpha-ave', 80 | 'MaxAlpha':'alpha', 81 | 'AverageQ1Vals':'q1', 82 | 'AverageLogPi':'average-logpi', 83 | 'MaxLogPi': 'max-logpi', 84 | 'AllNormalizedAverageQBias':'bias-ave-alln', 85 | 'AllNormalizedStdQBias':'bias-std-alln', 86 | 'AllNormalizedAverageQBiasSqr':'biassqr-ave-alln', 87 | 'Time':'time' 88 | } 89 | 90 | y2ylabel_dict = { 91 | 'Performance':'Average return', 92 | 'AverageNormQBias':'Average normalized bias', 93 | 'StdNormQBias':'Std of normalized bias', 94 | 'LossQ1':'Q loss', 95 | 'NormLossQ1':'Normalized Q loss', 96 | 'MaxPreTanh':'Max pretanh', 97 | 'AveragePreTanh':'Average pretanh', 98 | 'AverageQBias':'Average bias', 99 | 'StdQBias':'Std of bias', 100 | 'AverageQPred':'Average Q value', 101 | 'AverageQBiasSqr':'Average MSE', 102 | 'AverageNormQBiasSqr':'Average normalized MSE', 103 | 'AverageAlpha':'Average alpha', 104 | 'MaxAlpha':'Max alpha', 105 | 'AverageQ1Vals':'Q value', 106 | 'AverageLogPi': 'Average logPi', 107 | 'MaxLogPi': 'Max logPi', 108 | 'AllNormalizedAverageQBias': 'Average normalized bias', 109 | 'AllNormalizedStdQBias': 'Std of normalized bias', 110 | 'AllNormalizedAverageQBiasSqr': 'Average normalized MSE', 111 | 'Time':'Time' 112 | } 113 | 114 | # we can use strict mapping from exp base name to color 115 | expbase2color = { 116 | 'SAC-20': 'grey', 117 | 'SAC-10': 'slateblue', 118 | 'SAC-5': 'blue', 119 | 'SAC-1': 'skyblue', # blue-black for SAC, MBPO 120 | 'SAC-hs1':'black', 121 | 'SAC-hs2':'brown', 122 | 'SAC-hs3':'purple', 123 | 'SAC-pd1-5': 'black', 124 | 'SAC-pd1-10': 'brown', 125 | 'SAC-pd1-20': 'grey', 126 | 'MBPO':'tab:blue', 127 | 'REDQ-n15':'tab:orange', # dark red to light purple 128 | 'REDQ-n10':'tab:red', 129 | 'REDQ-n5': 'tab:cyan', 130 | 'REDQ-n3': 'tab:grey', 131 | 'REDQ-n2': 'black', 132 | 'REDQ-n10-m1-5':'tab:orange', # red-brown for M variants? 133 | 'REDQ-n10-m2-5':'tab:cyan', 134 | 'REDQ-n10-m3':'tab:grey', 135 | 'REDQ-n10-m5': 'black', 136 | 'REDQ-n10-hs1': 'tab:orange', 137 | 'REDQ-n10-hs2': 'violet', 138 | 'REDQ-n10-hs3': 'lightblue', 139 | 'REDQ-weighted': 'indigo', # then for redq q target variants, some random stuff 140 | 'REDQ-minpair': 'royalblue', 141 | 'REDQ-ave': 'peru', 142 | 'REDQ-rem': 'yellow', 143 | 'REDQ-min': 'slategrey', 144 | 'REDQ-n10-utd10': 'tab:cyan', 145 | 'REDQ-n10-utd5': 'tab:grey', 146 | 'REDQ-n10-utd1': 'black', 147 | 'REDQ-min-n3-utd1':'blue', 148 | 'REDQ-min-n3-utd5':'lightblue', 149 | 'REDQ-min-n3-utd10':'grey', 150 | 'REDQ-min-n3-utd20':'black', 151 | 'REDQ-ave-utd1':'blue', 152 | 'REDQ-ave-utd5':'grey', 153 | 'REDQ-ofe':'purple', 154 | 'REDQ-ofe-long':'purple', 155 | 'REDQ-dense': 'deeppink', 156 | 'SAC-dense': 'sandybrown', 157 | 'REDQ-ave-n15':'black', 158 | 'REDQ-ave-n5':'tab:orange', 159 | 'REDQ-ave-n3':'violet', 160 | 'REDQ-ave-n2':'lightblue', 161 | 'SAC-long':'skyblue', 162 | 'REDQ-n10-pd1':'tab:red', 163 | 'REDQ-fine':'tab:pink', 164 | 'REDQ-ss10k':'tab:pink', 165 | 'REDQ-ss15k': 'tab:orange', 166 | 'REDQ-ss20k': 'yellow', 167 | 'REDQ-weighted-ss10k':'lightblue', 168 | 'REDQ-weighted-ss15k': 'royalblue', 169 | 'REDQ-weighted-ss20k': 'black', 170 | 'REDQ-m1':'grey', 171 | 'REDQ-more':'tab:red', 172 | 'REDQ-weighted-more':'indigo', 173 | 'REDQ-anneal25k': 'tab:pink', 174 | 'REDQ-anneal50k': 'tab:orange', 175 | 'REDQ-anneal100k': 'yellow', 176 | 'REDQ-weighted-anneal25k': 'lightblue', 177 | 'REDQ-weighted-anneal50k': 'royalblue', 178 | 'REDQ-weighted-anneal100k': 'black', 179 | 'REDQ-init0-4': 'tab:pink', 180 | 'REDQ-init0-3': 'tab:orange', 181 | 'REDQ-init0-2': 'yellow', 182 | 'REDQ-weighted-init0-4': 'lightblue', 183 | 'REDQ-weighted-init0-3': 'royalblue', 184 | 'REDQ-weighted-init0-2': 'black', 185 | 'REDQ-weighted-a0-1':'lightblue', 186 | 'REDQ-weighted-a0-15': 'royalblue', 187 | 'REDQ-weighted-a0-2':'black', 188 | 'REDQ-weighted-a0-25':'pink', 189 | 'REDQ-weighted-a0-3':'yellow', 190 | 'REDQ-a0-1': 'lightblue', 191 | 'REDQ-a0-15': 'royalblue', 192 | 'REDQ-a0-2': 'black', 193 | 'REDQ-a0-25': 'pink', 194 | 'REDQ-a0-3': 'yellow', 195 | 'REDQ-weighted-wns0-1': 'lightblue', 196 | 'REDQ-weighted-wns0-3': 'royalblue', 197 | 'REDQ-weighted-wns0-5': 'black', 198 | 'REDQ-weighted-wns0-8': 'pink', 199 | 'REDQ-weighted-wns1-2': 'yellow', 200 | } 201 | 202 | def get_ylimit_from_env_ytype(e, ytype, overriding_dict=None): 203 | # will map env and ytype to ylimit 204 | # so that the plot looks better 205 | # the plot should give a reasonable y range so that the outlier won't dominate 206 | # can use a overriding dictionary to override the default values here 207 | default_dict = { 208 | ('ant', 'AverageQBias'): (-200, 300), 209 | ('ant', 'AverageNormQBias'): (-1, 4), 210 | ('ant', 'AllNormalizedAverageQBias'): (-1, 4), 211 | ('ant', 'StdQBias'): (0, 150), 212 | ('ant', 'StdNormQBias'): (0, 5), 213 | ('ant', 'AllNormalizedStdQBias'): (0, 2), 214 | ('ant', 'LossQ1'): (0, 60), 215 | ('ant', 'NormLossQ1'): (0, 0.4), 216 | ('ant', 'AverageQBiasSqr'): (0, 50000), 217 | ('ant', 'AllNormalizedAverageQBiasSqr'): (0, 400), 218 | ('ant', 'AverageNormQBiasSqr'): (0, 1200), 219 | ('ant', 'AverageQPred'): (-200, 1000), 220 | ('walker2d', 'AverageQBias'): (-150, 250), 221 | ('walker2d', 'AverageNormQBias'): (-0.5, 4), 222 | ('walker2d', 'AllNormalizedAverageQBias'): (-0.5, 2.5), 223 | ('walker2d', 'StdQBias'): (0, 150), 224 | ('walker2d', 'StdNormQBias'): (0, 4), 225 | ('walker2d', 'AllNormalizedStdQBias'):(0, 1.5), 226 | ('walker2d', 'LossQ1'): (0, 60), 227 | ('walker2d', 'NormLossQ1'): (0, 0.5), 228 | ('walker2d', 'AverageQBiasSqr'): (0, 50000), 229 | ('walker2d', 'AverageNormQBiasSqr'): (0, 1200), 230 | ('walker2d', 'AllNormalizedAverageQBiasSqr'): (0, 1200), 231 | ('walker2d', 'AverageQPred'): (-100, 600), 232 | ('humanoid', 'AllNormalizedAverageQBias'): (-1, 2.5), 233 | ('humanoid', 'AllNormalizedStdQBias'): (0, 1.5), 234 | } 235 | key = (e, ytype) 236 | if overriding_dict and key in overriding_dict: 237 | return overriding_dict[key] 238 | else: 239 | if key in default_dict: 240 | return default_dict[key] 241 | else: 242 | return None 243 | 244 | def select_data_list_for_plot(exp2dataset, exp_to_plot): 245 | data_list = [] 246 | for exp in exp_to_plot: 247 | # use append here, since we expect exp2dataset[exp] to be a single pandas df 248 | data_list.append(exp2dataset[exp]) 249 | return data_list # now a list of list 250 | 251 | def plot_from_data(data_list, exp2dataset=None, exp_to_plot=None, figsize=None, xaxis='TotalEnvInteracts', 252 | value_list=['Performance',], color_list=None, linestyle_list=None, label_list=None, 253 | count=False, 254 | font_scale=1.5, smooth=1, estimator='mean', no_legend=False, 255 | legend_loc='best', title=None, save_name=None, save_path = None, 256 | xlimit=-1, y_limit=None, label_font_size=24, xlabel=None, ylabel=None, 257 | y_log_scale=False): 258 | """ 259 | either give a data list 260 | or give a exp2dataset dictionary, and then specify the exp_to_plot 261 | will plot each experiment setting in data_list to the same figure. 262 | the label will be basically 'Condition1' column by default 263 | causal_plot will not care about order and will be messy and each time the order can be different. 264 | for value list if it contains only one value then we will use that for all experiments 265 | """ 266 | if data_list is not None: 267 | data_list = data_list 268 | else: 269 | data_list = select_data_list_for_plot(exp2dataset, exp_to_plot) 270 | 271 | n_curves = len(data_list) 272 | if len(value_list) == 1: 273 | value_list = [value_list[0] for _ in range(n_curves)] 274 | default_colors = ['tab:blue','tab:orange','tab:green','tab:red', 275 | 'tab:purple','tab:brown','tab:pink','tab:grey','tab:olive','tab:cyan',] 276 | color_list = default_colors if color_list is None else color_list 277 | condition = 'Condition2' if count else 'Condition1' 278 | estimator = getattr(np, estimator) # choose what to show on main curve: mean? max? min? 279 | plt.figure(figsize=figsize) if figsize else plt.figure() 280 | ########################## 281 | value_list_smooth_temp = [] 282 | """ 283 | smooth data with moving window average. 284 | that is, 285 | smoothed_y[t] = average(y[t-k], y[t-k+1], ..., y[t+k-1], y[t+k]) 286 | where the "smooth" param is width of that window (2k+1) 287 | IF WE MODIFY DATA DIRECTLY, THEN CAN LEAD TO PLOTTING BUG WHERE 288 | IT'S MODIFIED MULTIPLE TIMES 289 | """ 290 | y = np.ones(smooth) 291 | for i, data_seeds in enumerate(data_list): 292 | temp_value_name = value_list[i] + '__smooth_temp' 293 | value_list_smooth_temp.append(temp_value_name) 294 | for data in data_seeds: 295 | x = np.asarray(data[value_list[i]].copy()) 296 | z = np.ones(len(x)) 297 | smoothed_x = np.convolve(x, y, 'same') / np.convolve(z, y, 'same') 298 | # data[value_list[i]] = smoothed_x # this can be problematic 299 | if temp_value_name not in data: 300 | data.insert(len(data.columns), temp_value_name, smoothed_x) 301 | else: 302 | data[temp_value_name] = smoothed_x 303 | 304 | sns.set(style="darkgrid", font_scale=font_scale) 305 | # sns.set_palette('bright') 306 | # have the same axis (figure), plot one by one onto it 307 | ax = None 308 | if version.parse(sns.__version__) <= version.parse('0.8.1'): 309 | for i, data_seeds in enumerate(data_list): 310 | data_combined = pd.concat(data_seeds, ignore_index=True) 311 | ax = sns.tsplot(data=data_combined, time=xaxis, value=value_list_smooth_temp[i], unit="Unit", 312 | condition=condition, 313 | legend=(not no_legend), ci='sd', 314 | n_boot=0, color=color_list[i], ax=ax) 315 | else: 316 | print("Error: Seaborn version > 0.8.1 is currently not supported.") 317 | quit() 318 | 319 | if linestyle_list is not None: 320 | for i in range(len(linestyle_list)): 321 | ax.lines[i].set_linestyle(linestyle_list[i]) 322 | if label_list is not None: 323 | for i in range(len(label_list)): 324 | ax.lines[i].set_label(label_list[i]) 325 | xlabel = 'environment interactions' if xlabel is None else xlabel 326 | 327 | if ylabel is None: 328 | ylabel = 'average test return' 329 | elif ylabel == 'auto': 330 | if value_list[0] in y2ylabel_dict: 331 | ylabel = y2ylabel_dict[value_list[0]] 332 | else: 333 | ylabel = value_list[0] 334 | else: 335 | ylabel = ylabel 336 | plt.xlabel(xlabel, fontsize=label_font_size) 337 | plt.ylabel(ylabel, fontsize=label_font_size) 338 | if not no_legend: 339 | plt.legend(loc=legend_loc, fontsize=label_font_size) 340 | 341 | """ 342 | For the version of the legend used in the Spinning Up benchmarking page, 343 | plt.legend(loc='upper center', ncol=6, handlelength=1, 344 | mode="expand", borderaxespad=0., prop={'size': 13}) 345 | """ 346 | xscale = np.max(np.asarray(data[xaxis])) > 5e3 347 | if xscale: 348 | # Just some formatting niceness: x-axis scale in scientific notation if max x is large 349 | plt.ticklabel_format(style='sci', axis='x', scilimits=(0, 0)) 350 | if y_log_scale: 351 | plt.yscale('log') 352 | if xlimit > 0: 353 | plt.xlim(0, xlimit) 354 | if y_limit: 355 | plt.ylim(y_limit[0], y_limit[1]) 356 | if title: 357 | plt.title(title) 358 | plt.tight_layout() 359 | if save_name is not None: 360 | fig = plt.gcf() 361 | if not os.path.isdir(save_path): 362 | os.mkdir(save_path) 363 | fig.savefig(os.path.join(save_path, save_name)) 364 | plt.close(fig) 365 | else: 366 | plt.show() 367 | 368 | 369 | def decide_no_legend(legend_y_types, legend_es, y_type, e): 370 | # return no_legend as a bool (=True means no legend will be plotted) 371 | if (legend_y_types == 'all' or y_type in legend_y_types) and (legend_es == 'all' or e in legend_es): 372 | return False 373 | else: 374 | return True 375 | 376 | def plot_grid(save_path, save_name_prefix, exp2dataset, exp_base_to_plot, envs, y_types, smooth, figsize, label_list=None, 377 | legend_y_types='all', legend_es='all', overriding_ylimit_dict=None, legend_loc='best', longxaxis=False, 378 | linestyle_list=None, color_list=None): 379 | """ 380 | this function will help us do generic redq plots very easily and fast. 381 | Args: 382 | save_path: where to save the figures (i.e. 'figures') 383 | envs: specify the envs that will be plotted (i.e. 'ant') 384 | exp_base_to_plot: base name of the exp settings (i.e. 'REDQ-n10') 385 | y_types: y type to be plotting (for example, 'Performance') 386 | smooth: 387 | """ 388 | for y_i, y_type in enumerate(y_types): 389 | y_save_name = y2savename_dict[y_type] 390 | for e in envs: 391 | no_legend = decide_no_legend(legend_y_types, legend_es, y_type, e) 392 | exp_to_plot = [] 393 | for exp_base in exp_base_to_plot: 394 | exp_to_plot.append(exp_base + '-' + e) 395 | xlimit = 125000 if e == 'hopper' else int(3e5) 396 | if longxaxis: 397 | xlimit = int(2e6) 398 | ylimit = get_ylimit_from_env_ytype(e, y_type, overriding_ylimit_dict) 399 | assert len(label_list) == len(exp_base_to_plot) 400 | if save_path: 401 | save_name = save_name_prefix + '-%s-%s' %(e, y_save_name) 402 | else: 403 | save_name = None 404 | if color_list is None: 405 | color_list = get_default_colors(exp_base_to_plot) 406 | plot_from_data(None, exp2dataset, exp_to_plot, figsize, value_list=[y_type,], 407 | ylabel='auto', save_path=save_path, label_list=label_list, 408 | save_name=save_name, smooth=smooth, xlimit=xlimit, y_limit = ylimit, 409 | color_list=color_list, no_legend=no_legend, legend_loc=legend_loc, 410 | linestyle_list=linestyle_list) 411 | 412 | def plot_grid_solid_dashed(save_path, save_name_prefix, exp2dataset, exp_base_to_plot, envs, y_1, y_2, smooth, figsize, label_list=None): 413 | """ 414 | this function will help us do generic redq plots very easily and fast. 415 | Args: 416 | save_path: where to save the figures (i.e. 'figures') 417 | envs: specify the envs that will be plotted (i.e. 'ant') 418 | exp_base_to_plot: base name of the exp settings (i.e. 'REDQ-n10') 419 | y_types: y type to be plotting (for example, 'Performance') 420 | smooth: 421 | """ 422 | y_save_name = y2savename_dict[y_1] # typically this is Q value, then y2 is bias 423 | for e in envs: 424 | exp_to_plot = [] 425 | for exp_base in exp_base_to_plot: 426 | exp_to_plot.append(exp_base + '-' + e) 427 | value_list = [] 428 | linestyle_list = [] 429 | for i in range(len(exp_to_plot)): 430 | value_list.append(y_1) 431 | linestyle_list.append('solid') 432 | for i in range(len(exp_to_plot)): 433 | value_list.append(y_2) 434 | linestyle_list.append('dashed') 435 | exp_to_plot = exp_to_plot + exp_to_plot 436 | exp_base_to_plot_for_color = exp_base_to_plot + exp_base_to_plot 437 | xlimit = 125000 if e == 'hopper' else int(3e5) 438 | ylimit = get_ylimit_from_env_ytype(e, y_1) 439 | plot_from_data(None, exp2dataset, exp_to_plot, figsize, value_list=value_list, linestyle_list=linestyle_list, 440 | ylabel='auto', save_path=save_path, label_list=label_list, 441 | save_name=save_name_prefix + '-%s-%s' %(e, y_save_name), smooth=smooth, xlimit=xlimit, y_limit = ylimit, 442 | color_list=get_default_colors(exp_base_to_plot_for_color)) 443 | 444 | def get_default_colors(exp_base_names): 445 | # will map experiment base name (the good names) to color 446 | # later ones 447 | colors = [] 448 | for i in range(len(exp_base_names)): 449 | exp_base_name = exp_base_names[i] 450 | colors.append(expbase2color[exp_base_name]) 451 | return colors 452 | 453 | def get_exp2dataset(exp2path, base_path): 454 | """ 455 | Args: 456 | exp2path: a dictionary containing experiment name (you can decide what this is, can make it easy to read, 457 | will be set as "condition") to the actual experiment folder name 458 | base_path: the base path lead to where the experiment folders are 459 | Returns: 460 | a dictionary with keys being the experiment name, value being a pandas dataframe containing 461 | the experiment progress (multiple seeds in one dataframe) 462 | """ 463 | exp2dataset = {} 464 | for key in exp2path.keys(): 465 | complete_path = os.path.join(base_path, exp2path[key]) 466 | data_list = get_datasets(complete_path, key) # a list of different seeds 467 | # combined_df_across_seeds = pd.concat(data_list, ignore_index=True) 468 | # exp2dataset[key] = combined_df_across_seeds 469 | for d in data_list: 470 | try: 471 | add_normalized_values(d) 472 | except: 473 | print("can't add normalized loss in data:", key) 474 | exp2dataset[key] = data_list 475 | return exp2dataset 476 | 477 | def add_normalized_values(d): 478 | normalize_base = d['AverageMCDisRetEnt'].copy().abs() 479 | normalize_base[normalize_base < 10] = 10 480 | # use this to normalize the Q loss, Q bias 481 | 482 | normalized_qloss = d['LossQ1'] / normalize_base 483 | d.insert(len(d.columns), 'NormLossQ1', normalized_qloss) 484 | normalized_q_bias = d['AverageQBias'] / normalize_base 485 | d.insert(len(d.columns), 'AllNormalizedAverageQBias', normalized_q_bias) 486 | normalized_std_q_bias = d['StdQBias'] / normalize_base 487 | d.insert(len(d.columns), 'AllNormalizedStdQBias', normalized_std_q_bias) 488 | normalized_bias_mse = d['AverageQBiasSqr'] / normalize_base 489 | d.insert(len(d.columns), 'AllNormalizedAverageQBiasSqr', normalized_bias_mse) 490 | -------------------------------------------------------------------------------- /redq/__init__.py: -------------------------------------------------------------------------------- 1 | __version__ = '0.0.1' -------------------------------------------------------------------------------- /redq/algos/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/watchernyu/REDQ/7b5d1bff39291a57325a2836bd397a55728960bb/redq/algos/__init__.py -------------------------------------------------------------------------------- /redq/algos/core.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | from torch.nn import functional as F 5 | from torch.distributions import Distribution, Normal 6 | # following SAC authors' and OpenAI implementation 7 | LOG_SIG_MAX = 2 8 | LOG_SIG_MIN = -20 9 | ACTION_BOUND_EPSILON = 1E-6 10 | # these numbers are from the MBPO paper 11 | mbpo_target_entropy_dict = {'Hopper-v2':-1, 'HalfCheetah-v2':-3, 'Walker2d-v2':-3, 'Ant-v2':-4, 'Humanoid-v2':-2, 12 | 'Hopper-v3':-1, 'HalfCheetah-v3':-3, 'Walker2d-v3':-3, 'Ant-v3':-4, 'Humanoid-v3':-2, 13 | 'Hopper-v4':-1, 'HalfCheetah-v4':-3, 'Walker2d-v4':-3, 'Ant-v4':-4, 'Humanoid-v4':-2} 14 | mbpo_epoches = {'Hopper-v2':125, 'Walker2d-v2':300, 'Ant-v2':300, 'HalfCheetah-v2':400, 'Humanoid-v2':300, 15 | 'Hopper-v3':125, 'Walker2d-v3':300, 'Ant-v3':300, 'HalfCheetah-v3':400, 'Humanoid-v3':300, 16 | 'Hopper-v4':125, 'Walker2d-v4':300, 'Ant-v4':300, 'HalfCheetah-v4':400, 'Humanoid-v4':300} 17 | 18 | def weights_init_(m): 19 | # weight init helper function 20 | if isinstance(m, nn.Linear): 21 | torch.nn.init.xavier_uniform_(m.weight, gain=1) 22 | torch.nn.init.constant_(m.bias, 0) 23 | 24 | class ReplayBuffer: 25 | """ 26 | A simple FIFO experience replay buffer 27 | """ 28 | def __init__(self, obs_dim, act_dim, size): 29 | """ 30 | :param obs_dim: size of observation 31 | :param act_dim: size of the action 32 | :param size: size of the buffer 33 | """ 34 | ## init buffers as numpy arrays 35 | self.obs1_buf = np.zeros([size, obs_dim], dtype=np.float32) 36 | self.obs2_buf = np.zeros([size, obs_dim], dtype=np.float32) 37 | self.acts_buf = np.zeros([size, act_dim], dtype=np.float32) 38 | self.rews_buf = np.zeros(size, dtype=np.float32) 39 | self.done_buf = np.zeros(size, dtype=np.float32) 40 | self.ptr, self.size, self.max_size = 0, 0, size 41 | 42 | def store(self, obs, act, rew, next_obs, done): 43 | """ 44 | data will get stored in the pointer's location 45 | """ 46 | self.obs1_buf[self.ptr] = obs 47 | self.obs2_buf[self.ptr] = next_obs 48 | self.acts_buf[self.ptr] = act 49 | self.rews_buf[self.ptr] = rew 50 | self.done_buf[self.ptr] = done 51 | ## move the pointer to store in next location in buffer 52 | self.ptr = (self.ptr+1) % self.max_size 53 | ## keep track of the current buffer size 54 | self.size = min(self.size+1, self.max_size) 55 | 56 | def sample_batch(self, batch_size=32, idxs=None): 57 | """ 58 | :param batch_size: size of minibatch 59 | :param idxs: specify indexes if you want specific data points 60 | :return: mini-batch data as a dictionary 61 | """ 62 | if idxs is None: 63 | idxs = np.random.randint(0, self.size, size=batch_size) 64 | return dict(obs1=self.obs1_buf[idxs], 65 | obs2=self.obs2_buf[idxs], 66 | acts=self.acts_buf[idxs], 67 | rews=self.rews_buf[idxs], 68 | done=self.done_buf[idxs], 69 | idxs=idxs) 70 | 71 | 72 | class Mlp(nn.Module): 73 | def __init__( 74 | self, 75 | input_size, 76 | output_size, 77 | hidden_sizes, 78 | hidden_activation=F.relu 79 | ): 80 | super().__init__() 81 | 82 | self.input_size = input_size 83 | self.output_size = output_size 84 | self.hidden_activation = hidden_activation 85 | ## here we use ModuleList so that the layers in it can be 86 | ## detected by .parameters() call 87 | self.hidden_layers = nn.ModuleList() 88 | in_size = input_size 89 | 90 | ## initialize each hidden layer 91 | for i, next_size in enumerate(hidden_sizes): 92 | fc_layer = nn.Linear(in_size, next_size) 93 | in_size = next_size 94 | self.hidden_layers.append(fc_layer) 95 | 96 | ## init last fully connected layer with small weight and bias 97 | self.last_fc_layer = nn.Linear(in_size, output_size) 98 | self.apply(weights_init_) 99 | 100 | def forward(self, input): 101 | h = input 102 | for i, fc_layer in enumerate(self.hidden_layers): 103 | h = fc_layer(h) 104 | h = self.hidden_activation(h) 105 | output = self.last_fc_layer(h) 106 | return output 107 | 108 | class TanhNormal(Distribution): 109 | """ 110 | Represent distribution of X where 111 | X ~ tanh(Z) 112 | Z ~ N(mean, std) 113 | Note: this is not very numerically stable. 114 | """ 115 | def __init__(self, normal_mean, normal_std, epsilon=1e-6): 116 | """ 117 | :param normal_mean: Mean of the normal distribution 118 | :param normal_std: Std of the normal distribution 119 | :param epsilon: Numerical stability epsilon when computing log-prob. 120 | """ 121 | self.normal_mean = normal_mean 122 | self.normal_std = normal_std 123 | self.normal = Normal(normal_mean, normal_std) 124 | self.epsilon = epsilon 125 | 126 | def log_prob(self, value, pre_tanh_value=None): 127 | """ 128 | return the log probability of a value 129 | :param value: some value, x 130 | :param pre_tanh_value: arctanh(x) 131 | :return: 132 | """ 133 | # use arctanh formula to compute arctanh(value) 134 | if pre_tanh_value is None: 135 | pre_tanh_value = torch.log( 136 | (1+value) / (1-value) 137 | ) / 2 138 | return self.normal.log_prob(pre_tanh_value) - \ 139 | torch.log(1 - value * value + self.epsilon) 140 | 141 | def sample(self, return_pretanh_value=False): 142 | """ 143 | Gradients will and should *not* pass through this operation. 144 | See https://github.com/pytorch/pytorch/issues/4620 for discussion. 145 | """ 146 | z = self.normal.sample().detach() 147 | 148 | if return_pretanh_value: 149 | return torch.tanh(z), z 150 | else: 151 | return torch.tanh(z) 152 | 153 | def rsample(self, return_pretanh_value=False): 154 | """ 155 | Sampling in the reparameterization case. 156 | Implement: tanh(mu + sigma * eksee) 157 | with eksee~N(0,1) 158 | z here is mu+sigma+eksee 159 | """ 160 | z = ( 161 | self.normal_mean + 162 | self.normal_std * 163 | Normal( ## this part is eksee~N(0,1) 164 | torch.zeros(self.normal_mean.size()), 165 | torch.ones(self.normal_std.size()) 166 | ).sample() 167 | ) 168 | if return_pretanh_value: 169 | return torch.tanh(z), z 170 | else: 171 | return torch.tanh(z) 172 | 173 | class TanhGaussianPolicy(Mlp): 174 | """ 175 | A Gaussian policy network with Tanh to enforce action limits 176 | """ 177 | def __init__( 178 | self, 179 | obs_dim, 180 | action_dim, 181 | hidden_sizes, 182 | hidden_activation=F.relu, 183 | action_limit=1.0 184 | ): 185 | super().__init__( 186 | input_size=obs_dim, 187 | output_size=action_dim, 188 | hidden_sizes=hidden_sizes, 189 | hidden_activation=hidden_activation, 190 | ) 191 | last_hidden_size = obs_dim 192 | if len(hidden_sizes) > 0: 193 | last_hidden_size = hidden_sizes[-1] 194 | ## this is the layer that gives log_std, init this layer with small weight and bias 195 | self.last_fc_log_std = nn.Linear(last_hidden_size, action_dim) 196 | ## action limit: for example, humanoid has an action limit of -0.4 to 0.4 197 | self.action_limit = action_limit 198 | self.apply(weights_init_) 199 | 200 | def forward( 201 | self, 202 | obs, 203 | deterministic=False, 204 | return_log_prob=True, 205 | ): 206 | """ 207 | :param obs: Observation 208 | :param reparameterize: if True, use the reparameterization trick 209 | :param deterministic: If True, take determinisitc (test) action 210 | :param return_log_prob: If True, return a sample and its log probability 211 | """ 212 | h = obs 213 | for fc_layer in self.hidden_layers: 214 | h = self.hidden_activation(fc_layer(h)) 215 | mean = self.last_fc_layer(h) 216 | 217 | log_std = self.last_fc_log_std(h) 218 | log_std = torch.clamp(log_std, LOG_SIG_MIN, LOG_SIG_MAX) 219 | std = torch.exp(log_std) 220 | 221 | normal = Normal(mean, std) 222 | 223 | if deterministic: 224 | pre_tanh_value = mean 225 | action = torch.tanh(mean) 226 | else: 227 | pre_tanh_value = normal.rsample() 228 | action = torch.tanh(pre_tanh_value) 229 | 230 | if return_log_prob: 231 | log_prob = normal.log_prob(pre_tanh_value) 232 | log_prob -= torch.log(1 - action.pow(2) + ACTION_BOUND_EPSILON) 233 | log_prob = log_prob.sum(1, keepdim=True) 234 | else: 235 | log_prob = None 236 | 237 | return ( 238 | action * self.action_limit, mean, log_std, log_prob, std, pre_tanh_value, 239 | ) 240 | 241 | def soft_update_model1_with_model2(model1, model2, rou): 242 | """ 243 | used to polyak update a target network 244 | :param model1: a pytorch model 245 | :param model2: a pytorch model of the same class 246 | :param rou: the update is model1 <- rou*model1 + (1-rou)model2 247 | """ 248 | for model1_param, model2_param in zip(model1.parameters(), model2.parameters()): 249 | model1_param.data.copy_(rou*model1_param.data + (1-rou)*model2_param.data) 250 | 251 | def test_agent(agent, test_env, max_ep_len, logger, n_eval=1): 252 | """ 253 | This will test the agent's performance by running episodes 254 | During the runs, the agent should only take deterministic action 255 | This function assumes the agent has a function 256 | :param agent: agent instance 257 | :param test_env: the environment used for testing 258 | :param max_ep_len: max length of an episode 259 | :param logger: logger to store info in 260 | :param n_eval: number of episodes to run the agent 261 | :return: test return for each episode as a numpy array 262 | """ 263 | ep_return_list = np.zeros(n_eval) 264 | for j in range(n_eval): 265 | o, r, d, ep_ret, ep_len = test_env.reset(), 0, False, 0, 0 266 | while not (d or (ep_len == max_ep_len)): 267 | # Take deterministic actions at test time 268 | a = agent.get_test_action(o) 269 | o, r, d, _ = test_env.step(a) 270 | ep_ret += r 271 | ep_len += 1 272 | ep_return_list[j] = ep_ret 273 | if logger is not None: 274 | logger.store(TestEpRet=ep_ret, TestEpLen=ep_len) 275 | return ep_return_list 276 | -------------------------------------------------------------------------------- /redq/algos/redq_sac.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch import Tensor 4 | import torch.nn as nn 5 | import torch.optim as optim 6 | from redq.algos.core import TanhGaussianPolicy, Mlp, soft_update_model1_with_model2, ReplayBuffer,\ 7 | mbpo_target_entropy_dict 8 | 9 | def get_probabilistic_num_min(num_mins): 10 | # allows the number of min to be a float 11 | floored_num_mins = np.floor(num_mins) 12 | if num_mins - floored_num_mins > 0.001: 13 | prob_for_higher_value = num_mins - floored_num_mins 14 | if np.random.uniform(0, 1) < prob_for_higher_value: 15 | return int(floored_num_mins+1) 16 | else: 17 | return int(floored_num_mins) 18 | else: 19 | return num_mins 20 | 21 | class REDQSACAgent(object): 22 | """ 23 | Naive SAC: num_Q = 2, num_min = 2 24 | REDQ: num_Q > 2, num_min = 2 25 | MaxMin: num_mins = num_Qs 26 | for above three variants, set q_target_mode to 'min' (default) 27 | Ensemble Average: set q_target_mode to 'ave' 28 | REM: set q_target_mode to 'rem' 29 | """ 30 | def __init__(self, env_name, obs_dim, act_dim, act_limit, device, 31 | hidden_sizes=(256, 256), replay_size=int(1e6), batch_size=256, 32 | lr=3e-4, gamma=0.99, polyak=0.995, 33 | alpha=0.2, auto_alpha=True, target_entropy='mbpo', 34 | start_steps=5000, delay_update_steps='auto', 35 | utd_ratio=20, num_Q=10, num_min=2, q_target_mode='min', 36 | policy_update_delay=20, 37 | ): 38 | # set up networks 39 | self.policy_net = TanhGaussianPolicy(obs_dim, act_dim, hidden_sizes, action_limit=act_limit).to(device) 40 | self.q_net_list, self.q_target_net_list = [], [] 41 | for q_i in range(num_Q): 42 | new_q_net = Mlp(obs_dim + act_dim, 1, hidden_sizes).to(device) 43 | self.q_net_list.append(new_q_net) 44 | new_q_target_net = Mlp(obs_dim + act_dim, 1, hidden_sizes).to(device) 45 | new_q_target_net.load_state_dict(new_q_net.state_dict()) 46 | self.q_target_net_list.append(new_q_target_net) 47 | # set up optimizers 48 | self.policy_optimizer = optim.Adam(self.policy_net.parameters(), lr=lr) 49 | self.q_optimizer_list = [] 50 | for q_i in range(num_Q): 51 | self.q_optimizer_list.append(optim.Adam(self.q_net_list[q_i].parameters(), lr=lr)) 52 | # set up adaptive entropy (SAC adaptive) 53 | self.auto_alpha = auto_alpha 54 | if auto_alpha: 55 | if target_entropy == 'auto': 56 | self.target_entropy = - act_dim 57 | if target_entropy == 'mbpo': 58 | self.target_entropy = mbpo_target_entropy_dict[env_name] 59 | self.log_alpha = torch.zeros(1, requires_grad=True, device=device) 60 | self.alpha_optim = optim.Adam([self.log_alpha], lr=lr) 61 | self.alpha = self.log_alpha.cpu().exp().item() 62 | else: 63 | self.alpha = alpha 64 | self.target_entropy, self.log_alpha, self.alpha_optim = None, None, None 65 | # set up replay buffer 66 | self.replay_buffer = ReplayBuffer(obs_dim=obs_dim, act_dim=act_dim, size=replay_size) 67 | # set up other things 68 | self.mse_criterion = nn.MSELoss() 69 | 70 | # store other hyperparameters 71 | self.start_steps = start_steps 72 | self.obs_dim = obs_dim 73 | self.act_dim = act_dim 74 | self.act_limit = act_limit 75 | self.lr = lr 76 | self.hidden_sizes = hidden_sizes 77 | self.gamma = gamma 78 | self.polyak = polyak 79 | self.replay_size = replay_size 80 | self.alpha = alpha 81 | self.batch_size = batch_size 82 | self.num_min = num_min 83 | self.num_Q = num_Q 84 | self.utd_ratio = utd_ratio 85 | self.delay_update_steps = self.start_steps if delay_update_steps == 'auto' else delay_update_steps 86 | self.q_target_mode = q_target_mode 87 | self.policy_update_delay = policy_update_delay 88 | self.device = device 89 | 90 | def __get_current_num_data(self): 91 | # used to determine whether we should get action from policy or take random starting actions 92 | return self.replay_buffer.size 93 | 94 | def get_exploration_action(self, obs, env): 95 | # given an observation, output a sampled action in numpy form 96 | with torch.no_grad(): 97 | if self.__get_current_num_data() > self.start_steps: 98 | obs_tensor = torch.Tensor(obs).unsqueeze(0).to(self.device) 99 | action_tensor = self.policy_net.forward(obs_tensor, deterministic=False, 100 | return_log_prob=False)[0] 101 | action = action_tensor.cpu().numpy().reshape(-1) 102 | else: 103 | action = env.action_space.sample() 104 | return action 105 | 106 | def get_test_action(self, obs): 107 | # given an observation, output a deterministic action in numpy form 108 | with torch.no_grad(): 109 | obs_tensor = torch.Tensor(obs).unsqueeze(0).to(self.device) 110 | action_tensor = self.policy_net.forward(obs_tensor, deterministic=True, 111 | return_log_prob=False)[0] 112 | action = action_tensor.cpu().numpy().reshape(-1) 113 | return action 114 | 115 | def get_action_and_logprob_for_bias_evaluation(self, obs): #TODO modify the readme here 116 | # given an observation, output a sampled action in numpy form 117 | with torch.no_grad(): 118 | obs_tensor = torch.Tensor(obs).unsqueeze(0).to(self.device) 119 | action_tensor, _, _, log_prob_a_tilda, _, _, = self.policy_net.forward(obs_tensor, deterministic=False, 120 | return_log_prob=True) 121 | action = action_tensor.cpu().numpy().reshape(-1) 122 | return action, log_prob_a_tilda 123 | 124 | def get_ave_q_prediction_for_bias_evaluation(self, obs_tensor, acts_tensor): 125 | # given obs_tensor and act_tensor, output Q prediction 126 | q_prediction_list = [] 127 | for q_i in range(self.num_Q): 128 | q_prediction = self.q_net_list[q_i](torch.cat([obs_tensor, acts_tensor], 1)) 129 | q_prediction_list.append(q_prediction) 130 | q_prediction_cat = torch.cat(q_prediction_list, dim=1) 131 | average_q_prediction = torch.mean(q_prediction_cat, dim=1) 132 | return average_q_prediction 133 | 134 | def store_data(self, o, a, r, o2, d): 135 | # store one transition to the buffer 136 | self.replay_buffer.store(o, a, r, o2, d) 137 | 138 | def sample_data(self, batch_size): 139 | # sample data from replay buffer 140 | batch = self.replay_buffer.sample_batch(batch_size) 141 | obs_tensor = Tensor(batch['obs1']).to(self.device) 142 | obs_next_tensor = Tensor(batch['obs2']).to(self.device) 143 | acts_tensor = Tensor(batch['acts']).to(self.device) 144 | rews_tensor = Tensor(batch['rews']).unsqueeze(1).to(self.device) 145 | done_tensor = Tensor(batch['done']).unsqueeze(1).to(self.device) 146 | return obs_tensor, obs_next_tensor, acts_tensor, rews_tensor, done_tensor 147 | 148 | def get_redq_q_target_no_grad(self, obs_next_tensor, rews_tensor, done_tensor): 149 | # compute REDQ Q target, depending on the agent's Q target mode 150 | # allow min as a float: 151 | num_mins_to_use = get_probabilistic_num_min(self.num_min) 152 | sample_idxs = np.random.choice(self.num_Q, num_mins_to_use, replace=False) 153 | with torch.no_grad(): 154 | if self.q_target_mode == 'min': 155 | """Q target is min of a subset of Q values""" 156 | a_tilda_next, _, _, log_prob_a_tilda_next, _, _ = self.policy_net.forward(obs_next_tensor) 157 | q_prediction_next_list = [] 158 | for sample_idx in sample_idxs: 159 | q_prediction_next = self.q_target_net_list[sample_idx](torch.cat([obs_next_tensor, a_tilda_next], 1)) 160 | q_prediction_next_list.append(q_prediction_next) 161 | q_prediction_next_cat = torch.cat(q_prediction_next_list, 1) 162 | min_q, min_indices = torch.min(q_prediction_next_cat, dim=1, keepdim=True) 163 | next_q_with_log_prob = min_q - self.alpha * log_prob_a_tilda_next 164 | y_q = rews_tensor + self.gamma * (1 - done_tensor) * next_q_with_log_prob 165 | if self.q_target_mode == 'ave': 166 | """Q target is average of all Q values""" 167 | a_tilda_next, _, _, log_prob_a_tilda_next, _, _ = self.policy_net.forward(obs_next_tensor) 168 | q_prediction_next_list = [] 169 | for q_i in range(self.num_Q): 170 | q_prediction_next = self.q_target_net_list[q_i](torch.cat([obs_next_tensor, a_tilda_next], 1)) 171 | q_prediction_next_list.append(q_prediction_next) 172 | q_prediction_next_ave = torch.cat(q_prediction_next_list, 1).mean(dim=1).reshape(-1, 1) 173 | next_q_with_log_prob = q_prediction_next_ave - self.alpha * log_prob_a_tilda_next 174 | y_q = rews_tensor + self.gamma * (1 - done_tensor) * next_q_with_log_prob 175 | if self.q_target_mode == 'rem': 176 | """Q target is random ensemble mixture of Q values""" 177 | a_tilda_next, _, _, log_prob_a_tilda_next, _, _ = self.policy_net.forward(obs_next_tensor) 178 | q_prediction_next_list = [] 179 | for q_i in range(self.num_Q): 180 | q_prediction_next = self.q_target_net_list[q_i](torch.cat([obs_next_tensor, a_tilda_next], 1)) 181 | q_prediction_next_list.append(q_prediction_next) 182 | # apply rem here 183 | q_prediction_next_cat = torch.cat(q_prediction_next_list, 1) 184 | rem_weight = Tensor(np.random.uniform(0, 1, q_prediction_next_cat.shape)).to(device=self.device) 185 | normalize_sum = rem_weight.sum(1).reshape(-1, 1).expand(-1, self.num_Q) 186 | rem_weight = rem_weight / normalize_sum 187 | q_prediction_next_rem = (q_prediction_next_cat * rem_weight).sum(dim=1).reshape(-1, 1) 188 | next_q_with_log_prob = q_prediction_next_rem - self.alpha * log_prob_a_tilda_next 189 | y_q = rews_tensor + self.gamma * (1 - done_tensor) * next_q_with_log_prob 190 | return y_q, sample_idxs 191 | 192 | def train(self, logger): 193 | # this function is called after each datapoint collected. 194 | # when we only have very limited data, we don't make updates 195 | num_update = 0 if self.__get_current_num_data() <= self.delay_update_steps else self.utd_ratio 196 | for i_update in range(num_update): 197 | obs_tensor, obs_next_tensor, acts_tensor, rews_tensor, done_tensor = self.sample_data(self.batch_size) 198 | 199 | """Q loss""" 200 | y_q, sample_idxs = self.get_redq_q_target_no_grad(obs_next_tensor, rews_tensor, done_tensor) 201 | q_prediction_list = [] 202 | for q_i in range(self.num_Q): 203 | q_prediction = self.q_net_list[q_i](torch.cat([obs_tensor, acts_tensor], 1)) 204 | q_prediction_list.append(q_prediction) 205 | q_prediction_cat = torch.cat(q_prediction_list, dim=1) 206 | y_q = y_q.expand((-1, self.num_Q)) if y_q.shape[1] == 1 else y_q 207 | q_loss_all = self.mse_criterion(q_prediction_cat, y_q) * self.num_Q 208 | 209 | for q_i in range(self.num_Q): 210 | self.q_optimizer_list[q_i].zero_grad() 211 | q_loss_all.backward() 212 | 213 | """policy and alpha loss""" 214 | if ((i_update + 1) % self.policy_update_delay == 0) or i_update == num_update - 1: 215 | # get policy loss 216 | a_tilda, mean_a_tilda, log_std_a_tilda, log_prob_a_tilda, _, pretanh = self.policy_net.forward(obs_tensor) 217 | q_a_tilda_list = [] 218 | for sample_idx in range(self.num_Q): 219 | self.q_net_list[sample_idx].requires_grad_(False) 220 | q_a_tilda = self.q_net_list[sample_idx](torch.cat([obs_tensor, a_tilda], 1)) 221 | q_a_tilda_list.append(q_a_tilda) 222 | q_a_tilda_cat = torch.cat(q_a_tilda_list, 1) 223 | ave_q = torch.mean(q_a_tilda_cat, dim=1, keepdim=True) 224 | policy_loss = (self.alpha * log_prob_a_tilda - ave_q).mean() 225 | self.policy_optimizer.zero_grad() 226 | policy_loss.backward() 227 | for sample_idx in range(self.num_Q): 228 | self.q_net_list[sample_idx].requires_grad_(True) 229 | 230 | # get alpha loss 231 | if self.auto_alpha: 232 | alpha_loss = -(self.log_alpha * (log_prob_a_tilda + self.target_entropy).detach()).mean() 233 | self.alpha_optim.zero_grad() 234 | alpha_loss.backward() 235 | self.alpha_optim.step() 236 | self.alpha = self.log_alpha.cpu().exp().item() 237 | else: 238 | alpha_loss = Tensor([0]) 239 | 240 | """update networks""" 241 | for q_i in range(self.num_Q): 242 | self.q_optimizer_list[q_i].step() 243 | 244 | if ((i_update + 1) % self.policy_update_delay == 0) or i_update == num_update - 1: 245 | self.policy_optimizer.step() 246 | 247 | # polyak averaged Q target networks 248 | for q_i in range(self.num_Q): 249 | soft_update_model1_with_model2(self.q_target_net_list[q_i], self.q_net_list[q_i], self.polyak) 250 | 251 | # by default only log for the last update out of updates 252 | if i_update == num_update - 1: 253 | logger.store(LossPi=policy_loss.cpu().item(), LossQ1=q_loss_all.cpu().item() / self.num_Q, 254 | LossAlpha=alpha_loss.cpu().item(), Q1Vals=q_prediction.detach().cpu().numpy(), 255 | Alpha=self.alpha, LogPi=log_prob_a_tilda.detach().cpu().numpy(), 256 | PreTanh=pretanh.abs().detach().cpu().numpy().reshape(-1)) 257 | 258 | # if there is no update, log 0 to prevent logging problems 259 | if num_update == 0: 260 | logger.store(LossPi=0, LossQ1=0, LossAlpha=0, Q1Vals=0, Alpha=0, LogPi=0, PreTanh=0) 261 | 262 | -------------------------------------------------------------------------------- /redq/user_config.py: -------------------------------------------------------------------------------- 1 | """ 2 | Modified from OpenAI spinup code 3 | """ 4 | import os.path as osp 5 | 6 | # Where experiment outputs are saved by default: 7 | DEFAULT_DATA_DIR = osp.join(osp.abspath(osp.dirname(osp.dirname(__file__))),'data') 8 | 9 | # Whether to automatically insert a date and time stamp into the names of 10 | # save directories: 11 | FORCE_DATESTAMP = False 12 | 13 | # Whether GridSearch provides automatically-generated default shorthands: 14 | DEFAULT_SHORTHAND = True 15 | 16 | # Tells the GridSearch how many seconds to pause for before launching 17 | # experiments. 18 | WAIT_BEFORE_LAUNCH = 5 -------------------------------------------------------------------------------- /redq/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/watchernyu/REDQ/7b5d1bff39291a57325a2836bd397a55728960bb/redq/utils/__init__.py -------------------------------------------------------------------------------- /redq/utils/bias_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch import Tensor 4 | 5 | def get_mc_return_with_entropy_on_reset(bias_eval_env, agent, max_ep_len, alpha, gamma, n_mc_eval, n_mc_cutoff): 6 | # since we want to also compute bias, so we need to 7 | final_mc_list = np.zeros(0) 8 | final_mc_entropy_list = np.zeros(0) 9 | final_obs_list = [] 10 | final_act_list = [] 11 | while final_mc_list.shape[0] < n_mc_eval: 12 | # we continue if haven't collected enough data 13 | o = bias_eval_env.reset() 14 | # temporary lists 15 | reward_list, log_prob_a_tilda_list, obs_list, act_list = [], [], [], [] 16 | r, d, ep_ret, ep_len = 0, False, 0, 0 17 | discounted_return = 0 18 | discounted_return_with_entropy = 0 19 | for i_step in range(max_ep_len): # run an episode 20 | with torch.no_grad(): 21 | a, log_prob_a_tilda = agent.get_action_and_logprob_for_bias_evaluation(o) 22 | obs_list.append(o) 23 | act_list.append(a) 24 | o, r, d, _ = bias_eval_env.step(a) 25 | ep_ret += r 26 | ep_len += 1 27 | reward_list.append(r) 28 | log_prob_a_tilda_list.append(log_prob_a_tilda.item()) 29 | if d or (ep_len == max_ep_len): 30 | break 31 | discounted_return_list = np.zeros(ep_len) 32 | discounted_return_with_entropy_list = np.zeros(ep_len) 33 | for i_step in range(ep_len - 1, -1, -1): 34 | # backwards compute discounted return and with entropy for all s-a visited 35 | if i_step == ep_len - 1: 36 | discounted_return_list[i_step] = reward_list[i_step] 37 | discounted_return_with_entropy_list[i_step] = reward_list[i_step] 38 | else: 39 | discounted_return_list[i_step] = reward_list[i_step] + gamma * discounted_return_list[i_step + 1] 40 | discounted_return_with_entropy_list[i_step] = reward_list[i_step] + \ 41 | gamma * (discounted_return_with_entropy_list[i_step + 1] - alpha * log_prob_a_tilda_list[i_step + 1]) 42 | # now we take the first few of these. 43 | final_mc_list = np.concatenate((final_mc_list, discounted_return_list[:n_mc_cutoff])) 44 | final_mc_entropy_list = np.concatenate( 45 | (final_mc_entropy_list, discounted_return_with_entropy_list[:n_mc_cutoff])) 46 | final_obs_list += obs_list[:n_mc_cutoff] 47 | final_act_list += act_list[:n_mc_cutoff] 48 | return final_mc_list, final_mc_entropy_list, np.array(final_obs_list), np.array(final_act_list) 49 | 50 | def log_bias_evaluation(bias_eval_env, agent, logger, max_ep_len, alpha, gamma, n_mc_eval, n_mc_cutoff): 51 | final_mc_list, final_mc_entropy_list, final_obs_list, final_act_list = get_mc_return_with_entropy_on_reset(bias_eval_env, agent, max_ep_len, alpha, gamma, n_mc_eval, n_mc_cutoff) 52 | logger.store(MCDisRet=final_mc_list) 53 | logger.store(MCDisRetEnt=final_mc_entropy_list) 54 | obs_tensor = Tensor(final_obs_list).to(agent.device) 55 | acts_tensor = Tensor(final_act_list).to(agent.device) 56 | with torch.no_grad(): 57 | q_prediction = agent.get_ave_q_prediction_for_bias_evaluation(obs_tensor, acts_tensor).cpu().numpy().reshape(-1) 58 | bias = q_prediction - final_mc_entropy_list 59 | bias_abs = np.abs(bias) 60 | bias_squared = bias ** 2 61 | logger.store(QPred=q_prediction) 62 | logger.store(QBias=bias) 63 | logger.store(QBiasAbs=bias_abs) 64 | logger.store(QBiasSqr=bias_squared) 65 | final_mc_entropy_list_normalize_base = final_mc_entropy_list.copy() 66 | final_mc_entropy_list_normalize_base = np.abs(final_mc_entropy_list_normalize_base) 67 | final_mc_entropy_list_normalize_base[final_mc_entropy_list_normalize_base < 10] = 10 68 | normalized_bias_per_state = bias / final_mc_entropy_list_normalize_base 69 | logger.store(NormQBias=normalized_bias_per_state) 70 | normalized_bias_sqr_per_state = bias_squared / final_mc_entropy_list_normalize_base 71 | logger.store(NormQBiasSqr=normalized_bias_sqr_per_state) 72 | -------------------------------------------------------------------------------- /redq/utils/logx.py: -------------------------------------------------------------------------------- 1 | """ 2 | Some simple logging functionality, inspired by rllab's logging. 3 | Modified from OpenAI spinup logx code 4 | Logs to a tab-separated-values file (path/to/output_directory/progress.txt) 5 | """ 6 | import json 7 | import joblib 8 | import numpy as np 9 | import os.path as osp, time, atexit, os 10 | from redq.utils.serialization_utils import convert_json 11 | 12 | color2num = dict( 13 | gray=30, 14 | red=31, 15 | green=32, 16 | yellow=33, 17 | blue=34, 18 | magenta=35, 19 | cyan=36, 20 | white=37, 21 | crimson=38 22 | ) 23 | 24 | def colorize(string, color, bold=False, highlight=False): 25 | """ 26 | Colorize a string. 27 | 28 | This function was originally written by John Schulman. 29 | """ 30 | attr = [] 31 | num = color2num[color] 32 | if highlight: num += 10 33 | attr.append(str(num)) 34 | if bold: attr.append('1') 35 | return '\x1b[%sm%s\x1b[0m' % (';'.join(attr), string) 36 | 37 | class Logger: 38 | """ 39 | A general-purpose logger. 40 | 41 | Makes it easy to save diagnostics, hyperparameter configurations, the 42 | state of a training run, and the trained model. 43 | """ 44 | 45 | def __init__(self, output_dir=None, output_fname='progress.txt', exp_name=None): 46 | """ 47 | Initialize a Logger. 48 | 49 | Args: 50 | output_dir (string): A directory for saving results to. If 51 | ``None``, defaults to a temp directory of the form 52 | ``/tmp/experiments/somerandomnumber``. 53 | 54 | output_fname (string): Name for the tab-separated-value file 55 | containing metrics logged throughout a training run. 56 | Defaults to ``progress.txt``. 57 | 58 | exp_name (string): Experiment name. If you run multiple training 59 | runs and give them all the same ``exp_name``, the plotter 60 | will know to group them. (Use case: if you run the same 61 | hyperparameter configuration with multiple random seeds, you 62 | should give them all the same ``exp_name``.) 63 | """ 64 | self.output_dir = output_dir or "/tmp/experiments/%i"%int(time.time()) 65 | if osp.exists(self.output_dir): 66 | print("Warning: Log dir %s already exists! Storing info there anyway."%self.output_dir) 67 | else: 68 | os.makedirs(self.output_dir) 69 | self.output_file = open(osp.join(self.output_dir, output_fname), 'w') 70 | atexit.register(self.output_file.close) 71 | print(colorize("Logging data to %s"%self.output_file.name, 'green', bold=True)) 72 | self.first_row=True 73 | self.log_headers = [] 74 | self.log_current_row = {} 75 | self.exp_name = exp_name 76 | 77 | def log(self, msg, color='green'): 78 | """Print a colorized message to stdout.""" 79 | print(colorize(msg, color, bold=True)) 80 | 81 | def log_tabular(self, key, val): 82 | """ 83 | Log a value of some diagnostic. 84 | 85 | Call this only once for each diagnostic quantity, each iteration. 86 | After using ``log_tabular`` to store values for each diagnostic, 87 | make sure to call ``dump_tabular`` to write them out to file and 88 | stdout (otherwise they will not get saved anywhere). 89 | """ 90 | if self.first_row: 91 | self.log_headers.append(key) 92 | else: 93 | assert key in self.log_headers, "Trying to introduce a new key %s that you didn't include in the first iteration"%key 94 | assert key not in self.log_current_row, "You already set %s this iteration. Maybe you forgot to call dump_tabular()"%key 95 | self.log_current_row[key] = val 96 | 97 | def save_config(self, config): 98 | """ 99 | Log an experiment configuration. 100 | 101 | Call this once at the top of your experiment, passing in all important 102 | config vars as a dict. This will serialize the config to JSON, while 103 | handling anything which can't be serialized in a graceful way (writing 104 | as informative a string as possible). 105 | 106 | Example use: 107 | 108 | .. code-block:: python 109 | 110 | logger = EpochLogger(**logger_kwargs) 111 | logger.save_config(locals()) 112 | """ 113 | config_json = convert_json(config) 114 | if self.exp_name is not None: 115 | config_json['exp_name'] = self.exp_name 116 | output = json.dumps(config_json, separators=(',',':\t'), indent=4, sort_keys=True) 117 | print(colorize('Saving config:\n', color='cyan', bold=True)) 118 | print(output) 119 | with open(osp.join(self.output_dir, "config.json"), 'w') as out: 120 | out.write(output) 121 | 122 | def save_state(self, state_dict, itr=None): 123 | """ 124 | Saves the state of an experiment. 125 | 126 | To be clear: this is about saving *state*, not logging diagnostics. 127 | All diagnostic logging is separate from this function. This function 128 | will save whatever is in ``state_dict``---usually just a copy of the 129 | environment---and the most recent parameters for the model you 130 | previously set up saving for with ``setup_tf_saver``. 131 | 132 | Call with any frequency you prefer. If you only want to maintain a 133 | single state and overwrite it at each call with the most recent 134 | version, leave ``itr=None``. If you want to keep all of the states you 135 | save, provide unique (increasing) values for 'itr'. 136 | 137 | Args: 138 | state_dict (dict): Dictionary containing essential elements to 139 | describe the current state of training. 140 | 141 | itr: An int, or None. Current iteration of training. 142 | """ 143 | fname = 'vars.pkl' if itr is None else 'vars%d.pkl'%itr 144 | try: 145 | joblib.dump(state_dict, osp.join(self.output_dir, fname)) 146 | except: 147 | self.log('Warning: could not pickle state_dict.', color='red') 148 | 149 | def dump_tabular(self): 150 | """ 151 | Write all of the diagnostics from the current iteration. 152 | 153 | Writes both to stdout, and to the output file. 154 | """ 155 | vals = [] 156 | key_lens = [len(key) for key in self.log_headers] 157 | max_key_len = max(15,max(key_lens)) 158 | keystr = '%'+'%d'%max_key_len 159 | fmt = "| " + keystr + "s | %15s |" 160 | n_slashes = 22 + max_key_len 161 | print("-"*n_slashes) 162 | for key in self.log_headers: 163 | val = self.log_current_row.get(key, "") 164 | valstr = "%8.3g"%val if hasattr(val, "__float__") else val 165 | print(fmt%(key, valstr)) 166 | vals.append(val) 167 | print("-"*n_slashes) 168 | if self.output_file is not None: 169 | if self.first_row: 170 | self.output_file.write("\t".join(self.log_headers)+"\n") 171 | self.output_file.write("\t".join(map(str,vals))+"\n") 172 | self.output_file.flush() 173 | self.log_current_row.clear() 174 | self.first_row=False 175 | 176 | def get_statistics_scalar(x, with_min_and_max=False): 177 | """ 178 | Get mean/std and optional min/max of x 179 | 180 | Args: 181 | x: An array containing samples of the scalar to produce statistics 182 | for. 183 | 184 | with_min_and_max (bool): If true, return min and max of x in 185 | addition to mean and std. 186 | """ 187 | x = np.array(x, dtype=np.float32) 188 | mean, std = x.mean(), x.std() 189 | if with_min_and_max: 190 | min_v = x.min() 191 | max_v = x.max() 192 | return mean, std, min_v, max_v 193 | return mean, std 194 | 195 | class EpochLogger(Logger): 196 | """ 197 | A variant of Logger tailored for tracking average values over epochs. 198 | 199 | Typical use case: there is some quantity which is calculated many times 200 | throughout an epoch, and at the end of the epoch, you would like to 201 | report the average / std / min / max value of that quantity. 202 | 203 | With an EpochLogger, each time the quantity is calculated, you would 204 | use 205 | 206 | .. code-block:: python 207 | 208 | epoch_logger.store(NameOfQuantity=quantity_value) 209 | 210 | to load it into the EpochLogger's state. Then at the end of the epoch, you 211 | would use 212 | 213 | .. code-block:: python 214 | 215 | epoch_logger.log_tabular(NameOfQuantity, **options) 216 | 217 | to record the desired values. 218 | """ 219 | 220 | def __init__(self, *args, **kwargs): 221 | super().__init__(*args, **kwargs) 222 | self.epoch_dict = dict() 223 | 224 | def store(self, **kwargs): 225 | """ 226 | Save something into the epoch_logger's current state. 227 | 228 | Provide an arbitrary number of keyword arguments with numerical 229 | values. 230 | 231 | To prevent problems, let value be either a numpy array, or a single scalar value 232 | """ 233 | for k,v in kwargs.items(): 234 | if not k in self.epoch_dict: 235 | self.epoch_dict[k] = [] 236 | if isinstance(v, np.ndarray): # used to prevent problems due to shape issues 237 | v = v.reshape(-1) 238 | self.epoch_dict[k].append(v) 239 | 240 | def log_tabular(self, key, val=None, with_min_and_max=False, average_only=False): 241 | """ 242 | Log a value or possibly the mean/std/min/max values of a diagnostic. 243 | 244 | Args: 245 | key (string): The name of the diagnostic. If you are logging a 246 | diagnostic whose state has previously been saved with 247 | ``store``, the key here has to match the key you used there. 248 | 249 | val: A value for the diagnostic. If you have previously saved 250 | values for this key via ``store``, do *not* provide a ``val`` 251 | here. 252 | 253 | with_min_and_max (bool): If true, log min and max values of the 254 | diagnostic over the epoch. 255 | 256 | average_only (bool): If true, do not log the standard deviation 257 | of the diagnostic over the epoch. 258 | """ 259 | if val is not None: 260 | super().log_tabular(key,val) 261 | else: 262 | v = self.epoch_dict[key] 263 | vals = np.concatenate(v) if isinstance(v[0], np.ndarray) and len(v[0].shape)>0 else v 264 | stats = get_statistics_scalar(vals, with_min_and_max=with_min_and_max) 265 | super().log_tabular(key if average_only else 'Average' + key, stats[0]) 266 | if not(average_only): 267 | super().log_tabular('Std'+key, stats[1]) 268 | if with_min_and_max: 269 | super().log_tabular('Max'+key, stats[3]) 270 | super().log_tabular('Min'+key, stats[2]) 271 | self.epoch_dict[key] = [] 272 | 273 | def get_stats(self, key): 274 | """ 275 | Lets an algorithm ask the logger for mean/std/min/max of a diagnostic. 276 | """ 277 | v = self.epoch_dict[key] 278 | vals = np.concatenate(v) if isinstance(v[0], np.ndarray) and len(v[0].shape)>0 else v 279 | return get_statistics_scalar(vals, with_min_and_max=True) -------------------------------------------------------------------------------- /redq/utils/run_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Modified from OpenAI spinup code 3 | """ 4 | from redq.user_config import DEFAULT_DATA_DIR, FORCE_DATESTAMP 5 | import os.path as osp 6 | import time 7 | 8 | DIV_LINE_WIDTH = 80 9 | 10 | def setup_logger_kwargs(exp_name, seed=None, data_dir=None, datestamp=False): 11 | """ 12 | Sets up the output_dir for a logger and returns a dict for logger kwargs. 13 | 14 | If no seed is given and datestamp is false, 15 | 16 | :: 17 | 18 | output_dir = data_dir/exp_name 19 | 20 | If a seed is given and datestamp is false, 21 | 22 | :: 23 | 24 | output_dir = data_dir/exp_name/exp_name_s[seed] 25 | 26 | If datestamp is true, amend to 27 | 28 | :: 29 | 30 | output_dir = data_dir/YY-MM-DD_exp_name/YY-MM-DD_HH-MM-SS_exp_name_s[seed] 31 | 32 | You can force datestamp=True by setting ``FORCE_DATESTAMP=True`` in 33 | ``spinup/user_config.py``. 34 | 35 | Args: 36 | 37 | exp_name (string): Name for experiment. 38 | 39 | seed (int): Seed for random number generators used by experiment. 40 | 41 | data_dir (string): Path to folder where results should be saved. 42 | Default is the ``DEFAULT_DATA_DIR`` in ``spinup/user_config.py``. 43 | 44 | datestamp (bool): Whether to include a date and timestamp in the 45 | name of the save directory. 46 | 47 | Returns: 48 | 49 | logger_kwargs, a dict containing output_dir and exp_name. 50 | """ 51 | 52 | # Datestamp forcing 53 | datestamp = datestamp or FORCE_DATESTAMP 54 | 55 | # Make base path 56 | ymd_time = time.strftime("%Y-%m-%d_") if datestamp else '' 57 | relpath = ''.join([ymd_time, exp_name]) 58 | 59 | if seed is not None: 60 | # Make a seed-specific subfolder in the experiment directory. 61 | if datestamp: 62 | hms_time = time.strftime("%Y-%m-%d_%H-%M-%S") 63 | subfolder = ''.join([hms_time, '-', exp_name, '_s', str(seed)]) 64 | else: 65 | subfolder = ''.join([exp_name, '_s', str(seed)]) 66 | relpath = osp.join(relpath, subfolder) 67 | 68 | data_dir = data_dir or DEFAULT_DATA_DIR 69 | logger_kwargs = dict(output_dir=osp.join(data_dir, relpath), 70 | exp_name=exp_name) 71 | return logger_kwargs 72 | -------------------------------------------------------------------------------- /redq/utils/serialization_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Modified from OpenAI spinup code 3 | """ 4 | import json 5 | 6 | def convert_json(obj): 7 | """ Convert obj to a version which can be serialized with JSON. """ 8 | if is_json_serializable(obj): 9 | return obj 10 | else: 11 | if isinstance(obj, dict): 12 | return {convert_json(k): convert_json(v) 13 | for k,v in obj.items()} 14 | 15 | elif isinstance(obj, tuple): 16 | return (convert_json(x) for x in obj) 17 | 18 | elif isinstance(obj, list): 19 | return [convert_json(x) for x in obj] 20 | 21 | elif hasattr(obj,'__name__') and not('lambda' in obj.__name__): 22 | return convert_json(obj.__name__) 23 | 24 | elif hasattr(obj,'__dict__') and obj.__dict__: 25 | obj_dict = {convert_json(k): convert_json(v) 26 | for k,v in obj.__dict__.items()} 27 | return {str(obj): obj_dict} 28 | 29 | return str(obj) 30 | 31 | def is_json_serializable(v): 32 | try: 33 | json.dumps(v) 34 | return True 35 | except: 36 | return False -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from os.path import join, dirname, realpath 2 | from setuptools import setup 3 | import sys 4 | 5 | assert sys.version_info.major == 3 and sys.version_info.minor >= 6, \ 6 | "Require Python 3.6 or greater." 7 | 8 | setup( 9 | name='redq', 10 | py_modules=['redq'], 11 | version='0.0.1', 12 | install_requires=[ 13 | 'numpy', 14 | 'joblib', 15 | 'gym>=0.17.2' 16 | ], 17 | description="REDQ algorithm PyTorch implementation", 18 | author="Xinyue Chen, Che Wang, Zijian Zhou, Keith Ross", 19 | ) 20 | --------------------------------------------------------------------------------