├── .gitignore
├── LICENSE
├── README.md
├── conda_env.yml
├── docker-gym-mujocov2
├── Dockerfile
└── README.md
├── docker-gym-mujocov4
├── Dockerfile
└── README.md
├── experiments
├── grid_sample_hpc_script.sh
├── grid_train_redq_sac.py
├── grid_utils.py
├── sample_hpc_script.sh
└── train_redq_sac.py
├── mujoco_download.sh
├── plot_utils
├── plot_REDQ.py
└── redq_plot_helper.py
├── redq
├── __init__.py
├── algos
│ ├── __init__.py
│ ├── core.py
│ └── redq_sac.py
├── user_config.py
└── utils
│ ├── __init__.py
│ ├── bias_utils.py
│ ├── logx.py
│ ├── run_utils.py
│ └── serialization_utils.py
└── setup.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # ignore data and extra python libraries
2 | data/
3 | extra-python-libs/
4 |
5 | # ignore experiment files that are not useful to track
6 | hpc_experiments/
7 |
8 | .idea/
9 |
10 | # Byte-compiled / optimized / DLL files
11 | __pycache__/
12 | *.py[cod]
13 | *$py.class
14 |
15 | # C extensions
16 | *.so
17 |
18 | # Distribution / packaging
19 | .Python
20 | build/
21 | develop-eggs/
22 | dist/
23 | downloads/
24 | eggs/
25 | .eggs/
26 | lib/
27 | lib64/
28 | parts/
29 | sdist/
30 | var/
31 | wheels/
32 | pip-wheel-metadata/
33 | share/python-wheels/
34 | *.egg-info/
35 | .installed.cfg
36 | *.egg
37 | MANIFEST
38 |
39 | # PyInstaller
40 | # Usually these files are written by a python script from a template
41 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
42 | *.manifest
43 | *.spec
44 |
45 | # Installer logs
46 | pip-log.txt
47 | pip-delete-this-directory.txt
48 |
49 | # Unit test / coverage reports
50 | htmlcov/
51 | .tox/
52 | .nox/
53 | .coverage
54 | .coverage.*
55 | .cache
56 | nosetests.xml
57 | coverage.xml
58 | *.cover
59 | *.py,cover
60 | .hypothesis/
61 | .pytest_cache/
62 |
63 | # Translations
64 | *.mo
65 | *.pot
66 |
67 | # Django stuff:
68 | *.log
69 | local_settings.py
70 | db.sqlite3
71 | db.sqlite3-journal
72 |
73 | # Flask stuff:
74 | instance/
75 | .webassets-cache
76 |
77 | # Scrapy stuff:
78 | .scrapy
79 |
80 | # Sphinx documentation
81 | docs/_build/
82 |
83 | # PyBuilder
84 | target/
85 |
86 | # Jupyter Notebook
87 | .ipynb_checkpoints
88 |
89 | # IPython
90 | profile_default/
91 | ipython_config.py
92 |
93 | # pyenv
94 | .python-version
95 |
96 | # pipenv
97 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
98 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
99 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
100 | # install all needed dependencies.
101 | #Pipfile.lock
102 |
103 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
104 | __pypackages__/
105 |
106 | # Celery stuff
107 | celerybeat-schedule
108 | celerybeat.pid
109 |
110 | # SageMath parsed files
111 | *.sage.py
112 |
113 | # Environments
114 | .env
115 | .venv
116 | env/
117 | venv/
118 | ENV/
119 | env.bak/
120 | venv.bak/
121 |
122 | # Spyder project settings
123 | .spyderproject
124 | .spyproject
125 |
126 | # Rope project settings
127 | .ropeproject
128 |
129 | # mkdocs documentation
130 | /site
131 |
132 | # mypy
133 | .mypy_cache/
134 | .dmypy.json
135 | dmypy.json
136 |
137 | # Pyre type checker
138 | .pyre/
139 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2021 watchernyu
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # REDQ source code
2 | Author's PyTorch implementation of Randomized Ensembled Double Q-Learning (REDQ) algorithm. Paper link: https://arxiv.org/abs/2101.05982
3 |
4 |
5 |
6 | ## Table of Contents
7 | - [Table of contents](#table-of-contents)
8 | - [Code structure explained](#code-structure)
9 | - [Implementation video tutorial](#video-tutorial)
10 | - [Data and reproducing figures in REDQ](#reproduce-figures)
11 | - [Train an REDQ agent](#train-redq)
12 | - [Implement REDQ](#implement-redq)
13 | - [Reproduce the results](#reproduce-results)
14 | - [Newest Docker + Singularity setup for Gym + MuJoCo v2 and v4](#setup-dockersing)
15 | - [(outdated) Environment setup MuJoCo 2.1, v4 tasks, NYU HPC 18.04](#setup-nyuhpc-new)
16 | - [(outdated) Environment setup MuJoCo 2.1, Ubuntu 18.04](#setup-ubuntu)
17 | - [(outdated) Environment setup MuJoCo 2.1, NYU Shanghai HPC](#setup-nyuhpc)
18 | - [(outdated) Environment setup](#setup-old)
19 | - [Acknowledgement](#acknowledgement)
20 | - [Previous updates](#prevupdates)
21 |
22 | Feb 20, 2023: An updated [docker + singularity setup](#setup-dockersing) is now available. This is probably the easiest set up ever and allows you to use docker to start running your DRL experiments with just 3 commands. We have also released new dockerfiles for gym + mujoco v2 and v4 environments (in the newest version of the repo, you will see 2 folders (`docker-gym-mujocov2`, `docker-gym-mujocov4`) each containing a dockerfile).
23 |
24 |
25 |
26 | ## Code structure explained
27 | The code structure is pretty simple and should be easy to follow.
28 |
29 | In `experiments/train_redq_sac.py` you will find the main training loop. Here we set up the environment, initialize an instance of the `REDQSACAgent` class, specifying all the hyperparameters and train the agent. You can run this file to train a REDQ agent.
30 |
31 | In `redq/algos/redq_sac.py` we provide code for the `REDQSACAgent` class. If you are trying to take a look at how the core components of REDQ are implemented, the most important function is the `train()` function.
32 |
33 | In `redq/algos/core.py` we provide code for some basic classes (Q network, policy network, replay buffer) and some helper functions. These classes and functions are used by the REDQ agent class.
34 |
35 | In `redq/utils` there are some utility classes (such as a logger) and helper functions that largely have nothing to do with REDQ's core components. In `redq/utils/bias_utils.py` you can find utility functions to get bias estimation (bias estimate is computed roughly as: Monte Carlo return - current Q estimate). In `experiments/train_redq_sac.py` you can decide whether you want bias evaluation when running the experiment by setting the `evaluate_bias` flag (this will lead to some minor computation overhead).
36 |
37 | In `plot_utils` there are some utility functions to reproduce the figures we presented in the paper. (See the section on "Data and reproducing figures in REDQ")
38 |
39 |
40 |
41 |
42 | ## Implementation video tutorial
43 | Here is the link to a video tutorial we created that explains the REDQ implementation in detail:
44 |
45 | [REDQ code explained video tutorial (Google Drive Link)](https://drive.google.com/file/d/1QabNOz8VDyKhI0K7tPbIkmx4LKw6N6Te/view?usp=sharing)
46 |
47 |
48 |
49 | ## Data and reproducing figures in REDQ
50 | The data used to produce the figures in the REDQ paper can be downloaded here:
51 | [REDQ DATA download link](https://drive.google.com/file/d/11mjDYCzp3T1MaICGruySrWd-akr-SVJp/view?usp=sharing) (Google Drive Link, ~80 MB)
52 |
53 | To reproduce the figures, first download the data, and then extract the zip file to `REDQ/data`. So now a folder called `REDQ_ICLR21` should be at this path: `REDQ/data/REDQ_ICLR21`.
54 |
55 | Then you can go into the `plot_utils` folder, and run the `plot_REDQ.py` program there. You will need `seaborn==0.8.1` to run it correctly. We might update the code later so that it works for newer versions but currently seaborn newer than 0.8.1 is not supported. If you don't want to mess up existing conda or python virtual environments, you can create a new environment and simply install seaborn 0.8.1 there and use it to run the program.
56 |
57 | If you encounter any problem or cannot access the data (can't use google or can't download), please open an issue to let us know! Thanks!
58 |
59 |
60 |
61 | ## Environment setup (old guide, for the newest guide, see end of this page)
62 |
63 | **VERY IMPORTANT**: because MuJoCo is now free, the setup guide here is slightly outdated (this is the setup we used when we run our experiments for the REDQ paper), we now provide a newer updated setup guide that uses the newest MuJoCo, please see the end of the this page.
64 |
65 | Note: you don't need to exactly follow the tutorial here if you know well about how to install python packages.
66 |
67 | First create a conda environment and activate it:
68 | ```
69 | conda create -n redq python=3.6
70 | conda activate redq
71 | ```
72 |
73 | Install PyTorch (or you can follow the tutorial on PyTorch official website).
74 | On Ubuntu (might also work on Windows but is not fully tested):
75 | ```
76 | conda install pytorch==1.3.1 torchvision==0.4.2 cudatoolkit=10.1 -c pytorch
77 | ```
78 | On OSX:
79 | ```
80 | conda install pytorch==1.3.1 torchvision==0.4.2 -c pytorch
81 | ```
82 |
83 | Install gym (0.17.2):
84 | ```
85 | git clone https://github.com/openai/gym.git
86 | cd gym
87 | git checkout b2727d6
88 | pip install -e .
89 | cd ..
90 | ```
91 |
92 | Install mujoco_py (2.0.2.1):
93 | ```
94 | git clone https://github.com/openai/mujoco-py
95 | cd mujoco-py
96 | git checkout 379bb19
97 | pip install -e . --no-cache
98 | cd ..
99 | ```
100 |
101 | For gym and mujoco_py, depending on your system, you might need to install some other packages, if you run into such problems, please refer to their official sites for guidance.
102 | If you want to test on Mujoco environments, you will also need to get Mujoco files and license from Mujoco website. Please refer to the Mujoco website for how to do this correctly.
103 |
104 | Clone and install this repository (Although even if you don't install it you might still be able to use the code):
105 | ```
106 | git clone https://github.com/watchernyu/REDQ.git
107 | cd REDQ
108 | pip install -e .
109 | ```
110 |
111 |
112 |
113 | ## Train an REDQ agent
114 | To train an REDQ agent, run:
115 | ```
116 | python experiments/train_redq_sac.py
117 | ```
118 | On a 2080Ti GPU, running Hopper to 125K will approximately take 10-12 hours. Running Humanoid to 300K will approximately take 26 hours.
119 |
120 |
121 |
122 | ## Implement REDQ
123 | As discussed in the paper, we obtain REDQ by making minimal changes to a Soft Actor-Critic (SAC) baseline. You can easily modify your SAC code to get REDQ: (a) use an update-to-data (UTD) ratio > 1, (b) have > 2 Q networks, (c) when computing the Q target, randomly select a subset of Q target networks, take their min.
124 |
125 | If you intend to implement REDQ on your codebase, please refer to the paper and the [video tutorial](#video-tutorial) for guidance. In particular, in Appendix B of the paper, we discussed hyperparameters and some additional implementation details. One important detail is in the beginning of the training, for the first 5000 data points, we sample random action from the action space and do not perform any updates. If you perform a large number of updates with a very small amount of data, it can lead to severe bias accumulation and can negatively affect the performance.
126 |
127 | For REDQ-OFE, as mentioned in the paper, for some reason adding PyTorch batch norm to OFENet will lead to divergence. So in the end we did not use batch norm in our code.
128 |
129 |
130 |
131 | ## Reproduce the results
132 | If you use a different PyTorch version, it might still work, however, it might be better if your version is close to the ones we used. We have found that for example, on Ant environment, PyTorch 1.3 and 1.2 give quite different results. The reason is not entirely clear.
133 |
134 | Other factors such as versions of other packages (for example numpy) or environment (mujoco/gym) or even types of hardware (cpu/gpu) can also affect the final results. Thus reproducing exactly the same results can be difficult. However, if the package versions are the same, when averaged over a large number of random seeds, the overall performance should be similar to those reported in the paper.
135 |
136 | As of Mar. 29, 2021, we have used the installation guide on this page to re-setup a conda environment and run the code hosted on this repo and the reproduced results are similar to what we have in the paper (though not exactly the same, in some environments, performance are a bit stronger and others a bit weaker).
137 |
138 | Please open an issue if you find any problems in the code, thanks!
139 |
140 |
141 |
142 | ## Environment setup with MuJoCo and OpenAI Gym v2/V4 tasks, with Docker or Singularity
143 | This is a new 2023 Guide that is based on Docker and Singularity. (currently under more testing)
144 |
145 | Local setup: simply build a docker container with the dockerfile (either the v2 or the v4 version, depending on your need) provided in this repo (it basically specifies what you need to do to install all dependencies starting with a ubuntu18 system. You can also easily modify it to your needs).
146 |
147 | To get things to run very quick (with v2 gym-mujoco environments), simply pull this from my dockerhub: `docker pull cwatcherw/gym-mujocov2:1.0`
148 |
149 | After you pull the docker container, you can quickly test it:
150 | ```
151 | docker run -it --rm cwatcherw/gym-mujocov2:1.0
152 | ```
153 | Once you are inside the container, run:
154 |
155 | ```
156 | cd /workspace/REDQ/experiments/
157 | python train_redq_sac.py
158 | ```
159 |
160 | (Alternatively, remove `--rm` flag so the container is kept after shutting down, or add `--gpus all` to use GPU. )
161 |
162 | If you want to modify the REDQ codebase to test new ideas, you can clone (a fork of) REDQ repo to a local directory, and then mount it to `/workspace/REDQ`. For example:
163 |
164 | ```
165 | docker run -it --rm --mount type=bind,source=$(pwd)/REDQ,target=/workspace/REDQ cwatcherw/gym-mujocov2:1.0
166 | ```
167 |
168 | Example setup if you want to run on a Slurm HPC with singularity (you might need to make changes, depending on your HPC settings):
169 |
170 | First time setup:
171 | ```
172 | mkdir /scratch/$USER/.sing_cache
173 | export SINGULARITY_CACHEDIR=/scratch/$USER/.sing_cache
174 | echo "export SINGULARITY_CACHEDIR=/scratch/$USER/.sing_cache" >> ~/.bashrc
175 | mkdir /scratch/$USER/sing
176 | cd /scratch/$USER/sing
177 | git clone https://github.com/watchernyu/REDQ.git
178 | ```
179 |
180 | Build singularity and run singularity:
181 | ```
182 | module load singularity
183 | cd /scratch/$USER/sing/
184 | singularity build --sandbox mujoco-sandbox docker://cwatcherw/gym-mujocov2:1.0
185 | singularity exec -B /scratch/$USER/sing/REDQ:/workspace/REDQ -B /scratch/$USER/sing/mujoco-sandbox/opt/conda/lib/python3.8/site-packages/mujoco_py/:/opt/conda/lib/python3.8/site-packages/mujoco_py/ /scratch/$USER/sing/mujoco-sandbox bash
186 |
187 | singularity exec -B REDQ/:/workspace/REDQ/ mujoco-sandbox bash
188 | ```
189 |
190 | Don't forget to use following command to tell your mujoco to use egl headless rendering if you need to do rendering or use visual input. (this is required when you run on the hpc or other headless machines.)
191 | ```
192 | export MUJOCO_GL=egl
193 | ```
194 |
195 | Sample command to open an interactive session for debugging:
196 | ```
197 | srun -p parallel --pty --mem 12000 -t 0-05:00 bash
198 | ```
199 |
200 |
201 |
202 |
203 | ## Environment setup with newest MuJoCo and OpenAI Gym V4 tasks, on the NYU Shanghai hpc cluster (system is CentOS Linux release 7.4.1708, hpc management is Slurm)
204 | This one is the newest guide (2022 Summer) that helps you set up for Gym V4 MuJoCo tasks. And this newer version of Gym MuJoCo tasks is much easier to set up compared to previous versions. If you have limited CS background, when following these steps, make sure you don't perform extra commands in between steps.
205 |
206 | 1. to avoid storage space issues, we will work under the scratch partition (change the `netid` to your netid), we first clone the REDQ repo and `cd` into the REDQ folder.
207 | ```
208 | cd /scratch/NETID/
209 | git clone https://github.com/watchernyu/REDQ.git
210 | cd REDQ
211 | ```
212 |
213 | 2. download MuJoCo files by running the provided script.
214 | ```
215 | bash mujoco_download.sh
216 | ```
217 |
218 | 3. load the anaconda module (since we are on the HPC, we need to use `module load` to get certain software, instead of installing them ourselves), create and activate a conda environment using the yaml file provided. (Again don't forget to change the `netid` part, and the second line is to avoid overly long text on the terminal)
219 | ```
220 | module load anaconda3
221 | conda config --set env_prompt '({name})'
222 | conda env create -f conda_env.yml --prefix /scratch/NETID/redq_env
223 | conda activate /scratch/NETID/redq_env
224 | ```
225 |
226 | 4. install redq
227 | ```
228 | pip install -e .
229 | ```
230 |
231 | 5. run a test script, but make sure you filled in your actual netid in `experiments/sample_hpc_script.sh` so that the script can correctly locate your conda environment.
232 | ```
233 | cd experiments
234 | sbatch sample_hpc_script.sh
235 | ```
236 |
237 |
238 |
239 | ## Environment setup with newest MuJoCo 2.1, on a Ubuntu 18.04 local machine
240 | First download MuJoCo files, on a linux machine, we put them under ~/.mujoco:
241 | ```
242 | cd ~
243 | mkdir ~/.mujoco
244 | cd ~/.mujoco
245 | curl -O https://mujoco.org/download/mujoco210-linux-x86_64.tar.gz
246 | tar -xf mujoco210-linux-x86_64.tar.gz
247 | ```
248 |
249 | Now we create a conda environment (you will need anaconda), and install pytorch (if you just want mujoco+gym and don't want pytorch, then skip this step)
250 | ```
251 | conda create -y -n redq python=3.8
252 | conda activate redq
253 | conda install -y pytorch torchvision torchaudio cudatoolkit=11.3 -c pytorch
254 | ```
255 |
256 | Now install mujoco_py:
257 | ```
258 | cd ~
259 | mkdir rl_course
260 | cd rl_course
261 | git clone https://github.com/openai/mujoco-py.git
262 | cd mujoco-py/
263 | pip install -e . --no-cache
264 | ```
265 | A list of packages that need to be installed for linux is here https://github.com/openai/mujoco-py/blob/master/Dockerfile
266 |
267 | Now test by running python and then `import mujoco_py`, typically you will run into some error message, check that Dockerfile to see if you are missing any of the required packages (either python package or system package).
268 |
269 | If mujoco works, then install REDQ:
270 |
271 | ```
272 | cd ~
273 | cd rl_course
274 | git clone https://github.com/watchernyu/REDQ.git
275 | cd REDQ/
276 | pip install -e .
277 | ```
278 |
279 | Now test REDQ by running:
280 | ```
281 | python experiments/train_redq_sac.py --debug
282 | ```
283 | If you see training logs, then the environment should be setup correctly!
284 |
285 |
286 |
287 |
288 | ## Environment setup with newest MuJoCo 2.1, on the NYU Shanghai hpc cluster (system is Linux, hpc management is Slurm)
289 | This guide helps you set up MuJoCo and then OpenAI Gym, and then REDQ. (You can also follow the guide if you just want OpenAI Gym + MuJoCo and not REDQ, REDQ is only the last step). This likely also works for NYU NY hpc cluster, and might also works for hpc cluster in other schools, assuming your hpc is linux and is using Slurm.
290 |
291 | ### conda init
292 | First we need to login to the hpc.
293 | ```
294 | ssh netid@hpc.shanghai.nyu.edu
295 | ```
296 | After this you should be on the login node of the hpc. Note the login node is different from a compute node, we will set up environment on the login node, then when we submit actual jobs, they are in fact run on the compute node.
297 |
298 | Note that on the hpc, students typically don't have admin privileges (which means you cannot install things that require `sudo`), so for some of the required system packages, we will not use the typical `sudo apt install` command, instead, we will use `module avail` to check if they are available, and then use `module load` to load them. If on your hpc cluster, a system package is not there, check with your hpc admin and ask them to help you.
299 |
300 | On the NYU Shanghai hpc (after you ssh, you get to the login node), first we want to set up conda correctly (typically need to do this for new accounts):
301 | ```
302 | module load anaconda3
303 | conda init bash
304 | ```
305 | Now use Ctrl + D to logout, and then login again.
306 |
307 | ### set up MuJoCo
308 |
309 | Now we are again on the hpc login node simply run this to load all required packages:
310 | ```
311 | module load anaconda3 cuda/11.3.1
312 | ```
313 |
314 | Now download MuJoCo files, on a linux machine, we put them under ~/.mujoco:
315 | ```
316 | cd ~
317 | mkdir ~/.mujoco
318 | cd ~/.mujoco
319 | curl -O https://mujoco.org/download/mujoco210-linux-x86_64.tar.gz
320 | tar -xf mujoco210-linux-x86_64.tar.gz
321 | ```
322 |
323 | ### set up conda environment
324 | Then set up a conda virtualenv, and activate it (you can give it a different name, the env name does not matter)
325 | ```
326 | conda create -y -n redq python=3.8
327 | conda activate redq
328 | ```
329 |
330 | Install Pytorch (skip this step if you don't need pytorch)
331 | ```
332 | conda install -y pytorch torchvision torchaudio cudatoolkit=11.3 -c pytorch
333 | ```
334 |
335 | After you installed Pytorch, check if it works by running `python`, then `import torch` in the python interpreter.
336 | If Pytorch works, then either run `quit()` or Ctrl + D to exit the python interpreter.
337 |
338 | ### set up `mujoco_py`
339 |
340 | The next is to install MuJoCo, let's first run these
341 | ```
342 | echo 'export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$HOME/.mujoco/mujoco210/bin' >> ~/.bashrc
343 | echo 'export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/lib/nvidia' >> ~/.bashrc
344 | ```
345 | (it might be easier to do this step here. However, you can also set this up later when you test `import mujoco_py`)
346 |
347 | Now we want the `.bashrc` file to take effect, we need to ctrl+D to logout, and then login again, after we logout and login, we need to reload all the modules , and then also activate the conda env again. (each time you login, the loaded modules and activated environment will be reset, but files on disk will persist) After you login to the hpc again:
348 | ```
349 | module load anaconda3 cuda/11.3.1
350 | conda activate redq
351 | ```
352 |
353 | Now we can install `mujoco_py`. (You might want to know why do we need this? We already have the MuJoCo files, but they do not work directly with python, so `mujoco_py` is needed for us to use MuJoCo in python):
354 | ```
355 | cd ~
356 | mkdir rl_course
357 | cd rl_course
358 | git clone https://github.com/openai/mujoco-py.git
359 | cd mujoco-py/
360 | pip install -e . --no-cache
361 | ```
362 |
363 | Somewhere during installation, you might find that some libraries are missing (an error message might show up saying sth is missing, for example `GL/glew.h: No such file or directory #include `), on the HPC, since we are not super user, we cannot install system libraries, but can use `module load` to load them. This command works on some machines: `module load glfw/3.3 gcc/7.3 mesa/19.0.5 llvm/7.0.1`. You can try this and then see if the error goes away.
364 |
365 | Now we want to test whether it works, run `python`, and then in the python interpreter, run `import mujoco_py`. The first time you run it, it will give a ton of log, if you can run it again and get no log, then it should be working. (Summary: try run `import mujoco_py` twice, if the second time you do it, you get no log and no error message, then it should be working). After this testing is complete, quit the python interpreter with either `quit()` or Ctrl + D.
366 |
367 | Note: if you got an error when import `mujoco_py`, sometimes the error will tell you to add some `export` text to a file (typically `~/.bashrc`) on your system, if you see that, then likely this is because you are installing on a system that configured things slightly differently from NYU Shanghai hpc cluster, in that case, just follow the error message to do whatever it tells you, then logout and login, and test it again. Check https://github.com/openai/mujoco-py for more info.
368 |
369 | ### set up gym
370 |
371 | Now we install OpenAI gym:
372 | `pip install gym`
373 |
374 | After this step, test gym by again run `python`, in the python interpreter, run:
375 | ```
376 | import gym
377 | e = gym.make('Ant-v2')
378 | e.reset()
379 | ```
380 | If you see a large numpy array (which is the initial state, or initial observation for the Ant-v2 environment), then gym is working.
381 |
382 | ### set up REDQ
383 | After this step, you can install REDQ.
384 | ```
385 | cd ~
386 | cd rl_course
387 | git clone https://github.com/watchernyu/REDQ.git
388 | cd REDQ/
389 | pip install -e .
390 | ```
391 |
392 | ### test on login node (sometimes things work on login but not on compute, we will first test login)
393 | ```
394 | cd ~/rl_course/REDQ
395 | python experiments/train_redq_sac.py --debug
396 | ```
397 |
398 | ### test on compute node (will be updated soon)
399 | Now we test whether REDQ runs. We will first login to an interactive compute node (note a login node is not a compute node, don't do intensive computation on the login node.):
400 | ```
401 | srun -p aquila --pty --mem 5000 -t 0-05:00 bash
402 | ```
403 | And now don't forget we are in a new node and need to load modules and activate conda env:
404 | ```
405 | module load anaconda3 cuda/11.3.1
406 | conda deactivate
407 | conda activate redq
408 | ```
409 | Now test redq algorithm:
410 | ```
411 | cd ~/rl_course/REDQ
412 | python experiments/train_redq_sac.py --debug
413 | ```
414 |
415 |
416 | ## other HPC issues
417 | ### missing patchelf
418 | Simply install patchelf. Make sure your conda environment is activated, try this command (make sure you are inside your conda env): `conda install -c anaconda patchelf`.
419 |
420 | If you see warnings, or a message telling you to update conda, ignore it, if it asks you whether you want to install, choose yes and wait for the installation to finish.
421 |
422 | ### quota exceeded
423 | If you home quota is exceeded, you can contact the current HPC admin to extend your quota. Alternatively, you can install all require packages under `\scratch`, which has plenty of space (but your data under scratch will be removed if you don't use your account for too long). But you might need more python skills to do this correctly.
424 |
425 |
426 |
427 | ## Previous updates
428 |
429 | June 23, 2022: added guide for setting up with OpenAI MuJoCo v4 tasks on Slurm HPCs (not fully tested yet). Currently it seems this newer version of MuJoCo is much easier to set up compared to previous ones.
430 |
431 | Nov 14, 2021: **MuJoCo** is now free (thanks DeepMind!) and we now have a guide on setting up with MuJoCo 2.1 + OpenAI Gym + REDQ on a linux machine (see end of this page for newest setup guide).
432 |
433 | Aug 18, 2021: **VERY IMPORTANT BUG FIX** in `experiments/train_redq_sac.py`, the done signal is not being correctly used, the done signal value should be `False` when the episode terminates due to environment timelimit, but in the earlier version of the code,
434 | the agent puts the transition in buffer before this value is corrected. This can affect performance especially for environments where termination due to bad action is rare. This is now fixed and we might do some more testing. If you use this file to run experiments **please check immediately or pull the latest version** of the code.
435 | Sorry for the bug! Please don't hesitate to open an issue if you have any questions.
436 |
437 | July, 2021: data and the function to reproduce all figures in the paper are now available, see the `Data and reproducing figures in REDQ` section for details.
438 |
439 | Mar 23, 2021: We have reorganized the code to make it cleaner and more readable and the first version is now released!
440 |
441 | Mar 29, 2021: We tested the installation process and run the code, and everything seems to be working correctly. We are now working on the implementation video tutorial, which will be released soon.
442 |
443 | May 3, 2021: We uploaded a video tutorial (shared via google drive), please see link below. Hope it helps!
444 |
445 | Code for REDQ-OFE is still being cleaned up and will be released soon (essentially the same code but with additional input from a OFENet).
446 |
447 |
448 |
449 | ## Acknowledgement
450 |
451 | Our code for REDQ-SAC is partly based on the SAC implementation in OpenAI Spinup (https://github.com/openai/spinningup). The current code structure is inspired by the super clean TD3 source code by Scott Fujimoto (https://github.com/sfujim/TD3).
452 |
453 |
--------------------------------------------------------------------------------
/conda_env.yml:
--------------------------------------------------------------------------------
1 | name: redq
2 | channels:
3 | - defaults
4 | dependencies:
5 | - python=3.8
6 | - pip=21.2.4
7 | - pytorch::pytorch=1.11.0
8 | - pytorch::torchvision=0.12.0
9 | - nvidia::cudatoolkit=11.3.1
10 | - joblib=1.1.0
11 | - pip:
12 | - gym[mujoco,robotics]==0.24.1
13 |
--------------------------------------------------------------------------------
/docker-gym-mujocov2/Dockerfile:
--------------------------------------------------------------------------------
1 | # cudagl with miniconda and python 3.8, pytorch, mujoco and gym v2 envs, REDQ
2 |
3 | FROM nvidia/cudagl:11.0-base-ubuntu18.04
4 | WORKDIR /workspace
5 | ENV HOME=/workspace
6 | ENV LANG=C.UTF-8 LC_ALL=C.UTF-8
7 | ENV PATH /opt/conda/bin:$PATH
8 |
9 | # idea: start with a nvidia docker with gl support (guess this one also has cuda?)
10 | # then install miniconda, borrowing docker command from miniconda's Dockerfile (https://hub.docker.com/r/continuumio/anaconda/dockerfile/)
11 | # need to make sure the miniconda python version is what we need (https://docs.conda.io/en/latest/miniconda.html for the right version)
12 | # then install other dependencies we need
13 |
14 | # nvidia GPG key alternative fix (https://forums.developer.nvidia.com/t/notice-cuda-linux-repository-key-rotation/212772)
15 | # sudo apt-key del 7fa2af80
16 | # wget https://developer.download.nvidia.com/compute/cuda/repos/$distro/$arch/cuda-keyring_1.0-1_all.deb
17 | # sudo dpkg -i cuda-keyring_1.0-1_all.deb
18 |
19 | RUN \
20 | # Update nvidia GPG key
21 | rm /etc/apt/sources.list.d/cuda.list \
22 | && rm /etc/apt/sources.list.d/nvidia-ml.list \
23 | && apt-key del 7fa2af80 \
24 | && apt-get update && apt-get install -y --no-install-recommends wget \
25 | && wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2004/x86_64/cuda-keyring_1.0-1_all.deb \
26 | && dpkg -i cuda-keyring_1.0-1_all.deb \
27 | && apt-get update
28 |
29 | RUN apt-get update --fix-missing && apt-get install -y wget bzip2 ca-certificates \
30 | libglib2.0-0 libxext6 libsm6 libxrender1 \
31 | git mercurial subversion
32 |
33 | RUN wget --quiet https://repo.anaconda.com/miniconda/Miniconda3-py38_4.12.0-Linux-x86_64.sh -O ~/miniconda.sh \
34 | && /bin/bash ~/miniconda.sh -b -p /opt/conda \
35 | && rm ~/miniconda.sh \
36 | && ln -s /opt/conda/etc/profile.d/conda.sh /etc/profile.d/conda.sh \
37 | && echo ". /opt/conda/etc/profile.d/conda.sh" >> ~/.bashrc \
38 | && echo "conda activate base" >> ~/.bashrc
39 |
40 | RUN apt-get install -y curl grep sed dpkg \
41 | # TINI_VERSION=`curl https://github.com/krallin/tini/releases/latest | grep -o "/v.*\"" | sed 's:^..\(.*\).$:\1:'` && \
42 | # curl -L "https://github.com/krallin/tini/releases/download/v${TINI_VERSION}/tini_${TINI_VERSION}.deb" > tini.deb && \
43 | # dpkg -i tini.deb && \
44 | # rm tini.deb && \
45 | && apt-get clean
46 |
47 | # Install some basic utilities
48 | RUN apt-get update && apt-get install -y software-properties-common \
49 | && add-apt-repository -y ppa:redislabs/redis && apt-get update \
50 | && apt-get install -y sudo ssh libx11-6 gcc iputils-ping \
51 | libxrender-dev graphviz tmux htop build-essential wget cmake libgl1-mesa-glx redis \
52 | && rm -rf /var/lib/apt/lists/*
53 |
54 | RUN apt-get update && DEBIAN_FRONTEND=noninteractive \
55 | && apt-get install -y zlib1g zlib1g-dev libosmesa6-dev libgl1-mesa-glx libglfw3 libglew2.0
56 | # && ln -s /usr/lib/x86_64-linux-gnu/libGL.so.1 /usr/lib/x86_64-linux-gnu/libGL.so
57 | # ---------- now we should have all major dependencies ------------
58 |
59 | # --------- now we have cudagl + python38 ---------
60 | RUN pip install --no-cache-dir torch==1.13.1 torchvision==0.14.1 torchaudio==0.13.1 --extra-index-url https://download.pytorch.org/whl/cu116
61 | # pip install torch==1.8.0+cu111 torchvision==0.9.0+cu111 torchaudio==0.8.0 -f https://download.pytorch.org/whl/torch_stable.html
62 |
63 | RUN pip install --no-cache-dir scikit-learn pandas imageio
64 |
65 | # adroit env does not work with newer version of mujoco, so we have to use the old version...
66 | # setting MUJOCO_GL to egl is needed if run on headless machine
67 | ENV MUJOCO_GL=egl
68 | ENV LD_LIBRARY_PATH=${LD_LIBRARY_PATH}:/workspace/.mujoco/mujoco210/bin
69 | RUN mkdir -p /workspace/.mujoco \
70 | && wget https://github.com/deepmind/mujoco/releases/download/2.1.0/mujoco210-linux-x86_64.tar.gz -O mujoco210.tar.gz \
71 | && tar -xvzf mujoco210.tar.gz -C /workspace/.mujoco \
72 | && rm mujoco210.tar.gz
73 |
74 | # RL env: get mujoco and gym (v2 version environments)
75 | RUN pip install --no-cache-dir patchelf==0.17.2.0 mujoco-py==2.1.2.14 gym==0.21.0
76 |
77 | # first time import mujoco_py will take some time to run, do it here to avoid the wait later
78 | RUN echo "import mujoco_py" >> /workspace/import_mujoco.py \
79 | && python /workspace/import_mujoco.py
80 |
81 | # RL algorithm: install REDQ. If you don't want to use REDQ, simply remove these lines
82 | RUN cd /workspace/ \
83 | && git clone https://github.com/watchernyu/REDQ.git \
84 | && cd REDQ \
85 | && git checkout ac840198f143d10bb22425ed2105a49d01b383fa \
86 | && pip install -e .
87 |
88 | CMD [ "/bin/bash" ]
89 |
90 | # # build docker container:
91 | # docker build . -t name:tag
92 |
93 | # # example docker command to run interactive container, enable gpu, and remove it when shutdown:
94 | # docker run -it --rm --gpus all name:tag
95 |
--------------------------------------------------------------------------------
/docker-gym-mujocov2/README.md:
--------------------------------------------------------------------------------
1 | This dockerfile contains commands to setup the older gym mujoco v2 environments. The results in the REDQ paper were run with the v2 environments. If you use this dockerfile, you will be able to run REDQ code right away.
--------------------------------------------------------------------------------
/docker-gym-mujocov4/Dockerfile:
--------------------------------------------------------------------------------
1 | # cudagl with miniconda and python 3.8, pytorch, mujoco and gym v4 envs
2 |
3 | FROM nvidia/cudagl:11.0-base-ubuntu18.04
4 | WORKDIR /workspace
5 | ENV HOME=/workspace
6 | ENV LANG=C.UTF-8 LC_ALL=C.UTF-8
7 | ENV PATH /opt/conda/bin:$PATH
8 |
9 | # idea: start with a nvidia docker with gl support (guess this one also has cuda?)
10 | # then install miniconda, borrowing docker command from miniconda's Dockerfile (https://hub.docker.com/r/continuumio/anaconda/dockerfile/)
11 | # need to make sure the miniconda python version is what we need (https://docs.conda.io/en/latest/miniconda.html for the right version)
12 | # then install other dependencies we need
13 |
14 | # nvidia GPG key alternative fix (https://forums.developer.nvidia.com/t/notice-cuda-linux-repository-key-rotation/212772)
15 | # sudo apt-key del 7fa2af80
16 | # wget https://developer.download.nvidia.com/compute/cuda/repos/$distro/$arch/cuda-keyring_1.0-1_all.deb
17 | # sudo dpkg -i cuda-keyring_1.0-1_all.deb
18 |
19 | RUN \
20 | # Update nvidia GPG key
21 | rm /etc/apt/sources.list.d/cuda.list && \
22 | rm /etc/apt/sources.list.d/nvidia-ml.list && \
23 | apt-key del 7fa2af80 && \
24 | apt-get update && apt-get install -y --no-install-recommends wget && \
25 | wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2004/x86_64/cuda-keyring_1.0-1_all.deb && \
26 | dpkg -i cuda-keyring_1.0-1_all.deb && \
27 | apt-get update
28 |
29 | RUN apt-get update --fix-missing && apt-get install -y wget bzip2 ca-certificates \
30 | libglib2.0-0 libxext6 libsm6 libxrender1 \
31 | git mercurial subversion
32 |
33 | RUN wget --quiet https://repo.anaconda.com/miniconda/Miniconda3-py38_4.12.0-Linux-x86_64.sh -O ~/miniconda.sh && \
34 | /bin/bash ~/miniconda.sh -b -p /opt/conda && \
35 | rm ~/miniconda.sh && \
36 | ln -s /opt/conda/etc/profile.d/conda.sh /etc/profile.d/conda.sh && \
37 | echo ". /opt/conda/etc/profile.d/conda.sh" >> ~/.bashrc && \
38 | echo "conda activate base" >> ~/.bashrc
39 |
40 | RUN apt-get install -y curl grep sed dpkg && \
41 | # TINI_VERSION=`curl https://github.com/krallin/tini/releases/latest | grep -o "/v.*\"" | sed 's:^..\(.*\).$:\1:'` && \
42 | # curl -L "https://github.com/krallin/tini/releases/download/v${TINI_VERSION}/tini_${TINI_VERSION}.deb" > tini.deb && \
43 | # dpkg -i tini.deb && \
44 | # rm tini.deb && \
45 | apt-get clean
46 |
47 | # Install some basic utilities
48 | RUN apt-get update && apt-get install -y software-properties-common && \
49 | add-apt-repository -y ppa:redislabs/redis && apt-get update && \
50 | apt-get install -y sudo ssh libx11-6 gcc iputils-ping \
51 | libxrender-dev graphviz tmux htop build-essential wget cmake libgl1-mesa-glx redis && \
52 | rm -rf /var/lib/apt/lists/*
53 |
54 | RUN apt-get update && DEBIAN_FRONTEND=noninteractive \
55 | && apt-get install -y zlib1g zlib1g-dev libosmesa6-dev libgl1-mesa-glx libglfw3 libglew2.0
56 | # && ln -s /usr/lib/x86_64-linux-gnu/libGL.so.1 /usr/lib/x86_64-linux-gnu/libGL.so
57 | # ---------- now we should have all major dependencies ------------
58 |
59 | # --------- now we have cudagl + python38 ---------
60 | RUN pip install --no-cache-dir torch==1.13.1 torchvision==0.14.1 torchaudio==0.13.1 --extra-index-url https://download.pytorch.org/whl/cu116
61 | # pip install torch==1.8.0+cu111 torchvision==0.9.0+cu111 torchaudio==0.8.0 -f https://download.pytorch.org/whl/torch_stable.html
62 |
63 | RUN pip install --no-cache-dir scikit-learn pandas imageio
64 |
65 | # RL env: get mujoco and gym
66 | RUN pip install --no-cache-dir mujoco==2.2.2 gym==0.26.2
67 | # change to headless rendering (if you want rendering on a machine's monitor see https://github.com/deepmind/dm_control)
68 | ENV MUJOCO_GL=egl
69 | # import mujoco now to avoid long waiting time at runtime
70 | RUN echo "import mujoco_py" >> /workspace/import_mujoco.py \
71 | && python /workspace/import_mujoco.py
72 |
73 | # RL algorithm: install REDQ
74 | RUN cd /workspace/ \
75 | && git clone https://github.com/watchernyu/REDQ.git \
76 | && cd REDQ \
77 | # && git checkout ac840198f143d10bb22425ed2105a49d01b383fa \
78 | && pip install -e .
79 |
80 | CMD [ "/bin/bash" ]
81 |
82 | # example docker command to run interactive container, enable gpu, remove when shutdown:
83 | # docker run -it --rm --gpus all name:tag
84 |
85 |
--------------------------------------------------------------------------------
/docker-gym-mujocov4/README.md:
--------------------------------------------------------------------------------
1 | This dockerfile contains commands to setup a new version of mujoco and gym (gym mujoco v4 environments). However, many things in gym have changed from v2 to v4 (e.g. there are now more output from the environment, and seeding has been significantly modified). The v4 version environments are easier to setup and work with, but if you are using a codebase designed for v2 environments, you will need to modify the code to work with the new v4 environments.
2 |
--------------------------------------------------------------------------------
/experiments/grid_sample_hpc_script.sh:
--------------------------------------------------------------------------------
1 | #!/bin/sh
2 | #SBATCH --verbose
3 | #SBATCH -p aquila,parallel
4 | #SBATCH --time=168:00:00
5 | #SBATCH --nodes=1
6 | #SBATCH --mem=12GB
7 | #SBATCH --mail-type=ALL # select which email types will be sent
8 | #SBATCH --mail-user=NETID@nyu.edu # NOTE: put your netid here if you want emails
9 |
10 | #SBATCH --array=0-5 # here the number depends on number of tasks in the array, e.g. 0-59 will create 60 tasks
11 | #SBATCH --output=r_%A_%a.out # %A is SLURM_ARRAY_JOB_ID, %a is SLURM_ARRAY_TASK_ID
12 | #SBATCH --error=r_%A_%a.err
13 |
14 | # #####################################################
15 | # #SBATCH --gres=gpu:1 # uncomment this line to request a gpu
16 | #SBATCH --constraint=cpu # specify constraint features ('cpu' means only use nodes that have the 'cpu' feature) check features with the showcluster command
17 |
18 | sleep $(( (RANDOM%10) + 1 )) # to avoid issues when submitting large amounts of jobs
19 |
20 | echo "SLURM_JOBID: " $SLURM_JOBID
21 | echo "SLURM_ARRAY_JOB_ID: " $SLURM_ARRAY_JOB_ID
22 | echo "SLURM_ARRAY_TASK_ID: " $SLURM_ARRAY_TASK_ID
23 |
24 | module load anaconda3 cuda/9.0 glfw/3.3 gcc/7.3 mesa/19.0.5 llvm/7.0.1 # load modules that we might use
25 | conda init bash # this and the following line are important to avoid hpc issues
26 | source ~/.bashrc
27 | conda activate /scratch/NETID/redq_env # NOTE: remember to change to your actual netid
28 |
29 | echo ${SLURM_ARRAY_TASK_ID}
30 | python grid_train_redq_sac.py --setting ${SLURM_ARRAY_TASK_ID}
--------------------------------------------------------------------------------
/experiments/grid_train_redq_sac.py:
--------------------------------------------------------------------------------
1 | from train_redq_sac import redq_sac as function_to_run ## here make sure you import correct function
2 | import time
3 | from redq.utils.run_utils import setup_logger_kwargs
4 | from grid_utils import get_setting_and_exp_name
5 |
6 | if __name__ == '__main__':
7 | import argparse
8 | start_time = time.time()
9 | parser = argparse.ArgumentParser()
10 | parser.add_argument('--setting', type=int, default=0)
11 | args = parser.parse_args()
12 | data_dir = '../data'
13 |
14 | exp_prefix = 'trial'
15 | settings = ['env_name','',['Hopper-v4', 'Ant-v4'],
16 | 'seed','',[0, 1, 2],
17 | 'epochs','e',[20],
18 | 'num_Q','q',[2],
19 | 'utd_ratio','uf',[1],
20 | 'policy_update_delay','pd',[20],
21 | ]
22 |
23 | indexes, actual_setting, total, exp_name_full = get_setting_and_exp_name(settings, args.setting, exp_prefix)
24 | print("##### TOTAL NUMBER OF VARIANTS: %d #####" % total)
25 |
26 | logger_kwargs = setup_logger_kwargs(exp_name_full, actual_setting['seed'], data_dir)
27 | function_to_run(logger_kwargs=logger_kwargs, **actual_setting)
28 | print("Total time used: %.3f hours." % ((time.time() - start_time)/3600))
29 |
--------------------------------------------------------------------------------
/experiments/grid_utils.py:
--------------------------------------------------------------------------------
1 | from train_redq_sac import redq_sac as function_to_run ## here make sure you import correct function
2 | import time
3 | import numpy as np
4 | from redq.utils.run_utils import setup_logger_kwargs
5 |
6 | def get_setting_and_exp_name(settings, setting_number, exp_prefix, random_setting_seed=0, random_order=True):
7 | np.random.seed(random_setting_seed)
8 | hypers, lognames, values_list = [], [], []
9 | hyper2logname = {}
10 | n_settings = int(len(settings)/3)
11 | for i in range(n_settings):
12 | hypers.append(settings[i*3])
13 | lognames.append(settings[i*3+1])
14 | values_list.append(settings[i*3+2])
15 | hyper2logname[hypers[-1]] = lognames[-1]
16 |
17 | total = 1
18 | for values in values_list:
19 | total *= len(values)
20 | max_job = total
21 |
22 | new_indexes = np.random.choice(total, total, replace=False) if random_order else np.arange(total)
23 | new_index = new_indexes[setting_number]
24 |
25 | indexes = [] ## this says which hyperparameter we use
26 | remainder = new_index
27 | for values in values_list:
28 | division = int(total / len(values))
29 | index = int(remainder / division)
30 | remainder = remainder % division
31 | indexes.append(index)
32 | total = division
33 | actual_setting = {}
34 | for j in range(len(indexes)):
35 | actual_setting[hypers[j]] = values_list[j][indexes[j]]
36 |
37 | exp_name_full = exp_prefix
38 | for hyper, value in actual_setting.items():
39 | if hyper not in ['env_name', 'seed']:
40 | exp_name_full = exp_name_full + '_%s' % (hyper2logname[hyper] + str(value))
41 | exp_name_full = exp_name_full + '_%s' % actual_setting['env_name']
42 |
43 | return indexes, actual_setting, max_job, exp_name_full
44 |
45 |
--------------------------------------------------------------------------------
/experiments/sample_hpc_script.sh:
--------------------------------------------------------------------------------
1 | #!/bin/sh
2 | #SBATCH --verbose
3 | #SBATCH -p aquila,parallel
4 | #SBATCH --time=168:00:00
5 | #SBATCH --nodes=1
6 | #SBATCH --mem=12GB
7 | #SBATCH --mail-type=ALL # select which email types will be sent
8 | #SBATCH --mail-user=NETID@nyu.edu # NOTE: put your netid here if you want emails
9 |
10 | #SBATCH --array=0-0 # here the number depends on number of tasks in the array, e.g. 0-59 will create 60 tasks
11 | #SBATCH --output=r_%A_%a.out # %A is SLURM_ARRAY_JOB_ID, %a is SLURM_ARRAY_TASK_ID
12 | #SBATCH --error=r_%A_%a.err
13 |
14 | # #####################################################
15 | # #SBATCH --gres=gpu:1 # uncomment this line to request a gpu
16 | #SBATCH --constraint=cpu # specify constraint features ('cpu' means only use nodes that have the 'cpu' feature) check features with the showcluster command
17 |
18 | sleep $(( (RANDOM%10) + 1 )) # to avoid issues when submitting large amounts of jobs
19 |
20 | echo "SLURM_JOBID: " $SLURM_JOBID
21 | echo "SLURM_ARRAY_JOB_ID: " $SLURM_ARRAY_JOB_ID
22 | echo "SLURM_ARRAY_TASK_ID: " $SLURM_ARRAY_TASK_ID
23 |
24 | module load anaconda3 cuda/9.0 glfw/3.3 gcc/7.3 mesa/19.0.5 llvm/7.0.1 # load modules that we might use
25 | conda init bash # this and the following line are important to avoid hpc issues
26 | source ~/.bashrc
27 | conda activate /scratch/NETID/redq_env # NOTE: remember to change to your actual netid
28 |
29 | echo ${SLURM_ARRAY_TASK_ID}
30 | python train_redq_sac.py --debug --epochs 20 --env Hopper-v4 # use v4 version tasks if you followed the newest Gym+MuJoCo installation guide
--------------------------------------------------------------------------------
/experiments/train_redq_sac.py:
--------------------------------------------------------------------------------
1 | import gym
2 | import numpy as np
3 | import torch
4 | import time
5 | import sys
6 | from redq.algos.redq_sac import REDQSACAgent
7 | from redq.algos.core import mbpo_epoches, test_agent
8 | from redq.utils.run_utils import setup_logger_kwargs
9 | from redq.utils.bias_utils import log_bias_evaluation
10 | from redq.utils.logx import EpochLogger
11 |
12 | def redq_sac(env_name, seed=0, epochs='mbpo', steps_per_epoch=1000,
13 | max_ep_len=1000, n_evals_per_epoch=1,
14 | logger_kwargs=dict(), debug=False,
15 | # following are agent related hyperparameters
16 | hidden_sizes=(256, 256), replay_size=int(1e6), batch_size=256,
17 | lr=3e-4, gamma=0.99, polyak=0.995,
18 | alpha=0.2, auto_alpha=True, target_entropy='mbpo',
19 | start_steps=5000, delay_update_steps='auto',
20 | utd_ratio=20, num_Q=10, num_min=2, q_target_mode='min',
21 | policy_update_delay=20,
22 | # following are bias evaluation related
23 | evaluate_bias=True, n_mc_eval=1000, n_mc_cutoff=350, reseed_each_epoch=True
24 | ):
25 | """
26 | :param env_name: name of the gym environment
27 | :param seed: random seed
28 | :param epochs: number of epochs to run
29 | :param steps_per_epoch: number of timestep (datapoints) for each epoch
30 | :param max_ep_len: max timestep until an episode terminates
31 | :param n_evals_per_epoch: number of evaluation runs for each epoch
32 | :param logger_kwargs: arguments for logger
33 | :param debug: whether to run in debug mode
34 | :param hidden_sizes: hidden layer sizes
35 | :param replay_size: replay buffer size
36 | :param batch_size: mini-batch size
37 | :param lr: learning rate for all networks
38 | :param gamma: discount factor
39 | :param polyak: hyperparameter for polyak averaged target networks
40 | :param alpha: SAC entropy hyperparameter
41 | :param auto_alpha: whether to use adaptive SAC
42 | :param target_entropy: used for adaptive SAC
43 | :param start_steps: the number of random data collected in the beginning of training
44 | :param delay_update_steps: after how many data collected should we start updates
45 | :param utd_ratio: the update-to-data ratio
46 | :param num_Q: number of Q networks in the Q ensemble
47 | :param num_min: number of sampled Q values to take minimal from
48 | :param q_target_mode: 'min' for minimal, 'ave' for average, 'rem' for random ensemble mixture
49 | :param policy_update_delay: how many updates until we update policy network
50 | """
51 | if debug: # use --debug for very quick debugging
52 | hidden_sizes = [2,2]
53 | batch_size = 2
54 | utd_ratio = 2
55 | num_Q = 3
56 | max_ep_len = 100
57 | start_steps = 100
58 | steps_per_epoch = 100
59 |
60 | # use gpu if available
61 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
62 | # set number of epoch
63 | if epochs == 'mbpo' or epochs < 0:
64 | epochs = mbpo_epoches[env_name]
65 | total_steps = steps_per_epoch * epochs + 1
66 |
67 | """set up logger"""
68 | logger = EpochLogger(**logger_kwargs)
69 | logger.save_config(locals())
70 |
71 | """set up environment and seeding"""
72 | env_fn = lambda: gym.make(env_name)
73 | env, test_env, bias_eval_env = env_fn(), env_fn(), env_fn()
74 | # seed torch and numpy
75 | torch.manual_seed(seed)
76 | np.random.seed(seed)
77 |
78 | # seed environment along with env action space so that everything is properly seeded for reproducibility
79 | def seed_all(epoch):
80 | seed_shift = epoch * 9999
81 | mod_value = 999999
82 | env_seed = (seed + seed_shift) % mod_value
83 | test_env_seed = (seed + 10000 + seed_shift) % mod_value
84 | bias_eval_env_seed = (seed + 20000 + seed_shift) % mod_value
85 | torch.manual_seed(env_seed)
86 | np.random.seed(env_seed)
87 | env.seed(env_seed)
88 | env.action_space.np_random.seed(env_seed)
89 | test_env.seed(test_env_seed)
90 | test_env.action_space.np_random.seed(test_env_seed)
91 | bias_eval_env.seed(bias_eval_env_seed)
92 | bias_eval_env.action_space.np_random.seed(bias_eval_env_seed)
93 | seed_all(epoch=0)
94 |
95 | """prepare to init agent"""
96 | # get obs and action dimensions
97 | obs_dim = env.observation_space.shape[0]
98 | act_dim = env.action_space.shape[0]
99 | # if environment has a smaller max episode length, then use the environment's max episode length
100 | max_ep_len = env._max_episode_steps if max_ep_len > env._max_episode_steps else max_ep_len
101 | # Action limit for clamping: critically, assumes all dimensions share the same bound!
102 | # we need .item() to convert it from numpy float to python float
103 | act_limit = env.action_space.high[0].item()
104 | # keep track of run time
105 | start_time = time.time()
106 | # flush logger (optional)
107 | sys.stdout.flush()
108 | #################################################################################################
109 |
110 | """init agent and start training"""
111 | agent = REDQSACAgent(env_name, obs_dim, act_dim, act_limit, device,
112 | hidden_sizes, replay_size, batch_size,
113 | lr, gamma, polyak,
114 | alpha, auto_alpha, target_entropy,
115 | start_steps, delay_update_steps,
116 | utd_ratio, num_Q, num_min, q_target_mode,
117 | policy_update_delay)
118 |
119 | o, r, d, ep_ret, ep_len = env.reset(), 0, False, 0, 0
120 |
121 | for t in range(total_steps):
122 | # get action from agent
123 | a = agent.get_exploration_action(o, env)
124 | # Step the env, get next observation, reward and done signal
125 | o2, r, d, _ = env.step(a)
126 |
127 | # Very important: before we let agent store this transition,
128 | # Ignore the "done" signal if it comes from hitting the time
129 | # horizon (that is, when it's an artificial terminal signal
130 | # that isn't based on the agent's state)
131 | ep_len += 1
132 | d = False if ep_len == max_ep_len else d
133 |
134 | # give new data to agent
135 | agent.store_data(o, a, r, o2, d)
136 | # let agent update
137 | agent.train(logger)
138 | # set obs to next obs
139 | o = o2
140 | ep_ret += r
141 |
142 |
143 | if d or (ep_len == max_ep_len):
144 | # store episode return and length to logger
145 | logger.store(EpRet=ep_ret, EpLen=ep_len)
146 | # reset environment
147 | o, r, d, ep_ret, ep_len = env.reset(), 0, False, 0, 0
148 |
149 | # End of epoch wrap-up
150 | if (t+1) % steps_per_epoch == 0:
151 | epoch = t // steps_per_epoch
152 |
153 | # Test the performance of the deterministic version of the agent.
154 | test_agent(agent, test_env, max_ep_len, logger) # add logging here
155 | if evaluate_bias:
156 | log_bias_evaluation(bias_eval_env, agent, logger, max_ep_len, alpha, gamma, n_mc_eval, n_mc_cutoff)
157 |
158 | # reseed should improve reproducibility (should make results the same whether bias evaluation is on or not)
159 | if reseed_each_epoch:
160 | seed_all(epoch)
161 |
162 | """logging"""
163 | # Log info about epoch
164 | logger.log_tabular('Epoch', epoch)
165 | logger.log_tabular('TotalEnvInteracts', t)
166 | logger.log_tabular('Time', time.time()-start_time)
167 | logger.log_tabular('EpRet', with_min_and_max=True)
168 | logger.log_tabular('EpLen', average_only=True)
169 | logger.log_tabular('TestEpRet', with_min_and_max=True)
170 | logger.log_tabular('TestEpLen', average_only=True)
171 | logger.log_tabular('Q1Vals', with_min_and_max=True)
172 | logger.log_tabular('LossQ1', average_only=True)
173 | logger.log_tabular('LogPi', with_min_and_max=True)
174 | logger.log_tabular('LossPi', average_only=True)
175 | logger.log_tabular('Alpha', with_min_and_max=True)
176 | logger.log_tabular('LossAlpha', average_only=True)
177 | logger.log_tabular('PreTanh', with_min_and_max=True)
178 |
179 | if evaluate_bias:
180 | logger.log_tabular("MCDisRet", with_min_and_max=True)
181 | logger.log_tabular("MCDisRetEnt", with_min_and_max=True)
182 | logger.log_tabular("QPred", with_min_and_max=True)
183 | logger.log_tabular("QBias", with_min_and_max=True)
184 | logger.log_tabular("QBiasAbs", with_min_and_max=True)
185 | logger.log_tabular("NormQBias", with_min_and_max=True)
186 | logger.log_tabular("QBiasSqr", with_min_and_max=True)
187 | logger.log_tabular("NormQBiasSqr", with_min_and_max=True)
188 | logger.dump_tabular()
189 |
190 | # flush logged information to disk
191 | sys.stdout.flush()
192 |
193 | if __name__ == '__main__':
194 | import argparse
195 | parser = argparse.ArgumentParser()
196 | parser.add_argument('--env', type=str, default='Hopper-v2')
197 | parser.add_argument('--seed', '-s', type=int, default=0)
198 | parser.add_argument('--epochs', type=int, default=-1) # -1 means use mbpo epochs
199 | parser.add_argument('--exp_name', type=str, default='redq_sac')
200 | parser.add_argument('--data_dir', type=str, default='../data/')
201 | parser.add_argument('--debug', action='store_true')
202 | args = parser.parse_args()
203 |
204 | # modify the code here if you want to use a different naming scheme
205 | exp_name_full = args.exp_name + '_%s' % args.env
206 |
207 | # specify experiment name, seed and data_dir.
208 | # for example, for seed 0, the progress.txt will be saved under data_dir/exp_name/exp_name_s0
209 | logger_kwargs = setup_logger_kwargs(exp_name_full, args.seed, args.data_dir)
210 |
211 | redq_sac(args.env, seed=args.seed, epochs=args.epochs,
212 | logger_kwargs=logger_kwargs, debug=args.debug)
213 |
--------------------------------------------------------------------------------
/mujoco_download.sh:
--------------------------------------------------------------------------------
1 | cd ~
2 | mkdir ~/.mujoco
3 | cd ~/.mujoco
4 | curl -O https://mujoco.org/download/mujoco210-linux-x86_64.tar.gz
5 | tar -xf mujoco210-linux-x86_64.tar.gz
--------------------------------------------------------------------------------
/plot_utils/plot_REDQ.py:
--------------------------------------------------------------------------------
1 | """
2 | this one program will be used to basically generate all REDQ related figures for ICLR 2021.
3 | NOTE: currently only works with seaborn 0.8.1 the tsplot function is deprecated in the newer version
4 | the plotting function is originally based on the plot function in OpenAI spinningup
5 | """
6 | import os
7 | import matplotlib.pyplot as plt
8 | import numpy as np
9 | import pandas as pd
10 | import seaborn as sns
11 | from redq_plot_helper import *
12 |
13 | # the path leading to where the experiment file are located
14 | base_path = '../data/REDQ_ICLR21'
15 |
16 | # map experiment name to folder's name
17 | exp2path_main = {
18 | 'MBPO-ant': 'MBPO-Ant',
19 | 'MBPO-hopper': 'MBPO-Hopper',
20 | 'MBPO-walker2d': 'MBPO-Walker2d',
21 | 'MBPO-humanoid': 'MBPO-Humanoid',
22 | 'REDQ-n10-ant':'REDQ_embpo_qmin_piave_n10_m2_uf20_pd20_ant-v2',
23 | 'REDQ-n10-humanoid':'REDQ_embpo_qmin_piave_n10_m2_uf20_pd20_humanoid-v2',
24 | 'REDQ-n10-hopper':'REDQ_embpo_qmin_piave_n10_m2_uf20_pd20_hopper-v2',
25 | 'REDQ-n10-walker2d':'REDQ_embpo_qmin_piave_n10_m2_uf20_pd20_walker2d-v2',
26 | 'SAC-1-ant': 'REDQ_embpo_qmin_piave_n2_m2_uf1_pd20_ant-v2',
27 | 'SAC-1-humanoid': 'REDQ_embpo_qmin_piave_n2_m2_uf1_pd20_humanoid-v2',
28 | 'SAC-1-hopper': 'REDQ_embpo_qmin_piave_n2_m2_uf1_pd20_hopper-v2',
29 | 'SAC-1-walker2d': 'REDQ_embpo_qmin_piave_n2_m2_uf1_pd20_walker2d-v2',
30 | 'SAC-20-ant': 'REDQ_embpo_qmin_piave_n2_m2_uf20_pd20_ant-v2',
31 | 'SAC-20-walker2d': 'REDQ_embpo_qmin_piave_n2_m2_uf20_pd20_walker2d-v2',
32 | 'SAC-20-hopper': 'REDQ_embpo_qmin_piave_n2_m2_uf20_pd20_hopper-v2',
33 | 'SAC-20-humanoid': 'REDQ_embpo_qmin_piave_n2_m2_uf20_pd20_humanoid-v2',
34 | }
35 | exp2path_N_ablation = {
36 | 'REDQ-n15-ant':'REDQ_embpo_qmin_piave_n15_m2_uf20_pd20_ant-v2',
37 | 'REDQ-n15-humanoid':'REDQ_embpo_qmin_piave_n15_m2_uf20_pd20_humanoid-v2',
38 | 'REDQ-n15-hopper':'REDQ_embpo_qmin_piave_n15_m2_uf20_pd20_hopper-v2',
39 | 'REDQ-n15-walker2d':'REDQ_embpo_qmin_piave_n15_m2_uf20_pd20_walker2d-v2',
40 | 'REDQ-n5-ant': 'REDQ_embpo_qmin_piave_n5_m2_uf20_pd20_ant-v2',
41 | 'REDQ-n5-humanoid': 'REDQ_embpo_qmin_piave_n5_m2_uf20_pd20_humanoid-v2',
42 | 'REDQ-n5-hopper': 'REDQ_embpo_qmin_piave_n5_m2_uf20_pd20_hopper-v2',
43 | 'REDQ-n5-walker2d': 'REDQ_embpo_qmin_piave_n5_m2_uf20_pd20_walker2d-v2',
44 | 'REDQ-n3-ant': 'REDQ_embpo_qmin_piave_n3_m2_uf20_pd20_ant-v2',
45 | 'REDQ-n3-humanoid': 'REDQ_embpo_qmin_piave_n3_m2_uf20_pd20_humanoid-v2',
46 | 'REDQ-n3-hopper': 'REDQ_embpo_qmin_piave_n3_m2_uf20_pd20_hopper-v2',
47 | 'REDQ-n3-walker2d': 'REDQ_embpo_qmin_piave_n3_m2_uf20_pd20_walker2d-v2',
48 | 'REDQ-n2-ant': 'REDQ_embpo_qmin_piave_n2_m2_uf20_pd20_ant-v2',
49 | 'REDQ-n2-humanoid': 'REDQ_embpo_qmin_piave_n2_m2_uf20_pd20_humanoid-v2',
50 | 'REDQ-n2-hopper': 'REDQ_embpo_qmin_piave_n2_m2_uf20_pd20_hopper-v2',
51 | 'REDQ-n2-walker2d': 'REDQ_embpo_qmin_piave_n2_m2_uf20_pd20_walker2d-v2',
52 | }
53 | exp2path_redq_variants = {
54 | 'REDQ-weighted-hopper': 'REDQ_embpo_qweighted_piave_n10_m2_uf20_pd20_hopper-v2',
55 | 'REDQ-minpair-hopper': 'REDQ_embpo_qminpair_piave_n10_m2_uf20_pd20_hopper-v2',
56 | 'REDQ-weighted-walker2d': 'REDQ_embpo_qweighted_piave_n10_m2_uf20_pd20_walker2d-v2',
57 | 'REDQ-minpair-walker2d': 'REDQ_embpo_qminpair_piave_n10_m2_uf20_pd20_walker2d-v2',
58 | 'REDQ-weighted-ant': 'REDQ_embpo_qweighted_piave_n10_m2_uf20_pd20_ant-v2',
59 | 'REDQ-minpair-ant': 'REDQ_embpo_qminpair_piave_n10_m2_uf20_pd20_ant-v2',
60 | 'REDQ-rem-hopper': 'REDQ_embpo_qrem_piave_n10_m2_uf20_pd20_hopper-v2',
61 | 'REDQ-ave-hopper': 'REDQ_embpo_qave_piave_n10_m2_uf20_pd20_hopper-v2',
62 | 'REDQ-min-hopper': 'REDQ_embpo_qmin_piave_n3_msame_uf20_pd20_hopper-v2',
63 | 'REDQ-rem-walker2d': 'REDQ_embpo_qrem_piave_n10_m2_uf20_pd20_walker2d-v2',
64 | 'REDQ-ave-walker2d': 'REDQ_embpo_qave_piave_n10_m2_uf20_pd20_walker2d-v2',
65 | 'REDQ-min-walker2d': 'REDQ_embpo_qmin_piave_n3_msame_uf20_pd20_walker2d-v2',
66 | 'REDQ-rem-ant': 'REDQ_embpo_qrem_piave_n10_m2_uf20_pd20_ant-v2',
67 | 'REDQ-ave-ant': 'REDQ_embpo_qave_piave_n10_m2_uf20_pd20_ant-v2',
68 | 'REDQ-min-ant': 'REDQ_embpo_qmin_piave_n3_msame_uf20_pd20_ant-v2',
69 | 'REDQ-ave-humanoid': 'REDQ_embpo_qave_piave_n10_m2_uf20_pd20_humanoid-v2',
70 | 'REDQ-weighted-humanoid': 'REDQ_embpo_qweighted_piave_n10_m2_uf20_pd20_humanoid-v2',
71 | 'REDQ-m1-hopper':'REDQ_embpo_qmin_piave_n10_m1_uf20_pd20_hopper-v2',
72 | 'REDQ-m1-walker2d':'REDQ_embpo_qmin_piave_n10_m1_uf20_pd20_walker2d-v2',
73 | 'REDQ-m1-ant':'REDQ_embpo_qmin_piave_n10_m1_uf20_pd20_ant-v2',
74 | 'REDQ-m1-humanoid':'REDQ_embpo_qmin_piave_n10_m1_uf20_pd20_humanoid-v2',
75 | }
76 |
77 | exp2path_M_ablation = {
78 | 'REDQ-n10-m1-5-ant': 'REDQ_embpo_qmin_piave_n10_m1-5_uf20_pd20_ant-v2',
79 | 'REDQ-n10-m2-5-ant': 'REDQ_embpo_qmin_piave_n10_m2-5_uf20_pd20_ant-v2',
80 | 'REDQ-n10-m3-ant': 'REDQ_embpo_qmin_piave_n10_m3_uf20_pd20_ant-v2',
81 | 'REDQ-n10-m5-ant': 'REDQ_embpo_qmin_piave_n10_m5_uf20_pd20_ant-v2',
82 | 'REDQ-n10-m1-5-walker2d': 'REDQ_embpo_qmin_piave_n10_m1-5_uf20_pd20_walker2d-v2',
83 | 'REDQ-n10-m2-5-walker2d': 'REDQ_embpo_qmin_piave_n10_m2-5_uf20_pd20_walker2d-v2',
84 | 'REDQ-n10-m3-walker2d': 'REDQ_embpo_qmin_piave_n10_m3_uf20_pd20_walker2d-v2',
85 | 'REDQ-n10-m5-walker2d': 'REDQ_embpo_qmin_piave_n10_m5_uf20_pd20_walker2d-v2',
86 | 'REDQ-n10-m1-5-hopper': 'REDQ_embpo_qmin_piave_n10_m1-5_uf20_pd20_hopper-v2',
87 | 'REDQ-n10-m2-5-hopper': 'REDQ_embpo_qmin_piave_n10_m2-5_uf20_pd20_hopper-v2',
88 | 'REDQ-n10-m3-hopper': 'REDQ_embpo_qmin_piave_n10_m3_uf20_pd20_hopper-v2',
89 | 'REDQ-n10-m5-hopper': 'REDQ_embpo_qmin_piave_n10_m5_uf20_pd20_hopper-v2',
90 | }
91 |
92 | exp2path_ablations_app = {
93 | 'REDQ-ave-n15-ant': 'REDQ_embpo_qave_piave_n15_m2_uf20_pd20_ant-v2',
94 | 'REDQ-ave-n5-ant': 'REDQ_embpo_qave_piave_n5_m2_uf20_pd20_ant-v2',
95 | 'REDQ-ave-n3-ant': 'REDQ_embpo_qave_piave_n3_m2_uf20_pd20_ant-v2',
96 | 'REDQ-ave-n2-ant': 'REDQ_embpo_qave_piave_n2_m2_uf20_pd20_ant-v2',
97 | 'REDQ-ave-n15-walker2d': 'REDQ_embpo_qave_piave_n15_m2_uf20_pd20_walker2d-v2',
98 | 'REDQ-ave-n5-walker2d': 'REDQ_embpo_qave_piave_n5_m2_uf20_pd20_walker2d-v2',
99 | 'REDQ-ave-n3-walker2d': 'REDQ_embpo_qave_piave_n3_m2_uf20_pd20_walker2d-v2',
100 | 'REDQ-ave-n2-walker2d': 'REDQ_embpo_qave_piave_n2_m2_uf20_pd20_walker2d-v2',
101 | }
102 |
103 | exp2path_utd = {
104 | 'REDQ-n10-utd1-ant':'REDQ_embpo_qmin_piave_n10_m2_uf1_pd20_ant-v2',
105 | 'REDQ-n10-utd5-ant':'REDQ_embpo_qmin_piave_n10_m2_uf5_pd20_ant-v2',
106 | 'REDQ-n10-utd10-ant':'REDQ_embpo_qmin_piave_n10_m2_uf10_pd20_ant-v2',
107 | 'REDQ-n10-utd1-walker2d':'REDQ_embpo_qmin_piave_n10_m2_uf1_pd20_walker2d-v2',
108 | 'REDQ-n10-utd5-walker2d':'REDQ_embpo_qmin_piave_n10_m2_uf5_pd20_walker2d-v2',
109 | 'REDQ-n10-utd10-walker2d':'REDQ_embpo_qmin_piave_n10_m2_uf10_pd20_walker2d-v2',
110 | 'SAC-5-ant': 'REDQ_embpo_qmin_piave_n2_m2_uf5_pd20_ant-v2', # SAC with policy delay
111 | 'SAC-5-walker2d': 'REDQ_embpo_qmin_piave_n2_m2_uf5_pd20_walker2d-v2',
112 | 'SAC-10-ant': 'REDQ_embpo_qmin_piave_n2_m2_uf10_pd20_ant-v2',
113 | 'SAC-10-walker2d': 'REDQ_embpo_qmin_piave_n2_m2_uf10_pd20_walker2d-v2',
114 | 'REDQ-min-n3-utd1-ant':'REDQ_embpo_qmin_piave_n3_msame_uf1_pd20_ant-v2',
115 | 'REDQ-min-n3-utd5-ant':'REDQ_embpo_qmin_piave_n3_msame_uf5_pd20_ant-v2',
116 | 'REDQ-min-n3-utd10-ant':'REDQ_embpo_qmin_piave_n3_msame_uf10_pd20_ant-v2',
117 | 'REDQ-min-n3-utd20-ant':'REDQ_embpo_qmin_piave_n3_msame_uf20_pd20_ant-v2',
118 | 'REDQ-min-n3-utd1-walker2d':'REDQ_embpo_qmin_piave_n3_msame_uf1_pd20_walker2d-v2',
119 | 'REDQ-min-n3-utd5-walker2d':'REDQ_embpo_qmin_piave_n3_msame_uf5_pd20_walker2d-v2',
120 | 'REDQ-min-n3-utd10-walker2d':'REDQ_embpo_qmin_piave_n3_msame_uf10_pd20_walker2d-v2',
121 | 'REDQ-min-n3-utd20-walker2d':'REDQ_embpo_qmin_piave_n3_msame_uf20_pd20_walker2d-v2',
122 | 'REDQ-min-n3-utd1-hopper':'REDQ_embpo_qmin_piave_n3_msame_uf1_pd20_hopper-v2',
123 | 'REDQ-min-n3-utd5-hopper':'REDQ_embpo_qmin_piave_n3_msame_uf5_pd20_hopper-v2',
124 | 'REDQ-min-n3-utd10-hopper':'REDQ_embpo_qmin_piave_n3_msame_uf10_pd20_hopper-v2',
125 | 'REDQ-min-n3-utd20-hopper':'REDQ_embpo_qmin_piave_n3_msame_uf20_pd20_hopper-v2',
126 | 'REDQ-min-n3-utd1-humanoid':'REDQ_embpo_qmin_piave_n3_msame_uf1_pd20_humanoid-v2',
127 | 'REDQ-min-n3-utd5-humanoid':'REDQ_embpo_qmin_piave_n3_msame_uf5_pd20_humanoid-v2',
128 | 'REDQ-min-n3-utd10-humanoid':'REDQ_embpo_qmin_piave_n3_msame_uf10_pd20_humanoid-v2',
129 | 'REDQ-min-n3-utd20-humanoid':'REDQ_embpo_qmin_piave_n3_msame_uf20_pd20_humanoid-v2',
130 | 'REDQ-ave-utd1-ant':'REDQ_embpo_qave_piave_n10_m2_uf1_pd20_ant-v2',
131 | 'REDQ-ave-utd5-ant':'REDQ_embpo_qave_piave_n10_m2_uf5_pd20_ant-v2',
132 | 'REDQ-ave-utd1-walker2d':'REDQ_embpo_qave_piave_n10_m2_uf1_pd20_walker2d-v2',
133 | 'REDQ-ave-utd5-walker2d':'REDQ_embpo_qave_piave_n10_m2_uf5_pd20_walker2d-v2',
134 | }
135 |
136 | exp2path_pd = {
137 | 'SAC-pd1-5-ant': 'REDQ_embpo_qmin_piave_n2_m2_uf5_pd1_ant-v2', # no policy delay SAC
138 | 'SAC-pd1-5-walker2d': 'REDQ_embpo_qmin_piave_n2_m2_uf5_pd1_walker2d-v2',
139 | 'SAC-pd1-10-ant': 'REDQ_embpo_qmin_piave_n2_m2_uf10_pd1_ant-v2',
140 | 'SAC-pd1-10-walker2d': 'REDQ_embpo_qmin_piave_n2_m2_uf10_pd1_walker2d-v2',
141 | 'SAC-pd1-20-ant': 'REDQ_embpo_qmin_piave_n2_m2_uf20_pd1_ant-v2',
142 | 'SAC-pd1-20-walker2d': 'REDQ_embpo_qmin_piave_n2_m2_uf20_pd1_walker2d-v2',
143 | 'SAC-pd1-20-hopper': 'REDQ_embpo_qmin_piave_n2_m2_uf20_pd1_hopper-v2',
144 | 'SAC-pd1-20-humanoid': 'REDQ_embpo_qmin_piave_n2_m2_uf20_pd1_humanoid-v2',
145 | 'REDQ-n10-pd1-ant':'REDQ_embpo_qmin_piave_n10_m2_uf20_pd1_ant-v2',
146 | 'REDQ-n10-pd1-walker2d':'REDQ_embpo_qmin_piave_n10_m2_uf20_pd1_walker2d-v2',
147 | 'REDQ-n10-pd1-hopper':'REDQ_embpo_qmin_piave_n10_m2_uf20_pd1_hopper-v2',
148 | 'REDQ-n10-pd1-humanoid':'REDQ_embpo_qmin_piave_n10_m2_uf20_pd1_humanoid-v2',
149 | }
150 |
151 | exp2path_ofe = {
152 | 'REDQ-ofe-humanoid': 'REDQ_e20_uf20_pd_ofe20000_p100000_d5_lr0-0003_wd5e-05_w1_humanoid-v2',
153 | 'REDQ-ofe-ant': 'REDQ_e20_uf20_pd_ofe20000_p100000_d5_lr0-0003_wd5e-05_w1_ant-v2',
154 | }
155 |
156 | def generate_exp2path(name_prefix, list_values, list_name_suffix, pre_string, env_names):
157 | exp2path_dict = {}
158 | for e in env_names:
159 | for i in range(len(list_values)):
160 | new_name = '%s-%s-%s' % (name_prefix, list_name_suffix[i], e)
161 | new_string = pre_string % (list_values[i], e)
162 | exp2path_dict[new_name] = new_string
163 | return exp2path_dict
164 |
165 | figsize=(10, 7)
166 | exp2dataset = {}
167 | ############# finalizing plots ####################################
168 | default_smooth = 10
169 | all_4_envs = ['ant', 'humanoid', 'hopper', 'walker2d']
170 | just_2_envs = ['ant', 'walker2d']
171 | y_types_all = ['Performance', 'AverageQBias', 'StdQBias', 'AllNormalizedAverageQBias', 'AllNormalizedStdQBias',
172 | 'LossQ1', 'NormLossQ1', 'AverageQBiasSqr', 'MaxPreTanh', 'AveragePreTanh', 'AverageQ1Vals',
173 | 'MaxAlpha', 'AverageAlpha', 'AverageQPred']
174 | y_types_appendix_6 = ['AverageQBias', 'StdQBias', 'AverageQBiasSqr', 'AllNormalizedAverageQBiasSqr',
175 | 'LossQ1', 'AverageQPred']
176 | y_types_main_paper = ['Performance', 'AllNormalizedAverageQBias', 'AllNormalizedStdQBias']
177 | y_types_appendix_9 = y_types_main_paper + y_types_appendix_6
178 |
179 | # main paper
180 | plot_result_sec = True
181 | plot_analysis_sec = True
182 | plot_ablations = True
183 | plot_variants = True
184 | plot_ablations_revision = True
185 | # appendix ones
186 | plot_analysis_app = True
187 | plot_ablations_app = True
188 | plot_variants_app = True
189 | plot_utd_app = True
190 | plot_pd_app = True
191 | plot_ofe = True
192 |
193 | if plot_ablations_app:
194 | exp2dataset.update(get_exp2dataset(exp2path_ablations_app, base_path))
195 | if plot_utd_app:
196 | exp2dataset.update(get_exp2dataset(exp2path_utd, base_path))
197 | if plot_pd_app:
198 | exp2dataset.update(get_exp2dataset(exp2path_pd, base_path))
199 | if plot_ofe:
200 | exp2dataset.update(get_exp2dataset(exp2path_ofe, base_path))
201 |
202 | exp2dataset.update(get_exp2dataset(exp2path_main, base_path))
203 | exp2dataset.update(get_exp2dataset(exp2path_N_ablation, base_path))
204 | exp2dataset.update(get_exp2dataset(exp2path_M_ablation, base_path))
205 | exp2dataset.update(get_exp2dataset(exp2path_redq_variants, base_path))
206 |
207 | if plot_result_sec:
208 | # 1. redq-mbpo figure in results section, score only
209 | exp_base_to_plot = ['REDQ-n10', 'SAC-1', 'MBPO']
210 | save_path, prefix = 'results', 'redq-mbpo'
211 | label_list = ['REDQ', 'SAC', 'MBPO']
212 | plot_grid(save_path, prefix, exp2dataset, exp_base_to_plot,
213 | all_4_envs, ['Performance'], default_smooth, figsize, label_list=label_list, legend_y_types=['Performance'],
214 | legend_es=['humanoid'])
215 |
216 | if plot_analysis_sec:
217 | # 2. analysis sec, REDQ, SAC-20, ave comparison, argue naive methods don't work
218 | exp_base_to_plot = ['REDQ-n10', 'SAC-20', 'REDQ-ave']
219 | save_path, prefix = 'analysis', 'redq-sac'
220 | label_list = ['REDQ', 'SAC-20', 'AVG']
221 | plot_grid(save_path, prefix, exp2dataset, exp_base_to_plot,
222 | all_4_envs, y_types_all, default_smooth, figsize, label_list=label_list,
223 | legend_y_types=['AllNormalizedStdQBias'], legend_es='all', legend_loc='upper right')
224 |
225 | if plot_ablations:
226 | # 3. abaltion on N
227 | exp_base_to_plot = ['REDQ-n15', 'REDQ-n10', 'REDQ-n5', 'REDQ-n3', 'REDQ-n2']
228 | save_path, prefix = 'ablations', 'redq-N'
229 | label_list = ['REDQ-N15', 'REDQ-N10', 'REDQ-N5', 'REDQ-N3', 'REDQ-N2']
230 | overriding_ylimit_dict = {
231 | ('ant', 'AllNormalizedAverageQBias'): (-1, 1),
232 | ('ant', 'AllNormalizedStdQBias'): (0, 1),
233 | ('walker2d', 'AllNormalizedAverageQBias'): (-0.5, 0.75),
234 | ('walker2d', 'AllNormalizedStdQBias'):(0, 0.75),
235 | }
236 | plot_grid(save_path, prefix, exp2dataset, exp_base_to_plot,
237 | just_2_envs, y_types_main_paper, default_smooth, figsize,
238 | label_list=label_list, legend_y_types=['AllNormalizedStdQBias'], legend_es=['ant'], legend_loc='upper right',
239 | overriding_ylimit_dict=overriding_ylimit_dict)
240 |
241 | # 4. abaltion on M
242 | exp_base_to_plot = ['REDQ-n10-m1-5', 'REDQ-n10', 'REDQ-n10-m2-5', 'REDQ-n10-m3', 'REDQ-n10-m5']
243 | save_path, prefix = 'ablations', 'redq-M'
244 | label_list = ['REDQ-M1.5', 'REDQ-M2', 'REDQ-M2.5', 'REDQ-M3', 'REDQ-M5']
245 | overriding_ylimit_dict = {
246 | ('ant', 'AllNormalizedAverageQBias'): (-4, 4),
247 | ('walker2d', 'AllNormalizedAverageQBias'): (-1, 2.5),
248 | }
249 | plot_grid(save_path, prefix, exp2dataset, exp_base_to_plot,
250 | just_2_envs, y_types_main_paper, default_smooth, figsize,
251 | label_list=label_list, legend_y_types=['AllNormalizedStdQBias'], legend_es=['ant'], legend_loc='upper right',
252 | overriding_ylimit_dict=overriding_ylimit_dict)
253 |
254 | if plot_ablations_revision:
255 | # 5. different variant of Q target computation comparison (here maybe we should do... )
256 | exp_base_to_plot = ['REDQ-n10', 'REDQ-weighted']
257 | save_path, prefix = 'revision', 'redq-weighted'
258 | label_list = ['REDQ', 'Weighted']
259 | overriding_ylimit_dict = {
260 | ('ant', 'AllNormalizedAverageQBias'): (-1, 1),
261 | # ('hopper', 'AllNormalizedAverageQBias'): (-0.5, 1.5),
262 | # ('humanoid', 'AllNormalizedAverageQBias'): (-0.5, 0),
263 | # ('walker2d', 'AllNormalizedAverageQBias'): (-0.5, 0.5),
264 | }
265 | plot_grid(save_path, prefix, exp2dataset, exp_base_to_plot,
266 | all_4_envs, y_types_main_paper, default_smooth, figsize,
267 | label_list=label_list, legend_y_types=['AllNormalizedStdQBias'], legend_es='all', legend_loc='upper right',
268 | overriding_ylimit_dict=overriding_ylimit_dict)
269 |
270 | if plot_variants:
271 | # 5. different variant of Q target computation comparison (here maybe we should do... )
272 | exp_base_to_plot = ['REDQ-n10', 'REDQ-rem', 'REDQ-min', 'REDQ-weighted', 'REDQ-minpair']
273 | save_path, prefix = 'variants', 'redq-var'
274 | label_list = ['REDQ', 'REM', 'Maxmin', 'Weighted', 'MinPair']
275 | overriding_ylimit_dict = {
276 | ('ant', 'AllNormalizedAverageQBias'): (-2, 4),
277 | }
278 | plot_grid(save_path, prefix, exp2dataset, exp_base_to_plot,
279 | just_2_envs, y_types_main_paper, default_smooth, figsize,
280 | label_list=label_list, legend_y_types=['AllNormalizedStdQBias'], legend_es=['ant'], legend_loc='upper right',
281 | overriding_ylimit_dict=overriding_ylimit_dict)
282 |
283 | def plot_grid(save_path, save_name_prefix, exp2dataset, exp_base_to_plot, envs, y_types, smooth, figsize, label_list=None,
284 | legend_y_types='all', legend_es='all', overriding_ylimit_dict=None, legend_loc='best', longxaxis=False,
285 | linestyle_list=None, color_list=None, override_xlimit=None):
286 | for y_i, y_type in enumerate(y_types):
287 | y_save_name = y2savename_dict[y_type]
288 | for e in envs:
289 | no_legend = decide_no_legend(legend_y_types, legend_es, y_type, e)
290 | exp_to_plot = []
291 | for exp_base in exp_base_to_plot:
292 | exp_to_plot.append(exp_base + '-' + e)
293 | if override_xlimit is None:
294 | xlimit = 125000 if e == 'hopper' else int(3e5)
295 | else:
296 | xlimit = override_xlimit
297 | if xlimit is None and longxaxis:
298 | xlimit = int(2e6)
299 | ylimit = get_ylimit_from_env_ytype(e, y_type, overriding_ylimit_dict)
300 | assert len(label_list) == len(exp_base_to_plot)
301 | if save_path:
302 | save_name = save_name_prefix + '-%s-%s' %(e, y_save_name)
303 | else:
304 | save_name = None
305 | if color_list is None:
306 | color_list = get_default_colors(exp_base_to_plot)
307 | plot_from_data(None, exp2dataset, exp_to_plot, figsize, value_list=[y_type,],
308 | ylabel='auto', save_path=save_path, label_list=label_list,
309 | save_name=save_name, smooth=smooth, xlimit=xlimit, y_limit = ylimit,
310 | color_list=color_list, no_legend=no_legend, legend_loc=legend_loc,
311 | linestyle_list=linestyle_list)
312 |
313 | if plot_ofe:
314 | exp_base_to_plot = ['REDQ-n10', 'SAC-1', 'MBPO', 'REDQ-ofe']
315 | save_path, prefix = 'extra', 'redq-ofe'
316 | label_list = ['REDQ', 'SAC', 'MBPO', 'REDQ-OFE']
317 | plot_grid(save_path, prefix, exp2dataset, exp_base_to_plot,
318 | ['humanoid', 'ant'], ['Performance'], default_smooth, figsize,
319 | label_list=label_list, legend_y_types='all', legend_es=['humanoid'], legend_loc='lower right')
320 |
321 | """#################################### REST IS APPENDIX #################################### """
322 | if plot_analysis_app:
323 | # appendix: extra figures for 2. analysis
324 | exp_base_to_plot = ['REDQ-n10', 'SAC-20', 'REDQ-ave']
325 | save_path, prefix = 'analysis_app', 'redq-sac'
326 | label_list = ['REDQ', 'SAC-20', 'AVG']
327 | plot_grid(save_path, prefix, exp2dataset, exp_base_to_plot,
328 | just_2_envs, y_types_all, default_smooth, figsize, label_list=label_list,
329 | legend_y_types=['AverageQPred'], legend_es='all', legend_loc='upper left')
330 |
331 | if plot_ablations_app:
332 | # appendix: extra figures for N, M comparisons
333 | # app abaltion on N
334 | exp_base_to_plot = ['REDQ-n15', 'REDQ-n10', 'REDQ-n5', 'REDQ-n3', 'REDQ-n2']
335 | save_path, prefix = 'ablations_app', 'redq-N'
336 | label_list = ['REDQ-N15', 'REDQ-N10', 'REDQ-N5', 'REDQ-N3', 'REDQ-N2']
337 | plot_grid(save_path, prefix, exp2dataset, exp_base_to_plot,
338 | just_2_envs, y_types_appendix_6, default_smooth, figsize,
339 | label_list=label_list, legend_y_types=['AverageQPred'], legend_es='all', legend_loc='upper left',)
340 |
341 | # app abaltion on M
342 | exp_base_to_plot = ['REDQ-n10-m1-5', 'REDQ-n10', 'REDQ-n10-m2-5', 'REDQ-n10-m3', 'REDQ-n10-m5']
343 | save_path, prefix = 'ablations_app', 'redq-M'
344 | label_list = ['REDQ-M1.5', 'REDQ-M2', 'REDQ-M2.5', 'REDQ-M3', 'REDQ-M5']
345 | plot_grid(save_path, prefix, exp2dataset, exp_base_to_plot,
346 | just_2_envs, y_types_appendix_6, default_smooth, figsize,
347 | label_list=label_list, legend_y_types=['AverageQPred'], legend_es='all', legend_loc='upper left',)
348 |
349 | if plot_variants_app:
350 | # appendix: extra figures for variants
351 | exp_base_to_plot = ['REDQ-n10', 'REDQ-rem', 'REDQ-min', 'REDQ-weighted', 'REDQ-minpair']
352 | save_path, prefix = 'variants_app', 'redq-var'
353 | label_list = ['REDQ', 'REM', 'MIN', 'Weighted', 'MinPair']
354 | overriding_ylimit_dict = {
355 | ('ant', 'AllNormalizedAverageQBias'): (-2, 4),
356 | }
357 | plot_grid(save_path, prefix, exp2dataset, exp_base_to_plot,
358 | just_2_envs, y_types_appendix_6, default_smooth, figsize,
359 | label_list=label_list, legend_y_types=['AverageQPred'], legend_es='all', legend_loc='upper left',
360 | overriding_ylimit_dict=overriding_ylimit_dict)
361 |
362 | if plot_utd_app:
363 | # REDQ different utd
364 | exp_base_to_plot = [ 'REDQ-n10', 'REDQ-n10-utd10', 'REDQ-n10-utd5', 'REDQ-n10-utd1', ]
365 | save_path, prefix = 'utd_app', 'utd-redq'
366 | label_list = [ 'REDQ-UTD20','REDQ-UTD10','REDQ-UTD5' , 'REDQ-UTD1',]
367 | plot_grid(save_path, prefix, exp2dataset, exp_base_to_plot,
368 | just_2_envs, y_types_all, default_smooth, figsize,
369 | label_list=label_list, legend_y_types=['Performance'], legend_es='all', legend_loc='upper left',)
370 | # SAC different utd
371 | exp_base_to_plot = ['SAC-20', 'SAC-10', 'SAC-5', 'SAC-1']
372 | save_path, prefix = 'utd_app', 'utd-sac'
373 | label_list = ['SAC-UTD20', 'SAC-UTD10', 'SAC-UTD5', 'SAC-UTD1',]
374 | plot_grid(save_path, prefix, exp2dataset, exp_base_to_plot,
375 | just_2_envs, y_types_all, default_smooth, figsize,
376 | label_list=label_list, legend_y_types=['Performance'], legend_es='all', legend_loc='upper left',)
377 |
378 | if plot_pd_app:
379 | exp_base_to_plot = ['REDQ-n10', 'REDQ-n10-pd1', 'SAC-20', 'SAC-pd1-20']
380 | save_path, prefix = 'pd_app', 'pd-redq'
381 | label_list = ['REDQ','REDQ-NPD','SAC-20', 'SAC-20-NPD',]
382 | linestyle_list = ['solid', 'dashed', 'solid', 'dashed']
383 | plot_grid(save_path, prefix, exp2dataset, exp_base_to_plot,
384 | all_4_envs, y_types_all, default_smooth, figsize,
385 | label_list=label_list, legend_y_types=['Performance'], legend_es='all', legend_loc='upper left',
386 | linestyle_list=linestyle_list)
387 |
--------------------------------------------------------------------------------
/plot_utils/redq_plot_helper.py:
--------------------------------------------------------------------------------
1 | """
2 | NOTE: currently only works with seaborn 0.8.1 the tsplot function is deprecated in the newer version
3 | the plotting function is originally based on the plot function in OpenAI spinningup
4 | """
5 | import os
6 | import matplotlib.pyplot as plt
7 | import numpy as np
8 | import pandas as pd
9 | import seaborn as sns
10 | import json
11 | from packaging import version
12 |
13 | DIV_LINE_WIDTH = 50
14 |
15 | # Global vars for tracking and labeling data at load time.
16 | exp_idx = 0
17 | units = dict()
18 |
19 | def get_datasets(logdir, condition=None):
20 | """
21 | Recursively look through logdir for output files produced by
22 | spinup.logx.Logger.
23 |
24 | Assumes that any file "progress.txt" is a valid hit.
25 | the "condition" here can be a string, when plotting, can be used as label on the legend
26 | """
27 | global exp_idx
28 | global units
29 | datasets = []
30 | for root, _, files in os.walk(logdir):
31 | if 'progress.txt' in files:
32 | try:
33 | exp_data = pd.read_table(os.path.join(root, 'progress.txt'))
34 | exp_name = None
35 | try:
36 | config_path = open(os.path.join(root, 'config.json'))
37 | config = json.load(config_path)
38 | if 'exp_name' in config:
39 | exp_name = config['exp_name']
40 | except:
41 | print('No file named config.json')
42 | condition1 = condition or exp_name or 'exp'
43 | condition2 = condition1 + '-' + str(exp_idx)
44 | exp_idx += 1
45 | if condition1 not in units:
46 | units[condition1] = 0
47 | unit = units[condition1]
48 | units[condition1] += 1
49 |
50 | print(os.path.join(root, 'progress.txt'))
51 |
52 | # exp_data = pd.read_table(os.path.join(root, 'progress.txt'))
53 | performance = 'AverageTestEpRet' if 'AverageTestEpRet' in exp_data else 'AverageEpRet'
54 | exp_data.insert(len(exp_data.columns), 'Unit', unit)
55 | exp_data.insert(len(exp_data.columns), 'Condition1', condition1)
56 | exp_data.insert(len(exp_data.columns), 'Condition2', condition2)
57 | if performance in exp_data:
58 | exp_data.insert(len(exp_data.columns), 'Performance', exp_data[performance])
59 | datasets.append(exp_data)
60 | except Exception as e:
61 | print(e)
62 |
63 | return datasets
64 |
65 |
66 | y2savename_dict = {
67 | 'Performance':'score',
68 | 'AverageNormQBias':'bias-ave-n',
69 | 'StdNormQBias':'bias-std-n',
70 | 'LossQ1':'qloss',
71 | 'NormLossQ1':'qloss-n',
72 | 'MaxPreTanh':'pretanh-max',
73 | 'AveragePreTanh':'pretanh-ave',
74 | 'AverageQBias':'bias-ave',
75 | 'StdQBias':'bias-std',
76 | 'AverageQPred':'qpred',
77 | 'AverageQBiasSqr':'biassqr-ave',
78 | 'AverageNormQBiasSqr':'biassqr-ave-n',
79 | 'AverageAlpha':'alpha-ave',
80 | 'MaxAlpha':'alpha',
81 | 'AverageQ1Vals':'q1',
82 | 'AverageLogPi':'average-logpi',
83 | 'MaxLogPi': 'max-logpi',
84 | 'AllNormalizedAverageQBias':'bias-ave-alln',
85 | 'AllNormalizedStdQBias':'bias-std-alln',
86 | 'AllNormalizedAverageQBiasSqr':'biassqr-ave-alln',
87 | 'Time':'time'
88 | }
89 |
90 | y2ylabel_dict = {
91 | 'Performance':'Average return',
92 | 'AverageNormQBias':'Average normalized bias',
93 | 'StdNormQBias':'Std of normalized bias',
94 | 'LossQ1':'Q loss',
95 | 'NormLossQ1':'Normalized Q loss',
96 | 'MaxPreTanh':'Max pretanh',
97 | 'AveragePreTanh':'Average pretanh',
98 | 'AverageQBias':'Average bias',
99 | 'StdQBias':'Std of bias',
100 | 'AverageQPred':'Average Q value',
101 | 'AverageQBiasSqr':'Average MSE',
102 | 'AverageNormQBiasSqr':'Average normalized MSE',
103 | 'AverageAlpha':'Average alpha',
104 | 'MaxAlpha':'Max alpha',
105 | 'AverageQ1Vals':'Q value',
106 | 'AverageLogPi': 'Average logPi',
107 | 'MaxLogPi': 'Max logPi',
108 | 'AllNormalizedAverageQBias': 'Average normalized bias',
109 | 'AllNormalizedStdQBias': 'Std of normalized bias',
110 | 'AllNormalizedAverageQBiasSqr': 'Average normalized MSE',
111 | 'Time':'Time'
112 | }
113 |
114 | # we can use strict mapping from exp base name to color
115 | expbase2color = {
116 | 'SAC-20': 'grey',
117 | 'SAC-10': 'slateblue',
118 | 'SAC-5': 'blue',
119 | 'SAC-1': 'skyblue', # blue-black for SAC, MBPO
120 | 'SAC-hs1':'black',
121 | 'SAC-hs2':'brown',
122 | 'SAC-hs3':'purple',
123 | 'SAC-pd1-5': 'black',
124 | 'SAC-pd1-10': 'brown',
125 | 'SAC-pd1-20': 'grey',
126 | 'MBPO':'tab:blue',
127 | 'REDQ-n15':'tab:orange', # dark red to light purple
128 | 'REDQ-n10':'tab:red',
129 | 'REDQ-n5': 'tab:cyan',
130 | 'REDQ-n3': 'tab:grey',
131 | 'REDQ-n2': 'black',
132 | 'REDQ-n10-m1-5':'tab:orange', # red-brown for M variants?
133 | 'REDQ-n10-m2-5':'tab:cyan',
134 | 'REDQ-n10-m3':'tab:grey',
135 | 'REDQ-n10-m5': 'black',
136 | 'REDQ-n10-hs1': 'tab:orange',
137 | 'REDQ-n10-hs2': 'violet',
138 | 'REDQ-n10-hs3': 'lightblue',
139 | 'REDQ-weighted': 'indigo', # then for redq q target variants, some random stuff
140 | 'REDQ-minpair': 'royalblue',
141 | 'REDQ-ave': 'peru',
142 | 'REDQ-rem': 'yellow',
143 | 'REDQ-min': 'slategrey',
144 | 'REDQ-n10-utd10': 'tab:cyan',
145 | 'REDQ-n10-utd5': 'tab:grey',
146 | 'REDQ-n10-utd1': 'black',
147 | 'REDQ-min-n3-utd1':'blue',
148 | 'REDQ-min-n3-utd5':'lightblue',
149 | 'REDQ-min-n3-utd10':'grey',
150 | 'REDQ-min-n3-utd20':'black',
151 | 'REDQ-ave-utd1':'blue',
152 | 'REDQ-ave-utd5':'grey',
153 | 'REDQ-ofe':'purple',
154 | 'REDQ-ofe-long':'purple',
155 | 'REDQ-dense': 'deeppink',
156 | 'SAC-dense': 'sandybrown',
157 | 'REDQ-ave-n15':'black',
158 | 'REDQ-ave-n5':'tab:orange',
159 | 'REDQ-ave-n3':'violet',
160 | 'REDQ-ave-n2':'lightblue',
161 | 'SAC-long':'skyblue',
162 | 'REDQ-n10-pd1':'tab:red',
163 | 'REDQ-fine':'tab:pink',
164 | 'REDQ-ss10k':'tab:pink',
165 | 'REDQ-ss15k': 'tab:orange',
166 | 'REDQ-ss20k': 'yellow',
167 | 'REDQ-weighted-ss10k':'lightblue',
168 | 'REDQ-weighted-ss15k': 'royalblue',
169 | 'REDQ-weighted-ss20k': 'black',
170 | 'REDQ-m1':'grey',
171 | 'REDQ-more':'tab:red',
172 | 'REDQ-weighted-more':'indigo',
173 | 'REDQ-anneal25k': 'tab:pink',
174 | 'REDQ-anneal50k': 'tab:orange',
175 | 'REDQ-anneal100k': 'yellow',
176 | 'REDQ-weighted-anneal25k': 'lightblue',
177 | 'REDQ-weighted-anneal50k': 'royalblue',
178 | 'REDQ-weighted-anneal100k': 'black',
179 | 'REDQ-init0-4': 'tab:pink',
180 | 'REDQ-init0-3': 'tab:orange',
181 | 'REDQ-init0-2': 'yellow',
182 | 'REDQ-weighted-init0-4': 'lightblue',
183 | 'REDQ-weighted-init0-3': 'royalblue',
184 | 'REDQ-weighted-init0-2': 'black',
185 | 'REDQ-weighted-a0-1':'lightblue',
186 | 'REDQ-weighted-a0-15': 'royalblue',
187 | 'REDQ-weighted-a0-2':'black',
188 | 'REDQ-weighted-a0-25':'pink',
189 | 'REDQ-weighted-a0-3':'yellow',
190 | 'REDQ-a0-1': 'lightblue',
191 | 'REDQ-a0-15': 'royalblue',
192 | 'REDQ-a0-2': 'black',
193 | 'REDQ-a0-25': 'pink',
194 | 'REDQ-a0-3': 'yellow',
195 | 'REDQ-weighted-wns0-1': 'lightblue',
196 | 'REDQ-weighted-wns0-3': 'royalblue',
197 | 'REDQ-weighted-wns0-5': 'black',
198 | 'REDQ-weighted-wns0-8': 'pink',
199 | 'REDQ-weighted-wns1-2': 'yellow',
200 | }
201 |
202 | def get_ylimit_from_env_ytype(e, ytype, overriding_dict=None):
203 | # will map env and ytype to ylimit
204 | # so that the plot looks better
205 | # the plot should give a reasonable y range so that the outlier won't dominate
206 | # can use a overriding dictionary to override the default values here
207 | default_dict = {
208 | ('ant', 'AverageQBias'): (-200, 300),
209 | ('ant', 'AverageNormQBias'): (-1, 4),
210 | ('ant', 'AllNormalizedAverageQBias'): (-1, 4),
211 | ('ant', 'StdQBias'): (0, 150),
212 | ('ant', 'StdNormQBias'): (0, 5),
213 | ('ant', 'AllNormalizedStdQBias'): (0, 2),
214 | ('ant', 'LossQ1'): (0, 60),
215 | ('ant', 'NormLossQ1'): (0, 0.4),
216 | ('ant', 'AverageQBiasSqr'): (0, 50000),
217 | ('ant', 'AllNormalizedAverageQBiasSqr'): (0, 400),
218 | ('ant', 'AverageNormQBiasSqr'): (0, 1200),
219 | ('ant', 'AverageQPred'): (-200, 1000),
220 | ('walker2d', 'AverageQBias'): (-150, 250),
221 | ('walker2d', 'AverageNormQBias'): (-0.5, 4),
222 | ('walker2d', 'AllNormalizedAverageQBias'): (-0.5, 2.5),
223 | ('walker2d', 'StdQBias'): (0, 150),
224 | ('walker2d', 'StdNormQBias'): (0, 4),
225 | ('walker2d', 'AllNormalizedStdQBias'):(0, 1.5),
226 | ('walker2d', 'LossQ1'): (0, 60),
227 | ('walker2d', 'NormLossQ1'): (0, 0.5),
228 | ('walker2d', 'AverageQBiasSqr'): (0, 50000),
229 | ('walker2d', 'AverageNormQBiasSqr'): (0, 1200),
230 | ('walker2d', 'AllNormalizedAverageQBiasSqr'): (0, 1200),
231 | ('walker2d', 'AverageQPred'): (-100, 600),
232 | ('humanoid', 'AllNormalizedAverageQBias'): (-1, 2.5),
233 | ('humanoid', 'AllNormalizedStdQBias'): (0, 1.5),
234 | }
235 | key = (e, ytype)
236 | if overriding_dict and key in overriding_dict:
237 | return overriding_dict[key]
238 | else:
239 | if key in default_dict:
240 | return default_dict[key]
241 | else:
242 | return None
243 |
244 | def select_data_list_for_plot(exp2dataset, exp_to_plot):
245 | data_list = []
246 | for exp in exp_to_plot:
247 | # use append here, since we expect exp2dataset[exp] to be a single pandas df
248 | data_list.append(exp2dataset[exp])
249 | return data_list # now a list of list
250 |
251 | def plot_from_data(data_list, exp2dataset=None, exp_to_plot=None, figsize=None, xaxis='TotalEnvInteracts',
252 | value_list=['Performance',], color_list=None, linestyle_list=None, label_list=None,
253 | count=False,
254 | font_scale=1.5, smooth=1, estimator='mean', no_legend=False,
255 | legend_loc='best', title=None, save_name=None, save_path = None,
256 | xlimit=-1, y_limit=None, label_font_size=24, xlabel=None, ylabel=None,
257 | y_log_scale=False):
258 | """
259 | either give a data list
260 | or give a exp2dataset dictionary, and then specify the exp_to_plot
261 | will plot each experiment setting in data_list to the same figure.
262 | the label will be basically 'Condition1' column by default
263 | causal_plot will not care about order and will be messy and each time the order can be different.
264 | for value list if it contains only one value then we will use that for all experiments
265 | """
266 | if data_list is not None:
267 | data_list = data_list
268 | else:
269 | data_list = select_data_list_for_plot(exp2dataset, exp_to_plot)
270 |
271 | n_curves = len(data_list)
272 | if len(value_list) == 1:
273 | value_list = [value_list[0] for _ in range(n_curves)]
274 | default_colors = ['tab:blue','tab:orange','tab:green','tab:red',
275 | 'tab:purple','tab:brown','tab:pink','tab:grey','tab:olive','tab:cyan',]
276 | color_list = default_colors if color_list is None else color_list
277 | condition = 'Condition2' if count else 'Condition1'
278 | estimator = getattr(np, estimator) # choose what to show on main curve: mean? max? min?
279 | plt.figure(figsize=figsize) if figsize else plt.figure()
280 | ##########################
281 | value_list_smooth_temp = []
282 | """
283 | smooth data with moving window average.
284 | that is,
285 | smoothed_y[t] = average(y[t-k], y[t-k+1], ..., y[t+k-1], y[t+k])
286 | where the "smooth" param is width of that window (2k+1)
287 | IF WE MODIFY DATA DIRECTLY, THEN CAN LEAD TO PLOTTING BUG WHERE
288 | IT'S MODIFIED MULTIPLE TIMES
289 | """
290 | y = np.ones(smooth)
291 | for i, data_seeds in enumerate(data_list):
292 | temp_value_name = value_list[i] + '__smooth_temp'
293 | value_list_smooth_temp.append(temp_value_name)
294 | for data in data_seeds:
295 | x = np.asarray(data[value_list[i]].copy())
296 | z = np.ones(len(x))
297 | smoothed_x = np.convolve(x, y, 'same') / np.convolve(z, y, 'same')
298 | # data[value_list[i]] = smoothed_x # this can be problematic
299 | if temp_value_name not in data:
300 | data.insert(len(data.columns), temp_value_name, smoothed_x)
301 | else:
302 | data[temp_value_name] = smoothed_x
303 |
304 | sns.set(style="darkgrid", font_scale=font_scale)
305 | # sns.set_palette('bright')
306 | # have the same axis (figure), plot one by one onto it
307 | ax = None
308 | if version.parse(sns.__version__) <= version.parse('0.8.1'):
309 | for i, data_seeds in enumerate(data_list):
310 | data_combined = pd.concat(data_seeds, ignore_index=True)
311 | ax = sns.tsplot(data=data_combined, time=xaxis, value=value_list_smooth_temp[i], unit="Unit",
312 | condition=condition,
313 | legend=(not no_legend), ci='sd',
314 | n_boot=0, color=color_list[i], ax=ax)
315 | else:
316 | print("Error: Seaborn version > 0.8.1 is currently not supported.")
317 | quit()
318 |
319 | if linestyle_list is not None:
320 | for i in range(len(linestyle_list)):
321 | ax.lines[i].set_linestyle(linestyle_list[i])
322 | if label_list is not None:
323 | for i in range(len(label_list)):
324 | ax.lines[i].set_label(label_list[i])
325 | xlabel = 'environment interactions' if xlabel is None else xlabel
326 |
327 | if ylabel is None:
328 | ylabel = 'average test return'
329 | elif ylabel == 'auto':
330 | if value_list[0] in y2ylabel_dict:
331 | ylabel = y2ylabel_dict[value_list[0]]
332 | else:
333 | ylabel = value_list[0]
334 | else:
335 | ylabel = ylabel
336 | plt.xlabel(xlabel, fontsize=label_font_size)
337 | plt.ylabel(ylabel, fontsize=label_font_size)
338 | if not no_legend:
339 | plt.legend(loc=legend_loc, fontsize=label_font_size)
340 |
341 | """
342 | For the version of the legend used in the Spinning Up benchmarking page,
343 | plt.legend(loc='upper center', ncol=6, handlelength=1,
344 | mode="expand", borderaxespad=0., prop={'size': 13})
345 | """
346 | xscale = np.max(np.asarray(data[xaxis])) > 5e3
347 | if xscale:
348 | # Just some formatting niceness: x-axis scale in scientific notation if max x is large
349 | plt.ticklabel_format(style='sci', axis='x', scilimits=(0, 0))
350 | if y_log_scale:
351 | plt.yscale('log')
352 | if xlimit > 0:
353 | plt.xlim(0, xlimit)
354 | if y_limit:
355 | plt.ylim(y_limit[0], y_limit[1])
356 | if title:
357 | plt.title(title)
358 | plt.tight_layout()
359 | if save_name is not None:
360 | fig = plt.gcf()
361 | if not os.path.isdir(save_path):
362 | os.mkdir(save_path)
363 | fig.savefig(os.path.join(save_path, save_name))
364 | plt.close(fig)
365 | else:
366 | plt.show()
367 |
368 |
369 | def decide_no_legend(legend_y_types, legend_es, y_type, e):
370 | # return no_legend as a bool (=True means no legend will be plotted)
371 | if (legend_y_types == 'all' or y_type in legend_y_types) and (legend_es == 'all' or e in legend_es):
372 | return False
373 | else:
374 | return True
375 |
376 | def plot_grid(save_path, save_name_prefix, exp2dataset, exp_base_to_plot, envs, y_types, smooth, figsize, label_list=None,
377 | legend_y_types='all', legend_es='all', overriding_ylimit_dict=None, legend_loc='best', longxaxis=False,
378 | linestyle_list=None, color_list=None):
379 | """
380 | this function will help us do generic redq plots very easily and fast.
381 | Args:
382 | save_path: where to save the figures (i.e. 'figures')
383 | envs: specify the envs that will be plotted (i.e. 'ant')
384 | exp_base_to_plot: base name of the exp settings (i.e. 'REDQ-n10')
385 | y_types: y type to be plotting (for example, 'Performance')
386 | smooth:
387 | """
388 | for y_i, y_type in enumerate(y_types):
389 | y_save_name = y2savename_dict[y_type]
390 | for e in envs:
391 | no_legend = decide_no_legend(legend_y_types, legend_es, y_type, e)
392 | exp_to_plot = []
393 | for exp_base in exp_base_to_plot:
394 | exp_to_plot.append(exp_base + '-' + e)
395 | xlimit = 125000 if e == 'hopper' else int(3e5)
396 | if longxaxis:
397 | xlimit = int(2e6)
398 | ylimit = get_ylimit_from_env_ytype(e, y_type, overriding_ylimit_dict)
399 | assert len(label_list) == len(exp_base_to_plot)
400 | if save_path:
401 | save_name = save_name_prefix + '-%s-%s' %(e, y_save_name)
402 | else:
403 | save_name = None
404 | if color_list is None:
405 | color_list = get_default_colors(exp_base_to_plot)
406 | plot_from_data(None, exp2dataset, exp_to_plot, figsize, value_list=[y_type,],
407 | ylabel='auto', save_path=save_path, label_list=label_list,
408 | save_name=save_name, smooth=smooth, xlimit=xlimit, y_limit = ylimit,
409 | color_list=color_list, no_legend=no_legend, legend_loc=legend_loc,
410 | linestyle_list=linestyle_list)
411 |
412 | def plot_grid_solid_dashed(save_path, save_name_prefix, exp2dataset, exp_base_to_plot, envs, y_1, y_2, smooth, figsize, label_list=None):
413 | """
414 | this function will help us do generic redq plots very easily and fast.
415 | Args:
416 | save_path: where to save the figures (i.e. 'figures')
417 | envs: specify the envs that will be plotted (i.e. 'ant')
418 | exp_base_to_plot: base name of the exp settings (i.e. 'REDQ-n10')
419 | y_types: y type to be plotting (for example, 'Performance')
420 | smooth:
421 | """
422 | y_save_name = y2savename_dict[y_1] # typically this is Q value, then y2 is bias
423 | for e in envs:
424 | exp_to_plot = []
425 | for exp_base in exp_base_to_plot:
426 | exp_to_plot.append(exp_base + '-' + e)
427 | value_list = []
428 | linestyle_list = []
429 | for i in range(len(exp_to_plot)):
430 | value_list.append(y_1)
431 | linestyle_list.append('solid')
432 | for i in range(len(exp_to_plot)):
433 | value_list.append(y_2)
434 | linestyle_list.append('dashed')
435 | exp_to_plot = exp_to_plot + exp_to_plot
436 | exp_base_to_plot_for_color = exp_base_to_plot + exp_base_to_plot
437 | xlimit = 125000 if e == 'hopper' else int(3e5)
438 | ylimit = get_ylimit_from_env_ytype(e, y_1)
439 | plot_from_data(None, exp2dataset, exp_to_plot, figsize, value_list=value_list, linestyle_list=linestyle_list,
440 | ylabel='auto', save_path=save_path, label_list=label_list,
441 | save_name=save_name_prefix + '-%s-%s' %(e, y_save_name), smooth=smooth, xlimit=xlimit, y_limit = ylimit,
442 | color_list=get_default_colors(exp_base_to_plot_for_color))
443 |
444 | def get_default_colors(exp_base_names):
445 | # will map experiment base name (the good names) to color
446 | # later ones
447 | colors = []
448 | for i in range(len(exp_base_names)):
449 | exp_base_name = exp_base_names[i]
450 | colors.append(expbase2color[exp_base_name])
451 | return colors
452 |
453 | def get_exp2dataset(exp2path, base_path):
454 | """
455 | Args:
456 | exp2path: a dictionary containing experiment name (you can decide what this is, can make it easy to read,
457 | will be set as "condition") to the actual experiment folder name
458 | base_path: the base path lead to where the experiment folders are
459 | Returns:
460 | a dictionary with keys being the experiment name, value being a pandas dataframe containing
461 | the experiment progress (multiple seeds in one dataframe)
462 | """
463 | exp2dataset = {}
464 | for key in exp2path.keys():
465 | complete_path = os.path.join(base_path, exp2path[key])
466 | data_list = get_datasets(complete_path, key) # a list of different seeds
467 | # combined_df_across_seeds = pd.concat(data_list, ignore_index=True)
468 | # exp2dataset[key] = combined_df_across_seeds
469 | for d in data_list:
470 | try:
471 | add_normalized_values(d)
472 | except:
473 | print("can't add normalized loss in data:", key)
474 | exp2dataset[key] = data_list
475 | return exp2dataset
476 |
477 | def add_normalized_values(d):
478 | normalize_base = d['AverageMCDisRetEnt'].copy().abs()
479 | normalize_base[normalize_base < 10] = 10
480 | # use this to normalize the Q loss, Q bias
481 |
482 | normalized_qloss = d['LossQ1'] / normalize_base
483 | d.insert(len(d.columns), 'NormLossQ1', normalized_qloss)
484 | normalized_q_bias = d['AverageQBias'] / normalize_base
485 | d.insert(len(d.columns), 'AllNormalizedAverageQBias', normalized_q_bias)
486 | normalized_std_q_bias = d['StdQBias'] / normalize_base
487 | d.insert(len(d.columns), 'AllNormalizedStdQBias', normalized_std_q_bias)
488 | normalized_bias_mse = d['AverageQBiasSqr'] / normalize_base
489 | d.insert(len(d.columns), 'AllNormalizedAverageQBiasSqr', normalized_bias_mse)
490 |
--------------------------------------------------------------------------------
/redq/__init__.py:
--------------------------------------------------------------------------------
1 | __version__ = '0.0.1'
--------------------------------------------------------------------------------
/redq/algos/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/watchernyu/REDQ/7b5d1bff39291a57325a2836bd397a55728960bb/redq/algos/__init__.py
--------------------------------------------------------------------------------
/redq/algos/core.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn as nn
4 | from torch.nn import functional as F
5 | from torch.distributions import Distribution, Normal
6 | # following SAC authors' and OpenAI implementation
7 | LOG_SIG_MAX = 2
8 | LOG_SIG_MIN = -20
9 | ACTION_BOUND_EPSILON = 1E-6
10 | # these numbers are from the MBPO paper
11 | mbpo_target_entropy_dict = {'Hopper-v2':-1, 'HalfCheetah-v2':-3, 'Walker2d-v2':-3, 'Ant-v2':-4, 'Humanoid-v2':-2,
12 | 'Hopper-v3':-1, 'HalfCheetah-v3':-3, 'Walker2d-v3':-3, 'Ant-v3':-4, 'Humanoid-v3':-2,
13 | 'Hopper-v4':-1, 'HalfCheetah-v4':-3, 'Walker2d-v4':-3, 'Ant-v4':-4, 'Humanoid-v4':-2}
14 | mbpo_epoches = {'Hopper-v2':125, 'Walker2d-v2':300, 'Ant-v2':300, 'HalfCheetah-v2':400, 'Humanoid-v2':300,
15 | 'Hopper-v3':125, 'Walker2d-v3':300, 'Ant-v3':300, 'HalfCheetah-v3':400, 'Humanoid-v3':300,
16 | 'Hopper-v4':125, 'Walker2d-v4':300, 'Ant-v4':300, 'HalfCheetah-v4':400, 'Humanoid-v4':300}
17 |
18 | def weights_init_(m):
19 | # weight init helper function
20 | if isinstance(m, nn.Linear):
21 | torch.nn.init.xavier_uniform_(m.weight, gain=1)
22 | torch.nn.init.constant_(m.bias, 0)
23 |
24 | class ReplayBuffer:
25 | """
26 | A simple FIFO experience replay buffer
27 | """
28 | def __init__(self, obs_dim, act_dim, size):
29 | """
30 | :param obs_dim: size of observation
31 | :param act_dim: size of the action
32 | :param size: size of the buffer
33 | """
34 | ## init buffers as numpy arrays
35 | self.obs1_buf = np.zeros([size, obs_dim], dtype=np.float32)
36 | self.obs2_buf = np.zeros([size, obs_dim], dtype=np.float32)
37 | self.acts_buf = np.zeros([size, act_dim], dtype=np.float32)
38 | self.rews_buf = np.zeros(size, dtype=np.float32)
39 | self.done_buf = np.zeros(size, dtype=np.float32)
40 | self.ptr, self.size, self.max_size = 0, 0, size
41 |
42 | def store(self, obs, act, rew, next_obs, done):
43 | """
44 | data will get stored in the pointer's location
45 | """
46 | self.obs1_buf[self.ptr] = obs
47 | self.obs2_buf[self.ptr] = next_obs
48 | self.acts_buf[self.ptr] = act
49 | self.rews_buf[self.ptr] = rew
50 | self.done_buf[self.ptr] = done
51 | ## move the pointer to store in next location in buffer
52 | self.ptr = (self.ptr+1) % self.max_size
53 | ## keep track of the current buffer size
54 | self.size = min(self.size+1, self.max_size)
55 |
56 | def sample_batch(self, batch_size=32, idxs=None):
57 | """
58 | :param batch_size: size of minibatch
59 | :param idxs: specify indexes if you want specific data points
60 | :return: mini-batch data as a dictionary
61 | """
62 | if idxs is None:
63 | idxs = np.random.randint(0, self.size, size=batch_size)
64 | return dict(obs1=self.obs1_buf[idxs],
65 | obs2=self.obs2_buf[idxs],
66 | acts=self.acts_buf[idxs],
67 | rews=self.rews_buf[idxs],
68 | done=self.done_buf[idxs],
69 | idxs=idxs)
70 |
71 |
72 | class Mlp(nn.Module):
73 | def __init__(
74 | self,
75 | input_size,
76 | output_size,
77 | hidden_sizes,
78 | hidden_activation=F.relu
79 | ):
80 | super().__init__()
81 |
82 | self.input_size = input_size
83 | self.output_size = output_size
84 | self.hidden_activation = hidden_activation
85 | ## here we use ModuleList so that the layers in it can be
86 | ## detected by .parameters() call
87 | self.hidden_layers = nn.ModuleList()
88 | in_size = input_size
89 |
90 | ## initialize each hidden layer
91 | for i, next_size in enumerate(hidden_sizes):
92 | fc_layer = nn.Linear(in_size, next_size)
93 | in_size = next_size
94 | self.hidden_layers.append(fc_layer)
95 |
96 | ## init last fully connected layer with small weight and bias
97 | self.last_fc_layer = nn.Linear(in_size, output_size)
98 | self.apply(weights_init_)
99 |
100 | def forward(self, input):
101 | h = input
102 | for i, fc_layer in enumerate(self.hidden_layers):
103 | h = fc_layer(h)
104 | h = self.hidden_activation(h)
105 | output = self.last_fc_layer(h)
106 | return output
107 |
108 | class TanhNormal(Distribution):
109 | """
110 | Represent distribution of X where
111 | X ~ tanh(Z)
112 | Z ~ N(mean, std)
113 | Note: this is not very numerically stable.
114 | """
115 | def __init__(self, normal_mean, normal_std, epsilon=1e-6):
116 | """
117 | :param normal_mean: Mean of the normal distribution
118 | :param normal_std: Std of the normal distribution
119 | :param epsilon: Numerical stability epsilon when computing log-prob.
120 | """
121 | self.normal_mean = normal_mean
122 | self.normal_std = normal_std
123 | self.normal = Normal(normal_mean, normal_std)
124 | self.epsilon = epsilon
125 |
126 | def log_prob(self, value, pre_tanh_value=None):
127 | """
128 | return the log probability of a value
129 | :param value: some value, x
130 | :param pre_tanh_value: arctanh(x)
131 | :return:
132 | """
133 | # use arctanh formula to compute arctanh(value)
134 | if pre_tanh_value is None:
135 | pre_tanh_value = torch.log(
136 | (1+value) / (1-value)
137 | ) / 2
138 | return self.normal.log_prob(pre_tanh_value) - \
139 | torch.log(1 - value * value + self.epsilon)
140 |
141 | def sample(self, return_pretanh_value=False):
142 | """
143 | Gradients will and should *not* pass through this operation.
144 | See https://github.com/pytorch/pytorch/issues/4620 for discussion.
145 | """
146 | z = self.normal.sample().detach()
147 |
148 | if return_pretanh_value:
149 | return torch.tanh(z), z
150 | else:
151 | return torch.tanh(z)
152 |
153 | def rsample(self, return_pretanh_value=False):
154 | """
155 | Sampling in the reparameterization case.
156 | Implement: tanh(mu + sigma * eksee)
157 | with eksee~N(0,1)
158 | z here is mu+sigma+eksee
159 | """
160 | z = (
161 | self.normal_mean +
162 | self.normal_std *
163 | Normal( ## this part is eksee~N(0,1)
164 | torch.zeros(self.normal_mean.size()),
165 | torch.ones(self.normal_std.size())
166 | ).sample()
167 | )
168 | if return_pretanh_value:
169 | return torch.tanh(z), z
170 | else:
171 | return torch.tanh(z)
172 |
173 | class TanhGaussianPolicy(Mlp):
174 | """
175 | A Gaussian policy network with Tanh to enforce action limits
176 | """
177 | def __init__(
178 | self,
179 | obs_dim,
180 | action_dim,
181 | hidden_sizes,
182 | hidden_activation=F.relu,
183 | action_limit=1.0
184 | ):
185 | super().__init__(
186 | input_size=obs_dim,
187 | output_size=action_dim,
188 | hidden_sizes=hidden_sizes,
189 | hidden_activation=hidden_activation,
190 | )
191 | last_hidden_size = obs_dim
192 | if len(hidden_sizes) > 0:
193 | last_hidden_size = hidden_sizes[-1]
194 | ## this is the layer that gives log_std, init this layer with small weight and bias
195 | self.last_fc_log_std = nn.Linear(last_hidden_size, action_dim)
196 | ## action limit: for example, humanoid has an action limit of -0.4 to 0.4
197 | self.action_limit = action_limit
198 | self.apply(weights_init_)
199 |
200 | def forward(
201 | self,
202 | obs,
203 | deterministic=False,
204 | return_log_prob=True,
205 | ):
206 | """
207 | :param obs: Observation
208 | :param reparameterize: if True, use the reparameterization trick
209 | :param deterministic: If True, take determinisitc (test) action
210 | :param return_log_prob: If True, return a sample and its log probability
211 | """
212 | h = obs
213 | for fc_layer in self.hidden_layers:
214 | h = self.hidden_activation(fc_layer(h))
215 | mean = self.last_fc_layer(h)
216 |
217 | log_std = self.last_fc_log_std(h)
218 | log_std = torch.clamp(log_std, LOG_SIG_MIN, LOG_SIG_MAX)
219 | std = torch.exp(log_std)
220 |
221 | normal = Normal(mean, std)
222 |
223 | if deterministic:
224 | pre_tanh_value = mean
225 | action = torch.tanh(mean)
226 | else:
227 | pre_tanh_value = normal.rsample()
228 | action = torch.tanh(pre_tanh_value)
229 |
230 | if return_log_prob:
231 | log_prob = normal.log_prob(pre_tanh_value)
232 | log_prob -= torch.log(1 - action.pow(2) + ACTION_BOUND_EPSILON)
233 | log_prob = log_prob.sum(1, keepdim=True)
234 | else:
235 | log_prob = None
236 |
237 | return (
238 | action * self.action_limit, mean, log_std, log_prob, std, pre_tanh_value,
239 | )
240 |
241 | def soft_update_model1_with_model2(model1, model2, rou):
242 | """
243 | used to polyak update a target network
244 | :param model1: a pytorch model
245 | :param model2: a pytorch model of the same class
246 | :param rou: the update is model1 <- rou*model1 + (1-rou)model2
247 | """
248 | for model1_param, model2_param in zip(model1.parameters(), model2.parameters()):
249 | model1_param.data.copy_(rou*model1_param.data + (1-rou)*model2_param.data)
250 |
251 | def test_agent(agent, test_env, max_ep_len, logger, n_eval=1):
252 | """
253 | This will test the agent's performance by running episodes
254 | During the runs, the agent should only take deterministic action
255 | This function assumes the agent has a function
256 | :param agent: agent instance
257 | :param test_env: the environment used for testing
258 | :param max_ep_len: max length of an episode
259 | :param logger: logger to store info in
260 | :param n_eval: number of episodes to run the agent
261 | :return: test return for each episode as a numpy array
262 | """
263 | ep_return_list = np.zeros(n_eval)
264 | for j in range(n_eval):
265 | o, r, d, ep_ret, ep_len = test_env.reset(), 0, False, 0, 0
266 | while not (d or (ep_len == max_ep_len)):
267 | # Take deterministic actions at test time
268 | a = agent.get_test_action(o)
269 | o, r, d, _ = test_env.step(a)
270 | ep_ret += r
271 | ep_len += 1
272 | ep_return_list[j] = ep_ret
273 | if logger is not None:
274 | logger.store(TestEpRet=ep_ret, TestEpLen=ep_len)
275 | return ep_return_list
276 |
--------------------------------------------------------------------------------
/redq/algos/redq_sac.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | from torch import Tensor
4 | import torch.nn as nn
5 | import torch.optim as optim
6 | from redq.algos.core import TanhGaussianPolicy, Mlp, soft_update_model1_with_model2, ReplayBuffer,\
7 | mbpo_target_entropy_dict
8 |
9 | def get_probabilistic_num_min(num_mins):
10 | # allows the number of min to be a float
11 | floored_num_mins = np.floor(num_mins)
12 | if num_mins - floored_num_mins > 0.001:
13 | prob_for_higher_value = num_mins - floored_num_mins
14 | if np.random.uniform(0, 1) < prob_for_higher_value:
15 | return int(floored_num_mins+1)
16 | else:
17 | return int(floored_num_mins)
18 | else:
19 | return num_mins
20 |
21 | class REDQSACAgent(object):
22 | """
23 | Naive SAC: num_Q = 2, num_min = 2
24 | REDQ: num_Q > 2, num_min = 2
25 | MaxMin: num_mins = num_Qs
26 | for above three variants, set q_target_mode to 'min' (default)
27 | Ensemble Average: set q_target_mode to 'ave'
28 | REM: set q_target_mode to 'rem'
29 | """
30 | def __init__(self, env_name, obs_dim, act_dim, act_limit, device,
31 | hidden_sizes=(256, 256), replay_size=int(1e6), batch_size=256,
32 | lr=3e-4, gamma=0.99, polyak=0.995,
33 | alpha=0.2, auto_alpha=True, target_entropy='mbpo',
34 | start_steps=5000, delay_update_steps='auto',
35 | utd_ratio=20, num_Q=10, num_min=2, q_target_mode='min',
36 | policy_update_delay=20,
37 | ):
38 | # set up networks
39 | self.policy_net = TanhGaussianPolicy(obs_dim, act_dim, hidden_sizes, action_limit=act_limit).to(device)
40 | self.q_net_list, self.q_target_net_list = [], []
41 | for q_i in range(num_Q):
42 | new_q_net = Mlp(obs_dim + act_dim, 1, hidden_sizes).to(device)
43 | self.q_net_list.append(new_q_net)
44 | new_q_target_net = Mlp(obs_dim + act_dim, 1, hidden_sizes).to(device)
45 | new_q_target_net.load_state_dict(new_q_net.state_dict())
46 | self.q_target_net_list.append(new_q_target_net)
47 | # set up optimizers
48 | self.policy_optimizer = optim.Adam(self.policy_net.parameters(), lr=lr)
49 | self.q_optimizer_list = []
50 | for q_i in range(num_Q):
51 | self.q_optimizer_list.append(optim.Adam(self.q_net_list[q_i].parameters(), lr=lr))
52 | # set up adaptive entropy (SAC adaptive)
53 | self.auto_alpha = auto_alpha
54 | if auto_alpha:
55 | if target_entropy == 'auto':
56 | self.target_entropy = - act_dim
57 | if target_entropy == 'mbpo':
58 | self.target_entropy = mbpo_target_entropy_dict[env_name]
59 | self.log_alpha = torch.zeros(1, requires_grad=True, device=device)
60 | self.alpha_optim = optim.Adam([self.log_alpha], lr=lr)
61 | self.alpha = self.log_alpha.cpu().exp().item()
62 | else:
63 | self.alpha = alpha
64 | self.target_entropy, self.log_alpha, self.alpha_optim = None, None, None
65 | # set up replay buffer
66 | self.replay_buffer = ReplayBuffer(obs_dim=obs_dim, act_dim=act_dim, size=replay_size)
67 | # set up other things
68 | self.mse_criterion = nn.MSELoss()
69 |
70 | # store other hyperparameters
71 | self.start_steps = start_steps
72 | self.obs_dim = obs_dim
73 | self.act_dim = act_dim
74 | self.act_limit = act_limit
75 | self.lr = lr
76 | self.hidden_sizes = hidden_sizes
77 | self.gamma = gamma
78 | self.polyak = polyak
79 | self.replay_size = replay_size
80 | self.alpha = alpha
81 | self.batch_size = batch_size
82 | self.num_min = num_min
83 | self.num_Q = num_Q
84 | self.utd_ratio = utd_ratio
85 | self.delay_update_steps = self.start_steps if delay_update_steps == 'auto' else delay_update_steps
86 | self.q_target_mode = q_target_mode
87 | self.policy_update_delay = policy_update_delay
88 | self.device = device
89 |
90 | def __get_current_num_data(self):
91 | # used to determine whether we should get action from policy or take random starting actions
92 | return self.replay_buffer.size
93 |
94 | def get_exploration_action(self, obs, env):
95 | # given an observation, output a sampled action in numpy form
96 | with torch.no_grad():
97 | if self.__get_current_num_data() > self.start_steps:
98 | obs_tensor = torch.Tensor(obs).unsqueeze(0).to(self.device)
99 | action_tensor = self.policy_net.forward(obs_tensor, deterministic=False,
100 | return_log_prob=False)[0]
101 | action = action_tensor.cpu().numpy().reshape(-1)
102 | else:
103 | action = env.action_space.sample()
104 | return action
105 |
106 | def get_test_action(self, obs):
107 | # given an observation, output a deterministic action in numpy form
108 | with torch.no_grad():
109 | obs_tensor = torch.Tensor(obs).unsqueeze(0).to(self.device)
110 | action_tensor = self.policy_net.forward(obs_tensor, deterministic=True,
111 | return_log_prob=False)[0]
112 | action = action_tensor.cpu().numpy().reshape(-1)
113 | return action
114 |
115 | def get_action_and_logprob_for_bias_evaluation(self, obs): #TODO modify the readme here
116 | # given an observation, output a sampled action in numpy form
117 | with torch.no_grad():
118 | obs_tensor = torch.Tensor(obs).unsqueeze(0).to(self.device)
119 | action_tensor, _, _, log_prob_a_tilda, _, _, = self.policy_net.forward(obs_tensor, deterministic=False,
120 | return_log_prob=True)
121 | action = action_tensor.cpu().numpy().reshape(-1)
122 | return action, log_prob_a_tilda
123 |
124 | def get_ave_q_prediction_for_bias_evaluation(self, obs_tensor, acts_tensor):
125 | # given obs_tensor and act_tensor, output Q prediction
126 | q_prediction_list = []
127 | for q_i in range(self.num_Q):
128 | q_prediction = self.q_net_list[q_i](torch.cat([obs_tensor, acts_tensor], 1))
129 | q_prediction_list.append(q_prediction)
130 | q_prediction_cat = torch.cat(q_prediction_list, dim=1)
131 | average_q_prediction = torch.mean(q_prediction_cat, dim=1)
132 | return average_q_prediction
133 |
134 | def store_data(self, o, a, r, o2, d):
135 | # store one transition to the buffer
136 | self.replay_buffer.store(o, a, r, o2, d)
137 |
138 | def sample_data(self, batch_size):
139 | # sample data from replay buffer
140 | batch = self.replay_buffer.sample_batch(batch_size)
141 | obs_tensor = Tensor(batch['obs1']).to(self.device)
142 | obs_next_tensor = Tensor(batch['obs2']).to(self.device)
143 | acts_tensor = Tensor(batch['acts']).to(self.device)
144 | rews_tensor = Tensor(batch['rews']).unsqueeze(1).to(self.device)
145 | done_tensor = Tensor(batch['done']).unsqueeze(1).to(self.device)
146 | return obs_tensor, obs_next_tensor, acts_tensor, rews_tensor, done_tensor
147 |
148 | def get_redq_q_target_no_grad(self, obs_next_tensor, rews_tensor, done_tensor):
149 | # compute REDQ Q target, depending on the agent's Q target mode
150 | # allow min as a float:
151 | num_mins_to_use = get_probabilistic_num_min(self.num_min)
152 | sample_idxs = np.random.choice(self.num_Q, num_mins_to_use, replace=False)
153 | with torch.no_grad():
154 | if self.q_target_mode == 'min':
155 | """Q target is min of a subset of Q values"""
156 | a_tilda_next, _, _, log_prob_a_tilda_next, _, _ = self.policy_net.forward(obs_next_tensor)
157 | q_prediction_next_list = []
158 | for sample_idx in sample_idxs:
159 | q_prediction_next = self.q_target_net_list[sample_idx](torch.cat([obs_next_tensor, a_tilda_next], 1))
160 | q_prediction_next_list.append(q_prediction_next)
161 | q_prediction_next_cat = torch.cat(q_prediction_next_list, 1)
162 | min_q, min_indices = torch.min(q_prediction_next_cat, dim=1, keepdim=True)
163 | next_q_with_log_prob = min_q - self.alpha * log_prob_a_tilda_next
164 | y_q = rews_tensor + self.gamma * (1 - done_tensor) * next_q_with_log_prob
165 | if self.q_target_mode == 'ave':
166 | """Q target is average of all Q values"""
167 | a_tilda_next, _, _, log_prob_a_tilda_next, _, _ = self.policy_net.forward(obs_next_tensor)
168 | q_prediction_next_list = []
169 | for q_i in range(self.num_Q):
170 | q_prediction_next = self.q_target_net_list[q_i](torch.cat([obs_next_tensor, a_tilda_next], 1))
171 | q_prediction_next_list.append(q_prediction_next)
172 | q_prediction_next_ave = torch.cat(q_prediction_next_list, 1).mean(dim=1).reshape(-1, 1)
173 | next_q_with_log_prob = q_prediction_next_ave - self.alpha * log_prob_a_tilda_next
174 | y_q = rews_tensor + self.gamma * (1 - done_tensor) * next_q_with_log_prob
175 | if self.q_target_mode == 'rem':
176 | """Q target is random ensemble mixture of Q values"""
177 | a_tilda_next, _, _, log_prob_a_tilda_next, _, _ = self.policy_net.forward(obs_next_tensor)
178 | q_prediction_next_list = []
179 | for q_i in range(self.num_Q):
180 | q_prediction_next = self.q_target_net_list[q_i](torch.cat([obs_next_tensor, a_tilda_next], 1))
181 | q_prediction_next_list.append(q_prediction_next)
182 | # apply rem here
183 | q_prediction_next_cat = torch.cat(q_prediction_next_list, 1)
184 | rem_weight = Tensor(np.random.uniform(0, 1, q_prediction_next_cat.shape)).to(device=self.device)
185 | normalize_sum = rem_weight.sum(1).reshape(-1, 1).expand(-1, self.num_Q)
186 | rem_weight = rem_weight / normalize_sum
187 | q_prediction_next_rem = (q_prediction_next_cat * rem_weight).sum(dim=1).reshape(-1, 1)
188 | next_q_with_log_prob = q_prediction_next_rem - self.alpha * log_prob_a_tilda_next
189 | y_q = rews_tensor + self.gamma * (1 - done_tensor) * next_q_with_log_prob
190 | return y_q, sample_idxs
191 |
192 | def train(self, logger):
193 | # this function is called after each datapoint collected.
194 | # when we only have very limited data, we don't make updates
195 | num_update = 0 if self.__get_current_num_data() <= self.delay_update_steps else self.utd_ratio
196 | for i_update in range(num_update):
197 | obs_tensor, obs_next_tensor, acts_tensor, rews_tensor, done_tensor = self.sample_data(self.batch_size)
198 |
199 | """Q loss"""
200 | y_q, sample_idxs = self.get_redq_q_target_no_grad(obs_next_tensor, rews_tensor, done_tensor)
201 | q_prediction_list = []
202 | for q_i in range(self.num_Q):
203 | q_prediction = self.q_net_list[q_i](torch.cat([obs_tensor, acts_tensor], 1))
204 | q_prediction_list.append(q_prediction)
205 | q_prediction_cat = torch.cat(q_prediction_list, dim=1)
206 | y_q = y_q.expand((-1, self.num_Q)) if y_q.shape[1] == 1 else y_q
207 | q_loss_all = self.mse_criterion(q_prediction_cat, y_q) * self.num_Q
208 |
209 | for q_i in range(self.num_Q):
210 | self.q_optimizer_list[q_i].zero_grad()
211 | q_loss_all.backward()
212 |
213 | """policy and alpha loss"""
214 | if ((i_update + 1) % self.policy_update_delay == 0) or i_update == num_update - 1:
215 | # get policy loss
216 | a_tilda, mean_a_tilda, log_std_a_tilda, log_prob_a_tilda, _, pretanh = self.policy_net.forward(obs_tensor)
217 | q_a_tilda_list = []
218 | for sample_idx in range(self.num_Q):
219 | self.q_net_list[sample_idx].requires_grad_(False)
220 | q_a_tilda = self.q_net_list[sample_idx](torch.cat([obs_tensor, a_tilda], 1))
221 | q_a_tilda_list.append(q_a_tilda)
222 | q_a_tilda_cat = torch.cat(q_a_tilda_list, 1)
223 | ave_q = torch.mean(q_a_tilda_cat, dim=1, keepdim=True)
224 | policy_loss = (self.alpha * log_prob_a_tilda - ave_q).mean()
225 | self.policy_optimizer.zero_grad()
226 | policy_loss.backward()
227 | for sample_idx in range(self.num_Q):
228 | self.q_net_list[sample_idx].requires_grad_(True)
229 |
230 | # get alpha loss
231 | if self.auto_alpha:
232 | alpha_loss = -(self.log_alpha * (log_prob_a_tilda + self.target_entropy).detach()).mean()
233 | self.alpha_optim.zero_grad()
234 | alpha_loss.backward()
235 | self.alpha_optim.step()
236 | self.alpha = self.log_alpha.cpu().exp().item()
237 | else:
238 | alpha_loss = Tensor([0])
239 |
240 | """update networks"""
241 | for q_i in range(self.num_Q):
242 | self.q_optimizer_list[q_i].step()
243 |
244 | if ((i_update + 1) % self.policy_update_delay == 0) or i_update == num_update - 1:
245 | self.policy_optimizer.step()
246 |
247 | # polyak averaged Q target networks
248 | for q_i in range(self.num_Q):
249 | soft_update_model1_with_model2(self.q_target_net_list[q_i], self.q_net_list[q_i], self.polyak)
250 |
251 | # by default only log for the last update out of updates
252 | if i_update == num_update - 1:
253 | logger.store(LossPi=policy_loss.cpu().item(), LossQ1=q_loss_all.cpu().item() / self.num_Q,
254 | LossAlpha=alpha_loss.cpu().item(), Q1Vals=q_prediction.detach().cpu().numpy(),
255 | Alpha=self.alpha, LogPi=log_prob_a_tilda.detach().cpu().numpy(),
256 | PreTanh=pretanh.abs().detach().cpu().numpy().reshape(-1))
257 |
258 | # if there is no update, log 0 to prevent logging problems
259 | if num_update == 0:
260 | logger.store(LossPi=0, LossQ1=0, LossAlpha=0, Q1Vals=0, Alpha=0, LogPi=0, PreTanh=0)
261 |
262 |
--------------------------------------------------------------------------------
/redq/user_config.py:
--------------------------------------------------------------------------------
1 | """
2 | Modified from OpenAI spinup code
3 | """
4 | import os.path as osp
5 |
6 | # Where experiment outputs are saved by default:
7 | DEFAULT_DATA_DIR = osp.join(osp.abspath(osp.dirname(osp.dirname(__file__))),'data')
8 |
9 | # Whether to automatically insert a date and time stamp into the names of
10 | # save directories:
11 | FORCE_DATESTAMP = False
12 |
13 | # Whether GridSearch provides automatically-generated default shorthands:
14 | DEFAULT_SHORTHAND = True
15 |
16 | # Tells the GridSearch how many seconds to pause for before launching
17 | # experiments.
18 | WAIT_BEFORE_LAUNCH = 5
--------------------------------------------------------------------------------
/redq/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/watchernyu/REDQ/7b5d1bff39291a57325a2836bd397a55728960bb/redq/utils/__init__.py
--------------------------------------------------------------------------------
/redq/utils/bias_utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | from torch import Tensor
4 |
5 | def get_mc_return_with_entropy_on_reset(bias_eval_env, agent, max_ep_len, alpha, gamma, n_mc_eval, n_mc_cutoff):
6 | # since we want to also compute bias, so we need to
7 | final_mc_list = np.zeros(0)
8 | final_mc_entropy_list = np.zeros(0)
9 | final_obs_list = []
10 | final_act_list = []
11 | while final_mc_list.shape[0] < n_mc_eval:
12 | # we continue if haven't collected enough data
13 | o = bias_eval_env.reset()
14 | # temporary lists
15 | reward_list, log_prob_a_tilda_list, obs_list, act_list = [], [], [], []
16 | r, d, ep_ret, ep_len = 0, False, 0, 0
17 | discounted_return = 0
18 | discounted_return_with_entropy = 0
19 | for i_step in range(max_ep_len): # run an episode
20 | with torch.no_grad():
21 | a, log_prob_a_tilda = agent.get_action_and_logprob_for_bias_evaluation(o)
22 | obs_list.append(o)
23 | act_list.append(a)
24 | o, r, d, _ = bias_eval_env.step(a)
25 | ep_ret += r
26 | ep_len += 1
27 | reward_list.append(r)
28 | log_prob_a_tilda_list.append(log_prob_a_tilda.item())
29 | if d or (ep_len == max_ep_len):
30 | break
31 | discounted_return_list = np.zeros(ep_len)
32 | discounted_return_with_entropy_list = np.zeros(ep_len)
33 | for i_step in range(ep_len - 1, -1, -1):
34 | # backwards compute discounted return and with entropy for all s-a visited
35 | if i_step == ep_len - 1:
36 | discounted_return_list[i_step] = reward_list[i_step]
37 | discounted_return_with_entropy_list[i_step] = reward_list[i_step]
38 | else:
39 | discounted_return_list[i_step] = reward_list[i_step] + gamma * discounted_return_list[i_step + 1]
40 | discounted_return_with_entropy_list[i_step] = reward_list[i_step] + \
41 | gamma * (discounted_return_with_entropy_list[i_step + 1] - alpha * log_prob_a_tilda_list[i_step + 1])
42 | # now we take the first few of these.
43 | final_mc_list = np.concatenate((final_mc_list, discounted_return_list[:n_mc_cutoff]))
44 | final_mc_entropy_list = np.concatenate(
45 | (final_mc_entropy_list, discounted_return_with_entropy_list[:n_mc_cutoff]))
46 | final_obs_list += obs_list[:n_mc_cutoff]
47 | final_act_list += act_list[:n_mc_cutoff]
48 | return final_mc_list, final_mc_entropy_list, np.array(final_obs_list), np.array(final_act_list)
49 |
50 | def log_bias_evaluation(bias_eval_env, agent, logger, max_ep_len, alpha, gamma, n_mc_eval, n_mc_cutoff):
51 | final_mc_list, final_mc_entropy_list, final_obs_list, final_act_list = get_mc_return_with_entropy_on_reset(bias_eval_env, agent, max_ep_len, alpha, gamma, n_mc_eval, n_mc_cutoff)
52 | logger.store(MCDisRet=final_mc_list)
53 | logger.store(MCDisRetEnt=final_mc_entropy_list)
54 | obs_tensor = Tensor(final_obs_list).to(agent.device)
55 | acts_tensor = Tensor(final_act_list).to(agent.device)
56 | with torch.no_grad():
57 | q_prediction = agent.get_ave_q_prediction_for_bias_evaluation(obs_tensor, acts_tensor).cpu().numpy().reshape(-1)
58 | bias = q_prediction - final_mc_entropy_list
59 | bias_abs = np.abs(bias)
60 | bias_squared = bias ** 2
61 | logger.store(QPred=q_prediction)
62 | logger.store(QBias=bias)
63 | logger.store(QBiasAbs=bias_abs)
64 | logger.store(QBiasSqr=bias_squared)
65 | final_mc_entropy_list_normalize_base = final_mc_entropy_list.copy()
66 | final_mc_entropy_list_normalize_base = np.abs(final_mc_entropy_list_normalize_base)
67 | final_mc_entropy_list_normalize_base[final_mc_entropy_list_normalize_base < 10] = 10
68 | normalized_bias_per_state = bias / final_mc_entropy_list_normalize_base
69 | logger.store(NormQBias=normalized_bias_per_state)
70 | normalized_bias_sqr_per_state = bias_squared / final_mc_entropy_list_normalize_base
71 | logger.store(NormQBiasSqr=normalized_bias_sqr_per_state)
72 |
--------------------------------------------------------------------------------
/redq/utils/logx.py:
--------------------------------------------------------------------------------
1 | """
2 | Some simple logging functionality, inspired by rllab's logging.
3 | Modified from OpenAI spinup logx code
4 | Logs to a tab-separated-values file (path/to/output_directory/progress.txt)
5 | """
6 | import json
7 | import joblib
8 | import numpy as np
9 | import os.path as osp, time, atexit, os
10 | from redq.utils.serialization_utils import convert_json
11 |
12 | color2num = dict(
13 | gray=30,
14 | red=31,
15 | green=32,
16 | yellow=33,
17 | blue=34,
18 | magenta=35,
19 | cyan=36,
20 | white=37,
21 | crimson=38
22 | )
23 |
24 | def colorize(string, color, bold=False, highlight=False):
25 | """
26 | Colorize a string.
27 |
28 | This function was originally written by John Schulman.
29 | """
30 | attr = []
31 | num = color2num[color]
32 | if highlight: num += 10
33 | attr.append(str(num))
34 | if bold: attr.append('1')
35 | return '\x1b[%sm%s\x1b[0m' % (';'.join(attr), string)
36 |
37 | class Logger:
38 | """
39 | A general-purpose logger.
40 |
41 | Makes it easy to save diagnostics, hyperparameter configurations, the
42 | state of a training run, and the trained model.
43 | """
44 |
45 | def __init__(self, output_dir=None, output_fname='progress.txt', exp_name=None):
46 | """
47 | Initialize a Logger.
48 |
49 | Args:
50 | output_dir (string): A directory for saving results to. If
51 | ``None``, defaults to a temp directory of the form
52 | ``/tmp/experiments/somerandomnumber``.
53 |
54 | output_fname (string): Name for the tab-separated-value file
55 | containing metrics logged throughout a training run.
56 | Defaults to ``progress.txt``.
57 |
58 | exp_name (string): Experiment name. If you run multiple training
59 | runs and give them all the same ``exp_name``, the plotter
60 | will know to group them. (Use case: if you run the same
61 | hyperparameter configuration with multiple random seeds, you
62 | should give them all the same ``exp_name``.)
63 | """
64 | self.output_dir = output_dir or "/tmp/experiments/%i"%int(time.time())
65 | if osp.exists(self.output_dir):
66 | print("Warning: Log dir %s already exists! Storing info there anyway."%self.output_dir)
67 | else:
68 | os.makedirs(self.output_dir)
69 | self.output_file = open(osp.join(self.output_dir, output_fname), 'w')
70 | atexit.register(self.output_file.close)
71 | print(colorize("Logging data to %s"%self.output_file.name, 'green', bold=True))
72 | self.first_row=True
73 | self.log_headers = []
74 | self.log_current_row = {}
75 | self.exp_name = exp_name
76 |
77 | def log(self, msg, color='green'):
78 | """Print a colorized message to stdout."""
79 | print(colorize(msg, color, bold=True))
80 |
81 | def log_tabular(self, key, val):
82 | """
83 | Log a value of some diagnostic.
84 |
85 | Call this only once for each diagnostic quantity, each iteration.
86 | After using ``log_tabular`` to store values for each diagnostic,
87 | make sure to call ``dump_tabular`` to write them out to file and
88 | stdout (otherwise they will not get saved anywhere).
89 | """
90 | if self.first_row:
91 | self.log_headers.append(key)
92 | else:
93 | assert key in self.log_headers, "Trying to introduce a new key %s that you didn't include in the first iteration"%key
94 | assert key not in self.log_current_row, "You already set %s this iteration. Maybe you forgot to call dump_tabular()"%key
95 | self.log_current_row[key] = val
96 |
97 | def save_config(self, config):
98 | """
99 | Log an experiment configuration.
100 |
101 | Call this once at the top of your experiment, passing in all important
102 | config vars as a dict. This will serialize the config to JSON, while
103 | handling anything which can't be serialized in a graceful way (writing
104 | as informative a string as possible).
105 |
106 | Example use:
107 |
108 | .. code-block:: python
109 |
110 | logger = EpochLogger(**logger_kwargs)
111 | logger.save_config(locals())
112 | """
113 | config_json = convert_json(config)
114 | if self.exp_name is not None:
115 | config_json['exp_name'] = self.exp_name
116 | output = json.dumps(config_json, separators=(',',':\t'), indent=4, sort_keys=True)
117 | print(colorize('Saving config:\n', color='cyan', bold=True))
118 | print(output)
119 | with open(osp.join(self.output_dir, "config.json"), 'w') as out:
120 | out.write(output)
121 |
122 | def save_state(self, state_dict, itr=None):
123 | """
124 | Saves the state of an experiment.
125 |
126 | To be clear: this is about saving *state*, not logging diagnostics.
127 | All diagnostic logging is separate from this function. This function
128 | will save whatever is in ``state_dict``---usually just a copy of the
129 | environment---and the most recent parameters for the model you
130 | previously set up saving for with ``setup_tf_saver``.
131 |
132 | Call with any frequency you prefer. If you only want to maintain a
133 | single state and overwrite it at each call with the most recent
134 | version, leave ``itr=None``. If you want to keep all of the states you
135 | save, provide unique (increasing) values for 'itr'.
136 |
137 | Args:
138 | state_dict (dict): Dictionary containing essential elements to
139 | describe the current state of training.
140 |
141 | itr: An int, or None. Current iteration of training.
142 | """
143 | fname = 'vars.pkl' if itr is None else 'vars%d.pkl'%itr
144 | try:
145 | joblib.dump(state_dict, osp.join(self.output_dir, fname))
146 | except:
147 | self.log('Warning: could not pickle state_dict.', color='red')
148 |
149 | def dump_tabular(self):
150 | """
151 | Write all of the diagnostics from the current iteration.
152 |
153 | Writes both to stdout, and to the output file.
154 | """
155 | vals = []
156 | key_lens = [len(key) for key in self.log_headers]
157 | max_key_len = max(15,max(key_lens))
158 | keystr = '%'+'%d'%max_key_len
159 | fmt = "| " + keystr + "s | %15s |"
160 | n_slashes = 22 + max_key_len
161 | print("-"*n_slashes)
162 | for key in self.log_headers:
163 | val = self.log_current_row.get(key, "")
164 | valstr = "%8.3g"%val if hasattr(val, "__float__") else val
165 | print(fmt%(key, valstr))
166 | vals.append(val)
167 | print("-"*n_slashes)
168 | if self.output_file is not None:
169 | if self.first_row:
170 | self.output_file.write("\t".join(self.log_headers)+"\n")
171 | self.output_file.write("\t".join(map(str,vals))+"\n")
172 | self.output_file.flush()
173 | self.log_current_row.clear()
174 | self.first_row=False
175 |
176 | def get_statistics_scalar(x, with_min_and_max=False):
177 | """
178 | Get mean/std and optional min/max of x
179 |
180 | Args:
181 | x: An array containing samples of the scalar to produce statistics
182 | for.
183 |
184 | with_min_and_max (bool): If true, return min and max of x in
185 | addition to mean and std.
186 | """
187 | x = np.array(x, dtype=np.float32)
188 | mean, std = x.mean(), x.std()
189 | if with_min_and_max:
190 | min_v = x.min()
191 | max_v = x.max()
192 | return mean, std, min_v, max_v
193 | return mean, std
194 |
195 | class EpochLogger(Logger):
196 | """
197 | A variant of Logger tailored for tracking average values over epochs.
198 |
199 | Typical use case: there is some quantity which is calculated many times
200 | throughout an epoch, and at the end of the epoch, you would like to
201 | report the average / std / min / max value of that quantity.
202 |
203 | With an EpochLogger, each time the quantity is calculated, you would
204 | use
205 |
206 | .. code-block:: python
207 |
208 | epoch_logger.store(NameOfQuantity=quantity_value)
209 |
210 | to load it into the EpochLogger's state. Then at the end of the epoch, you
211 | would use
212 |
213 | .. code-block:: python
214 |
215 | epoch_logger.log_tabular(NameOfQuantity, **options)
216 |
217 | to record the desired values.
218 | """
219 |
220 | def __init__(self, *args, **kwargs):
221 | super().__init__(*args, **kwargs)
222 | self.epoch_dict = dict()
223 |
224 | def store(self, **kwargs):
225 | """
226 | Save something into the epoch_logger's current state.
227 |
228 | Provide an arbitrary number of keyword arguments with numerical
229 | values.
230 |
231 | To prevent problems, let value be either a numpy array, or a single scalar value
232 | """
233 | for k,v in kwargs.items():
234 | if not k in self.epoch_dict:
235 | self.epoch_dict[k] = []
236 | if isinstance(v, np.ndarray): # used to prevent problems due to shape issues
237 | v = v.reshape(-1)
238 | self.epoch_dict[k].append(v)
239 |
240 | def log_tabular(self, key, val=None, with_min_and_max=False, average_only=False):
241 | """
242 | Log a value or possibly the mean/std/min/max values of a diagnostic.
243 |
244 | Args:
245 | key (string): The name of the diagnostic. If you are logging a
246 | diagnostic whose state has previously been saved with
247 | ``store``, the key here has to match the key you used there.
248 |
249 | val: A value for the diagnostic. If you have previously saved
250 | values for this key via ``store``, do *not* provide a ``val``
251 | here.
252 |
253 | with_min_and_max (bool): If true, log min and max values of the
254 | diagnostic over the epoch.
255 |
256 | average_only (bool): If true, do not log the standard deviation
257 | of the diagnostic over the epoch.
258 | """
259 | if val is not None:
260 | super().log_tabular(key,val)
261 | else:
262 | v = self.epoch_dict[key]
263 | vals = np.concatenate(v) if isinstance(v[0], np.ndarray) and len(v[0].shape)>0 else v
264 | stats = get_statistics_scalar(vals, with_min_and_max=with_min_and_max)
265 | super().log_tabular(key if average_only else 'Average' + key, stats[0])
266 | if not(average_only):
267 | super().log_tabular('Std'+key, stats[1])
268 | if with_min_and_max:
269 | super().log_tabular('Max'+key, stats[3])
270 | super().log_tabular('Min'+key, stats[2])
271 | self.epoch_dict[key] = []
272 |
273 | def get_stats(self, key):
274 | """
275 | Lets an algorithm ask the logger for mean/std/min/max of a diagnostic.
276 | """
277 | v = self.epoch_dict[key]
278 | vals = np.concatenate(v) if isinstance(v[0], np.ndarray) and len(v[0].shape)>0 else v
279 | return get_statistics_scalar(vals, with_min_and_max=True)
--------------------------------------------------------------------------------
/redq/utils/run_utils.py:
--------------------------------------------------------------------------------
1 | """
2 | Modified from OpenAI spinup code
3 | """
4 | from redq.user_config import DEFAULT_DATA_DIR, FORCE_DATESTAMP
5 | import os.path as osp
6 | import time
7 |
8 | DIV_LINE_WIDTH = 80
9 |
10 | def setup_logger_kwargs(exp_name, seed=None, data_dir=None, datestamp=False):
11 | """
12 | Sets up the output_dir for a logger and returns a dict for logger kwargs.
13 |
14 | If no seed is given and datestamp is false,
15 |
16 | ::
17 |
18 | output_dir = data_dir/exp_name
19 |
20 | If a seed is given and datestamp is false,
21 |
22 | ::
23 |
24 | output_dir = data_dir/exp_name/exp_name_s[seed]
25 |
26 | If datestamp is true, amend to
27 |
28 | ::
29 |
30 | output_dir = data_dir/YY-MM-DD_exp_name/YY-MM-DD_HH-MM-SS_exp_name_s[seed]
31 |
32 | You can force datestamp=True by setting ``FORCE_DATESTAMP=True`` in
33 | ``spinup/user_config.py``.
34 |
35 | Args:
36 |
37 | exp_name (string): Name for experiment.
38 |
39 | seed (int): Seed for random number generators used by experiment.
40 |
41 | data_dir (string): Path to folder where results should be saved.
42 | Default is the ``DEFAULT_DATA_DIR`` in ``spinup/user_config.py``.
43 |
44 | datestamp (bool): Whether to include a date and timestamp in the
45 | name of the save directory.
46 |
47 | Returns:
48 |
49 | logger_kwargs, a dict containing output_dir and exp_name.
50 | """
51 |
52 | # Datestamp forcing
53 | datestamp = datestamp or FORCE_DATESTAMP
54 |
55 | # Make base path
56 | ymd_time = time.strftime("%Y-%m-%d_") if datestamp else ''
57 | relpath = ''.join([ymd_time, exp_name])
58 |
59 | if seed is not None:
60 | # Make a seed-specific subfolder in the experiment directory.
61 | if datestamp:
62 | hms_time = time.strftime("%Y-%m-%d_%H-%M-%S")
63 | subfolder = ''.join([hms_time, '-', exp_name, '_s', str(seed)])
64 | else:
65 | subfolder = ''.join([exp_name, '_s', str(seed)])
66 | relpath = osp.join(relpath, subfolder)
67 |
68 | data_dir = data_dir or DEFAULT_DATA_DIR
69 | logger_kwargs = dict(output_dir=osp.join(data_dir, relpath),
70 | exp_name=exp_name)
71 | return logger_kwargs
72 |
--------------------------------------------------------------------------------
/redq/utils/serialization_utils.py:
--------------------------------------------------------------------------------
1 | """
2 | Modified from OpenAI spinup code
3 | """
4 | import json
5 |
6 | def convert_json(obj):
7 | """ Convert obj to a version which can be serialized with JSON. """
8 | if is_json_serializable(obj):
9 | return obj
10 | else:
11 | if isinstance(obj, dict):
12 | return {convert_json(k): convert_json(v)
13 | for k,v in obj.items()}
14 |
15 | elif isinstance(obj, tuple):
16 | return (convert_json(x) for x in obj)
17 |
18 | elif isinstance(obj, list):
19 | return [convert_json(x) for x in obj]
20 |
21 | elif hasattr(obj,'__name__') and not('lambda' in obj.__name__):
22 | return convert_json(obj.__name__)
23 |
24 | elif hasattr(obj,'__dict__') and obj.__dict__:
25 | obj_dict = {convert_json(k): convert_json(v)
26 | for k,v in obj.__dict__.items()}
27 | return {str(obj): obj_dict}
28 |
29 | return str(obj)
30 |
31 | def is_json_serializable(v):
32 | try:
33 | json.dumps(v)
34 | return True
35 | except:
36 | return False
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from os.path import join, dirname, realpath
2 | from setuptools import setup
3 | import sys
4 |
5 | assert sys.version_info.major == 3 and sys.version_info.minor >= 6, \
6 | "Require Python 3.6 or greater."
7 |
8 | setup(
9 | name='redq',
10 | py_modules=['redq'],
11 | version='0.0.1',
12 | install_requires=[
13 | 'numpy',
14 | 'joblib',
15 | 'gym>=0.17.2'
16 | ],
17 | description="REDQ algorithm PyTorch implementation",
18 | author="Xinyue Chen, Che Wang, Zijian Zhou, Keith Ross",
19 | )
20 |
--------------------------------------------------------------------------------