├── .gitignore
├── CODE_OF_CONDUCT.md
├── LICENSE
├── README.md
├── SECURITY.md
├── SUPPORT.md
├── VRL3.png
├── docker
└── Dockerfile
├── plot_utils
├── drq_converter.py
├── plot_config.py
├── rrl_result_converter.py
├── vrl3_plot_example.py
├── vrl3_plot_helper.py
├── vrl3_plot_mover.py
└── vrl3_plot_runner.py
└── src
├── .gitignore
├── adroit.py
├── cfgs_adroit
├── config.yaml
└── task
│ ├── door.yaml
│ ├── hammer.yaml
│ ├── pen.yaml
│ └── relocate.yaml
├── dmc.py
├── logger.py
├── replay_buffer.py
├── rrl_local
├── rrl_encoder.py
├── rrl_multicam.py
└── rrl_utils.py
├── stage1_models.py
├── testing
├── computation_time_test.py
└── zzz.py
├── train_adroit.py
├── train_stage1.py
├── transfer_util.py
├── utils.py
├── video.py
└── vrl3_agent.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 | cover/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | .pybuilder/
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | # For a library or package, you might want to ignore these files since the code is
87 | # intended to run in multiple environments; otherwise, check them in:
88 | # .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
98 | __pypackages__/
99 |
100 | # Celery stuff
101 | celerybeat-schedule
102 | celerybeat.pid
103 |
104 | # SageMath parsed files
105 | *.sage.py
106 |
107 | # Environments
108 | .env
109 | .venv
110 | env/
111 | venv/
112 | ENV/
113 | env.bak/
114 | venv.bak/
115 |
116 | # Spyder project settings
117 | .spyderproject
118 | .spyproject
119 |
120 | # Rope project settings
121 | .ropeproject
122 |
123 | # mkdocs documentation
124 | /site
125 |
126 | # mypy
127 | .mypy_cache/
128 | .dmypy.json
129 | dmypy.json
130 |
131 | # Pyre type checker
132 | .pyre/
133 |
134 | # pytype static type analyzer
135 | .pytype/
136 |
137 | # Cython debug symbols
138 | cython_debug/
139 |
140 | .idea/
141 |
142 | # experiment files
143 | exp_local/
144 | data/
145 |
146 | images/
--------------------------------------------------------------------------------
/CODE_OF_CONDUCT.md:
--------------------------------------------------------------------------------
1 | # Microsoft Open Source Code of Conduct
2 |
3 | This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/).
4 |
5 | Resources:
6 |
7 | - [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/)
8 | - [Microsoft Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/)
9 | - Contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with questions or concerns
10 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) Microsoft Corporation.
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | ## Visual Deep Reinforcement Learning in 3 Stages (VRL3)
2 |
3 | Official code for the paper VRL3: A Data-Driven Framework for Visual Deep Reinforcement Learning. Summary site: https://sites.google.com/nyu.edu/vrl3.
4 |
5 | 
6 |
7 | Code has just been released and the entire codebase is re-written to make it cleaner and improve readability, if you run into any problem, please do not hesitate to open an issue.
8 |
9 | We are also doing some further clean-up of the code now. This repo will be updated.
10 |
11 |
12 |
13 | ### Table of Contents
14 |
15 | - [Repo structure](#repo-structure)
16 | - [Environment setup](#environment-setup)
17 | - [Docker setup](#docker)
18 | - [Run experiments](#run-exp)
19 | - [Singularity setup](#singularity)
20 | - [Plotting example](#plotting)
21 | - [Technical details](#hyper)
22 | - [Computation time](#computation)
23 | - [Known issues](#known-issues)
24 | - [Acknowledgement](#acknowledgement)
25 | - [Citation](#citation)
26 | - [Contributing](#contributing)
27 |
28 |
29 | ### Updates:
30 |
31 | 03/30/2023: added example plot function and a quick tutorial.
32 |
33 |
34 |
35 |
36 | ## Repo structure and important files:
37 |
38 | ```
39 | VRL3 # this repo
40 | │ README.md # read this file first!
41 | └───docker # dockerfile with all dependencies
42 | └───plot_utils # code for plotting, still working on it now...
43 | └───src
44 | │ train_adroit.py # setup and main training loop
45 | │ vrl3_agent.py # agent class, code for stage 2, 3
46 | │ train_stage1.py # code for stage 1 pretraining on imagenet
47 | │ stage1_models.py # the encoder classes pretrained in stage 1
48 | └───cfgs_adroit # configuration files with all hyperparameters
49 |
50 | # download these folders from the google drive link
51 | vrl3data
52 | └───demonstrations # adroit demos
53 | └───trained_models # pretrained stage 1 models
54 | vrl3examplelogs
55 | └───rrl # rrl training logs
56 | └───vrl3 # vrl3 with default hyperparams logs
57 | ```
58 |
59 | To get started, download this repo and download adroit demos, pretrained models, and example logs with the following link:
60 | https://drive.google.com/drive/folders/1j-7BKlYmknVCBfzO3MnLLD62JBngoMkn?usp=sharing
61 |
62 |
63 |
64 | ## Environment setup
65 |
66 | The recommended way is to just use the dockerfile I provided and follow the tutorial here. You can also look at the dockerfile to know the exact dependencies or modify it to build a new dockerfile.
67 |
68 |
69 |
70 | ### Setup with docker
71 |
72 | If you have a local machine with gpu, or your cluster allows docker (you have sudo), then you can just pull my docker image and run code there. (Newest version is 1.5, where the mujoco slow rendering with gpu issue is fixed).
73 | ```
74 | docker pull docker://cwatcherw/vrl3:1.5
75 | ```
76 |
77 | Now, `cd` into a directory where you have the `VRL3` folder (this repo), and also the `vrl3data` folder that you downloaded from my google drive link.
78 | Then, mount `VRL3/src` to `/code`, and mount `vrl3data` to `/vrl3data` (you can also mount to other places, but you will need to adjust some commands or paths in the config files):
79 | ```
80 | docker run -it --rm --gpus all -v "$(pwd)"/VRL3/src:/code -v "$(pwd)"/vrl3data:/vrl3data docker://cwatcherw/vrl3:1.5
81 | ```
82 | Now you should be inside the docker container. Refer to the "Run experiments" section now.
83 |
84 |
85 |
86 | ### Run experiments
87 |
88 | Once you get into the container (either docker or singularity), first run the following commands so the paths are correct. Very important especially on singularity since it uses automount which can mess up the paths. (newest version code now uses `os.environ` to do these so you can also skip this step.)
89 | ```
90 | export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/workspace/.mujoco/mujoco210/bin
91 | export MUJOCO_PY_MUJOCO_PATH=/workspace/.mujoco/mujoco210/
92 | export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/nvidia/lib
93 | export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/workspace/.mujoco/mujoco210/bin
94 | export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/lib/nvidia
95 | export MUJOCO_GL=egl
96 | ```
97 |
98 | Go to the VRL3 code directory that you mounted.
99 | ```
100 | cd /code
101 | ```
102 |
103 | First quickly check if mujoco is using your GPU correctly for rendering. If everything is correct, you should see the program print out the computation time for 1000 rendering (if it's first time mujoco is imported then there will also be mujoco build messages which takes a few minutes). The time used to do rendering 1000 times should be < 0.5 seconds.
104 | ```
105 | python testing/computation_time_test.py
106 | ```
107 |
108 |
109 | Then you can start run VRL3:
110 | ```
111 | cd /code
112 | python train_adroit.py task=door
113 | ```
114 |
115 | For first-time setup, use `debug=1` to do a quick test run to see if the code is working. This will reduce training epochs and change many other hyperparameters so you will get a full run in a few minutes.
116 | ```
117 | python train_adroit.py task=door debug=1
118 | ```
119 |
120 | You can also run with different hyperparameters, see the `config.yaml` for a full list of them. For example:
121 | ```
122 | python train_adroit.py task=door stage2_n_update=5000 agent.encoder_lr_scale=0.1
123 | ```
124 |
125 |
126 |
127 | ### Setup with singularity
128 |
129 | If your cluster does not allow sudo (for example, NYU's slurm HPC), then you can use singularity, it is similar to docker. But you might need to modify some of the commands depends on how your cluster is being managed. Here is an example setup on the NYU Greene HPC.
130 |
131 | Set up singularity container (this will make a folder called `sing` in your scratch directory, and then build a singularity sandbox container called `vrl3sing`, using the `cwatcherw/vrl3:1.5` docker container which I put on my docker hub):
132 | ```
133 | mkdir /scratch/$USER/sing/
134 | cd /scratch/$USER/sing/
135 | singularity build --sandbox vrl3sing docker://cwatcherw/vrl3:1.5
136 | ```
137 |
138 | For example, on NYU HPC, start interactive session (if your school has a different hpc system, consult your hpc admin):
139 | ```
140 | srun --pty --gres=gpu:1 --cpus-per-task=4 --mem 12000 -t 0-06:00 bash
141 | ```
142 | Here, by default VRL3 uses 4 workers for dataloader, so we request 4 cpus. Once the job is allocated to you, go the `sing` folder where you have your container, then run it:
143 | ```
144 | cd /scratch/$USER/sing
145 | singularity exec --nv -B /scratch/$USER/sing/VRL3/src:/code -B /scratch/$USER/sing/vrl3sing/opt/conda/lib/python3.8/site-packages/mujoco_py/:/opt/conda/lib/python3.8/site-packages/mujoco_py/ -B /scratch/$USER/sing/vrl3data:/vrl3data /scratch/$USER/sing/vrl3sing bash
146 | ```
147 | We mount the `mujoco_py` package folder because singularity files by default are read-only, and the older version of mujoco_py wants to modify files, which can be problematic. (Unfortunately, Adroit env relies on the older version of mujoco)
148 |
149 | After the singularity container started running, now refer to the "Run experiments" section.
150 |
151 |
152 |
153 | ## Plotting example
154 |
155 | If you like to use the plotting functions we used, you will need `matplotlib`, `seaborn` and some other basic packages to use the plotting programs. You can also use your own plotting functions.
156 |
157 | An example is given in `plot_utils/vrl3_plot_example.py`. To use it:
158 | 1. make sure you downloaded the `vrl3examplelogs` folder from the drive link and unzipped it.
159 | 2. in `plot_utils/vrl3_plot_example.py`, change the `base_dir` path to where the `vrl3examplelogs` folder is on your computer.
160 | 3. similarlly, change `base_save_dir` path to where you want the figures to be generated.
161 | 4. run `plot_utils/vrl3_plot_example.py`, this will generate a few figures comparing success rate between RRL and VRL3 to the specified path.
162 |
163 | (All ablation experiment logs generated during the VRL3 research are in the folder `vrl3logs` from the drive link. `plot_utils/vrl3_plot_runner.py` was used to generate figures in the paper. Currently still need further clean up.)
164 |
165 |
166 |
167 | ## Technical details
168 |
169 | - BC loss: in the config files, I now by default disable all BC loss since our ablations show they are not really helping.
170 | - under `src/cfgs_adroit/task/relocate.yaml` you will see that relocate has `encoder_lr_scale: 0.01`, as shown in the paper, relocate requires a smaller encoder learning rate. You can set specific default parameters for each task in their separate config files.
171 | - in the paper for most experiments, I used `frame_stack=3`, however later I found we can reduce it to 1 and still get the same performance. It might be beneficial to set it to 1 so it runs faster and takes less memory. If you set this to 1, then convolutional channel expansion will only be applied for the relocate env, where the input is a stack of 3 camera images.
172 | - all values in table 2 in appendix A.2 of the paper are set to be the default values in the config files.
173 | - to apply VRL3 to other environments, please consult the hyperparameter sensitivity study in appendix A, which identifies robust and sensitive hyperparameters.
174 |
175 |
176 |
177 | ### Computation time
178 |
179 | This table compares the computation time estimates for the open source code with default hyperparameters (tested on NYU Greene with RTX 8000 and 4 cpus). When you use the code on your machine, it might be slightly faster or slower, but should not be too different. These results seem to be slightly faster than what we reported in the paper (which tested on Azure P100 GPU machines). Improved computation speed is mainly due to we now set default `frame_stack` for Adroit.
180 |
181 | | Task | Stage 2 (30K updates) | Stage 3 (4M frames) | Total | Total (paper) |
182 | |------------------|-----------------------|---------------------|---------|------------|
183 | | Door/Pen/Hammer | ~0.5 hrs | ~13 hrs | ~14 hrs | ~16 hrs |
184 | | Relocate | ~0.5 hrs | ~16 hrs | ~17 hrs | ~24 hrs |
185 |
186 | Note that VRL3's performance kind of converged already at 1M data for Door, Hammer and Relocate. So depending on what you want to achieve in your work, you may or may not need to run a full 4M frames. In the paper we run to 4M to be consistent with prior work and show VRL3 can outperform previous SOTA in both short-term and long-term performance.
187 |
188 |
189 |
190 | ### Known issues:
191 |
192 | - Some might encounter a problem where mujoco can crush at an arbitrary point during training. I have not seen this issue before but I was told reinit `self.train_env` between stage 2 and stage 3 can fix it.
193 | - If you are not using the provided docker image and you run into the problem of slow rendering, it is possible that mujoco did not find your gpu and made a `CPUExtender` instead of a `GPUExtender`. You can follow the steps in the provided dockerfile, or force it to use the `GPUExtender` (see code in `mujoco-py/mujoco_py/builder.py`) Thanks to ZheCheng Yuan for identifying above 2 issues.
194 | - Newer versions of mujoco are easier to work with. We use an older version only because Adroit relies on it. (So you can try a newer mujoco if you want to test on other environments).
195 |
196 |
197 |
198 | ## Acknowledgement
199 |
200 | VRL3 code has been mainly built on top of the DrQv2 codebase (https://github.com/facebookresearch/drqv2). Some utility functions and dockerfile are modified from the REDQ codebase (https://github.com/watchernyu/REDQ). The Adroit demo loading code is modified from the RRL codebase (https://github.com/facebookresearch/RRL).
201 |
202 |
203 |
204 | ## Citation
205 |
206 | If you use VRL3 in your research, please consider citing the paper as:
207 | ```
208 | @inproceedings{wang2022vrl3,
209 | title={VRL3: A Data-Driven Framework for Visual Deep Reinforcement Learning},
210 | author={Wang, Che and Luo, Xufang and Ross, Keith and Li, Dongsheng},
211 | booktitle={Conference on Neural Information Processing Systems},
212 | year={2022},
213 | url={https://openreview.net/forum?id=NjKAm5wMbo2}
214 | }
215 | ```
216 |
217 |
218 |
219 | ## Contributing
220 |
221 |
222 | This project welcomes contributions and suggestions. Most contributions require you to agree to a
223 | Contributor License Agreement (CLA) declaring that you have the right to, and actually do, grant us
224 | the rights to use your contribution. For details, visit https://cla.opensource.microsoft.com.
225 |
226 | When you submit a pull request, a CLA bot will automatically determine whether you need to provide
227 | a CLA and decorate the PR appropriately (e.g., status check, comment). Simply follow the instructions
228 | provided by the bot. You will only need to do this once across all repos using our CLA.
229 |
230 | This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/).
231 | For more information see the [Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) or
232 | contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with any additional questions or comments.
233 |
234 |
235 |
236 | ## Trademarks
237 |
238 |
239 | This project may contain trademarks or logos for projects, products, or services. Authorized use of Microsoft
240 | trademarks or logos is subject to and must follow
241 | [Microsoft's Trademark & Brand Guidelines](https://www.microsoft.com/en-us/legal/intellectualproperty/trademarks/usage/general).
242 | Use of Microsoft trademarks or logos in modified versions of this project must not cause confusion or imply Microsoft sponsorship.
243 | Any use of third-party trademarks or logos are subject to those third-party's policies.
244 |
--------------------------------------------------------------------------------
/SECURITY.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | ## Security
4 |
5 | Microsoft takes the security of our software products and services seriously, which includes all source code repositories managed through our GitHub organizations, which include [Microsoft](https://github.com/microsoft), [Azure](https://github.com/Azure), [DotNet](https://github.com/dotnet), [AspNet](https://github.com/aspnet), [Xamarin](https://github.com/xamarin), and [our GitHub organizations](https://opensource.microsoft.com/).
6 |
7 | If you believe you have found a security vulnerability in any Microsoft-owned repository that meets [Microsoft's definition of a security vulnerability](https://aka.ms/opensource/security/definition), please report it to us as described below.
8 |
9 | ## Reporting Security Issues
10 |
11 | **Please do not report security vulnerabilities through public GitHub issues.**
12 |
13 | Instead, please report them to the Microsoft Security Response Center (MSRC) at [https://msrc.microsoft.com/create-report](https://aka.ms/opensource/security/create-report).
14 |
15 | If you prefer to submit without logging in, send email to [secure@microsoft.com](mailto:secure@microsoft.com). If possible, encrypt your message with our PGP key; please download it from the [Microsoft Security Response Center PGP Key page](https://aka.ms/opensource/security/pgpkey).
16 |
17 | You should receive a response within 24 hours. If for some reason you do not, please follow up via email to ensure we received your original message. Additional information can be found at [microsoft.com/msrc](https://aka.ms/opensource/security/msrc).
18 |
19 | Please include the requested information listed below (as much as you can provide) to help us better understand the nature and scope of the possible issue:
20 |
21 | * Type of issue (e.g. buffer overflow, SQL injection, cross-site scripting, etc.)
22 | * Full paths of source file(s) related to the manifestation of the issue
23 | * The location of the affected source code (tag/branch/commit or direct URL)
24 | * Any special configuration required to reproduce the issue
25 | * Step-by-step instructions to reproduce the issue
26 | * Proof-of-concept or exploit code (if possible)
27 | * Impact of the issue, including how an attacker might exploit the issue
28 |
29 | This information will help us triage your report more quickly.
30 |
31 | If you are reporting for a bug bounty, more complete reports can contribute to a higher bounty award. Please visit our [Microsoft Bug Bounty Program](https://aka.ms/opensource/security/bounty) page for more details about our active programs.
32 |
33 | ## Preferred Languages
34 |
35 | We prefer all communications to be in English.
36 |
37 | ## Policy
38 |
39 | Microsoft follows the principle of [Coordinated Vulnerability Disclosure](https://aka.ms/opensource/security/cvd).
40 |
41 |
42 |
--------------------------------------------------------------------------------
/SUPPORT.md:
--------------------------------------------------------------------------------
1 | # TODO: The maintainer of this repo has not yet edited this file
2 |
3 | **REPO OWNER**: Do you want Customer Service & Support (CSS) support for this product/project?
4 |
5 | - **No CSS support:** Fill out this template with information about how to file issues and get help.
6 | - **Yes CSS support:** Fill out an intake form at [aka.ms/onboardsupport](https://aka.ms/onboardsupport). CSS will work with/help you to determine next steps.
7 | - **Not sure?** Fill out an intake as though the answer were "Yes". CSS will help you decide.
8 |
9 | *Then remove this first heading from this SUPPORT.MD file before publishing your repo.*
10 |
11 | # Support
12 |
13 | ## How to file issues and get help
14 |
15 | This project uses GitHub Issues to track bugs and feature requests. Please search the existing
16 | issues before filing new issues to avoid duplicates. For new issues, file your bug or
17 | feature request as a new Issue.
18 |
19 | For help and questions about using this project, please **REPO MAINTAINER: INSERT INSTRUCTIONS HERE
20 | FOR HOW TO ENGAGE REPO OWNERS OR COMMUNITY FOR HELP. COULD BE A STACK OVERFLOW TAG OR OTHER
21 | CHANNEL. WHERE WILL YOU HELP PEOPLE?**.
22 |
23 | ## Microsoft Support Policy
24 |
25 | Support for this **PROJECT or PRODUCT** is limited to the resources listed above.
26 |
--------------------------------------------------------------------------------
/VRL3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/VRL3/f22e4d197d953663c8e77173b0aa639ea99d17a8/VRL3.png
--------------------------------------------------------------------------------
/docker/Dockerfile:
--------------------------------------------------------------------------------
1 | # cudagl with miniconda and python 3.8, pytorch, mujoco and gym v2 envs, REDQ
2 | # current corresponding image on dockerhub: docker://cwatcherw/vrl3:1.5
3 |
4 | FROM nvidia/cudagl:11.0-base-ubuntu18.04
5 | WORKDIR /workspace
6 | ENV HOME=/workspace
7 | ENV LANG=C.UTF-8 LC_ALL=C.UTF-8
8 | ENV PATH /opt/conda/bin:$PATH
9 |
10 | # idea: start with a nvidia docker with gl support (guess this one also has cuda?)
11 | # then install miniconda, borrowing docker command from miniconda's Dockerfile (https://hub.docker.com/r/continuumio/anaconda/dockerfile/)
12 | # need to make sure the miniconda python version is what we need (https://docs.conda.io/en/latest/miniconda.html for the right version)
13 | # then install other dependencies we need
14 |
15 | # nvidia GPG key alternative fix (https://forums.developer.nvidia.com/t/notice-cuda-linux-repository-key-rotation/212772)
16 | # sudo apt-key del 7fa2af80
17 | # wget https://developer.download.nvidia.com/compute/cuda/repos/$distro/$arch/cuda-keyring_1.0-1_all.deb
18 | # sudo dpkg -i cuda-keyring_1.0-1_all.deb
19 |
20 | RUN \
21 | # Update nvidia GPG key
22 | rm /etc/apt/sources.list.d/cuda.list \
23 | && rm /etc/apt/sources.list.d/nvidia-ml.list \
24 | && apt-key del 7fa2af80 \
25 | && apt-get update && apt-get install -y --no-install-recommends wget \
26 | && wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2004/x86_64/cuda-keyring_1.0-1_all.deb \
27 | && dpkg -i cuda-keyring_1.0-1_all.deb \
28 | && apt-get update
29 |
30 | RUN apt-get update --fix-missing && apt-get install -y wget bzip2 ca-certificates \
31 | libglib2.0-0 libxext6 libsm6 libxrender1 \
32 | git mercurial subversion
33 |
34 | RUN wget --quiet https://repo.anaconda.com/miniconda/Miniconda3-py38_4.12.0-Linux-x86_64.sh -O ~/miniconda.sh \
35 | && /bin/bash ~/miniconda.sh -b -p /opt/conda \
36 | && rm ~/miniconda.sh \
37 | && ln -s /opt/conda/etc/profile.d/conda.sh /etc/profile.d/conda.sh \
38 | && echo ". /opt/conda/etc/profile.d/conda.sh" >> ~/.bashrc \
39 | && echo "conda activate base" >> ~/.bashrc
40 |
41 | RUN apt-get install -y curl grep sed dpkg \
42 | # TINI_VERSION=`curl https://github.com/krallin/tini/releases/latest | grep -o "/v.*\"" | sed 's:^..\(.*\).$:\1:'` && \
43 | # curl -L "https://github.com/krallin/tini/releases/download/v${TINI_VERSION}/tini_${TINI_VERSION}.deb" > tini.deb && \
44 | # dpkg -i tini.deb && \
45 | # rm tini.deb && \
46 | && apt-get clean
47 |
48 | # Install some basic utilities
49 | RUN apt-get update && apt-get install -y software-properties-common \
50 | && add-apt-repository -y ppa:redislabs/redis && apt-get update \
51 | && apt-get install -y sudo ssh libx11-6 gcc iputils-ping \
52 | libxrender-dev graphviz tmux htop build-essential wget cmake libgl1-mesa-glx redis \
53 | && rm -rf /var/lib/apt/lists/*
54 |
55 | RUN apt-get update && DEBIAN_FRONTEND=noninteractive \
56 | && apt-get install -y zlib1g zlib1g-dev libosmesa6-dev libgl1-mesa-glx libglfw3 libglew2.0
57 | # && ln -s /usr/lib/x86_64-linux-gnu/libGL.so.1 /usr/lib/x86_64-linux-gnu/libGL.so
58 | # ---------- now we should have all major dependencies ------------
59 |
60 | # --------- now we have cudagl + python38 ---------
61 | RUN pip install --no-cache-dir torch==1.13.1 torchvision==0.14.1 torchaudio==0.13.1 --extra-index-url https://download.pytorch.org/whl/cu116
62 |
63 | RUN pip install --no-cache-dir scikit-learn pandas imageio
64 |
65 | # adroit env does not work with newer version of mujoco, so we have to use the old version...
66 | # setting MUJOCO_GL to egl is needed if run on headless machine
67 | ENV MUJOCO_GL=egl
68 | ENV LD_LIBRARY_PATH=${LD_LIBRARY_PATH}:/workspace/.mujoco/mujoco210/bin
69 | RUN mkdir -p /workspace/.mujoco \
70 | && wget https://github.com/deepmind/mujoco/releases/download/2.1.0/mujoco210-linux-x86_64.tar.gz -O mujoco210.tar.gz \
71 | && tar -xvzf mujoco210.tar.gz -C /workspace/.mujoco \
72 | && rm mujoco210.tar.gz
73 |
74 | # RL env: get mujoco and gym (v2 version environments)
75 | RUN pip install --no-cache-dir patchelf==0.17.2.0 mujoco-py==2.1.2.14 gym==0.21.0
76 |
77 | # first time import mujoco_py will take some time to run, do it here to avoid the wait later
78 | RUN echo "import mujoco_py" >> /workspace/import_mujoco.py \
79 | && python /workspace/import_mujoco.py
80 |
81 | # get dmc with specific version
82 | RUN cd /workspace/ \
83 | && git clone https://github.com/deepmind/dm_control.git \
84 | && cd dm_control \
85 | && git checkout 644d9e0 \
86 | && pip install -e .
87 |
88 | # dependencies for DrQv2 and VRL3 (mainly from https://github.com/facebookresearch/drqv2/blob/main/conda_env.yml)
89 | RUN pip install --no-cache-dir absl-py==0.13.0 pyparsing==2.4.7 jupyterlab==3.0.14 scikit-image \
90 | termcolor==1.1.0 imageio-ffmpeg==0.4.4 hydra-core==1.1.0 hydra-submitit-launcher==1.1.5 ipdb==0.13.9 \
91 | yapf==0.31.0 opencv-python==4.5.3.56 psutil tb-nightly
92 |
93 | # get adroit (mainly following https://github.com/facebookresearch/RRL, with small changes to avoid conflict with drqv2 dependencies)
94 | RUN cd /workspace/ \
95 | && git clone https://github.com/watchernyu/rrl-dependencies.git \
96 | && cd rrl-dependencies \
97 | && pip install -e mj_envs/. \
98 | && pip install -e mjrl/.
99 |
100 | # important to have this so mujoco can compile stuff related to gpu rendering
101 | RUN apt-get update && apt-get install -y libglew-dev \
102 | && rm -rf /var/lib/apt/lists/*
103 |
104 | # if we don't have this path `/usr/lib/nvidia`, the old mujoco can render with cpu only and not even give an error message
105 | # makes everything very slow and drive people insane
106 | RUN mkdir /usr/lib/nvidia
107 |
108 | CMD [ "/bin/bash" ]
109 |
110 | # # build docker container:
111 | # docker build . -t name:tag
112 |
113 | # # example docker command to run interactive container, enable gpu, and remove it when shutdown:
114 | # docker run -it --rm --gpus all name:tag
115 |
--------------------------------------------------------------------------------
/plot_utils/drq_converter.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT License.
3 |
4 | # NOTE: plotting helper code is currently being cleaned up
5 |
6 | import numpy as np
7 | import pandas as pd
8 | import os
9 | from numpy import genfromtxt
10 |
11 | source_path = 'C:\\z\\drq_author'
12 | for subdir, dirs, files in os.walk(source_path):
13 | for file in files:
14 | if '_' in file and '.csv' in file:
15 | full_path = os.path.join(subdir, file)
16 | d = pd.read_csv(full_path, index_col=0)
17 | task = file[4:-4]
18 | for seed in range(1, 11):
19 | print(seed)
20 | d_seed = d[d.seed == seed]
21 | # now basically save them separately
22 | folder_name = 'drqv2_author_%s_s%d' % (task, seed)
23 | folder_path = os.path.join(source_path, folder_name)
24 | os.mkdir(folder_path)
25 | save_path = os.path.join(folder_path, 'eval.csv')
26 | d_seed.to_csv(save_path, index=False)
27 | print('saved to', save_path)
28 |
29 |
--------------------------------------------------------------------------------
/plot_utils/plot_config.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT License.
3 |
4 | # NOTE: plotting helper code is currently being cleaned up
5 |
6 | DEFAULT_AMLT_PATH = 'D:\\a'
7 | DEFAULT_BASE_DATA_PATH = '/home/watcher/Desktop/projects/vrl3logs'
8 | DEFAULT_BASE_SAVE_PATH = '/home/watcher/Desktop/projects/vrl3figures'
--------------------------------------------------------------------------------
/plot_utils/rrl_result_converter.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT License.
3 |
4 | # NOTE: plotting helper code is currently being cleaned up
5 |
6 | # for each result file in each rrl data folder
7 | # load them and convert into a format we can use easily...
8 | import os
9 |
10 | import numpy as np
11 | import pandas
12 | from numpy import genfromtxt
13 |
14 | base_path = 'C:\\z\\abl'
15 |
16 | for subdir, dirs, files in os.walk(base_path):
17 | if 'log.csv' in files:
18 | # folder name is here is typically also the variant name
19 | folder_name = os.path.basename(os.path.normpath(subdir)) # e.g. drq_nstep3_pt5_0.001_hopper_hop_s1
20 | original_log_path = os.path.join(subdir, 'log.csv')
21 |
22 | seed_string_index = folder_name.rfind('s')
23 | name_no_seed = folder_name[:seed_string_index - 1]
24 | seed_value = int(folder_name[seed_string_index + 1:])
25 |
26 | original_data = genfromtxt(original_log_path, dtype=float, delimiter=',', names=True)
27 | iters = original_data['iteration']
28 | sr = original_data['success_rate']
29 | # original_data['frames'] = iters * 40000
30 | new_df = {
31 | 'iter': np.arange(len(iters)),
32 | 'frame': iters * 40000,
33 | 'success_rate': sr/100
34 | }
35 | new_df = pandas.DataFrame(new_df)
36 | save_path = os.path.join(subdir, 'eval.csv')
37 | new_df.to_csv(save_path, index=False)
38 | save_path = os.path.join(subdir, 'train.csv')
39 | new_df.to_csv(save_path, index=False)
40 |
41 |
42 |
43 |
44 |
--------------------------------------------------------------------------------
/plot_utils/vrl3_plot_example.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT License.
3 | import os
4 | from vrl3_plot_helper import plot_aggregate_0506
5 |
6 | ##############################################################################
7 | ########## change base_dir to where you downloaded the example logs ##########
8 | ##############################################################################
9 | base_dir = '/home/watcher/Desktop/projects/vrl3examplelogs'
10 | base_save_dir = '/home/watcher/Desktop/projects/vrl3figures'
11 |
12 | FRAMEWORK_NAME = 'VRL3'
13 |
14 | NUM_MILLION_ON_X_AXIS = 4
15 | DEFAULT_Y_MIN = -0.02
16 | DEFAULT_X_MIN = -50000
17 | DEFAULT_Y_MAX = 1.02
18 | DEFAULT_X_MAX = 4050000
19 | X_LONG = 12050000
20 | ADROIT_ALL_ENVS = ["door", "pen", "hammer","relocate"]
21 | adroit_success_rate_default_min_max_dict = {'ymin':DEFAULT_Y_MIN, 'ymax':DEFAULT_Y_MAX, 'xmin':DEFAULT_X_MIN, 'xmax':DEFAULT_X_MAX}
22 | adroit_other_value_default_min_max_dict = {'ymin':None, 'ymax':None, 'xmin':DEFAULT_X_MIN, 'xmax':DEFAULT_X_MAX}
23 |
24 | def decide_placeholder_prefix(envs, folder_name, plot_name):
25 | if isinstance(envs, list):
26 | if len(envs) > 1:
27 | return 'aggregate-' + folder_name, 'aggregate_' + plot_name
28 | else:
29 | return folder_name,plot_name +'_'+ envs[0]
30 | return folder_name, plot_name +'_'+ envs
31 |
32 | def plot_paper_main_more_aggregate(envs=ADROIT_ALL_ENVS, no_legend=False):
33 | save_folder_name, save_prefix = decide_placeholder_prefix(envs, 'fs', 'adroit_main')
34 | # below are data folders under the base logs folder, the program will try to find training logs under these folders
35 | list_of_data_dir = ['rrl', 'vrl3']
36 | paths = [os.path.join(base_dir, data_dir) for data_dir in list_of_data_dir]
37 | save_folder = os.path.join(base_save_dir, save_folder_name)
38 |
39 | labels = [FRAMEWORK_NAME, 'RRL']
40 | variants = [
41 | ['vrl3_ar2_fs3_pixels_resnet6_32channel_pTrue_augTrue_lr0.0001_esauto_dbs64_pt5_0.001_s2eTrue_bc5000_cql30000_w1_s2std0.1_nr10_s3std0.01_bc0.001_0.5_s2d0_sa0_ra0_demo25_qt200_qf0.5_ef0_True_etTrue'],
42 | ['rrl'],
43 | ]
44 | colors = ['tab:red', 'tab:grey', 'tab:blue', 'tab:orange', 'tab:pink','tab:brown']
45 | dashes = ['solid' for _ in range(6)]
46 |
47 | label2variants, label2colors, label2linestyle = {}, {}, {}
48 | for i, label in enumerate(labels):
49 | label2variants[label] = variants[i]
50 | label2colors[label] = colors[i]
51 | label2linestyle[label] = dashes[i]
52 |
53 | for label, variants in label2variants.items(): # add environment
54 | variants_with_envs = []
55 | for variant in variants:
56 | for env in envs:
57 | variants_with_envs.append(variant + '_' + env)
58 | label2variants[label] = variants_with_envs
59 |
60 | for plot_y_value in ['success_rate']:
61 | save_parent_path = os.path.join(save_folder, plot_y_value)
62 | d = adroit_success_rate_default_min_max_dict if plot_y_value == 'success_rate' else adroit_other_value_default_min_max_dict
63 | save_name = '%s_%s' % (save_prefix, plot_y_value)
64 | print(save_parent_path, save_name)
65 | plot_aggregate_0506(paths, label2variants=label2variants, save_folder=save_parent_path, save_name=save_name,
66 | plot_y_value=plot_y_value, label2colors=label2colors, label2linestyle=label2linestyle,
67 | max_seed=11, no_legend=no_legend, **d)
68 |
69 | def plot_paper_main_more_per_task():
70 | for env in ADROIT_ALL_ENVS:
71 | plot_paper_main_more_aggregate(envs=[env], no_legend=(env!='relocate'))
72 |
73 | plot_paper_main_more_per_task()
74 | plot_paper_main_more_aggregate()
--------------------------------------------------------------------------------
/plot_utils/vrl3_plot_mover.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT License.
3 |
4 | # NOTE: plotting helper code is currently being cleaned up
5 |
6 | from vrl3_plot_helper import *
7 |
8 | # this maps a data folder name to several amlt exp folder names. for example,
9 | data_folder_to_amlt_dict = {
10 | 'distill': ['driven-bedbug', 'enabled-starfish', 'humorous-octopus', 'singular-treefrog', # distill with cap 1, num_c 32, change distill weight
11 | 'fair-corgi', 'present-honeybee', # low number of channels
12 | 'guiding-earwig', 'bright-piranha' # high cap ul encoder
13 | ],
14 | 'n_compress':['coherent-caribou'], # change number of layer in the compression layers
15 | 'double':['select-hamster', 'modest-guppy', 'accepted-dragon'], # double encoder design, performance not great
16 | 'intermediate':['precious-hagfish', 'pleased-mosquito', 'resolved-cattle', 'selected-jawfish'], # single encoder, but try use intermediate layer output, see which layer gives best result
17 | 'double-add':['sweet-glider'], # double resnet with additional conv output, see if additional output for double resnet can help
18 | 'test-pg':['well-cow', 'optimum-corgi', 'robust-termite'], # double standard enc 8 variants, single res 8 variants. These results will help decide whether our hypothesis on policy gradient is correct or not
19 | 'naivelarge':['liberal-macaque'], # naively do single encoder, rl update, deep encoder
20 | 'newbaseline3s':['peaceful-anchovy', 'poetic-airedale', 'native-fawn', 'light-peacock', 'firm-starfish'],
21 | 'newdouble3s':['precious-koi', 'central-louse',
22 | 'sound-asp', 'welcomed-doe', 'fancy-polecat',
23 | 'oriented-kid' # humanoid
24 | ],
25 | 'utd':['innocent-gar','trusting-warthog'],
26 | 'pt': ['major-snipe', 'oriented-fox', 'brief-barnacle', 'superb-shepherd', 'golden-bobcat',
27 | 'capital-alpaca', 'premium-ringtail', # stage 2
28 | ],
29 | 'drq-base':['enabling-vervet',], # drq baseline with drq original code
30 | 'tl': ['able-starling' ,# tl freeze on easy 6
31 | ],
32 | 'ssl': ['on-gazelle', 'sincere-walrus', # freeze
33 | 'light-sheep', 'settled-mastiff', 'cheerful-muskrat'],
34 | 's2-base':['needed-hedgehog'],
35 | 'ssl-s2':['verified-impala', 'expert-cockatoo'],
36 | 'adroit-firstgroup':['lasting-ferret', 'united-basilisk'],
37 | 'adroit-secondgroup': ['upright-starling','star-opossum','distinct-viper','humorous-dassie',
38 | 'genuine-lynx', # no data aug
39 | 'choice-monkey', 'new-mole', # smaller lr
40 | 'quiet-worm', 'cuddly-cat' # sanity check
41 | ],
42 | 'adroit-stable': ['regular-monkfish', 'unique-monster',
43 | ],
44 | 'adroit-stable-hyper': ['loved-mastodon', 'equipped-owl', 'proven-urchin', 'innocent-pangolin'], # 1-seed hyper search
45 | 'adroit-stable-hyper2': ['proven-urchin', 'innocent-pangolin'],
46 | 'adroit-policy-lr': ['exciting-toad'],
47 | 'adroit-s2cql':['composed-salmon','driven-reptile'],
48 | 'adroit-s2cql2': ['adequate-reptile', 'native-kodiak', 'summary-baboon', 'safe-mammal'],
49 | 'adroit-relocatehyper': ['dashing-grizzly'],
50 | 'adroit-relocatehyper2': ['vast-robin'],
51 | 'adroit-relocatehyper3': ['quality-gull', 'peaceful-elf'],
52 | 'adroit-relocatehyper4': ['giving-krill', 'key-lizard'],
53 | 'adroit-relocatehyper3-best-10seed': ['strong-impala'],
54 | 'res6-pen': ['handy-redbird'],
55 | 'relocate-abl1': ['easy-sloth', 'casual-sloth','engaged-reptile','national-killdeer',
56 | 'sweeping-lion','finer-pup','fine-civet','guiding-mule',
57 | 'sweet-cicada','enormous-malamute','special-louse'], # first set of ablation
58 | 'abl2': ['funny-eel','unified-moose','clear-crappie','blessed-redfish','full-weevil','credible-rat','active-javelin'],
59 | 'main':['dashing-platypus','magical-magpie','adapted-tuna','bursting-clam'],
60 | 'main2':['trusting-condor','cosmic-baboon','composed-foxhound','large-goshawk'],
61 | 'main3':['loved-anemone','e0113_main3.yaml','adapting-polecat','many-cheetah'],
62 | 'main4': ['living-boxer'],
63 | 'abl':[
64 | 's1_pretrain',
65 | 's2_bc_update','s2_cql_random','s2_cql_std','s2_cql_update','cql_weight','s2_enc_update',
66 | 's3_bc_decay','s3_bc_weight','s3_std','s23_aug',
67 | 's23_lr','s23_demo_bs','s23_enc_lr','s23_lr','s23_pretanh_weight','s23_safe_q_factor', # first round of ablations
68 | 'fs_s2_bc', 'fs_s3_freeze_enc', 'fs_s3_buffer_new',
69 | 's1_main_seeds', # 12 more seeds for main results #TODO add after we get that result, mainseeds2 currently not used
70 | 's2_naive_rl', # stage 2 do naive RL updates, and shut down safe Q target
71 | 's3_conservative', # stage 3 keep using conservative loss
72 | 's23_q_threshold', # change Q threshold value #TODO maybe add to hyper sens?
73 | 's23_pretanh_bc_limit', # TODO maybe appendix focus study?
74 | 's23_pretanh_cql_limit', # TODO maybe appendix focus study?
75 | 'special_drq', 'special_drq2', # drq baseline (4 envs) # TODO might need to add to main plots?
76 | 'special_sdrq', 'special_sdrq2', # drq baseline with safe q target technique
77 | 'fs_s3_freeze_extra', # use stage 1, freeze for stage 2,3, give intermediate level features # TODO what about this?
78 | 'special_ferm', 'special_ferm_abl', # ferm baseline (improved baseline)
79 | 'special_vrl3_ferm', # vrl3 + extra ferm updates in stage 2
80 | 'special_ferm_plus', # ferm, but with safe Q target technique
81 | 's23_bn', # bn finetune ablation # TODO focus study?
82 | 's23_mom_enc_elr', 's23_mom_enc_safeq', 's1_model', 'frame_stack', # post sub new results, TODO 'action_repeat',
83 | 'model_elr_new', 'q_min_new', # 'model_elr', 'q_min' # old jobs failed
84 | 'moral-mudfish', 'cuddly-chicken', # these two are stride experiments
85 | 'model_hs_elr3_new', 'model_hs_elr2_new', 'model_hs_elr1_new',
86 | 'fs_enc_s2only', 'fs_enc_s3only', 's2_naive_rl_safe', 'ferm_and_s1',
87 | ],
88 | # neurips 2022 new
89 | 'ablneurips':['f0728_bc_new_ablation_dhp', 'f0731_bc_new_ablation_r',
90 | 'f0728_rand_first_layer_dh', 'f0731_rand_first_layer_p',
91 | 'f0728_channel_mismatch_dh', 'f0731_channel_mismatch_p',
92 | 'f0731_channel_latent_flow_dh', 'f0731_channel_latent_flow_p',
93 | 'f0802_channel_latent_sub_dhp',
94 | # '0728_channel_mismatch_r', 'f0728_rand_first_layer_r', 'f0728_channel_latent_flow_r',
95 | 'f0802_channel_latent_sub_dhp',
96 | # 'f0806_rand_first_layer_fs1_r','f0806_channel_latent_sub_r','f0806_channel_mismatch_fs1_r','f0806_channel_latent_cat_r',
97 | # 'f0808_byol',
98 | 'f0808_bc_new_ablation_dhp', 'f0808_bc_new_ablation_r',
99 | 'f0808_channel_latent_cat_r', 'f0808_channel_latent_sub_r', 'f0808_channel_mismatch_fs1_r', 'f0808_rand_first_layer_fs1_r',
100 | 'f0817_rand_first_layer_dhp_new', 'f0817_channel_latent_cat_dhp', 'f0817_channel_latent_sub_dhp',
101 | 'f0817_channel_mismatch_dhp', 'f0817_bc_only_withs1_r', 'f0817_bc_only_nos1_dhp', 'f0817_bc_only_nos1_r',
102 | 'f0902_highcap', 'f0902_bc_new_ablation_nos1_r', 'f0902_byol'
103 | ],
104 | 'ablph':['ph_s1_pretrain', 'ph_fs_s3_freeze_enc', 'ph_s2_cql_update', 'ph_s2_enc_update', 'ph_s2_naive_rl', 'ph_s3_conservative', 'ph_s23_enc_lr',
105 | 'ph_s23_safe_q_factor', 'ph_s2_cql_random', 'ph_s2_cql_std', 'ph_s2_cql_weight', 'ph_s2_disable',
106 | 'ph_s3_std', 'ph_s23_lr', 'ph_s23_q_threshold', 'ph_s3_bc_decay', 'ph_s3_bc_weight', 'ph_s23_aug', 'ph_s23_demo_bs',
107 | 'ph_s23_pretanh_weight', 'ph_s2_bc_update',
108 | 'ph_s2_bc', 'ph_s23_bn', 'ph_s23_frame_stack', 'ph_s1_model',
109 | ], # these are for pen and hammer
110 | 'dmc':['dmc_vrl3_s3_medium', 'dmc_vrl3_s3_medium_pretrain_new', 'dmc_vrl3_s3_easy','dmc_vrl3_s3_easy_pretrain','dmc_vrl3_s3_hard_new','dmc_vrl3_s3_hard_pretrain_new',
111 | 'dmc_e25k_medium', 'dmc_e25k_easy_hard'],
112 | 'dmc_hyper':['dmc_e25k_medium_h1'], # hyper search # , 'dmc_e25k_hard_h1' dmc_e25k_medium_h2
113 | 'dmc_hyper2':['dmc_e25k_medium_h2'],
114 | 'dmc_hard_hyper1':['dmc_e25k_hard_h1'],
115 | 'dmc_newhyper_3seed': ['dmc_e25k_medium_nh', 'dmc_e25k_easy_hard_nh', 'ferm_dmc_e25k_all', 'dmc_all_drqv2fd']
116 |
117 | # dmc_e25k_easy_hard_nh dmc_e25k_medium_nh # 3 seed full exp using new hyperparameters
118 |
119 | # tsne_relocate_nobc1 tsne_relocate_nobc2 # will just analyze on devnode
120 |
121 | ## this is command to download all
122 | # amlt results -I "*.csv" s1_pretrain s2_bc_update s2_cql_random s2_cql_std s2_cql_update cql_weight s2_enc_update s3_bc_decay s3_bc_weight s3_std s23_aug s23_lr s23_demo_bs s23_enc_lr s23_lr s23_pretanh_weight s23_safe_q_factor fs_s2_bc fs_s3_freeze_enc fs_s3_buffer_new s1_main_seeds
123 |
124 | # amlt results -I "*.csv" s2_naive_rl s3_conservative s23_q_threshold s23_pretanh_bc_limit s23_pretanh_cql_limit special_drq special_drq2 special_sdrq special_sdrq2 fs_s3_freeze_extra special_ferm special_ferm_abl special_vrl3_ferm special_ferm_plus s23_bn
125 |
126 | # amlt results -I "*.csv" s23_safe_q_factor fs_s2_bc fs_s3_freeze_enc fs_s3_buffer_new s1_main_seeds
127 |
128 | # TODO 0507: following is all new experiments we have run
129 | # amlt results -I "*.csv" ph_s1_pretrain ph_fs_s3_freeze_enc ph_s2_cql_update ph_s2_enc_update ph_s2_naive_rl ph_s3_conservative ph_s23_enc_lr
130 | # ph_s23_safe_q_factor ph_s2_cql_random ph_s2_cql_std ph_s2_cql_weight ph_s2_disable
131 | # ph_s3_std ph_s23_lr ph_s23_q_threshold
132 | # ph_s3_bc_decay ph_s3_bc_weight ph_s23_aug ph_s23_demo_bs ph_s23_pretanh_weight
133 |
134 | # ferm_dmc_e25k_all fs_enc_s2only fs_enc_s3only dmc_all_drqv2fd #### s2_naive_rl_safe ferm_and_s1
135 | }
136 |
137 | # 0525 new: ph_s2_bc ph_s23_bn ph_s23_frame_stack ph_s1_model
138 |
139 | move_data = True
140 | move_keys = ['double', 'intermediate', 'double-add', 'test-pg', 'naivelarge']
141 | move_keys = [ 'newbaseline3s', 'newdouble3s']
142 | move_keys = [ 'double', 'intermediate', 'double-add',]
143 | move_keys = [ 'pt']
144 | move_keys = [ 'newbaseline3s', 'newdouble3s', 'pt']
145 | move_keys = [ 'newbaseline3s', 'newdouble3s', 'pt', 'utd', 'drq-base', 'ssl', 'tl', 's2-base', 'ssl-s2']
146 | move_keys = ['adroit-firstgroup', 'adroit-secondgroup', 'adroit-stable', 'adroit-stable-hyper', 'adroit-stable-hyper2', 'adroit-policy-lr']
147 | move_keys = ['adroit-policy-lr']
148 | move_keys = ['adroit-s2cql']
149 | move_keys = ['adroit-relocatehyper2', 'adroit-relocatehyper3', 'res6-pen'] # hyper 3 is best
150 | move_keys = ['adroit-relocatehyper4']
151 | move_keys = ['adroit-relocatehyper3-best-10seed', 'relocate-abl1', 'abl2', 'main', 'main2','main3', 'main4']
152 | move_keys = ['abl', 'dmc', 'dmc_hyper', 'dmc_hyper2', 'dmc_hard_hyper1', 'dmc_newhyper_3seed', 'ablph']
153 | move_keys = ['ablneurips']
154 | if move_data:
155 | for data_folder in move_keys:
156 | list_of_amlt_folder = data_folder_to_amlt_dict[data_folder]
157 | for amlt_folder in list_of_amlt_folder:
158 | move_data_from_amlt(amlt_folder, data_folder)
159 |
160 |
--------------------------------------------------------------------------------
/src/.gitignore:
--------------------------------------------------------------------------------
1 | __pycache__/
2 | .ipynb_checkpoints/
3 | exp_local
4 | exp
5 | exp_fixed
6 | exp_hard
7 | nbs
8 | code_snapshots
9 | exp_drqv2_*.py
10 | dmc_benchmarks.py
11 | check_sweep.py
12 | cancel_sweep.py
13 | data/
14 |
15 |
16 | # Byte-compiled / optimized / DLL files
17 | __pycache__/
18 | *.py[cod]
19 | *$py.class
20 |
21 | # C extensions
22 | *.so
23 |
24 | # Distribution / packaging
25 | .Python
26 | build/
27 | develop-eggs/
28 | dist/
29 | downloads/
30 | eggs/
31 | .eggs/
32 | lib/
33 | lib64/
34 | parts/
35 | sdist/
36 | var/
37 | wheels/
38 | pip-wheel-metadata/
39 | share/python-wheels/
40 | *.egg-info/
41 | .installed.cfg
42 | *.egg
43 | MANIFEST
44 |
45 | # PyInstaller
46 | # Usually these files are written by a python script from a template
47 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
48 | *.manifest
49 | *.spec
50 |
51 | # Installer logs
52 | pip-log.txt
53 | pip-delete-this-directory.txt
54 |
55 | # Unit test / coverage reports
56 | htmlcov/
57 | .tox/
58 | .nox/
59 | .coverage
60 | .coverage.*
61 | .cache
62 | nosetests.xml
63 | coverage.xml
64 | *.cover
65 | *.py,cover
66 | .hypothesis/
67 | .pytest_cache/
68 |
69 | # Translations
70 | *.mo
71 | *.pot
72 |
73 | # Django stuff:
74 | *.log
75 | local_settings.py
76 | db.sqlite3
77 | db.sqlite3-journal
78 |
79 | # Flask stuff:
80 | instance/
81 | .webassets-cache
82 |
83 | # Scrapy stuff:
84 | .scrapy
85 |
86 | # Sphinx documentation
87 | docs/_build/
88 |
89 | # PyBuilder
90 | target/
91 |
92 | # Jupyter Notebook
93 | .ipynb_checkpoints
94 |
95 | # IPython
96 | profile_default/
97 | ipython_config.py
98 |
99 | # pyenv
100 | .python-version
101 |
102 | # pipenv
103 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
104 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
105 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
106 | # install all needed dependencies.
107 | #Pipfile.lock
108 |
109 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
110 | __pypackages__/
111 |
112 | # Celery stuff
113 | celerybeat-schedule
114 | celerybeat.pid
115 |
116 | # SageMath parsed files
117 | *.sage.py
118 |
119 | # Environments
120 | .env
121 | .venv
122 | env/
123 | venv/
124 | ENV/
125 | env.bak/
126 | venv.bak/
127 |
128 | # Spyder project settings
129 | .spyderproject
130 | .spyproject
131 |
132 | # Rope project settings
133 | .ropeproject
134 |
135 | # mkdocs documentation
136 | /site
137 |
138 | # mypy
139 | .mypy_cache/
140 | .dmypy.json
141 | dmypy.json
142 |
143 | # Pyre type checker
144 | .pyre/
--------------------------------------------------------------------------------
/src/adroit.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT License.
3 |
4 | # NOTE: adroit env code is currently being cleaned up
5 |
6 |
7 | from collections import deque
8 | from typing import Any, NamedTuple
9 | import warnings
10 |
11 | import dm_env
12 | import numpy as np
13 | from dm_env import StepType, specs
14 | from collections import OrderedDict
15 | import mj_envs
16 | # import adept_envs # TODO worry about this later
17 | import gym
18 |
19 | from mjrl.utils.gym_env import GymEnv
20 | # from rrl_local.rrl_utils import make_basic_env, make_dir
21 | from rrl_local.rrl_multicam import BasicAdroitEnv, BasicFrankaEnv
22 |
23 | # similar to dmc.py, we will have environment wrapper here...
24 |
25 | class ExtendedTimeStep(NamedTuple):
26 | step_type: Any
27 | reward: Any
28 | discount: Any
29 | observation: Any
30 | action: Any
31 |
32 | def first(self):
33 | return self.step_type == StepType.FIRST
34 |
35 | def mid(self):
36 | return self.step_type == StepType.MID
37 |
38 | def last(self):
39 | return self.step_type == StepType.LAST
40 |
41 | def __getitem__(self, attr):
42 | return getattr(self, attr)
43 |
44 | class ExtendedTimeStepAdroit(NamedTuple):
45 | step_type: Any
46 | reward: Any
47 | discount: Any
48 | observation: Any
49 | observation_sensor: Any
50 | action: Any
51 | n_goal_achieved: Any
52 | time_limit_reached: Any
53 |
54 | def first(self):
55 | return self.step_type == StepType.FIRST
56 |
57 | def mid(self):
58 | return self.step_type == StepType.MID
59 |
60 | def last(self):
61 | return self.step_type == StepType.LAST
62 |
63 | def __getitem__(self, attr):
64 | return getattr(self, attr)
65 |
66 |
67 | class ActionRepeatWrapper(dm_env.Environment):
68 | def __init__(self, env, num_repeats):
69 | self._env = env
70 | self._num_repeats = num_repeats
71 |
72 | def step(self, action):
73 | reward = 0.0
74 | discount = 1.0
75 | for i in range(self._num_repeats):
76 | time_step = self._env.step(action)
77 | reward += (time_step.reward or 0.0) * discount
78 | discount *= time_step.discount
79 | if time_step.last():
80 | break
81 |
82 | return time_step._replace(reward=reward, discount=discount)
83 |
84 | def observation_spec(self):
85 | return self._env.observation_spec()
86 |
87 | def action_spec(self):
88 | return self._env.action_spec()
89 |
90 | def reset(self):
91 | return self._env.reset()
92 |
93 | def __getattr__(self, name):
94 | return getattr(self._env, name)
95 |
96 |
97 | class FrameStackWrapper(dm_env.Environment):
98 | def __init__(self, env, num_frames, pixels_key='pixels'):
99 | self._env = env
100 | self._num_frames = num_frames
101 | self._frames = deque([], maxlen=num_frames)
102 | self._pixels_key = pixels_key
103 |
104 | wrapped_obs_spec = env.observation_spec()
105 | assert pixels_key in wrapped_obs_spec
106 |
107 | pixels_shape = wrapped_obs_spec[pixels_key].shape
108 | # remove batch dim
109 | if len(pixels_shape) == 4:
110 | pixels_shape = pixels_shape[1:]
111 | self._obs_spec = specs.BoundedArray(shape=np.concatenate(
112 | [[pixels_shape[2] * num_frames], pixels_shape[:2]], axis=0),
113 | dtype=np.uint8,
114 | minimum=0,
115 | maximum=255,
116 | name='observation')
117 |
118 | def _transform_observation(self, time_step):
119 | assert len(self._frames) == self._num_frames
120 | obs = np.concatenate(list(self._frames), axis=0)
121 | return time_step._replace(observation=obs)
122 |
123 | def _extract_pixels(self, time_step):
124 | pixels = time_step.observation[self._pixels_key]
125 | # remove batch dim
126 | if len(pixels.shape) == 4:
127 | pixels = pixels[0]
128 | return pixels.transpose(2, 0, 1).copy()
129 |
130 | def reset(self):
131 | time_step = self._env.reset()
132 | pixels = self._extract_pixels(time_step)
133 | for _ in range(self._num_frames):
134 | self._frames.append(pixels)
135 | return self._transform_observation(time_step)
136 |
137 | def step(self, action):
138 | time_step = self._env.step(action)
139 | pixels = self._extract_pixels(time_step)
140 | self._frames.append(pixels)
141 | return self._transform_observation(time_step)
142 |
143 | def observation_spec(self):
144 | return self._obs_spec
145 |
146 | def action_spec(self):
147 | return self._env.action_spec()
148 |
149 | def __getattr__(self, name):
150 | return getattr(self._env, name)
151 |
152 |
153 | class ActionDTypeWrapper(dm_env.Environment):
154 | def __init__(self, env, dtype):
155 | self._env = env
156 | wrapped_action_spec = env.action_spec()
157 | self._action_spec = specs.BoundedArray(wrapped_action_spec.shape,
158 | dtype,
159 | wrapped_action_spec.minimum,
160 | wrapped_action_spec.maximum,
161 | 'action')
162 |
163 | def step(self, action):
164 | action = action.astype(self._env.action_spec().dtype)
165 | return self._env.step(action)
166 |
167 | def observation_spec(self):
168 | return self._env.observation_spec()
169 |
170 | def action_spec(self):
171 | return self._action_spec
172 |
173 | def reset(self):
174 | return self._env.reset()
175 |
176 | def __getattr__(self, name):
177 | return getattr(self._env, name)
178 |
179 |
180 | class ExtendedTimeStepWrapper(dm_env.Environment):
181 | def __init__(self, env):
182 | self._env = env
183 |
184 | def reset(self):
185 | time_step = self._env.reset()
186 | return self._augment_time_step(time_step)
187 |
188 | def step(self, action):
189 | time_step = self._env.step(action)
190 | return self._augment_time_step(time_step, action)
191 |
192 | def _augment_time_step(self, time_step, action=None):
193 | if action is None:
194 | action_spec = self.action_spec()
195 | action = np.zeros(action_spec.shape, dtype=action_spec.dtype)
196 | return ExtendedTimeStep(observation=time_step.observation,
197 | step_type=time_step.step_type,
198 | action=action,
199 | reward=time_step.reward or 0.0,
200 | discount=time_step.discount or 1.0)
201 |
202 | def observation_spec(self):
203 | return self._env.observation_spec()
204 |
205 | def action_spec(self):
206 | return self._env.action_spec()
207 |
208 | def __getattr__(self, name):
209 | return getattr(self._env, name)
210 |
211 |
212 | def make_basic_env(env, cam_list=[], from_pixels=False, hybrid_state=None, test_image=False, channels_first=False,
213 | num_repeats=1, num_frames=1):
214 | e = GymEnv(env)
215 | env_kwargs = None
216 | if from_pixels : # TODO might want to improve this part
217 | height = 84
218 | width = 84
219 | latent_dim = height*width*len(cam_list)*3
220 | # RRL class instance is environment wrapper...
221 | e = BasicAdroitEnv(e, cameras=cam_list,
222 | height=height, width=width, latent_dim=latent_dim, hybrid_state=hybrid_state,
223 | test_image=test_image, channels_first=channels_first, num_repeats=num_repeats, num_frames=num_frames)
224 | env_kwargs = {'rrl_kwargs' : e.env_kwargs}
225 | # if not from pixels... then it's simpler
226 | return e, env_kwargs
227 |
228 | class AdroitEnv:
229 | # a wrapper class that will make Adroit env looks like a dmc env
230 | def __init__(self, env_name, test_image=False, cam_list=None,
231 | num_repeats=2, num_frames=3, env_feature_type='pixels', device=None, reward_rescale=False):
232 | default_env_to_cam_list = {
233 | 'hammer-v0': ['top'],
234 | 'door-v0': ['top'],
235 | 'pen-v0': ['vil_camera'],
236 | 'relocate-v0': ['cam1', 'cam2', 'cam3',],
237 | }
238 | if cam_list is None:
239 | cam_list = default_env_to_cam_list[env_name]
240 | self.env_name = env_name
241 | reward_rescale_dict = {
242 | 'hammer-v0': 1/100,
243 | 'door-v0': 1/20,
244 | 'pen-v0': 1/50,
245 | 'relocate-v0': 1/30,
246 | }
247 | if reward_rescale:
248 | self.reward_rescale_factor = reward_rescale_dict[env_name]
249 | else:
250 | self.reward_rescale_factor = 1
251 |
252 | # env, _ = make_basic_env(env_name, cam_list=cam_list, from_pixels=from_pixels, hybrid_state=True,
253 | # test_image=test_image, channels_first=True, num_repeats=num_repeats, num_frames=num_frames)
254 | env = GymEnv(env_name)
255 | if env_feature_type == 'state':
256 | raise NotImplementedError("state env not ready")
257 | elif env_feature_type == 'resnet18' or env_feature_type == 'resnet34' :
258 | # TODO maybe we will just throw everything into it..
259 | height = 256
260 | width = 256
261 | latent_dim = 512
262 | env = BasicAdroitEnv(env, cameras=cam_list,
263 | height=height, width=width, latent_dim=latent_dim, hybrid_state=True,
264 | test_image=test_image, channels_first=False, num_repeats=num_repeats, num_frames=num_frames, encoder_type=env_feature_type,
265 | device=device
266 | )
267 | elif env_feature_type == 'pixels':
268 | height = 84
269 | width = 84
270 | latent_dim = height*width*len(cam_list)*num_frames
271 | # RRL class instance is environment wrapper...
272 | env = BasicAdroitEnv(env, cameras=cam_list,
273 | height=height, width=width, latent_dim=latent_dim, hybrid_state=True,
274 | test_image=test_image, channels_first=True, num_repeats=num_repeats, num_frames=num_frames, device=device)
275 | else:
276 | raise ValueError("env feature not supported")
277 |
278 | self._env = env
279 | self.obs_dim = env.spec.observation_dim
280 | self.obs_sensor_dim = 24
281 | self.act_dim = env.spec.action_dim
282 | self.horizon = env.spec.horizon
283 | number_channel = len(cam_list) * 3 * num_frames
284 |
285 | if env_feature_type == 'pixels':
286 | self._obs_spec = specs.BoundedArray(shape=(number_channel, 84, 84), dtype='uint8', name='observation', minimum=0, maximum=255)
287 | self._obs_sensor_spec = specs.Array(shape=(self.obs_sensor_dim,), dtype='float32', name='observation_sensor')
288 | elif env_feature_type == 'resnet18' or env_feature_type == 'resnet34' :
289 | self._obs_spec = specs.Array(shape=(512 * num_frames *len(cam_list) ,), dtype='float32', name='observation') # TODO fix magic number
290 | self._obs_sensor_spec = specs.Array(shape=(self.obs_sensor_dim,), dtype='float32', name='observation_sensor')
291 | self._action_spec = specs.BoundedArray(shape=(self.act_dim,), dtype='float32', name='action', minimum=-1.0, maximum=1.0)
292 |
293 | def reset(self):
294 | # pixels and sensor values
295 | obs_pixels, obs_sensor = self._env.reset()
296 | obs_sensor = obs_sensor.astype(np.float32)
297 | action_spec = self.action_spec()
298 | action = np.zeros(action_spec.shape, dtype=action_spec.dtype)
299 |
300 | time_step = ExtendedTimeStepAdroit(observation=obs_pixels,
301 | observation_sensor=obs_sensor,
302 | step_type=StepType.FIRST,
303 | action=action,
304 | reward=0.0,
305 | discount=1.0,
306 | n_goal_achieved=0,
307 | time_limit_reached=False)
308 | return time_step
309 |
310 | def get_current_obs_without_reset(self):
311 | # use this to obtain the first state in a demo
312 | obs_pixels, obs_sensor = self._env.get_obs_for_first_state_but_without_reset()
313 | obs_sensor = obs_sensor.astype(np.float32)
314 | action_spec = self.action_spec()
315 | action = np.zeros(action_spec.shape, dtype=action_spec.dtype)
316 |
317 | time_step = ExtendedTimeStepAdroit(observation=obs_pixels,
318 | observation_sensor=obs_sensor,
319 | step_type=StepType.FIRST,
320 | action=action,
321 | reward=0.0,
322 | discount=1.0,
323 | n_goal_achieved=0,
324 | time_limit_reached=False)
325 | return time_step
326 |
327 | def get_pixels_with_width_height(self, w, h):
328 | return self._env.get_pixels_with_width_height(w, h)
329 |
330 | def step(self, action, force_step_type=None, debug=False):
331 | obs_all, reward, done, env_info = self._env.step(action)
332 | obs_pixels, obs_sensor = obs_all
333 | obs_sensor = obs_sensor.astype(np.float32)
334 |
335 | discount = 1.0
336 | n_goal_achieved = env_info['n_goal_achieved']
337 | time_limit_reached = env_info['TimeLimit.truncated'] if 'TimeLimit.truncated' in env_info else False
338 | if done:
339 | steptype = StepType.LAST
340 | else:
341 | steptype = StepType.MID
342 |
343 | if done and not time_limit_reached:
344 | discount = 0.0
345 |
346 | if force_step_type is not None:
347 | if force_step_type == 'mid':
348 | steptype = StepType.MID
349 | elif force_step_type == 'last':
350 | steptype = StepType.LAST
351 | else:
352 | steptype = StepType.FIRST
353 |
354 | reward = reward * self.reward_rescale_factor
355 |
356 | time_step = ExtendedTimeStepAdroit(observation=obs_pixels,
357 | observation_sensor=obs_sensor,
358 | step_type=steptype,
359 | action=action,
360 | reward=reward,
361 | discount=discount,
362 | n_goal_achieved=n_goal_achieved,
363 | time_limit_reached=time_limit_reached)
364 |
365 | if debug:
366 | return obs_all, reward, done, env_info
367 | return time_step
368 |
369 | def observation_spec(self):
370 | return self._obs_spec
371 |
372 | def observation_sensor_spec(self):
373 | return self._obs_sensor_spec
374 |
375 | def action_spec(self):
376 | return self._action_spec
377 |
378 | def set_env_state(self, state):
379 | self._env.set_env_state(state)
380 | # def __getattr__(self, name):
381 | # return getattr(self, name)
382 |
383 |
384 |
--------------------------------------------------------------------------------
/src/cfgs_adroit/config.yaml:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT License.
3 |
4 | # modified from DrQv2 config file
5 |
6 | defaults:
7 | - _self_
8 | - task@_global_: door
9 | - override hydra/launcher: submitit_local
10 |
11 | # task settings
12 | debug: 0
13 | # frame_stack: in paper actually used 3 as default for adroit, however
14 | # it can take a lot more memory especially for relocate, and later ablation shows
15 | # frame_stack does not affect performance significantly in Adroit, so here we set it to 1.
16 | frame_stack: 1
17 | action_repeat: 2
18 | discount: 0.99
19 | # eval
20 | eval_every_frames: 50000
21 | num_eval_episodes: 25
22 | stage2_eval_every_frames: 5000
23 | stage2_num_eval_episodes: 5
24 | # snapshot
25 | save_snapshot: false
26 | # replay buffer
27 | replay_buffer_size: 1000000
28 | replay_buffer_num_workers: 4
29 | nstep: 3
30 | batch_size: 256
31 | # misc
32 | seed: 1
33 | device: cuda
34 | save_video: false
35 | save_train_video: false
36 | use_tb: false
37 | save_models: false
38 | local_data_dir: '/vrl3data'
39 | show_computation_time_est: true # provide estimates on computation times
40 | show_time_est_interval: 2500
41 | # experiment
42 | experiment: exp
43 | # environment
44 | env_feature_type: 'pixels'
45 | use_sensor: true
46 | reward_rescale: true # TODO think about this...
47 | # agent
48 | lr: 1e-4
49 | feature_dim: 50
50 |
51 | # ====== stage 1 ======
52 | stage1_model_name: 'resnet6_32channel' # model name example: resnet6_32channel, resnet6_64channel, resnet10_32channel
53 | stage1_use_pretrain: true # if false then we start from scratch
54 |
55 | # ====== stage 2 ======
56 | load_demo: true
57 | num_demo: 25
58 | stage2_n_update: 30000
59 |
60 | # ====== stage 3 ======
61 | num_seed_frames: 0
62 |
63 | agent:
64 | _target_: vrl3_agent.VRL3Agent
65 | obs_shape: ??? # to be specified later
66 | action_shape: ??? # to be specified later
67 | device: ${device}
68 | critic_target_tau: 0.01
69 | update_every_steps: 2
70 | use_tb: ${use_tb}
71 | lr: 1e-4
72 | hidden_dim: 1024
73 | feature_dim: 50
74 |
75 | # environment related
76 | use_sensor: ${use_sensor}
77 |
78 | # ====== stage 1 ======
79 | stage1_model_name: ${stage1_model_name}
80 |
81 | # ====== stage 2, 3 ======
82 | use_data_aug: true
83 | encoder_lr_scale: 1
84 | stddev_clip: 0.3
85 | # safe Q
86 | safe_q_target_factor: 0.5 # value 1 is not using safe factor, value 0 is hard cutoff.
87 | safe_q_threshold: 200
88 | # pretanh penalty
89 | pretanh_threshold: 5
90 | pretanh_penalty: 0.001
91 |
92 | # ====== stage 2 ======
93 | stage2_update_encoder: true # decides whether encoder is frozen or finetune in stage 2
94 | cql_weight: 1
95 | cql_temp: 1
96 | cql_n_random: 10
97 | stage2_std: 0.1
98 | stage2_bc_weight: 0 # ablation shows additional BC does not help performance
99 |
100 | # ====== stage 3 ======
101 | stage3_update_encoder: true
102 | num_expl_steps: 0 # number of random actions at start of stage 3
103 | # std decay
104 | std0: 0.01
105 | std1: 0.01
106 | std_n_decay: 500000
107 | # bc decay
108 | stage3_bc_lam0: 0 # ablation shows additional BC does not help performance
109 | stage3_bc_lam1: 0.95
110 |
111 | hydra:
112 | run: # this "dir" decides where the training logs are stored
113 | dir: /vrl3data/logs/exp_local/${now:%Y.%m.%d}/${now:%H%M%S}_${hydra.job.override_dirname}
114 | sweep:
115 | dir: ./exp/${now:%Y.%m.%d}/${now:%H%M}_${agent_cfg.experiment}
116 | subdir: ${hydra.job.num}
117 | launcher:
118 | timeout_min: 4300
119 | cpus_per_task: 10
120 | gpus_per_node: 1
121 | tasks_per_node: 1
122 | mem_gb: 160
123 | nodes: 1
124 | submitit_folder: ./exp/${now:%Y.%m.%d}/${now:%H%M%S}_${agent_cfg.experiment}/.slurm
125 |
--------------------------------------------------------------------------------
/src/cfgs_adroit/task/door.yaml:
--------------------------------------------------------------------------------
1 | num_train_frames: 4100000
2 | task_name: door-v0
3 | agent:
4 | encoder_lr_scale: 1
--------------------------------------------------------------------------------
/src/cfgs_adroit/task/hammer.yaml:
--------------------------------------------------------------------------------
1 | num_train_frames: 4100000
2 | task_name: hammer-v0
3 | agent:
4 | encoder_lr_scale: 1
--------------------------------------------------------------------------------
/src/cfgs_adroit/task/pen.yaml:
--------------------------------------------------------------------------------
1 | num_train_frames: 4100000
2 | task_name: pen-v0
3 | agent:
4 | encoder_lr_scale: 1
--------------------------------------------------------------------------------
/src/cfgs_adroit/task/relocate.yaml:
--------------------------------------------------------------------------------
1 | num_train_frames: 4100000
2 | task_name: relocate-v0
3 | agent:
4 | encoder_lr_scale: 0.01
--------------------------------------------------------------------------------
/src/dmc.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 | from collections import deque
6 | from typing import Any, NamedTuple
7 |
8 | import dm_env
9 | import numpy as np
10 | from dm_control import manipulation, suite
11 | from dm_control.suite.wrappers import action_scale, pixels
12 | from dm_env import StepType, specs
13 |
14 | class ExtendedTimeStep(NamedTuple):
15 | step_type: Any
16 | reward: Any
17 | discount: Any
18 | observation: Any
19 | action: Any
20 |
21 | def first(self):
22 | return self.step_type == StepType.FIRST
23 |
24 | def mid(self):
25 | return self.step_type == StepType.MID
26 |
27 | def last(self):
28 | return self.step_type == StepType.LAST
29 |
30 | def __getitem__(self, attr):
31 | return getattr(self, attr)
32 |
33 |
34 | class ActionRepeatWrapper(dm_env.Environment):
35 | def __init__(self, env, num_repeats):
36 | self._env = env
37 | self._num_repeats = num_repeats
38 |
39 | def step(self, action):
40 | reward = 0.0
41 | discount = 1.0
42 | for i in range(self._num_repeats):
43 | time_step = self._env.step(action)
44 | reward += (time_step.reward or 0.0) * discount
45 | discount *= time_step.discount
46 | if time_step.last():
47 | break
48 |
49 | return time_step._replace(reward=reward, discount=discount)
50 |
51 | def observation_spec(self):
52 | return self._env.observation_spec()
53 |
54 | def action_spec(self):
55 | return self._env.action_spec()
56 |
57 | def reset(self):
58 | return self._env.reset()
59 |
60 | def __getattr__(self, name):
61 | return getattr(self._env, name)
62 |
63 |
64 | class FrameStackWrapper(dm_env.Environment):
65 | def __init__(self, env, num_frames, pixels_key='pixels'):
66 | self._env = env
67 | self._num_frames = num_frames
68 | self._frames = deque([], maxlen=num_frames)
69 | self._pixels_key = pixels_key
70 |
71 | wrapped_obs_spec = env.observation_spec()
72 | assert pixels_key in wrapped_obs_spec
73 |
74 | pixels_shape = wrapped_obs_spec[pixels_key].shape
75 | # remove batch dim
76 | if len(pixels_shape) == 4:
77 | pixels_shape = pixels_shape[1:]
78 | self._obs_spec = specs.BoundedArray(shape=np.concatenate(
79 | [[pixels_shape[2] * num_frames], pixels_shape[:2]], axis=0),
80 | dtype=np.uint8,
81 | minimum=0,
82 | maximum=255,
83 | name='observation')
84 |
85 | def _transform_observation(self, time_step):
86 | assert len(self._frames) == self._num_frames
87 | obs = np.concatenate(list(self._frames), axis=0)
88 | return time_step._replace(observation=obs)
89 |
90 | def _extract_pixels(self, time_step):
91 | pixels = time_step.observation[self._pixels_key]
92 | # remove batch dim
93 | if len(pixels.shape) == 4:
94 | pixels = pixels[0]
95 |
96 | # transpose: 84 x 84 x 3 -> 3 x 84 x 84
97 | return pixels.transpose(2, 0, 1).copy()
98 |
99 | def reset(self):
100 | time_step = self._env.reset()
101 | pixels = self._extract_pixels(time_step)
102 | for _ in range(self._num_frames):
103 | self._frames.append(pixels)
104 | return self._transform_observation(time_step)
105 |
106 | def step(self, action):
107 | time_step = self._env.step(action)
108 | pixels = self._extract_pixels(time_step)
109 | self._frames.append(pixels)
110 | return self._transform_observation(time_step)
111 |
112 | def observation_spec(self):
113 | return self._obs_spec
114 |
115 | def action_spec(self):
116 | return self._env.action_spec()
117 |
118 | def __getattr__(self, name):
119 | return getattr(self._env, name)
120 |
121 |
122 | class ActionDTypeWrapper(dm_env.Environment):
123 | def __init__(self, env, dtype):
124 | self._env = env
125 | wrapped_action_spec = env.action_spec()
126 | self._action_spec = specs.BoundedArray(wrapped_action_spec.shape,
127 | dtype,
128 | wrapped_action_spec.minimum,
129 | wrapped_action_spec.maximum,
130 | 'action')
131 |
132 | def step(self, action):
133 | action = action.astype(self._env.action_spec().dtype)
134 | return self._env.step(action)
135 |
136 | def observation_spec(self):
137 | return self._env.observation_spec()
138 |
139 | def action_spec(self):
140 | return self._action_spec
141 |
142 | def reset(self):
143 | return self._env.reset()
144 |
145 | def __getattr__(self, name):
146 | return getattr(self._env, name)
147 |
148 |
149 | class ExtendedTimeStepWrapper(dm_env.Environment):
150 | def __init__(self, env):
151 | self._env = env
152 |
153 | def reset(self):
154 | time_step = self._env.reset()
155 | return self._augment_time_step(time_step)
156 |
157 | def step(self, action):
158 | time_step = self._env.step(action)
159 | return self._augment_time_step(time_step, action)
160 |
161 | def _augment_time_step(self, time_step, action=None):
162 | if action is None:
163 | action_spec = self.action_spec()
164 | action = np.zeros(action_spec.shape, dtype=action_spec.dtype)
165 | return ExtendedTimeStep(observation=time_step.observation,
166 | step_type=time_step.step_type,
167 | action=action,
168 | reward=time_step.reward or 0.0,
169 | discount=time_step.discount or 1.0)
170 |
171 | def observation_spec(self):
172 | return self._env.observation_spec()
173 |
174 | def action_spec(self):
175 | return self._env.action_spec()
176 |
177 | def __getattr__(self, name):
178 | return getattr(self._env, name)
179 |
180 |
181 | def make(name, frame_stack, action_repeat, seed):
182 | domain, task = name.split('_', 1)
183 | # overwrite cup to ball_in_cup
184 | domain = dict(cup='ball_in_cup').get(domain, domain)
185 | # make sure reward is not visualized
186 | if (domain, task) in suite.ALL_TASKS:
187 | env = suite.load(domain,
188 | task,
189 | task_kwargs={'random': seed},
190 | visualize_reward=False)
191 | pixels_key = 'pixels'
192 | else:
193 | name = f'{domain}_{task}_vision'
194 | env = manipulation.load(name, seed=seed)
195 | pixels_key = 'front_close'
196 | # add wrappers
197 | env = ActionDTypeWrapper(env, np.float32)
198 | env = ActionRepeatWrapper(env, action_repeat)
199 | env = action_scale.Wrapper(env, minimum=-1.0, maximum=+1.0)
200 | # add renderings for clasical tasks
201 | if (domain, task) in suite.ALL_TASKS:
202 | # zoom in camera for quadruped
203 | camera_id = dict(quadruped=2).get(domain, 0)
204 | render_kwargs = dict(height=84, width=84, camera_id=camera_id)
205 | env = pixels.Wrapper(env,
206 | pixels_only=True,
207 | render_kwargs=render_kwargs)
208 | # stack several frames
209 | env = FrameStackWrapper(env, frame_stack, pixels_key)
210 | env = ExtendedTimeStepWrapper(env)
211 | return env
212 |
--------------------------------------------------------------------------------
/src/logger.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 | import csv
6 | import datetime
7 | from collections import defaultdict
8 |
9 | import numpy as np
10 | import torch
11 | import torchvision
12 | from termcolor import colored
13 | from torch.utils.tensorboard import SummaryWriter
14 |
15 | COMMON_TRAIN_FORMAT = [('frame', 'F', 'int'), ('step', 'S', 'int'),
16 | ('episode', 'E', 'int'), ('episode_length', 'L', 'int'),
17 | ('episode_reward', 'R', 'float'),
18 | ('buffer_size', 'BS', 'int'), ('fps', 'FPS', 'float'),
19 | ('total_time', 'T', 'time')]
20 |
21 | COMMON_EVAL_FORMAT = [('frame', 'F', 'int'), ('step', 'S', 'int'),
22 | ('episode', 'E', 'int'), ('episode_length', 'L', 'int'),
23 | ('episode_reward', 'R', 'float'),
24 | ('total_time', 'T', 'time')]
25 |
26 |
27 | class AverageMeter(object):
28 | def __init__(self):
29 | self._sum = 0
30 | self._count = 0
31 |
32 | def update(self, value, n=1):
33 | self._sum += value
34 | self._count += n
35 |
36 | def value(self):
37 | return self._sum / max(1, self._count)
38 |
39 |
40 | class MetersGroup(object):
41 | def __init__(self, csv_file_name, formating):
42 | self._csv_file_name = csv_file_name
43 | self._formating = formating
44 | self._meters = defaultdict(AverageMeter)
45 | self._csv_file = None
46 | self._csv_writer = None
47 |
48 | def log(self, key, value, n=1):
49 | self._meters[key].update(value, n)
50 |
51 | def _prime_meters(self):
52 | data = dict()
53 | for key, meter in self._meters.items():
54 | if key.startswith('train'):
55 | key = key[len('train') + 1:]
56 | else:
57 | key = key[len('eval') + 1:]
58 | key = key.replace('/', '_')
59 | data[key] = meter.value()
60 | return data
61 |
62 | def _remove_old_entries(self, data):
63 | rows = []
64 | with self._csv_file_name.open('r') as f:
65 | reader = csv.DictReader(f)
66 | for row in reader:
67 | if float(row['episode']) >= data['episode']:
68 | break
69 | rows.append(row)
70 | with self._csv_file_name.open('w') as f:
71 | writer = csv.DictWriter(f,
72 | fieldnames=sorted(data.keys()),
73 | restval=0.0)
74 | writer.writeheader()
75 | for row in rows:
76 | writer.writerow(row)
77 |
78 | def _dump_to_csv(self, data):
79 | if self._csv_writer is None:
80 | should_write_header = True
81 | if self._csv_file_name.exists():
82 | self._remove_old_entries(data)
83 | should_write_header = False
84 |
85 | self._csv_file = self._csv_file_name.open('a')
86 | self._csv_writer = csv.DictWriter(self._csv_file,
87 | fieldnames=sorted(data.keys()),
88 | restval=0.0)
89 | if should_write_header:
90 | self._csv_writer.writeheader()
91 |
92 | self._csv_writer.writerow(data)
93 | self._csv_file.flush()
94 |
95 | def _format(self, key, value, ty):
96 | if ty == 'int':
97 | value = int(value)
98 | return f'{key}: {value}'
99 | elif ty == 'float':
100 | return f'{key}: {value:.04f}'
101 | elif ty == 'time':
102 | value = str(datetime.timedelta(seconds=int(value)))
103 | return f'{key}: {value}'
104 | else:
105 | raise f'invalid format type: {ty}'
106 |
107 | def _dump_to_console(self, data, prefix):
108 | prefix = colored(prefix, 'yellow' if prefix == 'train' else 'green')
109 | pieces = [f'| {prefix: <14}']
110 | for key, disp_key, ty in self._formating:
111 | value = data.get(key, 0)
112 | pieces.append(self._format(disp_key, value, ty))
113 | print(' | '.join(pieces))
114 |
115 | def dump(self, step, prefix):
116 | if len(self._meters) == 0:
117 | return
118 | data = self._prime_meters()
119 | data['frame'] = step
120 | self._dump_to_csv(data)
121 | self._dump_to_console(data, prefix)
122 | self._meters.clear()
123 |
124 |
125 | class Logger(object):
126 | def __init__(self, log_dir, use_tb, stage2_logger=False):
127 | self._log_dir = log_dir
128 | if not stage2_logger:
129 | self._train_mg = MetersGroup(log_dir / 'train.csv',
130 | formating=COMMON_TRAIN_FORMAT)
131 | self._eval_mg = MetersGroup(log_dir / 'eval.csv',
132 | formating=COMMON_EVAL_FORMAT)
133 | else:
134 | self._train_mg = MetersGroup(log_dir / 'train_stage2.csv',
135 | formating=COMMON_TRAIN_FORMAT)
136 | self._eval_mg = MetersGroup(log_dir / 'eval_stage2.csv',
137 | formating=COMMON_EVAL_FORMAT)
138 | if use_tb:
139 | self._sw = SummaryWriter(str(log_dir / 'tb'))
140 | else:
141 | self._sw = None
142 |
143 | def _try_sw_log(self, key, value, step):
144 | if self._sw is not None:
145 | self._sw.add_scalar(key, value, step)
146 |
147 | def log(self, key, value, step):
148 | assert key.startswith('train') or key.startswith('eval')
149 | if type(value) == torch.Tensor:
150 | value = value.item()
151 | self._try_sw_log(key, value, step)
152 | mg = self._train_mg if key.startswith('train') else self._eval_mg
153 | mg.log(key, value)
154 |
155 | def log_metrics(self, metrics, step, ty):
156 | for key, value in metrics.items():
157 | self.log(f'{ty}/{key}', value, step)
158 |
159 | def dump(self, step, ty=None):
160 | if ty is None or ty == 'eval':
161 | self._eval_mg.dump(step, 'eval')
162 | if ty is None or ty == 'train':
163 | self._train_mg.dump(step, 'train')
164 |
165 | def log_and_dump_ctx(self, step, ty):
166 | return LogAndDumpCtx(self, step, ty)
167 |
168 |
169 | class LogAndDumpCtx:
170 | def __init__(self, logger, step, ty):
171 | self._logger = logger
172 | self._step = step
173 | self._ty = ty
174 |
175 | def __enter__(self):
176 | return self
177 |
178 | def __call__(self, key, value):
179 | self._logger.log(f'{self._ty}/{key}', value, self._step)
180 |
181 | def __exit__(self, *args):
182 | self._logger.dump(self._step, self._ty)
183 |
--------------------------------------------------------------------------------
/src/replay_buffer.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 | import datetime
6 | import io
7 | import random
8 | import traceback
9 | from collections import defaultdict
10 |
11 | import numpy as np
12 | import torch
13 | import torch.nn as nn
14 | from torch.utils.data import IterableDataset
15 |
16 |
17 | def episode_len(episode):
18 | # subtract -1 because the dummy first transition
19 | return next(iter(episode.values())).shape[0] - 1
20 |
21 |
22 | def save_episode(episode, fn):
23 | with io.BytesIO() as bs:
24 | np.savez_compressed(bs, **episode)
25 | bs.seek(0)
26 | with fn.open('wb') as f:
27 | f.write(bs.read())
28 |
29 |
30 | def load_episode(fn):
31 | with fn.open('rb') as f:
32 | episode = np.load(f)
33 | episode = {k: episode[k] for k in episode.keys()}
34 | return episode
35 |
36 |
37 | class ReplayBufferStorage:
38 | def __init__(self, data_specs, replay_dir):
39 | self._data_specs = data_specs
40 | self._replay_dir = replay_dir
41 | replay_dir.mkdir(exist_ok=True)
42 | self._current_episode = defaultdict(list)
43 | self._preload()
44 |
45 | def __len__(self):
46 | return self._num_transitions
47 |
48 | def add(self, time_step):
49 | for spec in self._data_specs:
50 | value = time_step[spec.name]
51 | if np.isscalar(value):
52 | value = np.full(spec.shape, value, spec.dtype)
53 | # print(spec.name, spec.shape, spec.dtype, value.shape, value.dtype)
54 | assert spec.shape == value.shape and spec.dtype == value.dtype
55 | self._current_episode[spec.name].append(value)
56 | if time_step.last():
57 | episode = dict()
58 | for spec in self._data_specs:
59 | value = self._current_episode[spec.name]
60 | episode[spec.name] = np.array(value, spec.dtype)
61 | self._current_episode = defaultdict(list)
62 | self._store_episode(episode)
63 |
64 | def _preload(self):
65 | self._num_episodes = 0
66 | self._num_transitions = 0
67 | for fn in self._replay_dir.glob('*.npz'):
68 | _, _, eps_len = fn.stem.split('_')
69 | self._num_episodes += 1
70 | self._num_transitions += int(eps_len)
71 |
72 | def _store_episode(self, episode):
73 | eps_idx = self._num_episodes
74 | eps_len = episode_len(episode)
75 | self._num_episodes += 1
76 | self._num_transitions += eps_len
77 | ts = datetime.datetime.now().strftime('%Y%m%dT%H%M%S')
78 | eps_fn = f'{ts}_{eps_idx}_{eps_len}.npz'
79 | save_episode(episode, self._replay_dir / eps_fn)
80 |
81 |
82 | class ReplayBuffer(IterableDataset):
83 | def __init__(self, replay_dir, max_size, num_workers, nstep, discount,
84 | fetch_every, save_snapshot, is_adroit=False, return_next_action=False):
85 | self._replay_dir = replay_dir
86 | self._size = 0
87 | self._max_size = max_size
88 | self._num_workers = max(1, num_workers)
89 | self._episode_fns = []
90 | self._episodes = dict()
91 | self._nstep = nstep
92 | self._discount = discount
93 | self._fetch_every = fetch_every
94 | self._samples_since_last_fetch = fetch_every
95 | self._save_snapshot = save_snapshot
96 | self._is_adroit = is_adroit
97 | self._return_next_action = return_next_action
98 |
99 | def set_nstep(self, nstep):
100 | self._nstep = nstep
101 |
102 | def _sample_episode(self):
103 | eps_fn = random.choice(self._episode_fns)
104 | return self._episodes[eps_fn]
105 |
106 | def _store_episode(self, eps_fn):
107 | try:
108 | episode = load_episode(eps_fn)
109 | except:
110 | return False
111 | eps_len = episode_len(episode)
112 | while eps_len + self._size > self._max_size:
113 | early_eps_fn = self._episode_fns.pop(0)
114 | early_eps = self._episodes.pop(early_eps_fn)
115 | self._size -= episode_len(early_eps)
116 | early_eps_fn.unlink(missing_ok=True)
117 | self._episode_fns.append(eps_fn)
118 | self._episode_fns.sort()
119 | self._episodes[eps_fn] = episode
120 | self._size += eps_len
121 |
122 | if not self._save_snapshot:
123 | eps_fn.unlink(missing_ok=True)
124 | return True
125 |
126 | def _try_fetch(self):
127 | if self._samples_since_last_fetch < self._fetch_every:
128 | return
129 | self._samples_since_last_fetch = 0
130 | try:
131 | worker_id = torch.utils.data.get_worker_info().id
132 | except:
133 | worker_id = 0
134 | eps_fns = sorted(self._replay_dir.glob('*.npz'), reverse=True)
135 | fetched_size = 0
136 | for eps_fn in eps_fns:
137 | eps_idx, eps_len = [int(x) for x in eps_fn.stem.split('_')[1:]]
138 | if eps_idx % self._num_workers != worker_id:
139 | continue
140 | if eps_fn in self._episodes.keys():
141 | break
142 | if fetched_size + eps_len > self._max_size:
143 | break
144 | fetched_size += eps_len
145 | if not self._store_episode(eps_fn):
146 | break
147 |
148 | def _sample(self):
149 | try:
150 | self._try_fetch()
151 | except:
152 | traceback.print_exc()
153 | self._samples_since_last_fetch += 1
154 | episode = self._sample_episode()
155 | # add +1 for the first dummy transition
156 | idx = np.random.randint(0, episode_len(episode) - self._nstep + 1) + 1
157 | obs = episode['observation'][idx - 1]
158 | action = episode['action'][idx]
159 | next_obs = episode['observation'][idx + self._nstep - 1]
160 | reward = np.zeros_like(episode['reward'][idx])
161 | discount = np.ones_like(episode['discount'][idx])
162 | for i in range(self._nstep):
163 | step_reward = episode['reward'][idx + i]
164 | reward += discount * step_reward
165 | discount *= episode['discount'][idx + i] * self._discount
166 |
167 | if self._return_next_action:
168 | next_action = episode['action'][idx + self._nstep - 1]
169 |
170 | if not self._is_adroit:
171 | if self._return_next_action:
172 | return (obs, action, reward, discount, next_obs, next_action)
173 | else:
174 | return (obs, action, reward, discount, next_obs)
175 | else:
176 | obs_sensor = episode['observation_sensor'][idx - 1]
177 | obs_sensor_next = episode['observation_sensor'][idx + self._nstep - 1]
178 | if self._return_next_action:
179 | return (obs, action, reward, discount, next_obs, obs_sensor, obs_sensor_next, next_action)
180 | else:
181 | return (obs, action, reward, discount, next_obs, obs_sensor, obs_sensor_next)
182 |
183 | def __iter__(self):
184 | while True:
185 | yield self._sample()
186 |
187 |
188 | def _worker_init_fn(worker_id):
189 | seed = np.random.get_state()[1][0] + worker_id
190 | np.random.seed(seed)
191 | random.seed(seed)
192 |
193 |
194 | def make_replay_loader(replay_dir, max_size, batch_size, num_workers,
195 | save_snapshot, nstep, discount, fetch_every=1000, is_adroit=False, return_next_action=False):
196 | max_size_per_worker = max_size // max(1, num_workers)
197 |
198 | iterable = ReplayBuffer(replay_dir,
199 | max_size_per_worker,
200 | num_workers,
201 | nstep,
202 | discount,
203 | fetch_every=fetch_every,
204 | save_snapshot=save_snapshot,
205 | is_adroit=is_adroit,
206 | return_next_action=return_next_action)
207 |
208 | loader = torch.utils.data.DataLoader(iterable,
209 | batch_size=batch_size,
210 | num_workers=num_workers,
211 | pin_memory=True,
212 | worker_init_fn=_worker_init_fn)
213 | return loader
214 |
215 | def reinit_data_loader(data_loader, batch_size, num_workers):
216 | # reinit a data loader with a new batch size
217 | loader = torch.utils.data.DataLoader(data_loader.dataset,
218 | batch_size=batch_size,
219 | num_workers=num_workers,
220 | pin_memory=True,
221 | worker_init_fn=_worker_init_fn)
222 | return loader
--------------------------------------------------------------------------------
/src/rrl_local/rrl_encoder.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Rutav Shah, Indian Institute of Technlogy Kharagpur
2 | # Copyright (c) Facebook, Inc. and its affiliates
3 |
4 | import numpy as np
5 | import torch
6 | import torch.nn as nn
7 | from torchvision import models, transforms
8 | from torchvision.models import resnet34, resnet18
9 | from PIL import Image
10 |
11 | _encoders = {'resnet34' : resnet34, 'resnet18' : resnet18, }
12 | _transforms = {
13 | 'resnet34' :
14 | transforms.Compose([
15 | transforms.Resize(256),
16 | transforms.CenterCrop(224),
17 | transforms.ToTensor(),
18 | transforms.Normalize([0.485, 0.456, 0.406],
19 | [0.229, 0.224, 0.225])
20 | ]),
21 | 'resnet18' :
22 | transforms.Compose([
23 | transforms.Resize(256),
24 | transforms.CenterCrop(224),
25 | transforms.ToTensor(),
26 | transforms.Normalize([0.485, 0.456, 0.406],
27 | [0.229, 0.224, 0.225])
28 | ]),
29 | }
30 |
31 | class Encoder(nn.Module):
32 | def __init__(self, encoder_type):
33 | super(Encoder, self).__init__()
34 | self.encoder_type = encoder_type
35 | if self.encoder_type in _encoders :
36 | self.model = _encoders[encoder_type](pretrained=True)
37 | else :
38 | print("Please enter a valid encoder type")
39 | raise Exception
40 | for param in self.model.parameters():
41 | param.requires_grad = False
42 | if self.encoder_type in _encoders :
43 | num_ftrs = self.model.fc.in_features
44 | self.num_ftrs = num_ftrs
45 | self.model.fc = Identity() # fc layer is replaced with identity
46 |
47 | def forward(self, x):
48 | x = self.model(x)
49 | return x
50 |
51 | # the transform is resize - center crop - normalize (imagenet normalize) No data aug here
52 | def get_transform(self):
53 | return _transforms[self.encoder_type]
54 |
55 | def get_features(self, x):
56 | with torch.no_grad():
57 | z = self.model(x)
58 | return z.cpu().data.numpy() ### Can't store everything in GPU :/
59 |
60 | class IdentityEncoder(nn.Module):
61 | def __init__(self):
62 | super(IdentityEncoder, self).__init__()
63 |
64 | def forward(self, x):
65 | return x
66 |
67 | def get_transform(self):
68 | return transforms.Compose([
69 | transforms.ToTensor(),
70 | ])
71 |
72 | def get_features(self, x):
73 | return x.reshape(-1)
74 |
75 | class Identity(nn.Module):
76 | def __init__(self):
77 | super(Identity, self).__init__()
78 |
79 | def forward(self, x):
80 | return x
81 |
--------------------------------------------------------------------------------
/src/rrl_local/rrl_multicam.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Rutav Shah, Indian Institute of Technlogy Kharagpur
2 | # Copyright (c) Facebook, Inc. and its affiliates
3 |
4 | import gym
5 | # from abc import ABC
6 | import numpy as np
7 | from rrl_local.rrl_encoder import Encoder, IdentityEncoder
8 | from PIL import Image
9 | import torch
10 | from collections import deque
11 |
12 | _mj_envs = {'pen-v0', 'hammer-v0', 'door-v0', 'relocate-v0'}
13 |
14 | def make_encoder(encoder, encoder_type, device, is_eval=True) :
15 | if not encoder :
16 | if encoder_type == 'resnet34' or encoder_type == 'resnet18' :
17 | encoder = Encoder(encoder_type)
18 | elif encoder_type == 'identity' :
19 | encoder = IdentityEncoder()
20 | else :
21 | print("Please enter valid encoder_type.")
22 | raise Exception
23 | if is_eval:
24 | encoder.eval()
25 | encoder.to(device)
26 | return encoder
27 |
28 | class BasicAdroitEnv(gym.Env): # , ABC
29 | def __init__(self, env, cameras, latent_dim=512, hybrid_state=True, channels_first=False,
30 | height=84, width=84, test_image=False, num_repeats=1, num_frames=1, encoder_type=None, device=None):
31 | self._env = env
32 | self.env_id = env.env.unwrapped.spec.id
33 | self.device = device
34 |
35 | self._num_repeats = num_repeats
36 | self._num_frames = num_frames
37 | self._frames = deque([], maxlen=num_frames)
38 |
39 | self.encoder = None
40 | self.transforms = None
41 | self.encoder_type = encoder_type
42 | if encoder_type is not None:
43 | self.encoder = make_encoder(encoder=None, encoder_type=self.encoder_type, device=self.device, is_eval=True)
44 | self.transforms = self.encoder.get_transform()
45 |
46 | if test_image:
47 | print("======================adroit image test mode==============================")
48 | print("======================adroit image test mode==============================")
49 | print("======================adroit image test mode==============================")
50 | print("======================adroit image test mode==============================")
51 | self.test_image = test_image
52 |
53 | self.cameras = cameras
54 | self.latent_dim = latent_dim
55 | self.hybrid_state = hybrid_state
56 | self.channels_first = channels_first
57 | self.height = height
58 | self.width = width
59 | self.action_space = self._env.action_space
60 | self.env_kwargs = {'cameras' : cameras, 'latent_dim' : latent_dim, 'hybrid_state': hybrid_state,
61 | 'channels_first' : channels_first, 'height' : height, 'width' : width}
62 |
63 | shape = [3, self.width, self.height]
64 | self._observation_space = gym.spaces.Box(
65 | low=0, high=255, shape=shape, dtype=np.uint8
66 | )
67 | self.sim = env.env.sim
68 | self._env.spec.observation_dim = latent_dim
69 |
70 | if hybrid_state :
71 | if self.env_id in _mj_envs:
72 | self._env.spec.observation_dim += 24 # Assuming 24 states for adroit hand.
73 |
74 | self.spec = self._env.spec
75 | self.observation_dim = self.spec.observation_dim
76 | self.horizon = self._env.env.spec.max_episode_steps
77 |
78 | def get_obs(self,):
79 | # for our case, let's output the image, and then also the sensor features
80 | if self.env_id in _mj_envs :
81 | env_state = self._env.env.get_env_state()
82 | qp = env_state['qpos']
83 |
84 | if self.env_id == 'pen-v0':
85 | qp = qp[:-6]
86 | elif self.env_id == 'door-v0':
87 | qp = qp[4:-2]
88 | elif self.env_id == 'hammer-v0':
89 | qp = qp[2:-7]
90 | elif self.env_id == 'relocate-v0':
91 | qp = qp[6:-6]
92 |
93 | imgs = [] # number of image is number of camera
94 |
95 | if self.encoder is not None:
96 | for cam in self.cameras :
97 | img = self._env.env.sim.render(width=self.width, height=self.height, mode='offscreen', camera_name=cam, device_id=0)
98 | # img = env.env.sim.render(width=84, height=84, mode='offscreen')
99 | img = img[::-1, :, : ] # Image given has to be flipped
100 | if self.channels_first :
101 | img = img.transpose((2, 0, 1))
102 | #img = img.astype(np.uint8)
103 | img = Image.fromarray(img)
104 | img = self.transforms(img)
105 | imgs.append(img)
106 |
107 | inp_img = torch.stack(imgs).to(self.device) # [num_cam, C, H, W]
108 | z = self.encoder.get_features(inp_img).reshape(-1)
109 | # assert z.shape[0] == self.latent_dim, "Encoded feature length : {}, Expected : {}".format(z.shape[0], self.latent_dim)
110 | pixels = z
111 | else:
112 | if not self.test_image:
113 | for cam in self.cameras : # for each camera, render once
114 | img = self._env.env.sim.render(width=self.width, height=self.height, mode='offscreen', camera_name=cam, device_id=0) # TODO device id will think later
115 | # img = img[::-1, :, : ] # Image given has to be flipped
116 | if self.channels_first :
117 | img = img.transpose((2, 0, 1)) # then it's 3 x width x height
118 | # we should do channels first... (not sure why by default it's not, maybe they did some transpose when using the encoder?)
119 | #img = img.astype(np.uint8)
120 | # img = Image.fromarray(img) # TODO is this necessary?
121 | imgs.append(img)
122 | else:
123 | img = (np.random.rand(1, 84, 84) * 255).astype(np.uint8)
124 | imgs.append(img)
125 | pixels = np.concatenate(imgs, axis=0)
126 |
127 | # TODO below are what we originally had...
128 | # if not self.test_image:
129 | # for cam in self.cameras : # for each camera, render once
130 | # img = self._env.env.sim.render(width=self.width, height=self.height, mode='offscreen', camera_name=cam, device_id=0) # TODO device id will think later
131 | # # img = img[::-1, :, : ] # Image given has to be flipped
132 | # if self.channels_first :
133 | # img = img.transpose((2, 0, 1)) # then it's 3 x width x height
134 | # # we should do channels first... (not sure why by default it's not, maybe they did some transpose when using the encoder?)
135 | # #img = img.astype(np.uint8)
136 | # # img = Image.fromarray(img) # TODO is this necessary?
137 | # imgs.append(img)
138 | # else:
139 | # img = (np.random.rand(1, 84, 84) * 255).astype(np.uint8)
140 | # imgs.append(img)
141 | # pixels = np.concatenate(imgs, axis=0)
142 |
143 | if not self.hybrid_state : # this defaults to True... so RRL uses hybrid state
144 | qp = None
145 |
146 | sensor_info = qp
147 | return pixels, sensor_info
148 |
149 | def get_env_infos(self):
150 | return self._env.get_env_infos()
151 | def set_seed(self, seed):
152 | return self._env.set_seed(seed)
153 |
154 | def get_stacked_pixels(self): #TODO fix it
155 | assert len(self._frames) == self._num_frames
156 | stacked_pixels = np.concatenate(list(self._frames), axis=0)
157 | return stacked_pixels
158 |
159 | def reset(self):
160 | self._env.reset()
161 | pixels, sensor_info = self.get_obs()
162 | for _ in range(self._num_frames):
163 | self._frames.append(pixels)
164 | stacked_pixels = self.get_stacked_pixels()
165 | return stacked_pixels, sensor_info
166 |
167 | def get_obs_for_first_state_but_without_reset(self):
168 | pixels, sensor_info = self.get_obs()
169 | for _ in range(self._num_frames):
170 | self._frames.append(pixels)
171 | stacked_pixels = self.get_stacked_pixels()
172 | return stacked_pixels, sensor_info
173 |
174 | def step(self, action):
175 | reward_sum = 0.0
176 | discount_prod = 1.0 # TODO pen can terminate early
177 | n_goal_achieved = 0
178 | for i_action in range(self._num_repeats):
179 | obs, reward, done, env_info = self._env.step(action)
180 | reward_sum += reward
181 | if env_info['goal_achieved'] == True:
182 | n_goal_achieved += 1
183 | if done:
184 | break
185 | env_info['n_goal_achieved'] = n_goal_achieved
186 | # now get stacked frames
187 | pixels, sensor_info = self.get_obs()
188 | self._frames.append(pixels)
189 | stacked_pixels = self.get_stacked_pixels()
190 | return [stacked_pixels, sensor_info], reward_sum, done, env_info
191 |
192 | def set_env_state(self, state):
193 | return self._env.set_env_state(state)
194 | def get_env_state(self):
195 | return self._env.get_env_state(state)
196 |
197 | def evaluate_policy(self, policy,
198 | num_episodes=5,
199 | horizon=None,
200 | gamma=1,
201 | visual=False,
202 | percentile=[],
203 | get_full_dist=False,
204 | mean_action=False,
205 | init_env_state=None,
206 | terminate_at_done=True,
207 | seed=123):
208 | # TODO this needs to be rewritten
209 |
210 | self.set_seed(seed)
211 | horizon = self.horizon if horizon is None else horizon
212 | mean_eval, std, min_eval, max_eval = 0.0, 0.0, -1e8, -1e8
213 | ep_returns = np.zeros(num_episodes)
214 | self.encoder.eval()
215 |
216 | for ep in range(num_episodes):
217 | o = self.reset()
218 | if init_env_state is not None:
219 | self.set_env_state(init_env_state)
220 | t, done = 0, False
221 | while t < horizon and (done == False or terminate_at_done == False):
222 | self.render() if visual is True else None
223 | o = self.get_obs(self._env.get_obs())
224 | a = policy.get_action(o)[1]['evaluation'] if mean_action is True else policy.get_action(o)[0]
225 | o, r, done, _ = self.step(a)
226 | ep_returns[ep] += (gamma ** t) * r
227 | t += 1
228 |
229 | mean_eval, std = np.mean(ep_returns), np.std(ep_returns)
230 | min_eval, max_eval = np.amin(ep_returns), np.amax(ep_returns)
231 | base_stats = [mean_eval, std, min_eval, max_eval]
232 |
233 | percentile_stats = []
234 | for p in percentile:
235 | percentile_stats.append(np.percentile(ep_returns, p))
236 |
237 | full_dist = ep_returns if get_full_dist is True else None
238 |
239 | return [base_stats, percentile_stats, full_dist]
240 |
241 | def get_pixels_with_width_height(self, w, h):
242 | imgs = [] # number of image is number of camera
243 |
244 | for cam in self.cameras : # for each camera, render once
245 | img = self._env.env.sim.render(width=w, height=h, mode='offscreen', camera_name=cam, device_id=0) # TODO device id will think later
246 | # img = img[::-1, :, : ] # Image given has to be flipped
247 | if self.channels_first :
248 | img = img.transpose((2, 0, 1)) # then it's 3 x width x height
249 | # we should do channels first... (not sure why by default it's not, maybe they did some transpose when using the encoder?)
250 | #img = img.astype(np.uint8)
251 | # img = Image.fromarray(img) # TODO is this necessary?
252 | imgs.append(img)
253 |
254 | pixels = np.concatenate(imgs, axis=0)
255 | return pixels
256 |
257 |
258 | class BasicFrankaEnv(gym.Env):
259 | def __init__(self, env, cameras, latent_dim=512, hybrid_state=True, channels_first=False,
260 | height=84, width=84, test_image=False, num_repeats=1, num_frames=1, encoder_type=None, device=None):
261 | # the parameter env is basically the kitchen env now
262 | self._env = env
263 | self.env_id = env.env.unwrapped.spec.id
264 | self.device = device
265 |
266 | self._num_repeats = num_repeats
267 | self._num_frames = num_frames
268 | self._frames = deque([], maxlen=num_frames)
269 |
270 | self.encoder = None
271 | self.transforms = None
272 | self.encoder_type = encoder_type
273 | if encoder_type is not None:
274 | self.encoder = make_encoder(encoder=None, encoder_type=self.encoder_type, device=self.device, is_eval=True)
275 | self.transforms = self.encoder.get_transform()
276 |
277 | if test_image:
278 | print("======================adroit image test mode==============================")
279 | print("======================adroit image test mode==============================")
280 | print("======================adroit image test mode==============================")
281 | print("======================adroit image test mode==============================")
282 | self.test_image = test_image
283 |
284 | self.cameras = cameras
285 | self.latent_dim = latent_dim
286 | self.hybrid_state = hybrid_state
287 | self.channels_first = channels_first
288 | self.height = height
289 | self.width = width
290 | self.action_space = self._env.action_space
291 | self.env_kwargs = {'cameras' : cameras, 'latent_dim' : latent_dim, 'hybrid_state': hybrid_state,
292 | 'channels_first' : channels_first, 'height' : height, 'width' : width}
293 |
294 | shape = [3, self.width, self.height]
295 | self._observation_space = gym.spaces.Box(
296 | low=0, high=255, shape=shape, dtype=np.uint8
297 | )
298 | self.sim = env.env.sim
299 | self._env.spec.observation_dim = latent_dim
300 | self._env.spec.action_dim = 9 # TODO magic number
301 | self._env.spec.horizon = self._env.env.spec.max_episode_steps
302 | # print("==============")
303 | # print(dir(self._env))
304 | # print("high: ", self._env.action_space)
305 | # quit()
306 |
307 | if hybrid_state :
308 | if self.env_id in _mj_envs:
309 | self._env.spec.observation_dim += 24 # Assuming 24 states for adroit hand.
310 |
311 | self.spec = self._env.spec
312 | self.observation_dim = self.spec.observation_dim
313 | self.horizon = self._env.env.spec.max_episode_steps
314 |
315 | def get_obs(self,):
316 | # for our case, let's output the image, and then also the sensor features
317 | imgs = [] # number of image is number of camera
318 |
319 | if self.encoder is not None:
320 | for cam in self.cameras :
321 | # img = self._env.env.sim.render(width=self.width, height=self.height, mode='offscreen', camera_name=cam, device_id=0)
322 | img = self._env.env.sim.render(width=84, height=84)
323 | # img = env.env.sim.render(width=84, height=84, mode='offscreen')
324 | img = img[::-1, :, : ] # Image given has to be flipped
325 | if self.channels_first :
326 | img = img.transpose((2, 0, 1))
327 | #img = img.astype(np.uint8)
328 | img = Image.fromarray(img)
329 | img = self.transforms(img)
330 | imgs.append(img)
331 |
332 | inp_img = torch.stack(imgs).to(self.device) # [num_cam, C, H, W]
333 | z = self.encoder.get_features(inp_img).reshape(-1)
334 | # assert z.shape[0] == self.latent_dim, "Encoded feature length : {}, Expected : {}".format(z.shape[0], self.latent_dim)
335 | pixels = z
336 | else:
337 | if not self.test_image:
338 | for cam in self.cameras : # for each camera, render once
339 | img = self._env.env.sim.render(width=84, height=84)
340 | # img = self._env.env.sim.render(width=self.width, height=self.height, mode='offscreen', camera_name=cam, device_id=0) # TODO device id will think later
341 | # img = img[::-1, :, : ] # Image given has to be flipped
342 | if self.channels_first :
343 | img = img.transpose((2, 0, 1)) # then it's 3 x width x height
344 | # we should do channels first... (not sure why by default it's not, maybe they did some transpose when using the encoder?)
345 | #img = img.astype(np.uint8)
346 | # img = Image.fromarray(img) # TODO is this necessary?
347 | imgs.append(img)
348 | else:
349 | img = (np.random.rand(1, 84, 84) * 255).astype(np.uint8)
350 | imgs.append(img)
351 | pixels = np.concatenate(imgs, axis=0)
352 |
353 | # TODO below are what we originally had...
354 | # if not self.test_image:
355 | # for cam in self.cameras : # for each camera, render once
356 | # img = self._env.env.sim.render(width=self.width, height=self.height, mode='offscreen', camera_name=cam, device_id=0) # TODO device id will think later
357 | # # img = img[::-1, :, : ] # Image given has to be flipped
358 | # if self.channels_first :
359 | # img = img.transpose((2, 0, 1)) # then it's 3 x width x height
360 | # # we should do channels first... (not sure why by default it's not, maybe they did some transpose when using the encoder?)
361 | # #img = img.astype(np.uint8)
362 | # # img = Image.fromarray(img) # TODO is this necessary?
363 | # imgs.append(img)
364 | # else:
365 | # img = (np.random.rand(1, 84, 84) * 255).astype(np.uint8)
366 | # imgs.append(img)
367 | # pixels = np.concatenate(imgs, axis=0)
368 |
369 | if not self.hybrid_state : # this defaults to True... so RRL uses hybrid state
370 | qp = None
371 |
372 | sensor_info = qp
373 | return pixels, sensor_info
374 |
375 | def get_env_infos(self):
376 | return self._env.get_env_infos()
377 | def set_seed(self, seed):
378 | return self._env.set_seed(seed)
379 |
380 | def get_stacked_pixels(self): #TODO fix it
381 | assert len(self._frames) == self._num_frames
382 | stacked_pixels = np.concatenate(list(self._frames), axis=0)
383 | return stacked_pixels
384 |
385 | def reset(self):
386 | self._env.reset()
387 | pixels, sensor_info = self.get_obs()
388 | for _ in range(self._num_frames):
389 | self._frames.append(pixels)
390 | stacked_pixels = self.get_stacked_pixels()
391 | return stacked_pixels, sensor_info
392 |
393 | def get_obs_for_first_state_but_without_reset(self):
394 | pixels, sensor_info = self.get_obs()
395 | for _ in range(self._num_frames):
396 | self._frames.append(pixels)
397 | stacked_pixels = self.get_stacked_pixels()
398 | return stacked_pixels, sensor_info
399 |
400 | def step(self, action):
401 | reward_sum = 0.0
402 | discount_prod = 1.0 # TODO pen can terminate early
403 | n_goal_achieved = 0
404 | for i_action in range(self._num_repeats):
405 | obs, reward, done, env_info = self._env.step(action)
406 | reward_sum += reward
407 | if env_info['goal_achieved'] == True:
408 | n_goal_achieved += 1
409 | if done:
410 | break
411 | env_info['n_goal_achieved'] = n_goal_achieved
412 | # now get stacked frames
413 | pixels, sensor_info = self.get_obs()
414 | self._frames.append(pixels)
415 | stacked_pixels = self.get_stacked_pixels()
416 | return [stacked_pixels, sensor_info], reward_sum, done, env_info
417 |
418 | def set_env_state(self, state):
419 | return self._env.set_env_state(state)
420 | def get_env_state(self):
421 | return self._env.get_env_state(state)
422 |
423 | def evaluate_policy(self, policy,
424 | num_episodes=5,
425 | horizon=None,
426 | gamma=1,
427 | visual=False,
428 | percentile=[],
429 | get_full_dist=False,
430 | mean_action=False,
431 | init_env_state=None,
432 | terminate_at_done=True,
433 | seed=123):
434 | # TODO this needs to be rewritten
435 |
436 | self.set_seed(seed)
437 | horizon = self.horizon if horizon is None else horizon
438 | mean_eval, std, min_eval, max_eval = 0.0, 0.0, -1e8, -1e8
439 | ep_returns = np.zeros(num_episodes)
440 | self.encoder.eval()
441 |
442 | for ep in range(num_episodes):
443 | o = self.reset()
444 | if init_env_state is not None:
445 | self.set_env_state(init_env_state)
446 | t, done = 0, False
447 | while t < horizon and (done == False or terminate_at_done == False):
448 | self.render() if visual is True else None
449 | o = self.get_obs(self._env.get_obs())
450 | a = policy.get_action(o)[1]['evaluation'] if mean_action is True else policy.get_action(o)[0]
451 | o, r, done, _ = self.step(a)
452 | ep_returns[ep] += (gamma ** t) * r
453 | t += 1
454 |
455 | mean_eval, std = np.mean(ep_returns), np.std(ep_returns)
456 | min_eval, max_eval = np.amin(ep_returns), np.amax(ep_returns)
457 | base_stats = [mean_eval, std, min_eval, max_eval]
458 |
459 | percentile_stats = []
460 | for p in percentile:
461 | percentile_stats.append(np.percentile(ep_returns, p))
462 |
463 | full_dist = ep_returns if get_full_dist is True else None
464 |
465 | return [base_stats, percentile_stats, full_dist]
466 |
467 | def get_pixels_with_width_height(self, w, h):
468 | imgs = [] # number of image is number of camera
469 |
470 | for cam in self.cameras : # for each camera, render once
471 | img = self._env.env.sim.render(width=w, height=h, mode='offscreen', camera_name=cam, device_id=0) # TODO device id will think later
472 | # img = img[::-1, :, : ] # Image given has to be flipped
473 | if self.channels_first :
474 | img = img.transpose((2, 0, 1)) # then it's 3 x width x height
475 | # we should do channels first... (not sure why by default it's not, maybe they did some transpose when using the encoder?)
476 | #img = img.astype(np.uint8)
477 | # img = Image.fromarray(img) # TODO is this necessary?
478 | imgs.append(img)
479 |
480 | pixels = np.concatenate(imgs, axis=0)
481 | return pixels
482 |
--------------------------------------------------------------------------------
/src/rrl_local/rrl_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Rutav Shah, Indian Institute of Technlogy Kharagpur
2 | # Copyright (c) Facebook, Inc. and its affiliates
3 |
4 | from mjrl.utils.gym_env import GymEnv
5 | from rrl_local.rrl_multicam import BasicAdroitEnv
6 | import os
7 |
8 | def make_basic_env(env, cam_list=[], from_pixels=False, hybrid_state=None, test_image=False, channels_first=False,
9 | num_repeats=1, num_frames=1):
10 | e = GymEnv(env)
11 | env_kwargs = None
12 | if from_pixels : # TODO here they first check if it's from pixels... if pixel and not resnet then use 84x84??
13 | height = 84
14 | width = 84
15 | latent_dim = height*width*len(cam_list)*3
16 | # RRL class instance is environment wrapper...
17 | e = BasicAdroitEnv(e, cameras=cam_list,
18 | height=height, width=width, latent_dim=latent_dim, hybrid_state=hybrid_state,
19 | test_image=test_image, channels_first=channels_first, num_repeats=num_repeats, num_frames=num_frames)
20 | env_kwargs = {'rrl_kwargs' : e.env_kwargs}
21 | # if not from pixels... then it's simpler
22 | return e, env_kwargs
23 |
24 | def make_env(env, cam_list=[], from_pixels=False, encoder_type=None, hybrid_state=None,) :
25 | # TODO why is encoder type put into the environment?
26 | e = GymEnv(env)
27 | env_kwargs = None
28 | if from_pixels : # TODO here they first check if it's from pixels... if pixel and not resnet then use 84x84??
29 | height = 84
30 | width = 84
31 | latent_dim = height*width*len(cam_list)*3
32 |
33 | if encoder_type and encoder_type == 'resnet34':
34 | assert from_pixels==True
35 | height = 256
36 | width = 256
37 | latent_dim = 512*len(cam_list) #TODO each camera provides an image?
38 | if from_pixels:
39 | # RRL class instance is environment wrapper...
40 | e = BasicAdroitEnv(e, cameras=cam_list, encoder_type=encoder_type,
41 | height=height, width=width, latent_dim=latent_dim, hybrid_state=hybrid_state)
42 | env_kwargs = {'rrl_kwargs' : e.env_kwargs}
43 | return e, env_kwargs
44 |
45 | def make_dir(dir_path):
46 | if not os.path.exists(dir_path):
47 | os.makedirs(dir_path)
48 |
49 |
50 | def preprocess_args(args):
51 | job_data = {}
52 | job_data['seed'] = args.seed
53 | job_data['env'] = args.env
54 | job_data['output'] = args.output
55 | job_data['from_pixels'] = args.from_pixels
56 | job_data['hybrid_state'] = args.hybrid_state
57 | job_data['stack_frames'] = args.stack_frames
58 | job_data['encoder_type'] = args.encoder_type
59 | job_data['cam1'] = args.cam1
60 | job_data['cam2'] = args.cam2
61 | job_data['cam3'] = args.cam3
62 | job_data['algorithm'] = args.algorithm
63 | job_data['num_cpu'] = args.num_cpu
64 | job_data['save_freq'] = args.save_freq
65 | job_data['eval_rollouts'] = args.eval_rollouts
66 | job_data['demo_file'] = args.demo_file
67 | job_data['bc_batch_size'] = args.bc_batch_size
68 | job_data['bc_epochs'] = args.bc_epochs
69 | job_data['bc_learn_rate'] = args.bc_learn_rate
70 | #job_data['policy_size'] = args.policy_size
71 | job_data['policy_size'] = tuple(map(int, args.policy_size.split(', ')))
72 | job_data['vf_batch_size'] = args.vf_batch_size
73 | job_data['vf_epochs'] = args.vf_epochs
74 | job_data['vf_learn_rate'] = args.vf_learn_rate
75 | job_data['rl_step_size'] = args.rl_step_size
76 | job_data['rl_gamma'] = args.rl_gamma
77 | job_data['rl_gae'] = args.rl_gae
78 | job_data['rl_num_traj'] = args.rl_num_traj
79 | job_data['rl_num_iter'] = args.rl_num_iter
80 | job_data['lam_0'] = args.lam_0
81 | job_data['lam_1'] = args.lam_1
82 | print(job_data)
83 | return job_data
84 |
85 |
86 |
--------------------------------------------------------------------------------
/src/stage1_models.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT License.
3 |
4 | from numpy import identity
5 | import torch.nn as nn
6 | import numpy as np
7 |
8 | """
9 | most code here are modified from the TORCHVISION.MODELS.RESNET
10 | """
11 |
12 | import torch
13 | import torch.nn as nn
14 |
15 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
16 | """3x3 convolution with padding"""
17 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
18 | padding=dilation, groups=groups, bias=False, dilation=dilation)
19 |
20 | def conv1x1(in_planes, out_planes, stride=1):
21 | """1x1 convolution"""
22 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
23 |
24 | class BasicBlock(nn.Module):
25 | expansion = 1
26 |
27 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
28 | base_width=64, dilation=1, norm_layer=None):
29 | super(BasicBlock, self).__init__()
30 | if norm_layer is None:
31 | norm_layer = nn.BatchNorm2d
32 | if groups != 1 or base_width != 64:
33 | raise ValueError('BasicBlock only supports groups=1 and base_width=64')
34 | if dilation > 1:
35 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
36 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1
37 | self.conv1 = conv3x3(inplanes, planes, stride)
38 | self.bn1 = norm_layer(planes)
39 | self.relu = nn.ReLU(inplace=True)
40 | self.conv2 = conv3x3(planes, planes)
41 | self.bn2 = norm_layer(planes)
42 | self.downsample = downsample
43 | self.stride = stride
44 |
45 | def forward(self, x):
46 | identity = x
47 |
48 | out = self.conv1(x)
49 | out = self.bn1(out)
50 | out = self.relu(out)
51 |
52 | out = self.conv2(out)
53 | out = self.bn2(out)
54 |
55 | if self.downsample is not None:
56 | identity = self.downsample(x)
57 |
58 | out += identity
59 | out = self.relu(out)
60 |
61 | return out
62 |
63 | class OneLayerBlock(nn.Module):
64 | # similar to BasicBlock, but shallower
65 | expansion = 1
66 |
67 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
68 | base_width=64, dilation=1, norm_layer=None):
69 | super(OneLayerBlock, self).__init__()
70 | if norm_layer is None:
71 | norm_layer = nn.BatchNorm2d
72 | if groups != 1 or base_width != 64:
73 | raise ValueError('BasicBlock only supports groups=1 and base_width=64')
74 | if dilation > 1:
75 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
76 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1
77 | self.conv1 = conv3x3(inplanes, planes, stride)
78 | self.bn1 = norm_layer(planes)
79 | self.relu = nn.ReLU(inplace=True)
80 | self.stride = stride
81 |
82 | def forward(self, x):
83 | out = self.conv1(x)
84 | out = self.bn1(out)
85 | out = self.relu(out)
86 | return out
87 |
88 | class Bottleneck(nn.Module):
89 | # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2)
90 | # while original implementation places the stride at the first 1x1 convolution(self.conv1)
91 | # according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385.
92 | # This variant is also known as ResNet V1.5 and improves accuracy according to
93 | # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch.
94 |
95 | expansion = 4
96 |
97 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
98 | base_width=64, dilation=1, norm_layer=None):
99 | super(Bottleneck, self).__init__()
100 | if norm_layer is None:
101 | norm_layer = nn.BatchNorm2d
102 | width = int(planes * (base_width / 64.)) * groups
103 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1
104 | self.conv1 = conv1x1(inplanes, width)
105 | self.bn1 = norm_layer(width)
106 | self.conv2 = conv3x3(width, width, stride, groups, dilation)
107 | self.bn2 = norm_layer(width)
108 | self.conv3 = conv1x1(width, planes * self.expansion)
109 | self.bn3 = norm_layer(planes * self.expansion)
110 | self.relu = nn.ReLU(inplace=True)
111 | self.downsample = downsample
112 | self.stride = stride
113 |
114 | def forward(self, x):
115 | identity = x
116 |
117 | out = self.conv1(x)
118 | out = self.bn1(out)
119 | out = self.relu(out)
120 |
121 | out = self.conv2(out)
122 | out = self.bn2(out)
123 | out = self.relu(out)
124 |
125 | out = self.conv3(out)
126 | out = self.bn3(out)
127 |
128 | if self.downsample is not None:
129 | identity = self.downsample(x)
130 |
131 | out += identity
132 | out = self.relu(out)
133 |
134 | return out
135 |
136 | def drq_weight_init(m):
137 | # weight init scheme used in DrQv2
138 | if isinstance(m, nn.Linear):
139 | nn.init.orthogonal_(m.weight.data)
140 | if hasattr(m.bias, 'data'):
141 | m.bias.data.fill_(0.0)
142 | elif isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
143 | gain = nn.init.calculate_gain('relu')
144 | nn.init.orthogonal_(m.weight.data, gain)
145 | if hasattr(m.bias, 'data'):
146 | m.bias.data.fill_(0.0)
147 |
148 | class Stage3ShallowEncoder(nn.Module):
149 | """
150 | this is the encoder architecture used in DrQv2
151 | """
152 | def __init__(self, obs_shape, n_channel):
153 | super().__init__()
154 | assert len(obs_shape) == 3
155 | self.repr_dim = n_channel * 35 * 35
156 | self.conv1 = nn.Conv2d(obs_shape[0], n_channel, 3, stride=2)
157 | self.conv2 = nn.Conv2d(n_channel, n_channel, 3, stride=1)
158 | self.conv3 = nn.Conv2d(n_channel, n_channel, 3, stride=1)
159 | self.conv4 = nn.Conv2d(n_channel, n_channel, 3, stride=1)
160 | self.relu = nn.ReLU(inplace=True)
161 | self.apply(drq_weight_init)
162 |
163 | def _forward_impl(self, x):
164 | x = self.relu(self.conv1(x))
165 | x = self.relu(self.conv2(x))
166 | x = self.relu(self.conv3(x))
167 | x = self.relu(self.conv4(x))
168 | return x
169 |
170 | def forward(self, obs):
171 | o = obs
172 | h = self._forward_impl(o)
173 | h = h.view(h.shape[0], -1)
174 | return h
175 |
176 | class ResNet84(nn.Module):
177 | """
178 | default stage 1 encoder used by VRL3, this is modified from the PyTorch standard ResNet class
179 | but is more lightweight and this is much faster with 84x84 input size
180 | use "layers" to specify how deep the network is
181 | use "start_num_channel" to control how wide it is
182 | """
183 | def __init__(self, block, layers, num_classes=1000, zero_init_residual=False,
184 | groups=1, width_per_group=64, replace_stride_with_dilation=None,
185 | norm_layer=None, start_num_channel=32):
186 | super(ResNet84, self).__init__()
187 | if norm_layer is None:
188 | norm_layer = nn.BatchNorm2d
189 | self._norm_layer = norm_layer
190 |
191 | self.start_num_channel = start_num_channel
192 | self.inplanes = start_num_channel
193 | self.dilation = 1
194 | if replace_stride_with_dilation is None:
195 | # each element in the tuple indicates if we should replace
196 | # the 2x2 stride with a dilated convolution instead
197 | replace_stride_with_dilation = [False, False, False]
198 | if len(replace_stride_with_dilation) != 3:
199 | raise ValueError("replace_stride_with_dilation should be None "
200 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation))
201 | self.groups = groups
202 | self.base_width = width_per_group
203 | self.conv1 = nn.Conv2d(3, self.inplanes, 3, stride=2)
204 | self.bn1 = norm_layer(self.inplanes)
205 | self.relu = nn.ReLU(inplace=True)
206 | # self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
207 |
208 | self.layer1 = self._make_layer(block, start_num_channel, layers[0])
209 | self.layer2 = self._make_layer(block, start_num_channel * 2, layers[1], stride=2,
210 | dilate=replace_stride_with_dilation[0])
211 | self.layer3 = self._make_layer(block, start_num_channel * 4, layers[2], stride=2,
212 | dilate=replace_stride_with_dilation[1])
213 | self.layer4 = self._make_layer(block, start_num_channel * 8, layers[3], stride=2,
214 | dilate=replace_stride_with_dilation[2])
215 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
216 | self.fc = nn.Linear(start_num_channel * 8 * block.expansion, num_classes)
217 |
218 | for m in self.modules():
219 | if isinstance(m, nn.Conv2d):
220 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
221 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
222 | nn.init.constant_(m.weight, 1)
223 | nn.init.constant_(m.bias, 0)
224 |
225 | # Zero-initialize the last BN in each residual branch,
226 | # so that the residual branch starts with zeros, and each residual block behaves like an identity.
227 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
228 | if zero_init_residual:
229 | for m in self.modules():
230 | if isinstance(m, Bottleneck):
231 | nn.init.constant_(m.bn3.weight, 0)
232 | elif isinstance(m, BasicBlock):
233 | nn.init.constant_(m.bn2.weight, 0)
234 |
235 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
236 | # vrl3: if block is 0, allows a smaller network size
237 | if blocks == 0:
238 | block = OneLayerBlock
239 |
240 | norm_layer = self._norm_layer
241 | downsample = None
242 | previous_dilation = self.dilation
243 | if dilate:
244 | self.dilation *= stride
245 | stride = 1
246 | if stride != 1 or self.inplanes != planes * block.expansion:
247 | downsample = nn.Sequential(
248 | conv1x1(self.inplanes, planes * block.expansion, stride),
249 | norm_layer(planes * block.expansion),
250 | )
251 |
252 | layers = []
253 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
254 | self.base_width, previous_dilation, norm_layer))
255 | self.inplanes = planes * block.expansion
256 | for _ in range(1, blocks):
257 | layers.append(block(self.inplanes, planes, groups=self.groups,
258 | base_width=self.base_width, dilation=self.dilation,
259 | norm_layer=norm_layer))
260 |
261 | return nn.Sequential(*layers)
262 |
263 | def _forward_impl(self, x):
264 | x = self.conv1(x)
265 | x = self.bn1(x)
266 | x = self.relu(x)
267 |
268 | x = self.layer1(x)
269 | x = self.layer2(x)
270 | x = self.layer3(x)
271 | x = self.layer4(x)
272 | x = self.avgpool(x)
273 | x = torch.flatten(x, 1)
274 | x = self.fc(x)
275 |
276 | return x
277 |
278 | def get_feature_size(self):
279 | assert self.start_num_channel % 32 == 0
280 | multiplier = self.start_num_channel // 32
281 | size = 256 * multiplier
282 | return size
283 |
284 | def forward(self, x):
285 | return self._forward_impl(x)
286 |
287 | def get_features(self, x):
288 | x = self.conv1(x)
289 | # print("0", x.shape) # 32 x 41 x 41 = 53792
290 | x = self.bn1(x)
291 | x = self.relu(x)
292 |
293 | x = self.layer1(x)
294 | # print("1", x.shape) # 32 x 41 x 41= 53792
295 |
296 | x = self.layer2(x)
297 | # print("2", x.shape) # 64 x 21 x 21 = 28224
298 |
299 | x = self.layer3(x)
300 | # print("3", x.shape) # 128 x 11 x 11 = 15488
301 |
302 | x = self.layer4(x)
303 | # print("4", x.shape) # 256 x 6 x 6 = 9216
304 |
305 | x = self.avgpool(x)
306 | # print("pool", x.shape) # 256 x 1 x 1
307 |
308 | final_out = torch.flatten(x, 1)
309 | # print("flatten", x.shape) # 256
310 | return final_out
311 |
312 | class Identity(nn.Module):
313 | def __init__(self, input_placeholder=None):
314 | super(Identity, self).__init__()
315 |
316 | def forward(self, x):
317 | return x
318 |
319 |
--------------------------------------------------------------------------------
/src/testing/computation_time_test.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT License.
3 |
4 | import os
5 | # make sure mujoco and nvidia will be found
6 | os.environ['LD_LIBRARY_PATH'] = os.environ.get('LD_LIBRARY_PATH', '') + \
7 | ':/workspace/.mujoco/mujoco210/bin:/usr/local/nvidia/lib:/usr/lib/nvidia'
8 | os.environ['MUJOCO_PY_MUJOCO_PATH'] = '/workspace/.mujoco/mujoco210/'
9 | os.environ['MUJOCO_GL'] = 'egl'
10 | # set to glfw if trying to render locally with a monitor
11 | # os.environ['MUJOCO_GL'] = 'glfw'
12 |
13 | from mjrl.utils.gym_env import GymEnv
14 | import mujoco_py
15 | import mj_envs
16 | import time
17 |
18 | e = GymEnv('hammer-v0')
19 |
20 | # when program is init, the first few renders will be a bit slower,
21 | # so we first run 1000 renders without recording their time
22 | for i in range(1000):
23 | img = e.env.sim.render(width=84, height=84, mode='offscreen', camera_name='top')
24 |
25 | # now try to know how much time it takes for 1000 rendering
26 | st = time.time()
27 | for i in range(1000):
28 | img = e.env.sim.render(width=84, height=84, mode='offscreen', camera_name='top')
29 | et = time.time()
30 |
31 | # when mujoco uses gpu to render (when things work correctly, should be fast)
32 | # should take < 0,5 seconds
33 | # if your GPU is not working correct, it might be 10x slower
34 | # if it is slow, then sth is wrong, either the program can't see your gpu
35 | # or your gpu is available, but the old mujoco failed to use it
36 | time_used = et-st
37 | if time_used < 0.5:
38 | print("time used: %.3f seconds, looks correct!" % time_used)
39 | else:
40 | print("time used: %.3f seconds." % time_used)
41 | print("WARNING!!!!! SLOWER THAN EXPECTED, YOUR MUJOCO RENDERING MIGHT NOT WORK CORRECTLY!")
42 |
43 | print("Depends on hardware but typically should be < 0.5 seconds.")
44 |
45 |
46 |
--------------------------------------------------------------------------------
/src/testing/zzz.py:
--------------------------------------------------------------------------------
1 | txt = "For only {} {:.0f} {:.3f} dollars!"
2 | print(txt.format(1, 49.5555, 22))
--------------------------------------------------------------------------------
/src/train_adroit.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT License.
3 |
4 | import os, psutil
5 | # make sure mujoco and nvidia will be found
6 | os.environ['LD_LIBRARY_PATH'] = os.environ.get('LD_LIBRARY_PATH', '') + \
7 | ':/workspace/.mujoco/mujoco210/bin:/usr/local/nvidia/lib:/usr/lib/nvidia'
8 | os.environ['MUJOCO_PY_MUJOCO_PATH'] = '/workspace/.mujoco/mujoco210/'
9 | import numpy as np
10 | import shutil
11 | import warnings
12 | warnings.filterwarnings('ignore', category=DeprecationWarning)
13 | os.environ['MKL_SERVICE_FORCE_INTEL'] = '1'
14 | import platform
15 | os.environ['MUJOCO_GL'] = 'egl'
16 | # set to glfw if trying to render locally with a monitor
17 | # os.environ['MUJOCO_GL'] = 'glfw'
18 | os.environ['EGL_DEVICE_ID'] = '0'
19 | from distutils.dir_util import copy_tree
20 | import torchvision.transforms as transforms
21 | from pathlib import Path
22 | import cv2
23 | import imageio
24 |
25 | import hydra
26 | import torch
27 | from dm_env import StepType, TimeStep, specs
28 |
29 | import utils
30 | from logger import Logger
31 | from replay_buffer import ReplayBufferStorage, make_replay_loader
32 | from video import TrainVideoRecorder, VideoRecorder
33 | import joblib
34 | import pickle
35 | import time
36 |
37 | torch.backends.cudnn.benchmark = True
38 |
39 | # TODO this part can be done during workspace setup
40 | ENV_TYPE = 'adroit'
41 | if ENV_TYPE == 'adroit':
42 | import mj_envs
43 | from mjrl.utils.gym_env import GymEnv
44 | from rrl_local.rrl_utils import make_basic_env, make_dir
45 | from adroit import AdroitEnv
46 | else:
47 | import dmc
48 | IS_ADROIT = True if ENV_TYPE == 'adroit' else False
49 |
50 | def get_ram_info():
51 | # get info on RAM usage
52 | d = dict(psutil.virtual_memory()._asdict())
53 | for key in d:
54 | if key != "percent": # convert to GB
55 | d[key] = int(d[key] / 1024**3 * 100)/ 100
56 | return d
57 |
58 | def make_agent(obs_spec, action_spec, cfg):
59 | cfg.obs_shape = obs_spec.shape
60 | cfg.action_shape = action_spec.shape
61 | return hydra.utils.instantiate(cfg)
62 |
63 | def print_stage2_time_est(time_used, curr_n_update, total_n_update):
64 | time_per_update = time_used / curr_n_update
65 | est_total_time = time_per_update * total_n_update
66 | est_time_remaining = est_total_time - time_used
67 | print("Stage 2 [{:.2f}%]. Frames:[{:.0f}/{:.0f}]K. Time:[{:.2f}/{:.2f}]hrs. Overall FPS: {}.".format(
68 | curr_n_update / total_n_update * 100, curr_n_update/1000, total_n_update/1000,
69 | time_used / 3600, est_total_time / 3600, int(curr_n_update / time_used)))
70 |
71 | def print_stage3_time_est(time_used, curr_n_frames, total_n_frames):
72 | time_per_update = time_used / curr_n_frames
73 | est_total_time = time_per_update * total_n_frames
74 | est_time_remaining = est_total_time - time_used
75 | print("Stage 3 [{:.2f}%]. Frames:[{:.0f}/{:.0f}]K. Time:[{:.2f}/{:.2f}]hrs. Overall FPS: {}.".format(
76 | curr_n_frames / total_n_frames * 100, curr_n_frames/1000, total_n_frames/1000,
77 | time_used / 3600, est_total_time / 3600, int(curr_n_frames / time_used)))
78 |
79 | class Workspace:
80 | def __init__(self, cfg):
81 | self.work_dir = Path.cwd()
82 | print("\n=== Training log stored to: ===")
83 | print(f'workspace: {self.work_dir}')
84 | self.direct_folder_name = os.path.basename(self.work_dir)
85 |
86 | self.cfg = cfg
87 | utils.set_seed_everywhere(cfg.seed)
88 | self.device = torch.device(cfg.device)
89 | self.replay_buffer_fetch_every = 1000
90 | if self.cfg.debug > 0: # if debug mode, then change hyperparameters for quick testing
91 | self.set_debug_hyperparameters()
92 | self.setup()
93 |
94 | self.agent = make_agent(self.train_env.observation_spec(),
95 | self.train_env.action_spec(),
96 | self.cfg.agent)
97 | self.timer = utils.Timer()
98 | self._global_step = 0
99 | self._global_episode = 0
100 |
101 | def set_debug_hyperparameters(self):
102 | self.cfg.num_seed_frames=1000 if self.cfg.num_seed_frames > 1000 else self.cfg.num_seed_frames
103 | self.cfg.agent.num_expl_steps=500 if self.cfg.agent.num_expl_steps > 1000 else self.cfg.agent.num_expl_steps
104 | if self.cfg.replay_buffer_num_workers > 1:
105 | self.cfg.replay_buffer_num_workers = 1
106 | self.cfg.num_eval_episodes = 1
107 | self.cfg.replay_buffer_size = 30000
108 | self.cfg.batch_size = 8
109 | self.cfg.feature_dim = 8
110 | self.cfg.num_train_frames = 5050
111 | self.replay_buffer_fetch_every = 30
112 | self.cfg.stage2_n_update = 100
113 | self.cfg.num_demo = 3
114 | self.cfg.eval_every_frames = 3000
115 | self.cfg.agent.hidden_dim = 8
116 | self.cfg.agent.num_expl_steps = 500
117 | self.cfg.stage2_eval_every_frames = 50
118 |
119 | def setup(self):
120 | warnings.filterwarnings('ignore', category=DeprecationWarning)
121 |
122 | if self.cfg.save_models:
123 | assert self.cfg.action_repeat % 2 == 0
124 |
125 | # create logger
126 | self.logger = Logger(self.work_dir, use_tb=self.cfg.use_tb)
127 | env_name = self.cfg.task_name
128 | env_type = 'adroit' if env_name in ('hammer-v0','door-v0','pen-v0','relocate-v0') else 'dmc'
129 | # assert env_name in ('hammer-v0','door-v0','pen-v0','relocate-v0',)
130 |
131 | if self.cfg.agent.encoder_lr_scale == 'auto':
132 | if env_name == 'relocate-v0':
133 | self.cfg.agent.encoder_lr_scale = 0.01
134 | else:
135 | self.cfg.agent.encoder_lr_scale = 1
136 |
137 | self.env_feature_type = self.cfg.env_feature_type
138 | if env_type == 'adroit':
139 | # reward rescale can either be added in the env or in the agent code when reward is used
140 | self.train_env = AdroitEnv(env_name, test_image=False, num_repeats=self.cfg.action_repeat,
141 | num_frames=self.cfg.frame_stack, env_feature_type=self.env_feature_type,
142 | device=self.device, reward_rescale=self.cfg.reward_rescale)
143 | self.eval_env = AdroitEnv(env_name, test_image=False, num_repeats=self.cfg.action_repeat,
144 | num_frames=self.cfg.frame_stack, env_feature_type=self.env_feature_type,
145 | device=self.device, reward_rescale=self.cfg.reward_rescale)
146 |
147 | data_specs = (self.train_env.observation_spec(),
148 | self.train_env.observation_sensor_spec(),
149 | self.train_env.action_spec(),
150 | specs.Array((1,), np.float32, 'reward'),
151 | specs.Array((1,), np.float32, 'discount'),
152 | specs.Array((1,), np.int8, 'n_goal_achieved'),
153 | specs.Array((1,), np.float32, 'time_limit_reached'),
154 | )
155 | else:
156 | self.train_env = dmc.make(self.cfg.task_name, self.cfg.frame_stack,
157 | self.cfg.action_repeat, self.cfg.seed)
158 | self.eval_env = dmc.make(self.cfg.task_name, self.cfg.frame_stack,
159 | self.cfg.action_repeat, self.cfg.seed)
160 |
161 | data_specs = (self.train_env.observation_spec(),
162 | self.train_env.action_spec(),
163 | specs.Array((1,), np.float32, 'reward'),
164 | specs.Array((1,), np.float32, 'discount'))
165 |
166 | # create replay buffer
167 | self.replay_storage = ReplayBufferStorage(data_specs, self.work_dir / 'buffer')
168 |
169 |
170 |
171 | self.replay_loader = make_replay_loader(
172 | self.work_dir / 'buffer', self.cfg.replay_buffer_size,
173 | self.cfg.batch_size, self.cfg.replay_buffer_num_workers,
174 | self.cfg.save_snapshot, self.cfg.nstep, self.cfg.discount, self.replay_buffer_fetch_every,
175 | is_adroit=IS_ADROIT)
176 | self._replay_iter = None
177 |
178 | self.video_recorder = VideoRecorder(
179 | self.work_dir if self.cfg.save_video else None)
180 | self.train_video_recorder = TrainVideoRecorder(
181 | self.work_dir if self.cfg.save_train_video else None)
182 |
183 | def set_demo_buffer_nstep(self, nstep):
184 | self.replay_loader_demo.dataset._nstep = nstep
185 |
186 | @property
187 | def global_step(self):
188 | return self._global_step
189 |
190 | @property
191 | def global_episode(self):
192 | return self._global_episode
193 |
194 | @property
195 | def global_frame(self):
196 | return self.global_step * self.cfg.action_repeat
197 |
198 | @property
199 | def replay_iter(self):
200 | if self._replay_iter is None:
201 | self._replay_iter = iter(self.replay_loader)
202 | return self._replay_iter
203 |
204 | def eval_dmc(self):
205 | step, episode, total_reward = 0, 0, 0
206 | eval_until_episode = utils.Until(self.cfg.num_eval_episodes)
207 |
208 | while eval_until_episode(episode):
209 | time_step = self.eval_env.reset()
210 | self.video_recorder.init(self.eval_env, enabled=(episode == 0))
211 | while not time_step.last():
212 | with torch.no_grad(), utils.eval_mode(self.agent):
213 | action = self.agent.act(time_step.observation,
214 | self.global_step,
215 | eval_mode=True)
216 | time_step = self.eval_env.step(action)
217 | self.video_recorder.record(self.eval_env)
218 | total_reward += time_step.reward
219 | step += 1
220 |
221 | episode += 1
222 | self.video_recorder.save(f'{self.global_frame}.mp4')
223 |
224 | with self.logger.log_and_dump_ctx(self.global_frame, ty='eval') as log:
225 | log('episode_reward', total_reward / episode)
226 | log('episode_length', step * self.cfg.action_repeat / episode)
227 | log('episode', self.global_episode)
228 | log('step', self.global_step)
229 |
230 | def eval_adroit(self, force_number_episodes=None, do_log=True):
231 | if ENV_TYPE != 'adroit':
232 | return self.eval_dmc()
233 |
234 | step, episode, total_reward = 0, 0, 0
235 | n_eval_episode = force_number_episodes if force_number_episodes is not None else self.cfg.num_eval_episodes
236 | eval_until_episode = utils.Until(n_eval_episode)
237 | total_success = 0.0
238 | while eval_until_episode(episode):
239 | n_goal_achieved_total = 0
240 | time_step = self.eval_env.reset()
241 | self.video_recorder.init(self.eval_env, enabled=(episode == 0))
242 | while not time_step.last():
243 | with torch.no_grad(), utils.eval_mode(self.agent):
244 | observation = time_step.observation
245 | action = self.agent.act(observation,
246 | self.global_step,
247 | eval_mode=True,
248 | obs_sensor=time_step.observation_sensor)
249 | time_step = self.eval_env.step(action)
250 | n_goal_achieved_total += time_step.n_goal_achieved
251 | self.video_recorder.record(self.eval_env)
252 | total_reward += time_step.reward
253 | step += 1
254 |
255 | # here check if success for Adroit tasks. The threshold values come from the mj_envs code
256 | # e.g. https://github.com/ShahRutav/mj_envs/blob/5ee75c6e294dda47983eb4c60b6dd8f23a3f9aec/mj_envs/hand_manipulation_suite/pen_v0.py
257 | # can also use the evaluate_success function from Adroit envs, but can be more complicated
258 | if self.cfg.task_name == 'pen-v0':
259 | threshold = 20
260 | else:
261 | threshold = 25
262 | if n_goal_achieved_total > threshold:
263 | total_success += 1
264 |
265 | episode += 1
266 | self.video_recorder.save(f'{self.global_frame}.mp4')
267 | success_rate_standard = total_success / n_eval_episode
268 | episode_reward_standard = total_reward / episode
269 | episode_length_standard = step * self.cfg.action_repeat / episode
270 |
271 | if do_log:
272 | with self.logger.log_and_dump_ctx(self.global_frame, ty='eval') as log:
273 | log('episode_reward', episode_reward_standard)
274 | log('success_rate', success_rate_standard)
275 | log('episode_length', episode_length_standard)
276 | log('episode', self.global_episode)
277 | log('step', self.global_step)
278 |
279 | return episode_reward_standard, success_rate_standard
280 |
281 | def get_data_folder_path(self):
282 | amlt_data_path = os.getenv('AMLT_DATA_DIR', '')
283 | if amlt_data_path != '': # if on cluster
284 | return amlt_data_path
285 | else: # if not on cluster, return the data folder path specified in cfg
286 | return self.cfg.local_data_dir
287 |
288 | def get_demo_path(self, env_name):
289 | # given a environment name (with the -v0 part), return the path to its demo file
290 | data_folder_path = self.get_data_folder_path()
291 | demo_folder_path = os.path.join(data_folder_path, "demonstrations")
292 | demo_path = os.path.join(demo_folder_path, env_name + "_demos.pickle")
293 | return demo_path
294 |
295 | def load_demo(self, replay_storage, env_name, verbose=True):
296 | # will load demo data and put them into a replay storage
297 | demo_path = self.get_demo_path(env_name)
298 | if verbose:
299 | print("Trying to get demo data from:", demo_path)
300 |
301 | # get the raw state demo data, a list of length 25
302 | demo_data = pickle.load(open(demo_path, 'rb'))
303 | if self.cfg.num_demo >= 0:
304 | demo_data = demo_data[:self.cfg.num_demo]
305 |
306 | """
307 | the adroit demo data is in raw state, so we need to convert them into image data
308 | we init an env and run episodes with the stored actions in the demo data
309 | then put the image and sensor data into the replay buffer, also need to clip actions here
310 | this part is basically following the RRL code
311 | """
312 | demo_env = AdroitEnv(env_name, test_image=False, num_repeats=1, num_frames=self.cfg.frame_stack,
313 | env_feature_type=self.env_feature_type, device=self.device, reward_rescale=self.cfg.reward_rescale)
314 | demo_env.reset()
315 |
316 | total_data_count = 0
317 | for i_path in range(len(demo_data)):
318 | path = demo_data[i_path]
319 | demo_env.reset()
320 | demo_env.set_env_state(path['init_state_dict'])
321 | time_step = demo_env.get_current_obs_without_reset()
322 | replay_storage.add(time_step)
323 |
324 | ep_reward = 0
325 | ep_n_goal = 0
326 | for i_act in range(len(path['actions'])):
327 | total_data_count += 1
328 | action = path['actions'][i_act]
329 | action = action.astype(np.float32)
330 | # when action is put into the environment, they will be clipped.
331 | action[action > 1] = 1
332 | action[action < -1] = -1
333 |
334 | # when they collect the demo data, they actually did not use a timelimit...
335 | if i_act == len(path['actions']) - 1:
336 | force_step_type = 'last'
337 | else:
338 | force_step_type = 'mid'
339 |
340 | time_step = demo_env.step(action, force_step_type=force_step_type)
341 | replay_storage.add(time_step)
342 |
343 | reward = time_step.reward
344 | ep_reward += reward
345 |
346 | goal_achieved = time_step.n_goal_achieved
347 | ep_n_goal += goal_achieved
348 | if verbose:
349 | print('demo trajectory %d, len: %d, return: %.2f, goal achieved steps: %d' %
350 | (i_path, len(path['actions']), ep_reward, ep_n_goal))
351 | if verbose:
352 | print("Demo data load finished, total data count:", total_data_count)
353 |
354 | def get_pretrained_model_path(self, stage1_model_name):
355 | # given a stage1 model name, return the path to the pretrained model
356 | data_folder_path = self.get_data_folder_path()
357 | model_folder_path = os.path.join(data_folder_path, "trained_models")
358 | model_path = os.path.join(model_folder_path, stage1_model_name + '_checkpoint.pth.tar')
359 | return model_path
360 |
361 | def train(self):
362 | train_start_time = time.time()
363 | print("\n=== Training started! ===")
364 | """=================================== LOAD PRETRAINED MODEL ==================================="""
365 | if self.cfg.stage1_use_pretrain:
366 | self.agent.load_pretrained_encoder(self.get_pretrained_model_path(self.cfg.stage1_model_name))
367 | self.agent.switch_to_RL_stages()
368 |
369 | """========================================= LOAD DATA ========================================="""
370 | if self.cfg.load_demo:
371 | self.load_demo(self.replay_storage, self.cfg.task_name)
372 | print("Model and data loading finished in %.2f hours." % ((time.time()-train_start_time) / 3600))
373 |
374 | """========================================== STAGE 2 =========================================="""
375 | print("\n=== Stage 2 started ===")
376 | stage2_start_time = time.time()
377 | stage2_n_update = self.cfg.stage2_n_update
378 | if stage2_n_update > 0:
379 | for i_stage2 in range(stage2_n_update):
380 | metrics = self.agent.update(self.replay_iter, i_stage2, stage=2, use_sensor=IS_ADROIT)
381 | if i_stage2 % self.cfg.stage2_eval_every_frames == 0:
382 | average_score, succ_rate = self.eval_adroit(force_number_episodes=self.cfg.stage2_num_eval_episodes,
383 | do_log=False)
384 | print('Stage 2 step %d, Q(s,a): %.2f, Q loss: %.2f, score: %.2f, succ rate: %.2f' %
385 | (i_stage2, metrics['critic_q1'], metrics['critic_loss'], average_score, succ_rate))
386 | if self.cfg.show_computation_time_est and i_stage2 > 0 and i_stage2 % self.cfg.show_time_est_interval == 0:
387 | print_stage2_time_est(time.time()-stage2_start_time, i_stage2+1, stage2_n_update)
388 | print("Stage 2 finished in %.2f hours." % ((time.time()-stage2_start_time) / 3600))
389 |
390 | """========================================== STAGE 3 =========================================="""
391 | print("\n=== Stage 3 started ===")
392 | stage3_start_time = time.time()
393 | # predicates
394 | train_until_step = utils.Until(self.cfg.num_train_frames, self.cfg.action_repeat)
395 | seed_until_step = utils.Until(self.cfg.num_seed_frames, self.cfg.action_repeat)
396 | eval_every_step = utils.Every(self.cfg.eval_every_frames, self.cfg.action_repeat)
397 |
398 | episode_step, episode_reward = 0, 0
399 | time_step = self.train_env.reset()
400 | self.replay_storage.add(time_step)
401 | self.train_video_recorder.init(time_step.observation)
402 | metrics = None
403 |
404 | episode_step_since_log, episode_reward_list, episode_frame_list = 0, [0], [0]
405 | self.timer.reset()
406 | while train_until_step(self.global_step):
407 | # if 1000 steps passed, do some logging
408 | if self.global_step % 1000 == 0 and metrics is not None:
409 | elapsed_time, total_time = self.timer.reset()
410 | episode_frame_since_log = episode_step_since_log * self.cfg.action_repeat
411 | with self.logger.log_and_dump_ctx(self.global_frame, ty='train') as log:
412 | log('fps', episode_frame_since_log / elapsed_time)
413 | log('total_time', total_time)
414 | log('episode_reward', np.mean(episode_reward_list))
415 | log('episode_length', np.mean(episode_frame_list))
416 | log('episode', self.global_episode)
417 | log('buffer_size', len(self.replay_storage))
418 | log('step', self.global_step)
419 | episode_step_since_log, episode_reward_list, episode_frame_list = 0, [0], [0]
420 | if self.cfg.show_computation_time_est and self.global_step > 0 and self.global_step % self.cfg.show_time_est_interval == 0:
421 | print_stage3_time_est(time.time() - stage3_start_time, self.global_frame + 1, self.cfg.num_train_frames)
422 |
423 | # if reached end of episode
424 | if time_step.last():
425 | self._global_episode += 1
426 | self.train_video_recorder.save(f'{self.global_frame}.mp4')
427 | # wait until all the metrics schema is populated
428 | if metrics is not None:
429 | # log stats
430 | episode_step_since_log += episode_step
431 | episode_reward_list.append(episode_reward)
432 | episode_frame = episode_step * self.cfg.action_repeat
433 | episode_frame_list.append(episode_frame)
434 |
435 | # reset env
436 | time_step = self.train_env.reset()
437 | self.replay_storage.add(time_step)
438 | self.train_video_recorder.init(time_step.observation)
439 | # try to save snapshot
440 | if self.cfg.save_snapshot:
441 | self.save_snapshot()
442 | episode_step, episode_reward = 0, 0
443 |
444 | # try to evaluate
445 | if eval_every_step(self.global_step):
446 | self.logger.log('eval_total_time', self.timer.total_time(), self.global_frame)
447 | self.eval_adroit()
448 |
449 | # sample action
450 | if IS_ADROIT:
451 | obs_sensor = time_step.observation_sensor
452 | else:
453 | obs_sensor = None
454 | with torch.no_grad(), utils.eval_mode(self.agent):
455 | action = self.agent.act(time_step.observation,
456 | self.global_step,
457 | eval_mode=False,
458 | obs_sensor=obs_sensor)
459 |
460 | # update the agent
461 | if not seed_until_step(self.global_step):
462 | metrics = self.agent.update(self.replay_iter, self.global_step, stage=3, use_sensor=IS_ADROIT)
463 | self.logger.log_metrics(metrics, self.global_frame, ty='train')
464 |
465 | # take env step
466 | time_step = self.train_env.step(action)
467 | episode_reward += time_step.reward
468 | self.replay_storage.add(time_step)
469 | self.train_video_recorder.record(time_step.observation)
470 | episode_step += 1
471 | self._global_step += 1
472 | """here we move the experiment files to azure blob"""
473 | if (self.global_step==1) or (self.global_step == 10000) or (self.global_step % 100000 == 0):
474 | try:
475 | self.copy_to_azure()
476 | except Exception as e:
477 | print(e)
478 |
479 | """here save model for later"""
480 | if self.cfg.save_models:
481 | if self.global_frame in (2, 100000, 500000, 1000000, 2000000, 4000000):
482 | self.save_snapshot(suffix=str(self.global_frame))
483 |
484 | try:
485 | self.copy_to_azure()
486 | except Exception as e:
487 | print(e)
488 | print("Stage 3 finished in %.2f hours." % ((time.time()-stage3_start_time) / 3600))
489 | print("All stages finished in %.2f hrs. Work dir:" % ((time.time()-train_start_time)/3600))
490 | print(self.work_dir)
491 |
492 | def save_snapshot(self, suffix=None):
493 | if suffix is None:
494 | save_name = 'snapshot.pt'
495 | else:
496 | save_name = 'snapshot' + suffix + '.pt'
497 | snapshot = self.work_dir / save_name
498 | keys_to_save = ['agent', 'timer', '_global_step', '_global_episode']
499 | payload = {k: self.__dict__[k] for k in keys_to_save}
500 | with snapshot.open('wb') as f:
501 | torch.save(payload, f)
502 | print("snapshot saved to:", str(snapshot))
503 |
504 | def load_snapshot(self):
505 | snapshot = self.work_dir / 'snapshot.pt'
506 | with snapshot.open('rb') as f:
507 | payload = torch.load(f)
508 | for k, v in payload.items():
509 | self.__dict__[k] = v
510 |
511 | def copy_to_azure(self):
512 | amlt_path = os.getenv('AMLT_OUTPUT_DIR', '')
513 | if amlt_path != '': # if on cluster
514 | container_log_path = self.work_dir
515 | amlt_path_to = os.path.join(amlt_path, self.direct_folder_name)
516 | copy_tree(str(container_log_path), amlt_path_to, update=1)
517 | # copytree(str(container_log_path), amlt_path_to, dirs_exist_ok=True, ignore=ignore_patterns('*.npy'))
518 | print("Data copied to:", amlt_path_to)
519 | # else: # if at local
520 | # container_log_path = self.work_dir
521 | # amlt_path_to = '/vrl3data/logs'
522 | # copy_tree(str(container_log_path), amlt_path_to, update=1)
523 |
524 | @hydra.main(config_path='cfgs_adroit', config_name='config')
525 | def main(cfg):
526 | # TODO potentially check the task name and decide which libs to load here?
527 | W = Workspace
528 | root_dir = Path.cwd()
529 | workspace = W(cfg)
530 | snapshot = root_dir / 'snapshot.pt'
531 | if snapshot.exists():
532 | print(f'resuming: {snapshot}')
533 | workspace.load_snapshot()
534 | workspace.train()
535 |
536 |
537 | if __name__ == '__main__':
538 | main()
--------------------------------------------------------------------------------
/src/train_stage1.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT License.
3 |
4 | # this file is modified from the pytorch official tutorial
5 | # NOTE: stage 1 training code is currently being cleaned up
6 |
7 | import argparse
8 | import os
9 | import random
10 | import shutil
11 | import time
12 | import warnings
13 |
14 | import torch
15 | import torch.nn as nn
16 | import torch.nn.parallel
17 | import torch.backends.cudnn as cudnn
18 | import torch.distributed as dist
19 | import torch.optim
20 | import torch.multiprocessing as mp
21 | import torch.utils.data
22 | import torch.utils.data.distributed
23 | import torchvision.transforms as transforms
24 | import torchvision.datasets as datasets
25 | import torchvision.models as models
26 |
27 | # TODO use another config file to indicate the location of the training data and also where to save models...
28 | # should also add an option to just test the accuracy of models...
29 | # we probably can test this locally ....
30 |
31 | from stage1_models import BasicBlock, ResNet84
32 |
33 | rl_model_names = ['resnet6_32channel', 'resnet10_32channel', 'resnet18_32channel',
34 | 'resnet6_64channel', 'resnet10_64channel', 'resnet18_64channel',]
35 | model_names = sorted(name for name in models.__dict__
36 | if name.islower() and not name.startswith("__")
37 | and callable(models.__dict__[name])) + rl_model_names
38 |
39 | parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')
40 | parser.add_argument('data', metavar='DIR',
41 | help='path to dataset')
42 | parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet10_32channel',
43 | choices=model_names,
44 | help='model architecture: ' +
45 | ' | '.join(model_names) +
46 | ' (default: resnet18)')
47 | parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
48 | help='number of data loading workers (default: 4)')
49 | parser.add_argument('--epochs', default=90, type=int, metavar='N',
50 | help='number of total epochs to run')
51 | parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
52 | help='manual epoch number (useful on restarts)')
53 | parser.add_argument('-b', '--batch-size', default=256, type=int,
54 | metavar='N',
55 | help='mini-batch size (default: 256), this is the total '
56 | 'batch size of all GPUs on the current node when '
57 | 'using Data Parallel or Distributed Data Parallel')
58 | parser.add_argument('--lr', '--learning-rate', default=0.1, type=float,
59 | metavar='LR', help='initial learning rate', dest='lr')
60 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
61 | help='momentum')
62 | parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float,
63 | metavar='W', help='weight decay (default: 1e-4)',
64 | dest='weight_decay')
65 | parser.add_argument('-p', '--print-freq', default=10, type=int,
66 | metavar='N', help='print frequency (default: 10)')
67 | parser.add_argument('--resume', default='', type=str, metavar='PATH',
68 | help='path to latest checkpoint (default: none)')
69 | parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true',
70 | help='evaluate model on validation set')
71 | parser.add_argument('--pretrained', dest='pretrained', action='store_true',
72 | help='use pre-trained model')
73 | parser.add_argument('--world-size', default=-1, type=int,
74 | help='number of nodes for distributed training')
75 | parser.add_argument('--rank', default=-1, type=int,
76 | help='node rank for distributed training')
77 | parser.add_argument('--dist-url', default='tcp://224.66.41.62:23456', type=str,
78 | help='url used to set up distributed training')
79 | parser.add_argument('--dist-backend', default='nccl', type=str,
80 | help='distributed backend')
81 | parser.add_argument('--seed', default=None, type=int,
82 | help='seed for initializing training. ')
83 | parser.add_argument('--gpu', default=None, type=int,
84 | help='GPU id to use.')
85 | parser.add_argument('--multiprocessing-distributed', action='store_true',
86 | help='Use multi-processing distributed training to launch '
87 | 'N processes per node, which has N GPUs. This is the '
88 | 'fastest way to use PyTorch for either single node or '
89 | 'multi node data parallel training')
90 | parser.add_argument('--debug', default=0, type=int,
91 | help='1 for debug mode, 2 for super fast debug mode')
92 |
93 | best_acc1 = 0
94 | INPUT_SIZE = 84
95 | VAL_RESIZE = 100
96 |
97 | def main():
98 | print(model_names)
99 |
100 | args = parser.parse_args()
101 |
102 | if args.seed is not None:
103 | random.seed(args.seed)
104 | torch.manual_seed(args.seed)
105 | cudnn.deterministic = True
106 | warnings.warn('You have chosen to seed training. '
107 | 'This will turn on the CUDNN deterministic setting, '
108 | 'which can slow down your training considerably! '
109 | 'You may see unexpected behavior when restarting '
110 | 'from checkpoints.')
111 |
112 | if args.gpu is not None:
113 | warnings.warn('You have chosen a specific GPU. This will completely '
114 | 'disable data parallelism.')
115 |
116 | if args.dist_url == "env://" and args.world_size == -1:
117 | args.world_size = int(os.environ["WORLD_SIZE"])
118 |
119 | args.distributed = args.world_size > 1 or args.multiprocessing_distributed
120 |
121 | ngpus_per_node = torch.cuda.device_count()
122 | if args.multiprocessing_distributed:
123 | # Since we have ngpus_per_node processes per node, the total world_size
124 | # needs to be adjusted accordingly
125 | args.world_size = ngpus_per_node * args.world_size
126 | # Use torch.multiprocessing.spawn to launch distributed processes: the
127 | # main_worker process function
128 | mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, args))
129 | else:
130 | # Simply call main_worker function
131 | main_worker(args.gpu, ngpus_per_node, args)
132 |
133 |
134 | def main_worker(gpu, ngpus_per_node, args):
135 | global best_acc1
136 | args.gpu = gpu
137 |
138 | if args.gpu is not None:
139 | print("Use GPU: {} for training".format(args.gpu))
140 |
141 | if args.distributed:
142 | if args.dist_url == "env://" and args.rank == -1:
143 | args.rank = int(os.environ["RANK"])
144 | if args.multiprocessing_distributed:
145 | # For multiprocessing distributed training, rank needs to be the
146 | # global rank among all the processes
147 | args.rank = args.rank * ngpus_per_node + gpu
148 | dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
149 | world_size=args.world_size, rank=args.rank)
150 | # create model
151 | if args.debug > 0:
152 | print("=> creating model for debug 2")
153 | # model = ResNet84(BasicBlock, [1, 1, 1, 1], num_classes=5) # 1, 1, 1, 1 will make a resnet10
154 | model = ResNet84(BasicBlock, [0, 0, 0, 0], num_classes=5) # 0, 0, 0, 0 make a convnet6 (5 conv layers in total lol)
155 | x = torch.rand((1, 3, 84, 84)).float()
156 | out = model(x)
157 | print(model)
158 | quit()
159 |
160 | # model = ResNetTest2(BasicBlock, [2, 2, 2, 2])
161 | #model = Drq4Encoder((3, 84, 84), n_channel, 200)
162 | else:
163 | if args.pretrained:
164 | print("=> using pre-trained model '{}'".format(args.arch))
165 | model = models.__dict__[args.arch](pretrained=True)
166 | else:
167 | print("=> creating model '{}'".format(args.arch))
168 | if args.arch in rl_model_names:
169 | if args.arch == 'resnet18_32channel':
170 | model = ResNet84(BasicBlock, [2, 2, 2, 2], start_num_channel=32) # 1, 1, 1, 1 will make a resnet10
171 | elif args.arch == 'resnet10_32channel':
172 | model = ResNet84(BasicBlock, [1, 1, 1, 1], start_num_channel=32) # 1, 1, 1, 1 will make a resnet10
173 | elif args.arch == 'resnet6_32channel':
174 | model = ResNet84(BasicBlock, [0, 0, 0, 0], start_num_channel=32) # resnet 6 (actually not even resnet because no skip connection)
175 | elif args.arch == 'resnet18_64channel':
176 | model = ResNet84(BasicBlock, [2, 2, 2, 2], start_num_channel=64)
177 | elif args.arch == 'resnet10_64channel':
178 | model = ResNet84(BasicBlock, [1, 1, 1, 1], start_num_channel=64)
179 | elif args.arch == 'resnet6_64channel':
180 | model = ResNet84(BasicBlock, [0, 0, 0, 0], start_num_channel=64)
181 | else:
182 | print("specialized model not yet implemented")
183 | quit()
184 | else:
185 | model = models.__dict__[args.arch]()
186 |
187 | if not torch.cuda.is_available():
188 | print('using CPU, this will be slow')
189 | elif args.distributed:
190 | print("distributed")
191 | # For multiprocessing distributed, DistributedDataParallel constructor
192 | # should always set the single device scope, otherwise,
193 | # DistributedDataParallel will use all available devices.
194 | if args.gpu is not None:
195 | torch.cuda.set_device(args.gpu)
196 | model.cuda(args.gpu)
197 | # When using a single GPU per process and per
198 | # DistributedDataParallel, we need to divide the batch size
199 | # ourselves based on the total number of GPUs we have
200 | args.batch_size = int(args.batch_size / ngpus_per_node)
201 | args.workers = int((args.workers + ngpus_per_node - 1) / ngpus_per_node)
202 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
203 | else:
204 | model.cuda()
205 | # DistributedDataParallel will divide and allocate batch_size to all
206 | # available GPUs if device_ids are not set
207 | model = torch.nn.parallel.DistributedDataParallel(model)
208 | elif args.gpu is not None:
209 | print("use gpu:", args.gpu)
210 | torch.cuda.set_device(args.gpu)
211 | model = model.cuda(args.gpu)
212 | else:
213 | print("data parallel")
214 | # DataParallel will divide and allocate batch_size to all available GPUs
215 | if args.arch.startswith('alexnet') or args.arch.startswith('vgg'):
216 | model.features = torch.nn.DataParallel(model.features)
217 | model.cuda()
218 | else:
219 | model = torch.nn.DataParallel(model).cuda()
220 |
221 | # define loss function (criterion) and optimizer
222 | criterion = nn.CrossEntropyLoss().cuda(args.gpu)
223 |
224 | optimizer = torch.optim.SGD(model.parameters(), args.lr,
225 | momentum=args.momentum,
226 | weight_decay=args.weight_decay)
227 |
228 | # optionally resume from a checkpoint
229 | if args.resume:
230 | if os.path.isfile(args.resume):
231 | print("=> loading checkpoint '{}'".format(args.resume))
232 | if args.gpu is None:
233 | checkpoint = torch.load(args.resume)
234 | else:
235 | # Map model to be loaded to specified single gpu.
236 | loc = 'cuda:{}'.format(args.gpu)
237 | checkpoint = torch.load(args.resume, map_location=loc)
238 | args.start_epoch = checkpoint['epoch']
239 | best_acc1 = checkpoint['best_acc1']
240 | if args.gpu is not None:
241 | # best_acc1 may be from a checkpoint from a different GPU
242 | best_acc1 = best_acc1.to(args.gpu)
243 | model.load_state_dict(checkpoint['state_dict'])
244 | optimizer.load_state_dict(checkpoint['optimizer'])
245 | print("=> loaded checkpoint '{}' (epoch {})"
246 | .format(args.resume, checkpoint['epoch']))
247 | else:
248 | print("=> no checkpoint found at '{}'".format(args.resume))
249 |
250 | cudnn.benchmark = True
251 |
252 | # Data loading code
253 | traindir = os.path.join(args.data, 'train')
254 | valdir = os.path.join(args.data, 'val')
255 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
256 | std=[0.229, 0.224, 0.225])
257 |
258 | print("train directory is:", traindir)
259 | print("val directory is:", valdir)
260 |
261 | train_dataset = datasets.ImageFolder(
262 | traindir,
263 | transforms.Compose([
264 | transforms.RandomResizedCrop(INPUT_SIZE),
265 | transforms.RandomHorizontalFlip(),
266 | transforms.ToTensor(),
267 | normalize,
268 | ]))
269 |
270 | print("data set ready")
271 |
272 | if args.distributed:
273 | train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
274 | else:
275 | train_sampler = None
276 |
277 | train_loader = torch.utils.data.DataLoader(
278 | train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None),
279 | num_workers=args.workers, pin_memory=True, sampler=train_sampler)
280 |
281 | val_loader = torch.utils.data.DataLoader(
282 | datasets.ImageFolder(valdir, transforms.Compose([
283 | # transforms.Resize(256),
284 | transforms.Resize(VAL_RESIZE),
285 | transforms.CenterCrop(INPUT_SIZE),
286 | transforms.ToTensor(),
287 | normalize,
288 | ])),
289 | batch_size=args.batch_size, shuffle=False,
290 | num_workers=args.workers, pin_memory=True)
291 |
292 | if args.evaluate:
293 | validate(val_loader, model, criterion, args)
294 | return
295 |
296 | for epoch in range(args.start_epoch, args.epochs):
297 | print(epoch)
298 | epoch_start_time = time.time()
299 |
300 | if args.distributed:
301 | train_sampler.set_epoch(epoch)
302 | adjust_learning_rate(optimizer, epoch, args)
303 |
304 | # train for one epoch
305 | train(train_loader, model, criterion, optimizer, epoch, args)
306 |
307 | # evaluate on validation set
308 | acc1 = validate(val_loader, model, criterion, args)
309 |
310 | # remember best acc@1 and save checkpoint
311 | is_best = acc1 > best_acc1
312 | best_acc1 = max(acc1, best_acc1)
313 |
314 | if not args.multiprocessing_distributed or (args.multiprocessing_distributed
315 | and args.rank % ngpus_per_node == 0):
316 | save_checkpoint({
317 | 'epoch': epoch + 1,
318 | 'arch': args.arch,
319 | 'state_dict': model.state_dict(),
320 | 'best_acc1': best_acc1,
321 | 'optimizer' : optimizer.state_dict(),
322 | }, is_best,
323 | save_name_prefix=args.arch)
324 |
325 | epoch_end_time = time.time() - epoch_start_time
326 | print("epoch finished in %.3f hour" % (epoch_end_time/3600))
327 |
328 | def train(train_loader, model, criterion, optimizer, epoch, args):
329 | batch_time = AverageMeter('Time', ':6.3f')
330 | data_time = AverageMeter('Data', ':6.3f')
331 | losses = AverageMeter('Loss', ':.4e')
332 | top1 = AverageMeter('Acc@1', ':6.2f')
333 | top5 = AverageMeter('Acc@5', ':6.2f')
334 | progress = ProgressMeter(
335 | len(train_loader),
336 | [batch_time, data_time, losses, top1, top5],
337 | prefix="Epoch: [{}]".format(epoch))
338 |
339 | # switch to train mode
340 | model.train()
341 |
342 | end = time.time()
343 | for i, (images, target) in enumerate(train_loader):
344 | # measure data loading time
345 | data_time.update(time.time() - end)
346 |
347 | if args.gpu is not None:
348 | images = images.cuda(args.gpu, non_blocking=True)
349 | if torch.cuda.is_available():
350 | target = target.cuda(args.gpu, non_blocking=True)
351 |
352 | # compute output
353 | output = model(images)
354 | loss = criterion(output, target)
355 |
356 | # measure accuracy and record loss
357 | acc1, acc5 = accuracy(output, target, topk=(1, 5))
358 | losses.update(loss.item(), images.size(0))
359 | top1.update(acc1[0], images.size(0))
360 | top5.update(acc5[0], images.size(0))
361 |
362 | # compute gradient and do SGD step
363 | optimizer.zero_grad()
364 | loss.backward()
365 | optimizer.step()
366 |
367 | # measure elapsed time
368 | batch_time.update(time.time() - end)
369 | end = time.time()
370 |
371 | if i % args.print_freq == 0:
372 | progress.display(i)
373 |
374 |
375 | def validate(val_loader, model, criterion, args):
376 | batch_time = AverageMeter('Time', ':6.3f')
377 | losses = AverageMeter('Loss', ':.4e')
378 | top1 = AverageMeter('Acc@1', ':6.2f')
379 | top5 = AverageMeter('Acc@5', ':6.2f')
380 | progress = ProgressMeter(
381 | len(val_loader),
382 | [batch_time, losses, top1, top5],
383 | prefix='Test: ')
384 |
385 | # switch to evaluate mode
386 | model.eval()
387 |
388 | with torch.no_grad():
389 | end = time.time()
390 | for i, (images, target) in enumerate(val_loader):
391 | if args.gpu is not None:
392 | images = images.cuda(args.gpu, non_blocking=True)
393 | if torch.cuda.is_available():
394 | target = target.cuda(args.gpu, non_blocking=True)
395 |
396 | # compute output
397 | output = model(images)
398 | loss = criterion(output, target)
399 |
400 | # measure accuracy and record loss
401 | acc1, acc5 = accuracy(output, target, topk=(1, 5))
402 | losses.update(loss.item(), images.size(0))
403 | top1.update(acc1[0], images.size(0))
404 | top5.update(acc5[0], images.size(0))
405 |
406 | # measure elapsed time
407 | batch_time.update(time.time() - end)
408 | end = time.time()
409 |
410 | if i % args.print_freq == 0:
411 | progress.display(i)
412 |
413 | # TODO: this should also be done with the ProgressMeter
414 | print(' * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}'
415 | .format(top1=top1, top5=top5))
416 |
417 | return top1.avg
418 |
419 |
420 | def save_checkpoint(state, is_best, filename='checkpoint.pth.tar', save_name_prefix = ''):
421 | save_name = save_name_prefix + '_' + filename
422 | torch.save(state, save_name)
423 | if is_best:
424 | best_model_save_name = save_name_prefix + '_' + 'model_best.pth.tar'
425 | shutil.copyfile(save_name, best_model_save_name)
426 |
427 | class AverageMeter(object):
428 | """Computes and stores the average and current value"""
429 | def __init__(self, name, fmt=':f'):
430 | self.name = name
431 | self.fmt = fmt
432 | self.reset()
433 |
434 | def reset(self):
435 | self.val = 0
436 | self.avg = 0
437 | self.sum = 0
438 | self.count = 0
439 |
440 | def update(self, val, n=1):
441 | self.val = val
442 | self.sum += val * n
443 | self.count += n
444 | self.avg = self.sum / self.count
445 |
446 | def __str__(self):
447 | fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
448 | return fmtstr.format(**self.__dict__)
449 |
450 |
451 | class ProgressMeter(object):
452 | def __init__(self, num_batches, meters, prefix=""):
453 | self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
454 | self.meters = meters
455 | self.prefix = prefix
456 |
457 | def display(self, batch):
458 | entries = [self.prefix + self.batch_fmtstr.format(batch)]
459 | entries += [str(meter) for meter in self.meters]
460 | print('\t'.join(entries))
461 |
462 | def _get_batch_fmtstr(self, num_batches):
463 | num_digits = len(str(num_batches // 1))
464 | fmt = '{:' + str(num_digits) + 'd}'
465 | return '[' + fmt + '/' + fmt.format(num_batches) + ']'
466 |
467 |
468 | def adjust_learning_rate(optimizer, epoch, args):
469 | """Sets the learning rate to the initial LR decayed by 10 every 30 epochs"""
470 | lr = args.lr * (0.1 ** (epoch // 30))
471 | for param_group in optimizer.param_groups:
472 | param_group['lr'] = lr
473 |
474 |
475 | def accuracy(output, target, topk=(1,)):
476 | """Computes the accuracy over the k top predictions for the specified values of k"""
477 | with torch.no_grad():
478 | maxk = max(topk)
479 | batch_size = target.size(0)
480 |
481 | _, pred = output.topk(maxk, 1, True, True)
482 | pred = pred.t()
483 | correct = pred.eq(target.view(1, -1).expand_as(pred))
484 |
485 | res = []
486 | for k in topk:
487 | correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
488 | res.append(correct_k.mul_(100.0 / batch_size))
489 | return res
490 |
491 |
492 | if __name__ == '__main__':
493 | main()
--------------------------------------------------------------------------------
/src/transfer_util.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT License.
3 |
4 | from torchvision import datasets, models, transforms
5 |
6 | def set_parameter_requires_grad(model, n_resblock_finetune):
7 | assert n_resblock_finetune in (0, 1, 2, 3, 4, 5)
8 | for param in model.parameters():
9 | param.requires_grad = False
10 |
11 | for name, param in model.named_parameters():
12 | condition = (n_resblock_finetune >= 1 and 'layer4' in name) or (n_resblock_finetune >= 2 and 'layer3' in name) or \
13 | (n_resblock_finetune >= 3 and 'layer2' in name) or (n_resblock_finetune >= 4 and 'layer1' in name) or \
14 | (n_resblock_finetune >= 5)
15 |
16 | if condition:
17 | param.requires_grad = True
18 |
19 | for name, param in model.named_parameters():
20 | if 'bn' in name:
21 | param.requires_grad = False
22 |
23 | def initialize_model(model_name, n_resblock_finetune, use_pretrained=True):
24 | # Initialize these variables which will be set in this if statement. Each of these
25 | # variables is model specific.
26 | model_ft = None
27 | input_size = 0
28 |
29 | if model_name == "resnet18":
30 | """ Resnet18
31 | """
32 | model_ft = models.resnet18(pretrained=use_pretrained)
33 | set_parameter_requires_grad(model_ft, n_resblock_finetune)
34 | feature_size = model_ft.fc.in_features
35 | input_size = 224
36 | elif model_name == 'resnet34':
37 | model_ft = models.resnet34(pretrained=use_pretrained)
38 | set_parameter_requires_grad(model_ft, n_resblock_finetune)
39 | feature_size = model_ft.fc.in_features
40 | input_size = 224
41 | else:
42 | print("Invalid model name, exiting...")
43 | exit()
44 |
45 | return model_ft, input_size, feature_size
46 |
47 |
48 |
49 |
50 |
51 |
--------------------------------------------------------------------------------
/src/utils.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 | import random
6 | import re
7 | import time
8 |
9 | import numpy as np
10 | import torch
11 | import torch.nn as nn
12 | import torch.nn.functional as F
13 | from omegaconf import OmegaConf
14 | from torch import distributions as pyd
15 | from torch.distributions.utils import _standard_normal
16 |
17 |
18 | class eval_mode:
19 | def __init__(self, *models):
20 | self.models = models
21 |
22 | def __enter__(self):
23 | self.prev_states = []
24 | for model in self.models:
25 | self.prev_states.append(model.training)
26 | model.train(False)
27 |
28 | def __exit__(self, *args):
29 | for model, state in zip(self.models, self.prev_states):
30 | model.train(state)
31 | return False
32 |
33 |
34 | def set_seed_everywhere(seed):
35 | torch.manual_seed(seed)
36 | if torch.cuda.is_available():
37 | torch.cuda.manual_seed_all(seed)
38 | np.random.seed(seed)
39 | random.seed(seed)
40 |
41 |
42 | def soft_update_params(net, target_net, tau):
43 | for param, target_param in zip(net.parameters(), target_net.parameters()):
44 | target_param.data.copy_(tau * param.data +
45 | (1 - tau) * target_param.data)
46 |
47 |
48 | def to_torch(xs, device):
49 | return tuple(torch.as_tensor(x, device=device) for x in xs)
50 |
51 |
52 | def weight_init(m):
53 | if isinstance(m, nn.Linear):
54 | nn.init.orthogonal_(m.weight.data)
55 | if hasattr(m.bias, 'data'):
56 | m.bias.data.fill_(0.0)
57 | elif isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
58 | gain = nn.init.calculate_gain('relu')
59 | nn.init.orthogonal_(m.weight.data, gain)
60 | if hasattr(m.bias, 'data'):
61 | m.bias.data.fill_(0.0)
62 |
63 |
64 | class Until:
65 | def __init__(self, until, action_repeat=1):
66 | self._until = until
67 | self._action_repeat = action_repeat
68 |
69 | def __call__(self, step):
70 | if self._until is None:
71 | return True
72 | until = self._until // self._action_repeat
73 | return step < until
74 |
75 |
76 | class Every:
77 | def __init__(self, every, action_repeat=1):
78 | self._every = every
79 | self._action_repeat = action_repeat
80 |
81 | def __call__(self, step):
82 | if self._every is None:
83 | return False
84 | every = self._every // self._action_repeat
85 | if step % every == 0:
86 | return True
87 | return False
88 |
89 |
90 | class Timer:
91 | def __init__(self):
92 | self._start_time = time.time()
93 | self._last_time = time.time()
94 |
95 | def reset(self):
96 | elapsed_time = time.time() - self._last_time
97 | self._last_time = time.time()
98 | total_time = time.time() - self._start_time
99 | return elapsed_time, total_time
100 |
101 | def total_time(self):
102 | return time.time() - self._start_time
103 |
104 |
105 | class TruncatedNormal(pyd.Normal):
106 | def __init__(self, loc, scale, low=-1.0, high=1.0, eps=1e-6):
107 | super().__init__(loc, scale, validate_args=False)
108 | self.low = low
109 | self.high = high
110 | self.eps = eps
111 |
112 | def _clamp(self, x):
113 | clamped_x = torch.clamp(x, self.low + self.eps, self.high - self.eps)
114 | x = x - x.detach() + clamped_x.detach()
115 | return x
116 |
117 | def sample(self, clip=None, sample_shape=torch.Size()):
118 | shape = self._extended_shape(sample_shape)
119 | eps = _standard_normal(shape,
120 | dtype=self.loc.dtype,
121 | device=self.loc.device)
122 | eps *= self.scale
123 | if clip is not None:
124 | eps = torch.clamp(eps, -clip, clip)
125 | x = self.loc + eps
126 | return self._clamp(x)
127 |
128 |
129 | def schedule(schdl, step):
130 | try:
131 | return float(schdl)
132 | except ValueError:
133 | match = re.match(r'linear\((.+),(.+),(.+)\)', schdl)
134 | if match:
135 | init, final, duration = [float(g) for g in match.groups()]
136 | mix = np.clip(step / duration, 0.0, 1.0)
137 | return (1.0 - mix) * init + mix * final
138 | match = re.match(r'step_linear\((.+),(.+),(.+),(.+),(.+)\)', schdl)
139 | if match:
140 | init, final1, duration1, final2, duration2 = [
141 | float(g) for g in match.groups()
142 | ]
143 | if step <= duration1:
144 | mix = np.clip(step / duration1, 0.0, 1.0)
145 | return (1.0 - mix) * init + mix * final1
146 | else:
147 | mix = np.clip((step - duration1) / duration2, 0.0, 1.0)
148 | return (1.0 - mix) * final1 + mix * final2
149 | raise NotImplementedError(schdl)
150 |
--------------------------------------------------------------------------------
/src/video.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 | import cv2
6 | import imageio
7 | import numpy as np
8 |
9 | class VideoRecorder:
10 | def __init__(self, root_dir, render_size=256, fps=20):
11 | if root_dir is not None:
12 | self.save_dir = root_dir / 'eval_video'
13 | self.save_dir.mkdir(exist_ok=True)
14 | else:
15 | self.save_dir = None
16 |
17 | self.render_size = render_size
18 | self.fps = fps
19 | self.frames = []
20 |
21 | def init(self, env, enabled=True):
22 | self.frames = []
23 | self.enabled = self.save_dir is not None and enabled
24 | self.record(env)
25 |
26 | def record(self, env):
27 | if self.enabled:
28 | if hasattr(env, 'physics'):
29 | frame = env.physics.render(height=self.render_size,
30 | width=self.render_size,
31 | camera_id=0)
32 | else:
33 | frame = env.render()
34 | self.frames.append(frame)
35 |
36 | def save(self, file_name):
37 | if self.enabled:
38 | path = self.save_dir / file_name
39 | imageio.mimsave(str(path), self.frames, fps=self.fps)
40 |
41 |
42 | class TrainVideoRecorder:
43 | def __init__(self, root_dir, render_size=256, fps=20):
44 | if root_dir is not None:
45 | self.save_dir = root_dir / 'train_video'
46 | self.save_dir.mkdir(exist_ok=True)
47 | else:
48 | self.save_dir = None
49 |
50 | self.render_size = render_size
51 | self.fps = fps
52 | self.frames = []
53 |
54 | def init(self, obs, enabled=True):
55 | self.frames = []
56 | self.enabled = self.save_dir is not None and enabled
57 | self.record(obs)
58 |
59 | def record(self, obs):
60 | if self.enabled:
61 | frame = cv2.resize(obs[-3:].transpose(1, 2, 0),
62 | dsize=(self.render_size, self.render_size),
63 | interpolation=cv2.INTER_CUBIC)
64 | self.frames.append(frame)
65 |
66 | def save(self, file_name):
67 | if self.enabled:
68 | path = self.save_dir / file_name
69 | imageio.mimsave(str(path), self.frames, fps=self.fps)
70 |
--------------------------------------------------------------------------------
/src/vrl3_agent.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT License.
3 |
4 | import numpy as np
5 | import torch
6 | import torch.nn as nn
7 | import torch.nn.functional as F
8 | from torchvision import datasets, models, transforms
9 | from transfer_util import initialize_model
10 | from stage1_models import BasicBlock, ResNet84
11 | import os
12 | import copy
13 | from PIL import Image
14 | import platform
15 | from numbers import Number
16 | import utils
17 |
18 | class RandomShiftsAug(nn.Module):
19 | def __init__(self, pad):
20 | super().__init__()
21 | self.pad = pad
22 |
23 | def forward(self, x):
24 | n, c, h, w = x.size()
25 | assert h == w
26 | padding = tuple([self.pad] * 4)
27 | x = F.pad(x, padding, 'replicate')
28 | eps = 1.0 / (h + 2 * self.pad)
29 | arange = torch.linspace(-1.0 + eps,
30 | 1.0 - eps,
31 | h + 2 * self.pad,
32 | device=x.device,
33 | dtype=x.dtype)[:h]
34 | arange = arange.unsqueeze(0).repeat(h, 1).unsqueeze(2)
35 | base_grid = torch.cat([arange, arange.transpose(1, 0)], dim=2)
36 | base_grid = base_grid.unsqueeze(0).repeat(n, 1, 1, 1)
37 |
38 | shift = torch.randint(0,
39 | 2 * self.pad + 1,
40 | size=(n, 1, 1, 2),
41 | device=x.device,
42 | dtype=x.dtype)
43 | shift *= 2.0 / (h + 2 * self.pad)
44 |
45 | grid = base_grid + shift
46 | return F.grid_sample(x,
47 | grid,
48 | padding_mode='zeros',
49 | align_corners=False)
50 |
51 | class Identity(nn.Module):
52 | def __init__(self, input_placeholder=None):
53 | super(Identity, self).__init__()
54 |
55 | def forward(self, x):
56 | return x
57 |
58 | class RLEncoder(nn.Module):
59 | def __init__(self, obs_shape, model_name, device):
60 | super().__init__()
61 | # a wrapper over a non-RL encoder model
62 | self.device = device
63 | assert len(obs_shape) == 3
64 | self.n_input_channel = obs_shape[0]
65 | assert self.n_input_channel % 3 == 0
66 | self.n_images = self.n_input_channel // 3
67 | self.model = self.init_model(model_name)
68 | self.model.fc = Identity()
69 | self.repr_dim = self.model.get_feature_size()
70 |
71 | self.normalize_op = transforms.Normalize((0.485, 0.456, 0.406),
72 | (0.229, 0.224, 0.225))
73 | self.channel_mismatch = True
74 |
75 | def init_model(self, model_name):
76 | # model name is e.g. resnet6_32channel
77 | n_layer_string, n_channel_string = model_name.split('_')
78 | layer_string_to_layer_list = {
79 | 'resnet6': [0, 0, 0, 0],
80 | 'resnet10': [1, 1, 1, 1],
81 | 'resnet18': [2, 2, 2, 2],
82 | }
83 | channel_string_to_n_channel = {
84 | '32channel': 32,
85 | '64channel': 64,
86 | }
87 | layer_list = layer_string_to_layer_list[n_layer_string]
88 | start_num_channel = channel_string_to_n_channel[n_channel_string]
89 | return ResNet84(BasicBlock, layer_list, start_num_channel=start_num_channel).to(self.device)
90 |
91 | def expand_first_layer(self):
92 | # convolutional channel expansion to deal with input mismatch
93 | multiplier = self.n_images
94 | self.model.conv1.weight.data = self.model.conv1.weight.data.repeat(1,multiplier,1,1) / multiplier
95 | means = (0.485, 0.456, 0.406) * multiplier
96 | stds = (0.229, 0.224, 0.225) * multiplier
97 | self.normalize_op = transforms.Normalize(means, stds)
98 | self.channel_mismatch = False
99 |
100 | def freeze_bn(self):
101 | # freeze batch norm layers (VRL3 ablation shows modifying how
102 | # batch norm is trained does not affect performance)
103 | for module in self.model.modules():
104 | if isinstance(module, nn.BatchNorm2d):
105 | if hasattr(module, 'weight'):
106 | module.weight.requires_grad_(False)
107 | if hasattr(module, 'bias'):
108 | module.bias.requires_grad_(False)
109 | module.eval()
110 |
111 | def get_parameters_that_require_grad(self):
112 | params = []
113 | for name, param in self.named_parameters():
114 | if param.requires_grad == True:
115 | params.append(param)
116 | return params
117 |
118 | def transform_obs_tensor_batch(self, obs):
119 | # transform obs batch before put into the pretrained resnet
120 | new_obs = self.normalize_op(obs.float()/255)
121 | return new_obs
122 |
123 | def _forward_impl(self, x):
124 | x = self.model.get_features(x)
125 | return x
126 |
127 | def forward(self, obs):
128 | o = self.transform_obs_tensor_batch(obs)
129 | h = self._forward_impl(o)
130 | return h
131 |
132 | class Stage3ShallowEncoder(nn.Module):
133 | def __init__(self, obs_shape, n_channel):
134 | super().__init__()
135 |
136 | assert len(obs_shape) == 3
137 | self.repr_dim = n_channel * 35 * 35
138 |
139 | self.n_input_channel = obs_shape[0]
140 | self.conv1 = nn.Conv2d(obs_shape[0], n_channel, 3, stride=2)
141 | self.conv2 = nn.Conv2d(n_channel, n_channel, 3, stride=1)
142 | self.conv3 = nn.Conv2d(n_channel, n_channel, 3, stride=1)
143 | self.conv4 = nn.Conv2d(n_channel, n_channel, 3, stride=1)
144 | self.relu = nn.ReLU(inplace=True)
145 |
146 | # TODO here add prediction head so we can do contrastive learning...
147 |
148 | self.apply(utils.weight_init)
149 | self.normalize_op = transforms.Normalize((0.485, 0.456, 0.406, 0.485, 0.456, 0.406, 0.485, 0.456, 0.406),
150 | (0.229, 0.224, 0.225, 0.229, 0.224, 0.225, 0.229, 0.224, 0.225))
151 |
152 | self.compress = nn.Sequential(nn.Linear(self.repr_dim, 50), nn.LayerNorm(50), nn.Tanh())
153 | self.pred_layer = nn.Linear(50, 50, bias=False)
154 |
155 | def transform_obs_tensor_batch(self, obs):
156 | # transform obs batch before put into the pretrained resnet
157 | # correct order might be first augment, then resize, then normalize
158 | # obs = F.interpolate(obs, size=self.pretrained_model_input_size)
159 | new_obs = obs / 255.0 - 0.5
160 | # new_obs = self.normalize_op(new_obs)
161 | return new_obs
162 |
163 | def _forward_impl(self, x):
164 | x = self.relu(self.conv1(x))
165 | x = self.relu(self.conv2(x))
166 | x = self.relu(self.conv3(x))
167 | x = self.relu(self.conv4(x))
168 | return x
169 |
170 | def forward(self, obs):
171 | o = self.transform_obs_tensor_batch(obs)
172 | h = self._forward_impl(o)
173 | h = h.view(h.shape[0], -1)
174 | return h
175 |
176 | def get_anchor_output(self, obs, actions=None):
177 | # typically go through conv and then compression layer and then a mlp
178 | # used for UL update
179 | conv_out = self.forward(obs)
180 | compressed = self.compress(conv_out)
181 | pred = self.pred_layer(compressed)
182 | return pred, conv_out
183 |
184 | def get_positive_output(self, obs):
185 | # typically go through conv, compression
186 | # used for UL update
187 | conv_out = self.forward(obs)
188 | compressed = self.compress(conv_out)
189 | return compressed
190 |
191 | class Encoder(nn.Module):
192 | def __init__(self, obs_shape, n_channel):
193 | super().__init__()
194 |
195 | assert len(obs_shape) == 3
196 | self.repr_dim = n_channel * 35 * 35
197 |
198 | self.convnet = nn.Sequential(nn.Conv2d(obs_shape[0], n_channel, 3, stride=2),
199 | nn.ReLU(), nn.Conv2d(n_channel, n_channel, 3, stride=1),
200 | nn.ReLU(), nn.Conv2d(n_channel, n_channel, 3, stride=1),
201 | nn.ReLU(), nn.Conv2d(n_channel, n_channel, 3, stride=1),
202 | nn.ReLU())
203 |
204 | self.apply(utils.weight_init)
205 |
206 | def forward(self, obs):
207 | obs = obs / 255.0 - 0.5
208 | h = self.convnet(obs)
209 | h = h.view(h.shape[0], -1)
210 | return h
211 |
212 | class IdentityEncoder(nn.Module):
213 | def __init__(self, obs_shape):
214 | super().__init__()
215 |
216 | assert len(obs_shape) == 1
217 | self.repr_dim = obs_shape[0]
218 |
219 | def forward(self, obs):
220 | return obs
221 |
222 | class Actor(nn.Module):
223 | def __init__(self, repr_dim, action_shape, feature_dim, hidden_dim):
224 | super().__init__()
225 |
226 | self.trunk = nn.Sequential(nn.Linear(repr_dim, feature_dim),
227 | nn.LayerNorm(feature_dim), nn.Tanh())
228 |
229 | self.policy = nn.Sequential(nn.Linear(feature_dim, hidden_dim),
230 | nn.ReLU(inplace=True),
231 | nn.Linear(hidden_dim, hidden_dim),
232 | nn.ReLU(inplace=True),
233 | nn.Linear(hidden_dim, action_shape[0]))
234 |
235 | self.action_shift=0
236 | self.action_scale=1
237 | self.apply(utils.weight_init)
238 |
239 | def forward(self, obs, std):
240 | h = self.trunk(obs)
241 |
242 | mu = self.policy(h)
243 | mu = torch.tanh(mu)
244 | mu = mu * self.action_scale + self.action_shift
245 | std = torch.ones_like(mu) * std
246 |
247 | dist = utils.TruncatedNormal(mu, std)
248 | return dist
249 |
250 | def forward_with_pretanh(self, obs, std):
251 | h = self.trunk(obs)
252 |
253 | mu = self.policy(h)
254 | pretanh = mu
255 | mu = torch.tanh(mu)
256 | mu = mu * self.action_scale + self.action_shift
257 | std = torch.ones_like(mu) * std
258 |
259 | dist = utils.TruncatedNormal(mu, std)
260 | return dist, pretanh
261 |
262 | class Critic(nn.Module):
263 | def __init__(self, repr_dim, action_shape, feature_dim, hidden_dim):
264 | super().__init__()
265 |
266 | self.trunk = nn.Sequential(nn.Linear(repr_dim, feature_dim),
267 | nn.LayerNorm(feature_dim), nn.Tanh())
268 |
269 | self.Q1 = nn.Sequential(
270 | nn.Linear(feature_dim + action_shape[0], hidden_dim),
271 | nn.ReLU(inplace=True), nn.Linear(hidden_dim, hidden_dim),
272 | nn.ReLU(inplace=True), nn.Linear(hidden_dim, 1))
273 |
274 | self.Q2 = nn.Sequential(
275 | nn.Linear(feature_dim + action_shape[0], hidden_dim),
276 | nn.ReLU(inplace=True), nn.Linear(hidden_dim, hidden_dim),
277 | nn.ReLU(inplace=True), nn.Linear(hidden_dim, 1))
278 |
279 | self.apply(utils.weight_init)
280 |
281 | def forward(self, obs, action):
282 | h = self.trunk(obs)
283 | h_action = torch.cat([h, action], dim=-1)
284 | q1 = self.Q1(h_action)
285 | q2 = self.Q2(h_action)
286 |
287 | return q1, q2
288 |
289 | class VRL3Agent:
290 | def __init__(self, obs_shape, action_shape, device, use_sensor, lr, feature_dim,
291 | hidden_dim, critic_target_tau, num_expl_steps,
292 | update_every_steps, stddev_clip, use_tb, use_data_aug, encoder_lr_scale,
293 | stage1_model_name, safe_q_target_factor, safe_q_threshold, pretanh_penalty, pretanh_threshold,
294 | stage2_update_encoder, cql_weight, cql_temp, cql_n_random, stage2_std, stage2_bc_weight,
295 | stage3_update_encoder, std0, std1, std_n_decay,
296 | stage3_bc_lam0, stage3_bc_lam1):
297 | self.device = device
298 | self.critic_target_tau = critic_target_tau
299 | self.update_every_steps = update_every_steps
300 | self.use_tb = use_tb
301 | self.num_expl_steps = num_expl_steps
302 |
303 | self.stage2_std = stage2_std
304 | self.stage2_update_encoder = stage2_update_encoder
305 |
306 | if std1 > std0:
307 | std1 = std0
308 | self.stddev_schedule = "linear(%s,%s,%s)" % (str(std0), str(std1), str(std_n_decay))
309 |
310 | self.stddev_clip = stddev_clip
311 | self.use_data_aug = use_data_aug
312 | self.safe_q_target_factor = safe_q_target_factor
313 | self.q_threshold = safe_q_threshold
314 | self.pretanh_penalty = pretanh_penalty
315 |
316 | self.cql_temp = cql_temp
317 | self.cql_weight = cql_weight
318 | self.cql_n_random = cql_n_random
319 |
320 | self.pretanh_threshold = pretanh_threshold
321 |
322 | self.stage2_bc_weight = stage2_bc_weight
323 | self.stage3_bc_lam0 = stage3_bc_lam0
324 | self.stage3_bc_lam1 = stage3_bc_lam1
325 |
326 | if stage3_update_encoder and encoder_lr_scale > 0 and len(obs_shape) > 1:
327 | self.stage3_update_encoder = True
328 | else:
329 | self.stage3_update_encoder = False
330 |
331 | self.encoder = RLEncoder(obs_shape, stage1_model_name, device).to(device)
332 |
333 | self.act_dim = action_shape[0]
334 |
335 | if use_sensor:
336 | downstream_input_dim = self.encoder.repr_dim + 24
337 | else:
338 | downstream_input_dim = self.encoder.repr_dim
339 |
340 | self.actor = Actor(downstream_input_dim, action_shape, feature_dim,
341 | hidden_dim).to(device)
342 | self.critic = Critic(downstream_input_dim, action_shape, feature_dim,
343 | hidden_dim).to(device)
344 | self.critic_target = Critic(downstream_input_dim, action_shape,
345 | feature_dim, hidden_dim).to(device)
346 | self.critic_target.load_state_dict(self.critic.state_dict())
347 |
348 | # optimizers
349 | self.actor_opt = torch.optim.Adam(self.actor.parameters(), lr=lr)
350 | self.critic_opt = torch.optim.Adam(self.critic.parameters(), lr=lr)
351 |
352 | encoder_lr = lr * encoder_lr_scale
353 | """ set up encoder optimizer """
354 | self.encoder_opt = torch.optim.Adam(self.encoder.parameters(), lr=encoder_lr)
355 | # data augmentation
356 | self.aug = RandomShiftsAug(pad=4)
357 | self.train()
358 | self.critic_target.train()
359 |
360 | def load_pretrained_encoder(self, model_path, verbose=True):
361 | if verbose:
362 | print("Trying to load pretrained model from:", model_path)
363 | checkpoint = torch.load(model_path, map_location=torch.device(self.device))
364 | state_dict = checkpoint['state_dict']
365 |
366 | pretrained_dict = {}
367 | # remove `module.` if model was pretrained with distributed mode
368 | for k, v in state_dict.items():
369 | if 'module.' in k:
370 | name = k[7:]
371 | else:
372 | name = k
373 | pretrained_dict[name] = v
374 | self.encoder.model.load_state_dict(pretrained_dict, strict=False)
375 | if verbose:
376 | print("Pretrained model loaded!")
377 |
378 | def switch_to_RL_stages(self, verbose=True):
379 | # run convolutional channel expansion to match input shape
380 | self.encoder.expand_first_layer()
381 | if verbose:
382 | print("Convolutional channel expansion finished: now can take in %d images as input." % self.encoder.n_images)
383 |
384 | def train(self, training=True):
385 | self.training = training
386 | self.encoder.train(training)
387 | self.actor.train(training)
388 | self.critic.train(training)
389 |
390 | def act(self, obs, step, eval_mode, obs_sensor=None, is_tensor_input=False, force_action_std=None):
391 | # eval_mode should be False when taking an exploration action in stage 3
392 | # eval_mode should be True when evaluate agent performance
393 | if force_action_std == None:
394 | stddev = utils.schedule(self.stddev_schedule, step)
395 | if step < self.num_expl_steps and not eval_mode:
396 | action = np.random.uniform(0, 1, (self.act_dim,)).astype(np.float32)
397 | return action
398 | else:
399 | stddev = force_action_std
400 |
401 | if is_tensor_input:
402 | obs = self.encoder(obs)
403 | else:
404 | obs = torch.as_tensor(obs, device=self.device)
405 | obs = self.encoder(obs.unsqueeze(0))
406 |
407 | if obs_sensor is not None:
408 | obs_sensor = torch.as_tensor(obs_sensor, device=self.device)
409 | obs_sensor = obs_sensor.unsqueeze(0)
410 | obs_combined = torch.cat([obs, obs_sensor], dim=1)
411 | else:
412 | obs_combined = obs
413 |
414 | dist = self.actor(obs_combined, stddev)
415 | if eval_mode:
416 | action = dist.mean
417 | else:
418 | action = dist.sample(clip=None)
419 | if step < self.num_expl_steps:
420 | action.uniform_(-1.0, 1.0)
421 | return action.cpu().numpy()[0]
422 |
423 | def update(self, replay_iter, step, stage, use_sensor):
424 | # for stage 2 and 3, we use the same functions but with different hyperparameters
425 | assert stage in (2, 3)
426 | metrics = dict()
427 |
428 | if stage == 2:
429 | update_encoder = self.stage2_update_encoder
430 | stddev = self.stage2_std
431 | conservative_loss_weight = self.cql_weight
432 | bc_weight = self.stage2_bc_weight
433 |
434 | if stage == 3:
435 | if step % self.update_every_steps != 0:
436 | return metrics
437 | update_encoder = self.stage3_update_encoder
438 |
439 | stddev = utils.schedule(self.stddev_schedule, step)
440 | conservative_loss_weight = 0
441 |
442 | # compute stage 3 BC weight
443 | bc_data_per_iter = 40000
444 | i_iter = step // bc_data_per_iter
445 | bc_weight = self.stage3_bc_lam0 * self.stage3_bc_lam1 ** i_iter
446 |
447 | # batch data
448 | batch = next(replay_iter)
449 | if use_sensor: # TODO might want to...?
450 | obs, action, reward, discount, next_obs, obs_sensor, obs_sensor_next = utils.to_torch(batch, self.device)
451 | else:
452 | obs, action, reward, discount, next_obs = utils.to_torch(batch, self.device)
453 | obs_sensor, obs_sensor_next = None, None
454 |
455 | # augment
456 | if self.use_data_aug:
457 | obs = self.aug(obs.float())
458 | next_obs = self.aug(next_obs.float())
459 | else:
460 | obs = obs.float()
461 | next_obs = next_obs.float()
462 |
463 | # encode
464 | if update_encoder:
465 | obs = self.encoder(obs)
466 | else:
467 | with torch.no_grad():
468 | obs = self.encoder(obs)
469 |
470 | with torch.no_grad():
471 | next_obs = self.encoder(next_obs)
472 |
473 | # concatenate obs with additional sensor observation if needed
474 | obs_combined = torch.cat([obs, obs_sensor], dim=1) if obs_sensor is not None else obs
475 | obs_next_combined = torch.cat([next_obs, obs_sensor_next], dim=1) if obs_sensor_next is not None else next_obs
476 |
477 | # update critic
478 | metrics.update(self.update_critic_vrl3(obs_combined, action, reward, discount, obs_next_combined,
479 | stddev, update_encoder, conservative_loss_weight))
480 |
481 | # update actor, following previous works, we do not use actor gradient for encoder update
482 | metrics.update(self.update_actor_vrl3(obs_combined.detach(), action, stddev, bc_weight,
483 | self.pretanh_penalty, self.pretanh_threshold))
484 |
485 | metrics['batch_reward'] = reward.mean().item()
486 |
487 | # update critic target networks
488 | utils.soft_update_params(self.critic, self.critic_target, self.critic_target_tau)
489 | return metrics
490 |
491 | def update_critic_vrl3(self, obs, action, reward, discount, next_obs, stddev, update_encoder, conservative_loss_weight):
492 | metrics = dict()
493 | batch_size = obs.shape[0]
494 |
495 | """
496 | STANDARD Q LOSS COMPUTATION:
497 | - get standard Q loss first, this is the same as in any other online RL methods
498 | - except for the safe Q technique, which controls how large the Q value can be
499 | """
500 | with torch.no_grad():
501 | dist = self.actor(next_obs, stddev)
502 | next_action = dist.sample(clip=self.stddev_clip)
503 | target_Q1, target_Q2 = self.critic_target(next_obs, next_action)
504 | target_V = torch.min(target_Q1, target_Q2)
505 | target_Q = reward + (discount * target_V)
506 |
507 | if self.safe_q_target_factor < 1:
508 | target_Q[target_Q > (self.q_threshold + 1)] = self.q_threshold + (target_Q[target_Q > (self.q_threshold+1)] - self.q_threshold) ** self.safe_q_target_factor
509 |
510 | Q1, Q2 = self.critic(obs, action)
511 | critic_loss = F.mse_loss(Q1, target_Q) + F.mse_loss(Q2, target_Q)
512 |
513 | """
514 | CONSERVATIVE Q LOSS COMPUTATION:
515 | - sample random actions, actions from policy and next actions from policy, as done in CQL authors' code
516 | (though this detail is not really discussed in the CQL paper)
517 | - only compute this loss when conservative loss weight > 0
518 | """
519 | if conservative_loss_weight > 0:
520 | random_actions = (torch.rand((batch_size * self.cql_n_random, self.act_dim), device=self.device) - 0.5) * 2
521 |
522 | dist = self.actor(obs, stddev)
523 | current_actions = dist.sample(clip=self.stddev_clip)
524 |
525 | dist = self.actor(next_obs, stddev)
526 | next_current_actions = dist.sample(clip=self.stddev_clip)
527 |
528 | # now get Q values for all these actions (for both Q networks)
529 | obs_repeat = obs.unsqueeze(1).repeat(1, self.cql_n_random, 1).view(obs.shape[0] * self.cql_n_random,
530 | obs.shape[1])
531 |
532 | Q1_rand, Q2_rand = self.critic(obs_repeat,
533 | random_actions) # TODO might want to double check the logic here see if the repeat is correct
534 | Q1_rand = Q1_rand.view(obs.shape[0], self.cql_n_random)
535 | Q2_rand = Q2_rand.view(obs.shape[0], self.cql_n_random)
536 |
537 | Q1_curr, Q2_curr = self.critic(obs, current_actions)
538 | Q1_curr_next, Q2_curr_next = self.critic(obs, next_current_actions)
539 |
540 | # now concat all these Q values together
541 | Q1_cat = torch.cat([Q1_rand, Q1, Q1_curr, Q1_curr_next], 1)
542 | Q2_cat = torch.cat([Q2_rand, Q2, Q2_curr, Q2_curr_next], 1)
543 |
544 | cql_min_q1_loss = torch.logsumexp(Q1_cat / self.cql_temp,
545 | dim=1, ).mean() * conservative_loss_weight * self.cql_temp
546 | cql_min_q2_loss = torch.logsumexp(Q2_cat / self.cql_temp,
547 | dim=1, ).mean() * conservative_loss_weight * self.cql_temp
548 |
549 | """Subtract the log likelihood of data"""
550 | conservative_q_loss = cql_min_q1_loss + cql_min_q2_loss - (Q1.mean() + Q2.mean()) * conservative_loss_weight
551 | critic_loss_combined = critic_loss + conservative_q_loss
552 | else:
553 | critic_loss_combined = critic_loss
554 |
555 | # logging
556 | metrics['critic_target_q'] = target_Q.mean().item()
557 | metrics['critic_q1'] = Q1.mean().item()
558 | metrics['critic_q2'] = Q2.mean().item()
559 | metrics['critic_loss'] = critic_loss.item()
560 |
561 | # if needed, also update encoder with critic loss
562 | if update_encoder:
563 | self.encoder_opt.zero_grad(set_to_none=True)
564 | self.critic_opt.zero_grad(set_to_none=True)
565 | critic_loss_combined.backward()
566 | self.critic_opt.step()
567 | if update_encoder:
568 | self.encoder_opt.step()
569 |
570 | return metrics
571 |
572 | def update_actor_vrl3(self, obs, action, stddev, bc_weight, pretanh_penalty, pretanh_threshold):
573 | metrics = dict()
574 |
575 | """
576 | get standard actor loss
577 | """
578 | dist, pretanh = self.actor.forward_with_pretanh(obs, stddev)
579 | current_action = dist.sample(clip=self.stddev_clip)
580 | log_prob = dist.log_prob(current_action).sum(-1, keepdim=True)
581 | Q1, Q2 = self.critic(obs, current_action)
582 | Q = torch.min(Q1, Q2)
583 | actor_loss = -Q.mean()
584 |
585 | """
586 | add BC loss
587 | """
588 | if bc_weight > 0:
589 | # get mean action with no action noise (though this might not be necessary)
590 | stddev_bc = 0
591 | dist_bc = self.actor(obs, stddev_bc)
592 | current_mean_action = dist_bc.sample(clip=self.stddev_clip)
593 | actor_loss_bc = F.mse_loss(current_mean_action, action) * bc_weight
594 | else:
595 | actor_loss_bc = torch.FloatTensor([0]).to(self.device)
596 |
597 | """
598 | add pretanh penalty (might not be necessary for Adroit)
599 | """
600 | pretanh_loss = 0
601 | if pretanh_penalty > 0:
602 | pretanh_loss = pretanh.abs() - pretanh_threshold
603 | pretanh_loss[pretanh_loss < 0] = 0
604 | pretanh_loss = (pretanh_loss ** 2).mean() * pretanh_penalty
605 |
606 | """
607 | combine actor losses and optimize
608 | """
609 | actor_loss_combined = actor_loss + actor_loss_bc + pretanh_loss
610 |
611 | self.actor_opt.zero_grad(set_to_none=True)
612 | actor_loss_combined.backward()
613 | self.actor_opt.step()
614 |
615 | metrics['actor_loss'] = actor_loss.item()
616 | metrics['actor_loss_bc'] = actor_loss_bc.item()
617 | metrics['actor_logprob'] = log_prob.mean().item()
618 | metrics['actor_ent'] = dist.entropy().sum(dim=-1).mean().item()
619 | metrics['abs_pretanh'] = pretanh.abs().mean().item()
620 | metrics['max_abs_pretanh'] = pretanh.abs().max().item()
621 |
622 | return metrics
623 |
--------------------------------------------------------------------------------