├── .gitignore
├── LICENSE
├── README.md
├── maple.yml
├── maple
├── __init__.py
├── core
│ ├── __init__.py
│ ├── batch_rl_algorithm.py
│ ├── eval_util.py
│ ├── logging.py
│ ├── loss.py
│ ├── online_rl_algorithm.py
│ ├── rl_algorithm.py
│ ├── serializable.py
│ ├── tabulate.py
│ └── trainer.py
├── data_management
│ ├── __init__.py
│ ├── env_replay_buffer.py
│ ├── normalizer.py
│ ├── obs_dict_replay_buffer.py
│ ├── online_vae_replay_buffer.py
│ ├── path_builder.py
│ ├── replay_buffer.py
│ ├── shared_obs_dict_replay_buffer.py
│ ├── simple_replay_buffer.py
│ └── split_buffer.py
├── envs
│ ├── __init__.py
│ ├── env_utils.py
│ ├── make_env.py
│ ├── mujoco_env.py
│ ├── mujoco_image_env.py
│ ├── proxy_env.py
│ ├── vae_wrapper.py
│ ├── wrappers.py
│ └── wrappers
│ │ ├── __init__.py
│ │ ├── discretize_env.py
│ │ ├── history_env.py
│ │ ├── image_mujoco_env.py
│ │ ├── image_mujoco_env_with_obs.py
│ │ ├── normalized_box_env.py
│ │ ├── reward_wrapper_env.py
│ │ └── stack_observation_env.py
├── exploration_strategies
│ ├── __init__.py
│ ├── base.py
│ ├── epsilon_greedy.py
│ ├── gaussian_and_epsilon_strategy.py
│ ├── gaussian_strategy.py
│ └── ou_strategy.py
├── launchers
│ ├── __init__.py
│ ├── conf.py
│ ├── launcher_util.py
│ ├── robosuite_launcher.py
│ └── visualization.py
├── policies
│ ├── __init__.py
│ ├── argmax.py
│ ├── base.py
│ └── simple.py
├── pythonplusplus.py
├── samplers
│ ├── __init__.py
│ ├── data_collector
│ │ ├── __init__.py
│ │ ├── base.py
│ │ ├── contextual_path_collector.py
│ │ ├── joint_path_collector.py
│ │ ├── path_collector.py
│ │ ├── step_collector.py
│ │ └── vae_env.py
│ ├── in_place.py
│ ├── rollout_functions.py
│ └── util.py
├── torch
│ ├── __init__.py
│ ├── conv_networks.py
│ ├── core.py
│ ├── data.py
│ ├── data_management
│ │ ├── __init__.py
│ │ └── normalizer.py
│ ├── distributions.py
│ ├── lvm
│ │ ├── __init__.py
│ │ ├── bear_vae.py
│ │ └── latent_variable_model.py
│ ├── modules.py
│ ├── networks
│ │ ├── __init__.py
│ │ ├── basic.py
│ │ ├── cnn.py
│ │ ├── custom.py
│ │ ├── dcnn.py
│ │ ├── feat_point_mlp.py
│ │ ├── image_state.py
│ │ ├── linear_transform.py
│ │ ├── mlp.py
│ │ ├── normalization.py
│ │ ├── pretrained_cnn.py
│ │ ├── stochastic
│ │ │ └── distribution_generator.py
│ │ └── two_headed_mlp.py
│ ├── pytorch_util.py
│ ├── sac
│ │ ├── __init__.py
│ │ ├── policies
│ │ │ ├── __init__.py
│ │ │ ├── base.py
│ │ │ ├── gaussian_policy.py
│ │ │ ├── lvm_policy.py
│ │ │ ├── pamdp_policy.py
│ │ │ └── policy_from_q.py
│ │ ├── sac.py
│ │ └── sac_hybrid.py
│ └── torch_rl_algorithm.py
└── util
│ ├── data_processing.py
│ ├── hyperparameter.py
│ ├── io.py
│ ├── ml_util.py
│ └── slurm_util.py
├── scripts
├── eval.py
├── run_experiment_from_doodad.py
└── train.py
└── setup.py
/.gitignore:
--------------------------------------------------------------------------------
1 | mjkey.txt
2 | **/.DS_STORE
3 | **/*.pyc
4 | **/*.swp
5 | **/*.pdf
6 | maple/launchers/conf_private.py
7 | MANIFEST
8 | *.egg-info
9 | \.idea/
10 | data
11 | MUJOCO_LOG.TXT
12 | **/sbatch/*
13 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2021 UT Robot Perception and Learning Lab
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # MAPLE: Augmenting Reinforcement Learning with Behavior Primitives for Diverse Manipulation Tasks
2 |
3 | This is the official codebase for **Ma**nipulation **P**rimitive-augmented reinforcement **Le**arning (MAPLE), from the following paper:
4 |
5 | **Augmenting Reinforcement Learning with Behavior Primitives for Diverse Manipulation Tasks**
6 |
[Soroush Nasiriany](http://snasiriany.me/), [Huihan Liu](https://huihanl.github.io/), [Yuke Zhu](https://www.cs.utexas.edu/~yukez/)
7 |
[UT Austin Robot Perception and Learning Lab](https://rpl.cs.utexas.edu/)
8 |
IEEE International Conference on Robotics and Automation (ICRA), 2022
9 |
**[[Paper]](https://arxiv.org/abs/2110.03655)** **[[Project Website]](https://ut-austin-rpl.github.io/maple/)**
10 |
11 |
12 |
13 |
14 | This guide contains information about (1) [Installation](#installation), (2) [Running Experiments](#running-experiments), (3) [Setting Up Your Own Environments](#setting-up-your-own-environments), (4) [Acknowledgement](#acknowledgement), and (5) [Citation](#citation).
15 |
16 | ## Installation
17 | ### Download code
18 | - Current codebase: ```git clone https://github.com/UT-Austin-RPL/maple```
19 | - (for environments) the `maple` branch in robosuite: ```git clone -b maple https://github.com/ARISE-Initiative/robosuite```
20 |
21 | ### Setup robosuite
22 | 1. Download MuJoCo 2.0 (Linux and Mac OS X) and unzip its contents into `~/.mujoco/mujoco200`, and copy your MuJoCo license key `~/.mujoco/mjkey.txt`. You can obtain a license key from [here](https://www.roboti.us/license.html).
23 | 2. (linux) Setup additional dependencies: ```sudo apt install libgl1-mesa-dev libgl1-mesa-glx libglew-dev libosmesa6-dev software-properties-common net-tools xpra xserver-xorg-dev libglfw3-dev patchelf```
24 | 3. Add MuJoCo to library paths: `export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$HOME/.mujoco/mujoco200/bin`
25 |
26 | ### Setup conda environment
27 | 1. Create the conda environment: `conda env create --name maple --file=maple.yml`
28 | 2. (if above fails) edit `maple.yml` to modify dependencies and then resume setup: `conda env update --name maple --file=maple.yml`
29 | 3. Activate the conda environment: `conda activate maple`
30 | 4. Finish maple setup: (in your maple repo path do) `pip install -e .`
31 | 5. Finish robosuite setup: (in your robosuite repo path do) `pip install -e .`
32 |
33 | ## Running Experiments
34 | Scripts for training policies and re-playing policy checkpoints are located in `scripts/train.py` and `scripts/eval.py`, respectively.
35 |
36 | These experiment scripts use the following structure:
37 | ```
38 | base_variant = dict(
39 | # default hyperparam settings for all envs
40 | )
41 |
42 | env_params = {
43 | '' : {
44 | # add/override default hyperparam settings for specific env
45 | # each setting is specified as a dictionary address (key),
46 | # followed by list of possible options (value).
47 | # Example in following line:
48 | # 'env_variant.controller_type': ['OSC_POSITION'],
49 | },
50 | '' : {
51 | ...
52 | },
53 | }
54 | ```
55 |
56 | ### Command Line Options
57 | See `parser` in `scripts/train.py` for a complete list of options. Some notable options:
58 | - `env`: the env to run (eg. `stack`)
59 | - `label`: name for experiment
60 | - `debug`: run with lite options for debugging
61 |
62 | ### Plotting Experiment Results
63 | During training, the results will be saved to a file called under `LOCAL_LOG_DIR///`.
64 | Inside this folder, the experiment results are stored in `progress.csv`. We recommend using [viskit](https://github.com/vitchyr/viskit) to plot the results.
65 |
66 | ## Setting Up Your Own Environments
67 | Note that this codebase is designed to work with robosuite environments only. For setting up your own environments, please follow [these examples](https://github.com/ARISE-Initiative/robosuite/tree/maple/robosuite/environments/manipulation) for reference. Notably, you will need to add the `skill_config` variable to the constructor, and define the keypoints for the affordance score by implementing the `_get_skill_info` function.
68 |
69 | If you would like to know the inner workings of the primitives, refer to [`skill_controller.py`](https://github.com/ARISE-Initiative/robosuite/tree/maple/robosuite/controllers/skill_controller.py) and [`skills.py`](https://github.com/ARISE-Initiative/robosuite/tree/maple/robosuite/controllers/skills.py). Note that we use the term "skill" to refer to behavior primitives in the code.
70 |
71 | ## Acknowledgement
72 | Much of this codebase is directly based on [RLkit](https://github.com/vitchyr/rlkit), which itself is based on [rllab](https://github.com/rll/rllab).
73 | In addition, the environments were developed as a forked branch of [robosuite](https://github.com/ARISE-Initiative/robosuite) `v1.1.0`.
74 |
75 | ## Citation
76 | ```bibtex
77 | @inproceedings{nasiriany2022maple,
78 | title={Augmenting Reinforcement Learning with Behavior Primitives for Diverse Manipulation Tasks},
79 | author={Soroush Nasiriany and Huihan Liu and Yuke Zhu},
80 | booktitle={IEEE International Conference on Robotics and Automation (ICRA)},
81 | year={2022}
82 | }
83 | ```
84 |
--------------------------------------------------------------------------------
/maple.yml:
--------------------------------------------------------------------------------
1 | name: maple
2 | channels:
3 | - defaults
4 | dependencies:
5 | - _libgcc_mutex=0.1=main
6 | - argon2-cffi=20.1.0=py38h27cfd23_1
7 | - async_generator=1.10=pyhd3eb1b0_0
8 | - backcall=0.2.0=pyhd3eb1b0_0
9 | - bleach=3.3.0=pyhd3eb1b0_0
10 | - ca-certificates=2021.1.19=h06a4308_1
11 | - certifi=2020.12.5=py38h06a4308_0
12 | - dbus=1.13.18=hb2f20db_0
13 | - decorator=4.4.2=pyhd3eb1b0_0
14 | - defusedxml=0.7.1=pyhd3eb1b0_0
15 | - entrypoints=0.3=py38_0
16 | - expat=2.3.0=h2531618_2
17 | - fontconfig=2.13.1=h6c09931_0
18 | - freetype=2.10.4=h5ab3b9f_0
19 | - glib=2.68.0=h36276a3_0
20 | - gst-plugins-base=1.14.0=h8213a91_2
21 | - gstreamer=1.14.0=h28cd5cc_2
22 | - icu=58.2=he6710b0_3
23 | - importlib-metadata=3.7.3=py38h06a4308_1
24 | - importlib_metadata=3.7.3=hd3eb1b0_1
25 | - ipykernel=5.3.4=py38h5ca1d4c_0
26 | - ipython=7.22.0=py38hb070fc8_0
27 | - ipython_genutils=0.2.0=pyhd3eb1b0_1
28 | - ipywidgets=7.6.3=pyhd3eb1b0_1
29 | - jedi=0.17.0=py38_0
30 | - jpeg=9b=h024ee3a_2
31 | - jsonschema=3.2.0=py_2
32 | - jupyter=1.0.0=py38_7
33 | - jupyter_client=6.1.12=pyhd3eb1b0_0
34 | - jupyter_console=6.4.0=pyhd3eb1b0_0
35 | - jupyter_core=4.7.1=py38h06a4308_0
36 | - jupyterlab_pygments=0.1.2=py_0
37 | - jupyterlab_widgets=1.0.0=pyhd3eb1b0_1
38 | - ld_impl_linux-64=2.33.1=h53a641e_7
39 | - libedit=3.1.20191231=h14c3975_1
40 | - libffi=3.3=he6710b0_2
41 | - libgcc-ng=9.1.0=hdf63c60_0
42 | - libpng=1.6.37=hbc83047_0
43 | - libsodium=1.0.18=h7b6447c_0
44 | - libstdcxx-ng=9.1.0=hdf63c60_0
45 | - libuuid=1.0.3=h1bed415_2
46 | - libxcb=1.14=h7b6447c_0
47 | - libxml2=2.9.10=hb55368b_3
48 | - markupsafe=1.1.1=py38h7b6447c_0
49 | - mistune=0.8.4=py38h7b6447c_1000
50 | - nbclient=0.5.3=pyhd3eb1b0_0
51 | - nbconvert=6.0.7=py38_0
52 | - ncurses=6.2=he6710b0_1
53 | - nest-asyncio=1.5.1=pyhd3eb1b0_0
54 | - notebook=6.3.0=py38h06a4308_0
55 | - openssl=1.1.1k=h27cfd23_0
56 | - packaging=20.9=pyhd3eb1b0_0
57 | - pandoc=2.12=h06a4308_0
58 | - pandocfilters=1.4.3=py38h06a4308_1
59 | - parso=0.8.1=pyhd3eb1b0_0
60 | - pcre=8.44=he6710b0_0
61 | - pexpect=4.8.0=pyhd3eb1b0_3
62 | - pickleshare=0.7.5=pyhd3eb1b0_1003
63 | - pip=20.2.2=py38_0
64 | - prometheus_client=0.9.0=pyhd3eb1b0_0
65 | - prompt-toolkit=3.0.17=pyh06a4308_0
66 | - prompt_toolkit=3.0.17=hd3eb1b0_0
67 | - ptyprocess=0.7.0=pyhd3eb1b0_2
68 | - pycparser=2.20=py_2
69 | - pygments=2.8.1=pyhd3eb1b0_0
70 | - pyparsing=2.4.7=pyhd3eb1b0_0
71 | - pyqt=5.9.2=py38h05f1152_4
72 | - pyrsistent=0.17.3=py38h7b6447c_0
73 | - python=3.8.5=h7579374_1
74 | - python-dateutil=2.8.1=pyhd3eb1b0_0
75 | - pyzmq=20.0.0=py38h2531618_1
76 | - qt=5.9.7=h5867ecd_1
77 | - qtconsole=5.0.3=pyhd3eb1b0_0
78 | - qtpy=1.9.0=py_0
79 | - readline=8.0=h7b6447c_0
80 | - send2trash=1.5.0=pyhd3eb1b0_1
81 | - setuptools=49.6.0=py38_0
82 | - sip=4.19.13=py38he6710b0_0
83 | - six=1.15.0=py38h06a4308_0
84 | - sqlite=3.33.0=h62c20be_0
85 | - terminado=0.9.4=py38h06a4308_0
86 | - testpath=0.4.4=pyhd3eb1b0_0
87 | - tk=8.6.10=hbc83047_0
88 | - tornado=6.1=py38h27cfd23_0
89 | - traitlets=5.0.5=pyhd3eb1b0_0
90 | - wcwidth=0.2.5=py_0
91 | - webencodings=0.5.1=py38_1
92 | - widgetsnbextension=3.5.1=py38_0
93 | - xz=5.2.5=h7b6447c_0
94 | - zeromq=4.3.4=h2531618_0
95 | - zipp=3.4.1=pyhd3eb1b0_0
96 | - zlib=1.2.11=h7b6447c_3
97 | - pip:
98 | - absl-py==0.12.0
99 | - addict==2.4.0
100 | - anyio==3.3.1
101 | - astunparse==1.6.3
102 | - attrs==20.2.0
103 | - babel==2.9.1
104 | - cachetools==4.2.1
105 | - cffi==1.14.2
106 | - chardet==3.0.4
107 | - clang==5.0
108 | - click==7.1.2
109 | - cloudpickle==1.3.0
110 | - cycler==0.10.0
111 | - cython==0.29.21
112 | - deprecation==2.1.0
113 | - dill==0.3.4
114 | - fasteners==0.15
115 | - flask==1.0.2
116 | - flatbuffers==1.12
117 | - future==0.18.2
118 | - gast==0.4.0
119 | - gitdb==4.0.2
120 | - gitdb2==2.0.6
121 | - gitpython==2.1.7
122 | - glfw==1.12.0
123 | - google-api-core==1.26.0
124 | - google-api-python-client==1.12.8
125 | - google-auth==1.26.1
126 | - google-auth-httplib2==0.0.4
127 | - google-auth-oauthlib==0.4.4
128 | - google-cloud-core==1.6.0
129 | - google-cloud-storage==1.36.0
130 | - google-crc32c==1.1.2
131 | - google-pasta==0.2.0
132 | - google-resumable-media==1.2.0
133 | - googleapis-common-protos==1.52.0
134 | - grpcio==1.39.0
135 | - gtimer==1.0.0b5
136 | - gym==0.17.2
137 | - h5py==3.1.0
138 | - httplib2==0.19.0
139 | - idna==2.10
140 | - imageio==2.9.0
141 | - imageio-ffmpeg==0.4.4
142 | - itsdangerous==1.1.0
143 | - jinja2==2.11.2
144 | - joblib==1.0.1
145 | - json5==0.9.6
146 | - jupyter-core==4.6.3
147 | - jupyter-packaging==0.10.4
148 | - jupyter-server==1.11.0
149 | - jupyterlab==3.1.12
150 | - jupyterlab-server==2.8.1
151 | - keras==2.6.0
152 | - keras-preprocessing==1.1.2
153 | - kiwisolver==1.3.2
154 | - llvmlite==0.32.1
155 | - lockfile==0.12.2
156 | - markdown==3.3.4
157 | - matplotlib==3.4.3
158 | - monotonic==1.5
159 | - moviepy==1.0.3
160 | - mujoco-py==2.0.2.9
161 | - nbclassic==0.3.1
162 | - nbformat==5.0.8
163 | - networkx==2.6.2
164 | - numba==0.49.1
165 | - numpy==1.19.5
166 | - nvisii==1.1.72
167 | - oauthlib==3.1.0
168 | - open3d==0.13.0
169 | - opencv-python==4.4.0.44
170 | - opt-einsum==3.3.0
171 | - pandas==1.3.3
172 | - pillow==8.3.2
173 | - plotly==3.4.2
174 | - proglog==0.1.9
175 | - progressbar2==3.53.1
176 | - protobuf==3.14.0
177 | - psutil==5.8.0
178 | - pyasn1==0.4.8
179 | - pyasn1-modules==0.2.8
180 | - pybullet==2.6.9
181 | - pyglet==1.5.0
182 | - python-utils==2.5.6
183 | - pytz==2020.4
184 | - pywavelets==1.1.1
185 | - pyyaml==5.4.1
186 | - requests==2.24.0
187 | - requests-oauthlib==1.3.0
188 | - requests-unixsocket==0.2.0
189 | - retrying==1.3.3
190 | - rsa==4.7
191 | - scikit-image==0.18.2
192 | - scikit-learn==0.24.2
193 | - scikit-video==1.1.11
194 | - scipy==1.5.2
195 | - smmap==3.0.5
196 | - smmap2==3.0.1
197 | - sniffio==1.2.0
198 | - tensorboard==2.6.0
199 | - tensorboard-data-server==0.6.1
200 | - tensorboard-plugin-wit==1.8.0
201 | - tensorboardx==2.4
202 | - tensorflow-estimator==2.6.0
203 | - tensorflow-gpu==2.6.0
204 | - termcolor==1.1.0
205 | - threadpoolctl==2.2.0
206 | - tifffile==2021.8.8
207 | - tomlkit==0.7.2
208 | - torch==1.6.0
209 | - torchvision==0.7.0
210 | - tqdm==4.61.1
211 | - typing-extensions==3.7.4.3
212 | - uritemplate==3.0.1
213 | - urllib3==1.25.10
214 | - websocket-client==1.2.1
215 | - werkzeug==1.0.1
216 | - wheel==0.37.0
217 | - wrapt==1.12.1
218 |
219 |
--------------------------------------------------------------------------------
/maple/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/UT-Austin-RPL/maple/c13f584d4184e5796bd0adb822b8761ff4ebeee0/maple/__init__.py
--------------------------------------------------------------------------------
/maple/core/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | General classes, functions, utilities that are used throughout maple.
3 | """
4 | from maple.core.logging import logger
5 |
6 | __all__ = ['logger']
7 |
8 |
--------------------------------------------------------------------------------
/maple/core/batch_rl_algorithm.py:
--------------------------------------------------------------------------------
1 | import abc
2 |
3 | import gtimer as gt
4 | from maple.core.rl_algorithm import BaseRLAlgorithm
5 | from maple.data_management.replay_buffer import ReplayBuffer
6 | from maple.samplers.data_collector import PathCollector
7 |
8 |
9 | class BatchRLAlgorithm(BaseRLAlgorithm, metaclass=abc.ABCMeta):
10 | def __init__(
11 | self,
12 | trainer,
13 | exploration_env,
14 | evaluation_env,
15 | exploration_data_collector: PathCollector,
16 | evaluation_data_collector: PathCollector,
17 | replay_buffer: ReplayBuffer,
18 | batch_size,
19 | max_path_length,
20 | num_epochs,
21 | num_eval_steps_per_epoch,
22 | num_expl_steps_per_train_loop,
23 | num_trains_per_train_loop,
24 | num_train_loops_per_epoch=1,
25 | min_num_steps_before_training=0,
26 | eval_epoch_freq=1,
27 | expl_epoch_freq=1,
28 | eval_only=False,
29 | no_training=False,
30 | ):
31 | super().__init__(
32 | trainer,
33 | exploration_env,
34 | evaluation_env,
35 | exploration_data_collector,
36 | evaluation_data_collector,
37 | replay_buffer,
38 | eval_epoch_freq=eval_epoch_freq,
39 | expl_epoch_freq=expl_epoch_freq,
40 | eval_only=eval_only,
41 | no_training=no_training,
42 | )
43 | self.batch_size = batch_size
44 | self.max_path_length = max_path_length
45 | self.num_epochs = num_epochs
46 | self.num_eval_steps_per_epoch = num_eval_steps_per_epoch
47 | self.num_trains_per_train_loop = num_trains_per_train_loop
48 | self.num_train_loops_per_epoch = num_train_loops_per_epoch
49 | self.num_expl_steps_per_train_loop = num_expl_steps_per_train_loop
50 | self.min_num_steps_before_training = min_num_steps_before_training
51 | gt.reset_root()
52 |
53 | def _train(self):
54 | if self.min_num_steps_before_training > 0 and not self._eval_only:
55 | init_expl_paths = self.expl_data_collector.collect_new_paths(
56 | self.max_path_length,
57 | self.min_num_steps_before_training,
58 | discard_incomplete_paths=True, #False,
59 | )
60 | self.replay_buffer.add_paths(init_expl_paths)
61 | self.expl_data_collector.end_epoch(-1)
62 |
63 | for epoch in gt.timed_for(
64 | range(self._start_epoch, self.num_epochs + 1),
65 | save_itrs=True,
66 | ):
67 | for pre_epoch_func in self.pre_epoch_funcs:
68 | pre_epoch_func(self, epoch)
69 |
70 | if epoch % self._eval_epoch_freq == 0:
71 | self.eval_data_collector.collect_new_paths(
72 | self.max_path_length,
73 | self.num_eval_steps_per_epoch,
74 | discard_incomplete_paths=True,
75 | )
76 | gt.stamp('evaluation sampling')
77 |
78 | if not self._eval_only:
79 | for _ in range(self.num_train_loops_per_epoch):
80 | if epoch % self._expl_epoch_freq == 0:
81 | new_expl_paths = self.expl_data_collector.collect_new_paths(
82 | self.max_path_length,
83 | self.num_expl_steps_per_train_loop,
84 | discard_incomplete_paths=True, #False,
85 | )
86 | gt.stamp('exploration sampling', unique=False)
87 |
88 | self.replay_buffer.add_paths(new_expl_paths)
89 | gt.stamp('data storing', unique=False)
90 |
91 | if not self._no_training:
92 | self.training_mode(True)
93 | for _ in range(self.num_trains_per_train_loop):
94 | train_data = self.replay_buffer.random_batch(
95 | self.batch_size)
96 | self.trainer.train(train_data)
97 | gt.stamp('training', unique=False)
98 | self.training_mode(False)
99 |
100 | self._end_epoch(epoch)
101 |
--------------------------------------------------------------------------------
/maple/core/eval_util.py:
--------------------------------------------------------------------------------
1 | """
2 | Common evaluation utilities.
3 | """
4 |
5 | from collections import OrderedDict
6 | from numbers import Number
7 |
8 | import numpy as np
9 | import itertools
10 |
11 | import maple.pythonplusplus as ppp
12 |
13 |
14 | def get_generic_path_information(paths, stat_prefix=''):
15 | """
16 | Get an OrderedDict with a bunch of statistic names and values.
17 | """
18 | statistics = OrderedDict()
19 | returns = [sum(path["rewards"]) for path in paths]
20 |
21 | rewards = np.vstack([path["rewards"] for path in paths])
22 | statistics.update(create_stats_ordered_dict('Rewards', rewards,
23 | stat_prefix=stat_prefix))
24 | statistics.update(create_stats_ordered_dict('Returns', returns,
25 | stat_prefix=stat_prefix))
26 | actions = [path["actions"] for path in paths]
27 | if len(actions[0].shape) == 1:
28 | actions = np.hstack([path["actions"] for path in paths])
29 | else:
30 | actions = np.vstack([path["actions"] for path in paths])
31 | statistics.update(create_stats_ordered_dict(
32 | 'Actions', actions, stat_prefix=stat_prefix
33 | ))
34 | statistics['Num Paths'] = len(paths)
35 | statistics[stat_prefix + 'Average Returns'] = get_average_returns(paths)
36 |
37 | statistics[stat_prefix + 'Task Returns Sum'] = np.mean([path['reward_actions_sum'] for path in paths])
38 | statistics[stat_prefix + 'Task Returns Avg'] = np.mean([path['reward_actions_sum'] / path['path_length_actions'] for path in paths])
39 |
40 | statistics[stat_prefix + 'Num Rollout Success'] = get_num_rollout_success(paths)
41 |
42 | for info_key in ['env_infos', 'agent_infos']:
43 | if info_key in paths[0]:
44 | all_env_infos = [
45 | ppp.list_of_dicts__to__dict_of_lists(p[info_key])
46 | for p in paths
47 | ]
48 | for k in all_env_infos[0].keys():
49 | final_ks = np.array([info[k][-1] for info in all_env_infos])
50 | first_ks = np.array([info[k][0] for info in all_env_infos])
51 | all_ks = np.concatenate([info[k] for info in all_env_infos])
52 | statistics.update(create_stats_ordered_dict(
53 | stat_prefix + k,
54 | final_ks,
55 | stat_prefix='{}/final/'.format(info_key),
56 | ))
57 | statistics.update(create_stats_ordered_dict(
58 | stat_prefix + k,
59 | first_ks,
60 | stat_prefix='{}/initial/'.format(info_key),
61 | ))
62 | statistics.update(create_stats_ordered_dict(
63 | stat_prefix + k,
64 | all_ks,
65 | stat_prefix='{}/'.format(info_key),
66 | ))
67 |
68 | return statistics
69 |
70 |
71 | def get_average_returns(paths):
72 | returns = [sum(path["rewards"]) for path in paths]
73 | return np.mean(returns)
74 |
75 | def get_num_rollout_success(paths):
76 | num_success = 0
77 | for path in paths:
78 | if any([info.get('success', False) for info in path['env_infos']]):
79 | num_success += 1
80 | return num_success
81 |
82 |
83 | def create_stats_ordered_dict(
84 | name,
85 | data,
86 | stat_prefix=None,
87 | always_show_all_stats=True,
88 | exclude_max_min=True,
89 | ):
90 | if stat_prefix is not None:
91 | name = "{}{}".format(stat_prefix, name)
92 | if isinstance(data, Number):
93 | return OrderedDict({name: data})
94 |
95 | if len(data) == 0:
96 | return OrderedDict()
97 |
98 | if isinstance(data, tuple):
99 | ordered_dict = OrderedDict()
100 | for number, d in enumerate(data):
101 | sub_dict = create_stats_ordered_dict(
102 | "{0}_{1}".format(name, number),
103 | d,
104 | )
105 | ordered_dict.update(sub_dict)
106 | return ordered_dict
107 |
108 | if isinstance(data, list):
109 | try:
110 | iter(data[0])
111 | except TypeError:
112 | pass
113 | else:
114 | data = np.concatenate(data)
115 |
116 | if (isinstance(data, np.ndarray) and data.size == 1
117 | and not always_show_all_stats):
118 | return OrderedDict({name: float(data)})
119 |
120 | stats = OrderedDict([
121 | (name + ' Mean', np.mean(data)),
122 | (name + ' Std', np.std(data)),
123 | ])
124 | if not exclude_max_min:
125 | stats[name + ' Max'] = np.max(data)
126 | stats[name + ' Min'] = np.min(data)
127 | return stats
128 |
--------------------------------------------------------------------------------
/maple/core/loss.py:
--------------------------------------------------------------------------------
1 | import abc
2 | from collections import OrderedDict
3 |
4 |
5 | LossStatistics = OrderedDict
6 |
7 |
8 | class LossFunction(object, metaclass=abc.ABCMeta):
9 | @abc.abstractmethod
10 | def compute_loss(self, batch, skip_statistics=False, **kwargs):
11 | """Returns loss and statistics given a batch of data.
12 | batch : Data to compute loss of
13 | skip_statistics: Whether statistics should be calculated. If True, then
14 | an empty dict is returned for the statistics.
15 |
16 | Returns: (loss, stats) tuple.
17 | """
18 | pass
19 |
--------------------------------------------------------------------------------
/maple/core/online_rl_algorithm.py:
--------------------------------------------------------------------------------
1 | import abc
2 |
3 | import gtimer as gt
4 | from maple.core.rl_algorithm import BaseRLAlgorithm
5 | from maple.data_management.replay_buffer import ReplayBuffer
6 | from maple.samplers.data_collector import (
7 | PathCollector,
8 | StepCollector,
9 | )
10 |
11 |
12 | class OnlineRLAlgorithm(BaseRLAlgorithm, metaclass=abc.ABCMeta):
13 | def __init__(
14 | self,
15 | trainer,
16 | exploration_env,
17 | evaluation_env,
18 | exploration_data_collector: StepCollector,
19 | evaluation_data_collector: PathCollector,
20 | replay_buffer: ReplayBuffer,
21 | batch_size,
22 | max_path_length,
23 | num_epochs,
24 | num_eval_steps_per_epoch,
25 | num_expl_steps_per_train_loop,
26 | num_trains_per_train_loop,
27 | num_train_loops_per_epoch=1,
28 | min_num_steps_before_training=0,
29 | ):
30 | super().__init__(
31 | trainer,
32 | exploration_env,
33 | evaluation_env,
34 | exploration_data_collector,
35 | evaluation_data_collector,
36 | replay_buffer,
37 | )
38 | self.batch_size = batch_size
39 | self.max_path_length = max_path_length
40 | self.num_epochs = num_epochs
41 | self.num_eval_steps_per_epoch = num_eval_steps_per_epoch
42 | self.num_trains_per_train_loop = num_trains_per_train_loop
43 | self.num_train_loops_per_epoch = num_train_loops_per_epoch
44 | self.num_expl_steps_per_train_loop = num_expl_steps_per_train_loop
45 | self.min_num_steps_before_training = min_num_steps_before_training
46 |
47 | assert self.num_trains_per_train_loop >= self.num_expl_steps_per_train_loop, \
48 | 'Online training presumes num_trains_per_train_loop >= num_expl_steps_per_train_loop'
49 |
50 | def _train(self):
51 | self.training_mode(False)
52 | if self.min_num_steps_before_training > 0:
53 | self.expl_data_collector.collect_new_steps(
54 | self.max_path_length,
55 | self.min_num_steps_before_training,
56 | discard_incomplete_paths=False,
57 | )
58 | init_expl_paths = self.expl_data_collector.get_epoch_paths()
59 | self.replay_buffer.add_paths(init_expl_paths)
60 | self.expl_data_collector.end_epoch(-1)
61 |
62 | gt.stamp('initial exploration', unique=True)
63 |
64 | num_trains_per_expl_step = self.num_trains_per_train_loop // self.num_expl_steps_per_train_loop
65 | for epoch in gt.timed_for(
66 | range(self._start_epoch, self.num_epochs),
67 | save_itrs=True,
68 | ):
69 | self.eval_data_collector.collect_new_paths(
70 | self.max_path_length,
71 | self.num_eval_steps_per_epoch,
72 | discard_incomplete_paths=True,
73 | )
74 | gt.stamp('evaluation sampling')
75 |
76 | for _ in range(self.num_train_loops_per_epoch):
77 | for _ in range(self.num_expl_steps_per_train_loop):
78 | self.expl_data_collector.collect_new_steps(
79 | self.max_path_length,
80 | 1, # num steps
81 | discard_incomplete_paths=False,
82 | )
83 | gt.stamp('exploration sampling', unique=False)
84 |
85 | self.training_mode(True)
86 | for _ in range(num_trains_per_expl_step):
87 | train_data = self.replay_buffer.random_batch(
88 | self.batch_size)
89 | self.trainer.train(train_data)
90 | gt.stamp('training', unique=False)
91 | self.training_mode(False)
92 |
93 | new_expl_paths = self.expl_data_collector.get_epoch_paths()
94 | self.replay_buffer.add_paths(new_expl_paths)
95 | gt.stamp('data storing', unique=False)
96 |
97 | self._end_epoch(epoch)
98 |
--------------------------------------------------------------------------------
/maple/core/rl_algorithm.py:
--------------------------------------------------------------------------------
1 | import abc
2 | from collections import OrderedDict
3 |
4 | import gtimer as gt
5 |
6 | from maple.core import logger, eval_util
7 | from maple.core.logging import append_log
8 | from maple.data_management.replay_buffer import ReplayBuffer
9 | from maple.samplers.data_collector import DataCollector
10 |
11 |
12 | def _get_epoch_timings():
13 | times_itrs = gt.get_times().stamps.itrs
14 | times = OrderedDict()
15 | epoch_time = 0
16 | for key in sorted(times_itrs):
17 | time = times_itrs[key][-1]
18 | epoch_time += time
19 | times['time/{} (s)'.format(key)] = time
20 | times['time/epoch (s)'] = epoch_time
21 | times['time/total (s)'] = gt.get_times().total
22 | return times
23 |
24 |
25 | class BaseRLAlgorithm(object, metaclass=abc.ABCMeta):
26 | def __init__(
27 | self,
28 | trainer,
29 | exploration_env,
30 | evaluation_env,
31 | exploration_data_collector: DataCollector,
32 | evaluation_data_collector: DataCollector,
33 | replay_buffer: ReplayBuffer,
34 | eval_epoch_freq=1,
35 | expl_epoch_freq=1,
36 | eval_only=False,
37 | no_training=False,
38 | ):
39 | self.trainer = trainer
40 | self.expl_env = exploration_env
41 | self.eval_env = evaluation_env
42 | self.expl_data_collector = exploration_data_collector
43 | self.eval_data_collector = evaluation_data_collector
44 | self.replay_buffer = replay_buffer
45 | self._start_epoch = 0
46 |
47 | self.pre_epoch_funcs = []
48 | self.post_epoch_funcs = []
49 |
50 | self._eval_epoch_freq = eval_epoch_freq
51 | self._expl_epoch_freq = expl_epoch_freq
52 | self._eval_only = eval_only
53 |
54 | self._no_training = no_training
55 |
56 | def train(self, start_epoch=0):
57 | self._start_epoch = start_epoch
58 | self._train()
59 |
60 | def _train(self):
61 | """
62 | Train model.
63 | """
64 | raise NotImplementedError('_train must implemented by inherited class')
65 |
66 | def _end_epoch(self, epoch):
67 | snapshot = self._get_snapshot()
68 | logger.save_itr_params(epoch, snapshot)
69 | gt.stamp('saving')
70 | self._log_stats(epoch)
71 |
72 | self.expl_data_collector.end_epoch(epoch)
73 | self.eval_data_collector.end_epoch(epoch)
74 | self.replay_buffer.end_epoch(epoch)
75 | self.trainer.end_epoch(epoch)
76 |
77 | for post_epoch_func in self.post_epoch_funcs:
78 | post_epoch_func(self, epoch)
79 |
80 | def _get_snapshot(self):
81 | snapshot = {}
82 | for k, v in self.trainer.get_snapshot().items():
83 | snapshot['trainer/' + k] = v
84 | for k, v in self.expl_data_collector.get_snapshot().items():
85 | snapshot['exploration/' + k] = v
86 | for k, v in self.eval_data_collector.get_snapshot().items():
87 | snapshot['evaluation/' + k] = v
88 | for k, v in self.replay_buffer.get_snapshot().items():
89 | snapshot['replay_buffer/' + k] = v
90 | return snapshot
91 |
92 | def _log_stats(self, epoch):
93 | dump_logs = False
94 | if not self._eval_only and epoch % self._expl_epoch_freq == 0:
95 | dump_logs = True
96 | if epoch % self._eval_epoch_freq == 0:
97 | dump_logs = True
98 | if not dump_logs:
99 | return
100 |
101 | logger.log("Epoch {} finished".format(epoch), with_timestamp=True)
102 |
103 | """
104 | Replay Buffer
105 | """
106 | if not self._eval_only:
107 | logger.record_dict(
108 | self.replay_buffer.get_diagnostics(),
109 | prefix='replay_buffer/'
110 | )
111 |
112 | """
113 | Trainer
114 | """
115 | if not self._eval_only:
116 | logger.record_dict(self.trainer.get_diagnostics(), prefix='trainer/')
117 |
118 | """
119 | Exploration
120 | """
121 | if not self._eval_only:
122 | if epoch % self._expl_epoch_freq == 0:
123 | self._cur_expl_log = OrderedDict()
124 | self._cur_expl_log.update(
125 | self.expl_data_collector.get_diagnostics()
126 | )
127 | expl_paths = self.expl_data_collector.get_epoch_paths()
128 | if hasattr(self.expl_env, 'get_diagnostics'):
129 | self._cur_expl_log.update(
130 | self.expl_env.get_diagnostics(expl_paths),
131 | )
132 | self._cur_expl_log.update(
133 | eval_util.get_generic_path_information(expl_paths),
134 | )
135 | logger.record_dict(self._cur_expl_log, prefix='expl/')
136 |
137 | """
138 | Evaluation
139 | """
140 | if epoch % self._eval_epoch_freq == 0:
141 | self._cur_eval_log = OrderedDict()
142 | self._cur_eval_log.update(
143 | self.eval_data_collector.get_diagnostics()
144 | )
145 | eval_paths = self.eval_data_collector.get_epoch_paths()
146 | if hasattr(self.eval_env, 'get_diagnostics'):
147 | self._cur_eval_log.update(
148 | self.eval_env.get_diagnostics(eval_paths),
149 | )
150 | self._cur_eval_log.update(
151 | eval_util.get_generic_path_information(eval_paths),
152 | )
153 | logger.record_dict(self._cur_eval_log, prefix='eval/')
154 |
155 | """
156 | Misc
157 | """
158 | try:
159 | import os
160 | import psutil
161 | process = psutil.Process(os.getpid())
162 | k = 'process/RAM Usage (Mb)'
163 | v = int(process.memory_info().rss / 1000000)
164 | logger.record_tabular(k, v)
165 | logger.record_tabular('process/Num Threads', process.num_threads())
166 | except ImportError:
167 | pass
168 |
169 | gt.stamp('logging')
170 | logger.record_dict(_get_epoch_timings())
171 | logger.record_tabular('Dummy', 0)
172 | logger.record_tabular('Epoch', epoch)
173 | logger.dump_tabular(with_prefix=False, with_timestamp=False)
174 |
175 | @abc.abstractmethod
176 | def training_mode(self, mode):
177 | """
178 | Set training mode to `mode`.
179 | :param mode: If True, training will happen (e.g. set the dropout
180 | probabilities to not all ones).
181 | """
182 | pass
183 |
--------------------------------------------------------------------------------
/maple/core/serializable.py:
--------------------------------------------------------------------------------
1 | """
2 | Based on rllab's serializable.py file
3 |
4 | https://github.com/rll/rllab
5 | """
6 |
7 | import inspect
8 | import sys
9 |
10 |
11 | class Serializable(object):
12 |
13 | def __init__(self, *args, **kwargs):
14 | self.__args = args
15 | self.__kwargs = kwargs
16 |
17 | def quick_init(self, locals_):
18 | if getattr(self, "_serializable_initialized", False):
19 | return
20 | if sys.version_info >= (3, 0):
21 | spec = inspect.getfullargspec(self.__init__)
22 | # Exclude the first "self" parameter
23 | if spec.varkw:
24 | kwargs = locals_[spec.varkw].copy()
25 | else:
26 | kwargs = dict()
27 | if spec.kwonlyargs:
28 | for key in spec.kwonlyargs:
29 | kwargs[key] = locals_[key]
30 | else:
31 | spec = inspect.getargspec(self.__init__)
32 | if spec.keywords:
33 | kwargs = locals_[spec.keywords]
34 | else:
35 | kwargs = dict()
36 | if spec.varargs:
37 | varargs = locals_[spec.varargs]
38 | else:
39 | varargs = tuple()
40 | in_order_args = [locals_[arg] for arg in spec.args][1:]
41 | self.__args = tuple(in_order_args) + varargs
42 | self.__kwargs = kwargs
43 | setattr(self, "_serializable_initialized", True)
44 |
45 | def __getstate__(self):
46 | return {"__args": self.__args, "__kwargs": self.__kwargs}
47 |
48 | def __setstate__(self, d):
49 | # convert all __args to keyword-based arguments
50 | if sys.version_info >= (3, 0):
51 | spec = inspect.getfullargspec(self.__init__)
52 | else:
53 | spec = inspect.getargspec(self.__init__)
54 | in_order_args = spec.args[1:]
55 | out = type(self)(**dict(zip(in_order_args, d["__args"]), **d["__kwargs"]))
56 | self.__dict__.update(out.__dict__)
57 |
58 | @classmethod
59 | def clone(cls, obj, **kwargs):
60 | assert isinstance(obj, Serializable)
61 | d = obj.__getstate__()
62 | d["__kwargs"] = dict(d["__kwargs"], **kwargs)
63 | out = type(obj).__new__(type(obj))
64 | out.__setstate__(d)
65 | return out
66 |
--------------------------------------------------------------------------------
/maple/core/trainer.py:
--------------------------------------------------------------------------------
1 | import abc
2 |
3 |
4 | class Trainer(object, metaclass=abc.ABCMeta):
5 | @abc.abstractmethod
6 | def train(self, data):
7 | pass
8 |
9 | def end_epoch(self, epoch):
10 | pass
11 |
12 | def get_snapshot(self):
13 | return {}
14 |
15 | def get_diagnostics(self):
16 | return {}
17 |
--------------------------------------------------------------------------------
/maple/data_management/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/UT-Austin-RPL/maple/c13f584d4184e5796bd0adb822b8761ff4ebeee0/maple/data_management/__init__.py
--------------------------------------------------------------------------------
/maple/data_management/env_replay_buffer.py:
--------------------------------------------------------------------------------
1 | from gym.spaces import Discrete
2 |
3 | from maple.data_management.simple_replay_buffer import SimpleReplayBuffer
4 | from maple.envs.env_utils import get_dim
5 | import numpy as np
6 |
7 |
8 | class EnvReplayBuffer(SimpleReplayBuffer):
9 | def __init__(
10 | self,
11 | max_replay_buffer_size,
12 | env,
13 | env_info_sizes=None
14 | ):
15 | """
16 | :param max_replay_buffer_size:
17 | :param env:
18 | """
19 | self.env = env
20 | self._ob_space = env.observation_space
21 | self._action_space = env.action_space
22 |
23 | if env_info_sizes is None:
24 | if hasattr(env, 'info_sizes'):
25 | env_info_sizes = env.info_sizes
26 | else:
27 | env_info_sizes = dict()
28 |
29 | super().__init__(
30 | max_replay_buffer_size=max_replay_buffer_size,
31 | observation_dim=get_dim(self._ob_space),
32 | action_dim=get_dim(self._action_space),
33 | env_info_sizes=env_info_sizes
34 | )
35 |
36 | def add_sample(self, observation, action, reward, terminal,
37 | next_observation, **kwargs):
38 | if isinstance(self._action_space, Discrete):
39 | new_action = np.zeros(self._action_dim)
40 | new_action[action] = 1
41 | else:
42 | new_action = action
43 | return super().add_sample(
44 | observation=observation,
45 | action=new_action,
46 | reward=reward,
47 | next_observation=next_observation,
48 | terminal=terminal,
49 | **kwargs
50 | )
51 |
--------------------------------------------------------------------------------
/maple/data_management/normalizer.py:
--------------------------------------------------------------------------------
1 | """
2 | Based on code from Marcin Andrychowicz
3 | """
4 | import numpy as np
5 |
6 |
7 | class Normalizer(object):
8 | def __init__(
9 | self,
10 | size,
11 | eps=1e-8,
12 | default_clip_range=np.inf,
13 | mean=0,
14 | std=1,
15 | ):
16 | self.size = size
17 | self.eps = eps
18 | self.default_clip_range = default_clip_range
19 | self.sum = np.zeros(self.size, np.float32)
20 | self.sumsq = np.zeros(self.size, np.float32)
21 | self.count = np.ones(1, np.float32)
22 | self.mean = mean + np.zeros(self.size, np.float32)
23 | self.std = std * np.ones(self.size, np.float32)
24 | self.synchronized = True
25 |
26 | def update(self, v):
27 | if v.ndim == 1:
28 | v = np.expand_dims(v, 0)
29 | assert v.ndim == 2
30 | assert v.shape[1] == self.size
31 | self.sum += v.sum(axis=0)
32 | self.sumsq += (np.square(v)).sum(axis=0)
33 | self.count[0] += v.shape[0]
34 | self.synchronized = False
35 |
36 | def normalize(self, v, clip_range=None):
37 | if not self.synchronized:
38 | self.synchronize()
39 | if clip_range is None:
40 | clip_range = self.default_clip_range
41 | mean, std = self.mean, self.std
42 | if v.ndim == 2:
43 | mean = mean.reshape(1, -1)
44 | std = std.reshape(1, -1)
45 | return np.clip((v - mean) / std, -clip_range, clip_range)
46 |
47 | def denormalize(self, v):
48 | if not self.synchronized:
49 | self.synchronize()
50 | mean, std = self.mean, self.std
51 | if v.ndim == 2:
52 | mean = mean.reshape(1, -1)
53 | std = std.reshape(1, -1)
54 | return mean + v * std
55 |
56 | def synchronize(self):
57 | self.mean[...] = self.sum / self.count[0]
58 | self.std[...] = np.sqrt(
59 | np.maximum(
60 | np.square(self.eps),
61 | self.sumsq / self.count[0] - np.square(self.mean)
62 | )
63 | )
64 | self.synchronized = True
65 |
66 |
67 | class IdentityNormalizer(object):
68 | def __init__(self, *args, **kwargs):
69 | pass
70 |
71 | def update(self, v):
72 | pass
73 |
74 | def normalize(self, v, clip_range=None):
75 | return v
76 |
77 | def denormalize(self, v):
78 | return v
79 |
80 |
81 | class FixedNormalizer(object):
82 | def __init__(
83 | self,
84 | size,
85 | default_clip_range=np.inf,
86 | mean=0,
87 | std=1,
88 | eps=1e-8,
89 | ):
90 | assert std > 0
91 | std = std + eps
92 | self.size = size
93 | self.default_clip_range = default_clip_range
94 | self.mean = mean + np.zeros(self.size, np.float32)
95 | self.std = std + np.zeros(self.size, np.float32)
96 | self.eps = eps
97 |
98 | def set_mean(self, mean):
99 | self.mean = mean + np.zeros(self.size, np.float32)
100 |
101 | def set_std(self, std):
102 | std = std + self.eps
103 | self.std = std + np.zeros(self.size, np.float32)
104 |
105 | def normalize(self, v, clip_range=None):
106 | if clip_range is None:
107 | clip_range = self.default_clip_range
108 | mean, std = self.mean, self.std
109 | if v.ndim == 2:
110 | mean = mean.reshape(1, -1)
111 | std = std.reshape(1, -1)
112 | return np.clip((v - mean) / std, -clip_range, clip_range)
113 |
114 | def denormalize(self, v):
115 | mean, std = self.mean, self.std
116 | if v.ndim == 2:
117 | mean = mean.reshape(1, -1)
118 | std = std.reshape(1, -1)
119 | return mean + v * std
120 |
121 | def copy_stats(self, other):
122 | self.set_mean(other.mean)
123 | self.set_std(other.std)
124 |
--------------------------------------------------------------------------------
/maple/data_management/path_builder.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 |
4 | class PathBuilder(dict):
5 | """
6 | Usage:
7 | ```
8 | path_builder = PathBuilder()
9 | path.add_sample(
10 | observations=1,
11 | actions=2,
12 | next_observations=3,
13 | ...
14 | )
15 | path.add_sample(
16 | observations=4,
17 | actions=5,
18 | next_observations=6,
19 | ...
20 | )
21 |
22 | path = path_builder.get_all_stacked()
23 |
24 | path['observations']
25 | # output: [1, 4]
26 | path['actions']
27 | # output: [2, 5]
28 | ```
29 |
30 | Note that the key should be "actions" and not "action" since the
31 | resulting dictionary will have those keys.
32 | """
33 |
34 | def __init__(self):
35 | super().__init__()
36 | self._path_length = 0
37 |
38 | def add_all(self, **key_to_value):
39 | for k, v in key_to_value.items():
40 | if k not in self:
41 | self[k] = [v]
42 | else:
43 | self[k].append(v)
44 | self._path_length += 1
45 |
46 | def get_all_stacked(self):
47 | output_dict = dict()
48 | for k, v in self.items():
49 | output_dict[k] = stack_list(v)
50 | return output_dict
51 |
52 | def __len__(self):
53 | return self._path_length
54 |
55 |
56 | def stack_list(lst):
57 | if isinstance(lst[0], dict):
58 | return lst
59 | else:
60 | return np.array(lst)
61 |
--------------------------------------------------------------------------------
/maple/data_management/replay_buffer.py:
--------------------------------------------------------------------------------
1 | import abc
2 |
3 |
4 | class ReplayBuffer(object, metaclass=abc.ABCMeta):
5 | """
6 | A class used to save and replay data.
7 | """
8 |
9 | @abc.abstractmethod
10 | def add_sample(self, observation, action, reward, next_observation,
11 | terminal, **kwargs):
12 | """
13 | Add a transition tuple.
14 | """
15 | pass
16 |
17 | @abc.abstractmethod
18 | def terminate_episode(self):
19 | """
20 | Let the replay buffer know that the episode has terminated in case some
21 | special book-keeping has to happen.
22 | :return:
23 | """
24 | pass
25 |
26 | @abc.abstractmethod
27 | def num_steps_can_sample(self, **kwargs):
28 | """
29 | :return: # of unique items that can be sampled.
30 | """
31 | pass
32 |
33 | def add_path(self, path):
34 | """
35 | Add a path to the replay buffer.
36 |
37 | This default implementation naively goes through every step, but you
38 | may want to optimize this.
39 |
40 | NOTE: You should NOT call "terminate_episode" after calling add_path.
41 | It's assumed that this function handles the episode termination.
42 |
43 | :param path: Dict like one outputted by maple.samplers.util.rollout
44 | """
45 | for i, (
46 | obs,
47 | action,
48 | reward,
49 | next_obs,
50 | terminal,
51 | agent_info,
52 | env_info
53 | ) in enumerate(zip(
54 | path["observations"],
55 | path["actions"],
56 | path["rewards"],
57 | path["next_observations"],
58 | path["terminals"],
59 | path["agent_infos"],
60 | path["env_infos"],
61 | )):
62 | self.add_sample(
63 | observation=obs,
64 | action=action,
65 | reward=reward,
66 | next_observation=next_obs,
67 | terminal=terminal,
68 | agent_info=agent_info,
69 | env_info=env_info,
70 | )
71 | self.terminate_episode()
72 |
73 | def add_paths(self, paths):
74 | for path in paths:
75 | self.add_path(path)
76 |
77 | @abc.abstractmethod
78 | def random_batch(self, batch_size):
79 | """
80 | Return a batch of size `batch_size`.
81 | :param batch_size:
82 | :return:
83 | """
84 | pass
85 |
86 | def get_diagnostics(self):
87 | return {}
88 |
89 | def get_snapshot(self):
90 | return {}
91 |
92 | def end_epoch(self, epoch):
93 | return
94 |
95 |
--------------------------------------------------------------------------------
/maple/data_management/shared_obs_dict_replay_buffer.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | from maple.data_management.obs_dict_replay_buffer import ObsDictRelabelingBuffer
4 |
5 | import torch.multiprocessing as mp
6 | import ctypes
7 |
8 |
9 | class SharedObsDictRelabelingBuffer(ObsDictRelabelingBuffer):
10 | """
11 | Same as an ObsDictRelabelingBuffer but the obs and next_obs are backed
12 | by multiprocessing arrays. The replay buffer size is also shared. The
13 | intended use case is for if one wants obs/next_obs to be shared between
14 | processes. Accesses are synchronized internally by locks (mp takes care
15 | of that). Technically, putting such large arrays in shared memory/requiring
16 | synchronized access can be extremely slow, but it seems ok empirically.
17 |
18 | This code also breaks a lot of functionality for the subprocess. For example,
19 | random_batch is incorrect as actions and _idx_to_future_obs_idx are not
20 | shared. If the subprocess needs all of the functionality, a mp.Array
21 | must be used for all numpy arrays in the replay buffer.
22 |
23 | """
24 |
25 | def __init__(
26 | self,
27 | *args,
28 | **kwargs
29 | ):
30 | self._shared_size = mp.Value(ctypes.c_long, 0)
31 | ObsDictRelabelingBuffer.__init__(self, *args, **kwargs)
32 |
33 | self._mp_array_info = {}
34 | self._shared_obs_info = {}
35 | self._shared_next_obs_info = {}
36 |
37 | for obs_key, obs_arr in self._obs.items():
38 | ctype = ctypes.c_double
39 | if obs_arr.dtype == np.uint8:
40 | ctype = ctypes.c_uint8
41 |
42 | self._shared_obs_info[obs_key] = (
43 | mp.Array(ctype, obs_arr.size),
44 | obs_arr.dtype,
45 | obs_arr.shape,
46 | )
47 | self._shared_next_obs_info[obs_key] = (
48 | mp.Array(ctype, obs_arr.size),
49 | obs_arr.dtype,
50 | obs_arr.shape,
51 | )
52 |
53 | self._obs[obs_key] = to_np(*self._shared_obs_info[obs_key])
54 | self._next_obs[obs_key] = to_np(
55 | *self._shared_next_obs_info[obs_key])
56 | self._register_mp_array("_actions")
57 | self._register_mp_array("_terminals")
58 |
59 | def _register_mp_array(self, arr_instance_var_name):
60 | """
61 | Use this function to register an array to be shared. This will wipe arr.
62 | """
63 | assert hasattr(self, arr_instance_var_name), arr_instance_var_name
64 | arr = getattr(self, arr_instance_var_name)
65 |
66 | ctype = ctypes.c_double
67 | if arr.dtype == np.uint8:
68 | ctype = ctypes.c_uint8
69 |
70 | self._mp_array_info[arr_instance_var_name] = (
71 | mp.Array(ctype, arr.size), arr.dtype, arr.shape,
72 | )
73 | setattr(
74 | self,
75 | arr_instance_var_name,
76 | to_np(*self._mp_array_info[arr_instance_var_name])
77 | )
78 |
79 | def init_from_mp_info(
80 | self,
81 | mp_info,
82 | ):
83 | """
84 | The intended use is to have a subprocess serialize/copy a
85 | SharedObsDictRelabelingBuffer instance and call init_from on the
86 | instance's shared variables. This can't be done during serialization
87 | since multiprocessing shared objects can't be serialized and must be
88 | passed directly to the subprocess as an argument to the fork call.
89 | """
90 | shared_obs_info, shared_next_obs_info, mp_array_info, shared_size = mp_info
91 |
92 | self._shared_obs_info = shared_obs_info
93 | self._shared_next_obs_info = shared_next_obs_info
94 | self._mp_array_info = mp_array_info
95 | for obs_key in self._shared_obs_info.keys():
96 | self._obs[obs_key] = to_np(*self._shared_obs_info[obs_key])
97 | self._next_obs[obs_key] = to_np(
98 | *self._shared_next_obs_info[obs_key])
99 |
100 | for arr_instance_var_name in self._mp_array_info.keys():
101 | setattr(
102 | self,
103 | arr_instance_var_name,
104 | to_np(*self._mp_array_info[arr_instance_var_name])
105 | )
106 | self._shared_size = shared_size
107 |
108 | def get_mp_info(self):
109 | return (
110 | self._shared_obs_info,
111 | self._shared_next_obs_info,
112 | self._mp_array_info,
113 | self._shared_size,
114 | )
115 |
116 | @property
117 | def _size(self):
118 | return self._shared_size.value
119 |
120 | @_size.setter
121 | def _size(self, size):
122 | self._shared_size.value = size
123 |
124 |
125 | def to_np(shared_arr, np_dtype, shape):
126 | return np.frombuffer(shared_arr.get_obj(), dtype=np_dtype).reshape(shape)
127 |
--------------------------------------------------------------------------------
/maple/data_management/simple_replay_buffer.py:
--------------------------------------------------------------------------------
1 | from collections import OrderedDict
2 |
3 | import numpy as np
4 | import warnings
5 |
6 | from maple.data_management.replay_buffer import ReplayBuffer
7 |
8 |
9 | class SimpleReplayBuffer(ReplayBuffer):
10 |
11 | def __init__(
12 | self,
13 | max_replay_buffer_size,
14 | observation_dim,
15 | action_dim,
16 | env_info_sizes,
17 | replace = True,
18 | ):
19 | self._observation_dim = observation_dim
20 | self._action_dim = action_dim
21 | self._max_replay_buffer_size = max_replay_buffer_size
22 | self._observations = np.zeros((max_replay_buffer_size, observation_dim))
23 | # It's a bit memory inefficient to save the observations twice,
24 | # but it makes the code *much* easier since you no longer have to
25 | # worry about termination conditions.
26 | self._next_obs = np.zeros((max_replay_buffer_size, observation_dim))
27 | self._actions = np.zeros((max_replay_buffer_size, action_dim))
28 | # Make everything a 2D np array to make it easier for other code to
29 | # reason about the shape of the data
30 | self._rewards = np.zeros((max_replay_buffer_size, 1))
31 | # self._terminals[i] = a terminal was received at time i
32 | self._terminals = np.zeros((max_replay_buffer_size, 1), dtype='uint8')
33 | # Define self._env_infos[key][i] to be the return value of env_info[key]
34 | # at time i
35 | self._env_infos = {}
36 | for key, size in env_info_sizes.items():
37 | self._env_infos[key] = np.zeros((max_replay_buffer_size, size))
38 | self._env_info_keys = env_info_sizes.keys()
39 |
40 | self._replace = replace
41 |
42 | self._top = 0
43 | self._size = 0
44 |
45 | def add_sample(self, observation, action, reward, next_observation,
46 | terminal, env_info, **kwargs):
47 | self._observations[self._top] = observation
48 | self._actions[self._top] = action
49 | self._rewards[self._top] = reward
50 | self._terminals[self._top] = terminal
51 | self._next_obs[self._top] = next_observation
52 |
53 | for key in self._env_info_keys:
54 | self._env_infos[key][self._top] = env_info[key]
55 | self._advance()
56 |
57 | def terminate_episode(self):
58 | pass
59 |
60 | def _advance(self):
61 | self._top = (self._top + 1) % self._max_replay_buffer_size
62 | if self._size < self._max_replay_buffer_size:
63 | self._size += 1
64 |
65 | def random_batch(self, batch_size):
66 | indices = np.random.choice(self._size, size=batch_size, replace=self._replace or self._size < batch_size)
67 | if not self._replace and self._size < batch_size:
68 | warnings.warn('Replace was set to false, but is temporarily set to true because batch size is larger than current size of replay.')
69 | batch = dict(
70 | observations=self._observations[indices],
71 | actions=self._actions[indices],
72 | rewards=self._rewards[indices],
73 | terminals=self._terminals[indices],
74 | next_observations=self._next_obs[indices],
75 | )
76 | for key in self._env_info_keys:
77 | assert key not in batch.keys()
78 | batch[key] = self._env_infos[key][indices]
79 | return batch
80 |
81 | def rebuild_env_info_dict(self, idx):
82 | return {
83 | key: self._env_infos[key][idx]
84 | for key in self._env_info_keys
85 | }
86 |
87 | def batch_env_info_dict(self, indices):
88 | return {
89 | key: self._env_infos[key][indices]
90 | for key in self._env_info_keys
91 | }
92 |
93 | def num_steps_can_sample(self):
94 | return self._size
95 |
96 | def get_diagnostics(self):
97 | return OrderedDict([
98 | ('size', self._size)
99 | ])
100 |
--------------------------------------------------------------------------------
/maple/data_management/split_buffer.py:
--------------------------------------------------------------------------------
1 | import random
2 |
3 | from maple.data_management.replay_buffer import ReplayBuffer
4 |
5 |
6 | class SplitReplayBuffer(ReplayBuffer):
7 | """
8 | Split the data into a training and validation set.
9 | """
10 | def __init__(
11 | self,
12 | train_replay_buffer: ReplayBuffer,
13 | validation_replay_buffer: ReplayBuffer,
14 | fraction_paths_in_train,
15 | ):
16 | self.train_replay_buffer = train_replay_buffer
17 | self.validation_replay_buffer = validation_replay_buffer
18 | self.fraction_paths_in_train = fraction_paths_in_train
19 | self.replay_buffer = self.train_replay_buffer
20 |
21 | def add_sample(self, *args, **kwargs):
22 | self.replay_buffer.add_sample(*args, **kwargs)
23 |
24 | def add_path(self, path):
25 | self.replay_buffer.add_path(path)
26 | self._randomly_set_replay_buffer()
27 |
28 | def num_steps_can_sample(self):
29 | return min(
30 | self.train_replay_buffer.num_steps_can_sample(),
31 | self.validation_replay_buffer.num_steps_can_sample(),
32 | )
33 |
34 | def terminate_episode(self, *args, **kwargs):
35 | self.replay_buffer.terminate_episode(*args, **kwargs)
36 | self._randomly_set_replay_buffer()
37 |
38 | def _randomly_set_replay_buffer(self):
39 | if random.random() <= self.fraction_paths_in_train:
40 | self.replay_buffer = self.train_replay_buffer
41 | else:
42 | self.replay_buffer = self.validation_replay_buffer
43 |
44 | def get_replay_buffer(self, training=True):
45 | if training:
46 | return self.train_replay_buffer
47 | else:
48 | return self.validation_replay_buffer
49 |
50 | def random_batch(self, batch_size):
51 | return self.train_replay_buffer.random_batch(batch_size)
52 |
53 | def __getattr__(self, attrname):
54 | return getattr(self.replay_buffer, attrname)
55 |
56 | def __getstate__(self):
57 | # Do not save self.replay_buffer since it's a duplicate and seems to
58 | # cause joblib recursion issues.
59 | return dict(
60 | train_replay_buffer=self.train_replay_buffer,
61 | validation_replay_buffer=self.validation_replay_buffer,
62 | fraction_paths_in_train=self.fraction_paths_in_train,
63 | )
64 |
65 | def __setstate__(self, d):
66 | self.train_replay_buffer = d['train_replay_buffer']
67 | self.validation_replay_buffer = d['validation_replay_buffer']
68 | self.fraction_paths_in_train = d['fraction_paths_in_train']
69 | self.replay_buffer = self.train_replay_buffer
70 |
--------------------------------------------------------------------------------
/maple/envs/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/UT-Austin-RPL/maple/c13f584d4184e5796bd0adb822b8761ff4ebeee0/maple/envs/__init__.py
--------------------------------------------------------------------------------
/maple/envs/env_utils.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | from gym.spaces import Box, Discrete, Tuple
4 |
5 | ENV_ASSET_DIR = os.path.join(os.path.dirname(__file__), 'assets')
6 |
7 |
8 | def get_asset_full_path(file_name):
9 | return os.path.join(ENV_ASSET_DIR, file_name)
10 |
11 |
12 | def get_dim(space):
13 | if isinstance(space, Box):
14 | return space.low.size
15 | elif isinstance(space, Discrete):
16 | return space.n
17 | elif isinstance(space, Tuple):
18 | return sum(get_dim(subspace) for subspace in space.spaces)
19 | elif hasattr(space, 'flat_dim'):
20 | return space.flat_dim
21 | else:
22 | raise TypeError("Unknown space: {}".format(space))
23 |
24 |
25 | def mode(env, mode_type):
26 | try:
27 | getattr(env, mode_type)()
28 | except AttributeError:
29 | pass
30 |
--------------------------------------------------------------------------------
/maple/envs/make_env.py:
--------------------------------------------------------------------------------
1 | """
2 | This file provides a more uniform interface to gym.make(env_id) that handles
3 | imports and normalization
4 | """
5 |
6 | import gym
7 |
8 | from maple.envs.wrappers import NormalizedBoxEnv
9 |
10 | DAPG_ENVS = [
11 | 'pen-v0', 'pen-sparse-v0', 'pen-notermination-v0', 'pen-binary-v0', 'pen-binary-old-v0',
12 | 'door-v0', 'door-sparse-v0', 'door-binary-v0', 'door-binary-old-v0',
13 | 'relocate-v0', 'relocate-sparse-v0', 'relocate-binary-v0', 'relocate-binary-old-v0',
14 | 'hammer-v0', 'hammer-sparse-v0', 'hammer-binary-v0',
15 | ]
16 |
17 | D4RL_ENVS = [
18 | "maze2d-open-v0", "maze2d-umaze-v0", "maze2d-medium-v0", "maze2d-large-v0",
19 | "maze2d-open-dense-v0", "maze2d-umaze-dense-v0", "maze2d-medium-dense-v0", "maze2d-large-dense-v0",
20 | "antmaze-umaze-v0", "antmaze-umaze-diverse-v0", "antmaze-medium-diverse-v0",
21 | "antmaze-medium-play-v0", "antmaze-large-diverse-v0", "antmaze-large-play-v0",
22 | "pen-human-v0", "pen-cloned-v0", "pen-expert-v0", "hammer-human-v0", "hammer-cloned-v0", "hammer-expert-v0",
23 | "door-human-v0", "door-cloned-v0", "door-expert-v0", "relocate-human-v0", "relocate-cloned-v0", "relocate-expert-v0",
24 | "halfcheetah-random-v0", "halfcheetah-medium-v0", "halfcheetah-expert-v0", "halfcheetah-mixed-v0", "halfcheetah-medium-expert-v0",
25 | "walker2d-random-v0", "walker2d-medium-v0", "walker2d-expert-v0", "walker2d-mixed-v0", "walker2d-medium-expert-v0",
26 | "hopper-random-v0", "hopper-medium-v0", "hopper-expert-v0", "hopper-mixed-v0", "hopper-medium-expert-v0"
27 | ]
28 |
29 | def make(env_id=None, env_class=None, env_kwargs=None, normalize_env=True):
30 | assert env_id or env_class
31 | if env_class:
32 | env = env_class(**env_kwargs)
33 | elif env_id in DAPG_ENVS:
34 | import mj_envs
35 | assert normalize_env == False
36 | env = gym.make(env_id)
37 | elif env_id in D4RL_ENVS:
38 | import d4rl
39 | assert normalize_env == False
40 | env = gym.make(env_id)
41 | elif env_id:
42 | env = gym.make(env_id)
43 | env = env.env # unwrap TimeLimit
44 |
45 | if normalize_env:
46 | env = NormalizedBoxEnv(env)
47 |
48 | return env
49 |
--------------------------------------------------------------------------------
/maple/envs/mujoco_env.py:
--------------------------------------------------------------------------------
1 | import os
2 | from os import path
3 |
4 | import mujoco_py
5 | import numpy as np
6 | from gym.envs.mujoco import mujoco_env
7 |
8 | from maple.core.serializable import Serializable
9 |
10 | ENV_ASSET_DIR = os.path.join(os.path.dirname(__file__), 'assets')
11 |
12 |
13 | class MujocoEnv(mujoco_env.MujocoEnv, Serializable):
14 | """
15 | My own wrapper around MujocoEnv.
16 |
17 | The caller needs to declare
18 | """
19 | def __init__(
20 | self,
21 | model_path,
22 | frame_skip=1,
23 | model_path_is_local=True,
24 | automatically_set_obs_and_action_space=False,
25 | ):
26 | if model_path_is_local:
27 | model_path = get_asset_xml(model_path)
28 | if automatically_set_obs_and_action_space:
29 | mujoco_env.MujocoEnv.__init__(self, model_path, frame_skip)
30 | else:
31 | """
32 | Code below is copy/pasted from MujocoEnv's __init__ function.
33 | """
34 | if model_path.startswith("/"):
35 | fullpath = model_path
36 | else:
37 | fullpath = os.path.join(os.path.dirname(__file__), "assets", model_path)
38 | if not path.exists(fullpath):
39 | raise IOError("File %s does not exist" % fullpath)
40 | self.frame_skip = frame_skip
41 | self.model = mujoco_py.MjModel(fullpath)
42 | self.data = self.model.data
43 | self.viewer = None
44 |
45 | self.metadata = {
46 | 'render.modes': ['human', 'rgb_array'],
47 | 'video.frames_per_second': int(np.round(1.0 / self.dt))
48 | }
49 |
50 | self.init_qpos = self.model.data.qpos.ravel().copy()
51 | self.init_qvel = self.model.data.qvel.ravel().copy()
52 | self._seed()
53 |
54 | def init_serialization(self, locals):
55 | Serializable.quick_init(self, locals)
56 |
57 | def log_diagnostics(self, paths):
58 | pass
59 |
60 |
61 | def get_asset_xml(xml_name):
62 | return os.path.join(ENV_ASSET_DIR, xml_name)
63 |
--------------------------------------------------------------------------------
/maple/envs/mujoco_image_env.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import numpy as np
3 | import torch
4 | from PIL import Image
5 | from collections.__init__ import deque
6 |
7 | from gym import Env
8 | from gym.spaces import Box
9 |
10 | from maple.envs.wrappers import ProxyEnv
11 |
12 |
13 | class ImageMujocoEnv(ProxyEnv, Env):
14 | def __init__(self,
15 | wrapped_env,
16 | imsize=32,
17 | keep_prev=0,
18 | init_camera=None,
19 | camera_name=None,
20 | transpose=False,
21 | grayscale=False,
22 | normalize=False,
23 | ):
24 | import mujoco_py
25 | super().__init__(wrapped_env)
26 |
27 | self.imsize = imsize
28 | if grayscale:
29 | self.image_length = self.imsize * self.imsize
30 | else:
31 | self.image_length = 3 * self.imsize * self.imsize
32 | # This is torch format rather than PIL image
33 | self.image_shape = (self.imsize, self.imsize)
34 | # Flattened past image queue
35 | self.history_length = keep_prev + 1
36 | self.history = deque(maxlen=self.history_length)
37 | # init camera
38 | if init_camera is not None:
39 | sim = self._wrapped_env.sim
40 | viewer = mujoco_py.MjRenderContextOffscreen(sim, device_id=-1)
41 | init_camera(viewer.cam)
42 | sim.add_render_context(viewer)
43 | self.camera_name = camera_name # None means default camera
44 | self.transpose = transpose
45 | self.grayscale = grayscale
46 | self.normalize = normalize
47 | self._render_local = False
48 |
49 | self.observation_space = Box(low=0.0,
50 | high=1.0,
51 | shape=(
52 | self.image_length * self.history_length,))
53 |
54 | def step(self, action):
55 | # image observation get returned as a flattened 1D array
56 | true_state, reward, done, info = super().step(action)
57 |
58 | observation = self._image_observation()
59 | self.history.append(observation)
60 | history = self._get_history().flatten()
61 | full_obs = self._get_obs(history, true_state)
62 | return full_obs, reward, done, info
63 |
64 | def reset(self, **kwargs):
65 | true_state = super().reset(**kwargs)
66 | self.history = deque(maxlen=self.history_length)
67 |
68 | observation = self._image_observation()
69 | self.history.append(observation)
70 | history = self._get_history().flatten()
71 | full_obs = self._get_obs(history, true_state)
72 | return full_obs
73 |
74 | def get_image(self):
75 | """TODO: this should probably consider history"""
76 | return self._image_observation()
77 |
78 | def _get_obs(self, history_flat, true_state):
79 | # adds extra information from true_state into to the image observation.
80 | # Used in ImageWithObsEnv.
81 | return history_flat
82 |
83 | def _image_observation(self):
84 | # returns the image as a torch format np array
85 | image_obs = self._wrapped_env.sim.render(width=self.imsize,
86 | height=self.imsize,
87 | camera_name=self.camera_name)
88 | if self._render_local:
89 | cv2.imshow('env', image_obs)
90 | cv2.waitKey(1)
91 | if self.grayscale:
92 | image_obs = Image.fromarray(image_obs).convert('L')
93 | image_obs = np.array(image_obs)
94 | if self.normalize:
95 | image_obs = image_obs / 255.0
96 | if self.transpose:
97 | image_obs = image_obs.transpose()
98 | return image_obs
99 |
100 | def _get_history(self):
101 | observations = list(self.history)
102 |
103 | obs_count = len(observations)
104 | for _ in range(self.history_length - obs_count):
105 | dummy = np.zeros(self.image_shape)
106 | observations.append(dummy)
107 | return np.c_[observations]
108 |
109 | def retrieve_images(self):
110 | # returns images in unflattened PIL format
111 | images = []
112 | for image_obs in self.history:
113 | pil_image = self.torch_to_pil(torch.Tensor(image_obs))
114 | images.append(pil_image)
115 | return images
116 |
117 | def split_obs(self, obs):
118 | # splits observation into image input and true observation input
119 | imlength = self.image_length * self.history_length
120 | obs_length = self.observation_space.low.size
121 | obs = obs.view(-1, obs_length)
122 | image_obs = obs.narrow(start=0,
123 | length=imlength,
124 | dimension=1)
125 | if obs_length == imlength:
126 | return image_obs, None
127 |
128 | fc_obs = obs.narrow(start=imlength,
129 | length=obs.shape[1] - imlength,
130 | dimension=1)
131 | return image_obs, fc_obs
132 |
133 | def enable_render(self):
134 | self._render_local = True
135 |
136 |
137 | class ImageMujocoWithObsEnv(ImageMujocoEnv):
138 | def __init__(self, env, **kwargs):
139 | super().__init__(env, **kwargs)
140 | self.observation_space = Box(low=0.0,
141 | high=1.0,
142 | shape=(
143 | self.image_length * self.history_length +
144 | self.wrapped_env.obs_dim,))
145 |
146 | def _get_obs(self, history_flat, true_state):
147 | return np.concatenate([history_flat,
148 | true_state])
--------------------------------------------------------------------------------
/maple/envs/proxy_env.py:
--------------------------------------------------------------------------------
1 | from gym import Env
2 |
3 |
4 | class ProxyEnv(Env):
5 | def __init__(self, wrapped_env):
6 | self._wrapped_env = wrapped_env
7 | self.action_space = self._wrapped_env.action_space
8 | self.observation_space = self._wrapped_env.observation_space
9 |
10 | @property
11 | def wrapped_env(self):
12 | return self._wrapped_env
13 |
14 | def reset(self, **kwargs):
15 | return self._wrapped_env.reset(**kwargs)
16 |
17 | def step(self, action):
18 | return self._wrapped_env.step(action)
19 |
20 | def render(self, *args, **kwargs):
21 | return self._wrapped_env.render(*args, **kwargs)
22 |
23 | @property
24 | def horizon(self):
25 | return self._wrapped_env.horizon
26 |
27 | def terminate(self):
28 | if hasattr(self.wrapped_env, "terminate"):
29 | self.wrapped_env.terminate()
30 |
31 | def __getattr__(self, attr):
32 | if attr == '_wrapped_env':
33 | raise AttributeError()
34 | return getattr(self._wrapped_env, attr)
35 |
36 | def __getstate__(self):
37 | """
38 | This is useful to override in case the wrapped env has some funky
39 | __getstate__ that doesn't play well with overriding __getattr__.
40 |
41 | The main problematic case is/was gym's EzPickle serialization scheme.
42 | :return:
43 | """
44 | return self.__dict__
45 |
46 | def __setstate__(self, state):
47 | self.__dict__.update(state)
48 |
49 | def __str__(self):
50 | return '{}({})'.format(type(self).__name__, self.wrapped_env)
--------------------------------------------------------------------------------
/maple/envs/wrappers.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import itertools
3 | from gym import Env
4 | from gym.spaces import Box
5 | from gym.spaces import Discrete
6 |
7 | from collections import deque
8 |
9 |
10 | class ProxyEnv(Env):
11 | def __init__(self, wrapped_env):
12 | self._wrapped_env = wrapped_env
13 | self.action_space = self._wrapped_env.action_space
14 | self.observation_space = self._wrapped_env.observation_space
15 |
16 | @property
17 | def wrapped_env(self):
18 | return self._wrapped_env
19 |
20 | def reset(self, **kwargs):
21 | return self._wrapped_env.reset(**kwargs)
22 |
23 | def step(self, action):
24 | return self._wrapped_env.step(action)
25 |
26 | def render(self, *args, **kwargs):
27 | return self._wrapped_env.render(*args, **kwargs)
28 |
29 | @property
30 | def horizon(self):
31 | return self._wrapped_env.horizon
32 |
33 | def terminate(self):
34 | if hasattr(self.wrapped_env, "terminate"):
35 | self.wrapped_env.terminate()
36 |
37 | def __getattr__(self, attr):
38 | if attr == '_wrapped_env':
39 | raise AttributeError()
40 | return getattr(self._wrapped_env, attr)
41 |
42 | def __getstate__(self):
43 | """
44 | This is useful to override in case the wrapped env has some funky
45 | __getstate__ that doesn't play well with overriding __getattr__.
46 |
47 | The main problematic case is/was gym's EzPickle serialization scheme.
48 | :return:
49 | """
50 | return self.__dict__
51 |
52 | def __setstate__(self, state):
53 | self.__dict__.update(state)
54 |
55 | def __str__(self):
56 | return '{}({})'.format(type(self).__name__, self.wrapped_env)
57 |
58 |
59 | class HistoryEnv(ProxyEnv, Env):
60 | def __init__(self, wrapped_env, history_len):
61 | super().__init__(wrapped_env)
62 | self.history_len = history_len
63 |
64 | high = np.inf * np.ones(
65 | self.history_len * self.observation_space.low.size)
66 | low = -high
67 | self.observation_space = Box(low=low,
68 | high=high,
69 | )
70 | self.history = deque(maxlen=self.history_len)
71 |
72 | def step(self, action):
73 | state, reward, done, info = super().step(action)
74 | self.history.append(state)
75 | flattened_history = self._get_history().flatten()
76 | return flattened_history, reward, done, info
77 |
78 | def reset(self, **kwargs):
79 | state = super().reset()
80 | self.history = deque(maxlen=self.history_len)
81 | self.history.append(state)
82 | flattened_history = self._get_history().flatten()
83 | return flattened_history
84 |
85 | def _get_history(self):
86 | observations = list(self.history)
87 |
88 | obs_count = len(observations)
89 | for _ in range(self.history_len - obs_count):
90 | dummy = np.zeros(self._wrapped_env.observation_space.low.size)
91 | observations.append(dummy)
92 | return np.c_[observations]
93 |
94 |
95 | class DiscretizeEnv(ProxyEnv, Env):
96 | def __init__(self, wrapped_env, num_bins):
97 | super().__init__(wrapped_env)
98 | low = self.wrapped_env.action_space.low
99 | high = self.wrapped_env.action_space.high
100 | action_ranges = [
101 | np.linspace(low[i], high[i], num_bins)
102 | for i in range(len(low))
103 | ]
104 | self.idx_to_continuous_action = [
105 | np.array(x) for x in itertools.product(*action_ranges)
106 | ]
107 | self.action_space = Discrete(len(self.idx_to_continuous_action))
108 |
109 | def step(self, action):
110 | continuous_action = self.idx_to_continuous_action[action]
111 | return super().step(continuous_action)
112 |
113 |
114 | class NormalizedBoxEnv(ProxyEnv):
115 | """
116 | Normalize action to in [-1, 1].
117 |
118 | Optionally normalize observations and scale reward.
119 | """
120 |
121 | def __init__(
122 | self,
123 | env,
124 | reward_scale=1.,
125 | obs_mean=None,
126 | obs_std=None,
127 | ):
128 | ProxyEnv.__init__(self, env)
129 | self._should_normalize = not (obs_mean is None and obs_std is None)
130 | if self._should_normalize:
131 | if obs_mean is None:
132 | obs_mean = np.zeros_like(env.observation_space.low)
133 | else:
134 | obs_mean = np.array(obs_mean)
135 | if obs_std is None:
136 | obs_std = np.ones_like(env.observation_space.low)
137 | else:
138 | obs_std = np.array(obs_std)
139 | self._reward_scale = reward_scale
140 | self._obs_mean = obs_mean
141 | self._obs_std = obs_std
142 | ub = np.ones(self._wrapped_env.action_space.shape)
143 | self.action_space = Box(-1 * ub, ub)
144 |
145 | def estimate_obs_stats(self, obs_batch, override_values=False):
146 | if self._obs_mean is not None and not override_values:
147 | raise Exception("Observation mean and std already set. To "
148 | "override, set override_values to True.")
149 | self._obs_mean = np.mean(obs_batch, axis=0)
150 | self._obs_std = np.std(obs_batch, axis=0)
151 |
152 | def _apply_normalize_obs(self, obs):
153 | return (obs - self._obs_mean) / (self._obs_std + 1e-8)
154 |
155 | def step(self, action):
156 | lb = self._wrapped_env.action_space.low
157 | ub = self._wrapped_env.action_space.high
158 | scaled_action = lb + (action + 1.) * 0.5 * (ub - lb)
159 | scaled_action = np.clip(scaled_action, lb, ub)
160 |
161 | wrapped_step = self._wrapped_env.step(scaled_action)
162 | next_obs, reward, done, info = wrapped_step
163 | if self._should_normalize:
164 | next_obs = self._apply_normalize_obs(next_obs)
165 | return next_obs, reward * self._reward_scale, done, info
166 |
167 | def __str__(self):
168 | return "Normalized: %s" % self._wrapped_env
169 |
170 |
--------------------------------------------------------------------------------
/maple/envs/wrappers/__init__.py:
--------------------------------------------------------------------------------
1 | from maple.envs.wrappers.discretize_env import DiscretizeEnv
2 | from maple.envs.wrappers.history_env import HistoryEnv
3 | from maple.envs.wrappers.image_mujoco_env import ImageMujocoEnv
4 | from maple.envs.wrappers.image_mujoco_env_with_obs import ImageMujocoWithObsEnv
5 | from maple.envs.wrappers.normalized_box_env import NormalizedBoxEnv
6 | from maple.envs.proxy_env import ProxyEnv
7 | from maple.envs.wrappers.reward_wrapper_env import RewardWrapperEnv
8 | from maple.envs.wrappers.stack_observation_env import StackObservationEnv
9 |
10 |
11 | __all__ = [
12 | 'DiscretizeEnv',
13 | 'HistoryEnv',
14 | 'ImageMujocoEnv',
15 | 'ImageMujocoWithObsEnv',
16 | 'NormalizedBoxEnv',
17 | 'ProxyEnv',
18 | 'RewardWrapperEnv',
19 | 'StackObservationEnv',
20 | ]
--------------------------------------------------------------------------------
/maple/envs/wrappers/discretize_env.py:
--------------------------------------------------------------------------------
1 | import itertools
2 |
3 | import numpy as np
4 | from gym import Env
5 | from gym.spaces import Discrete
6 |
7 | from maple.envs.proxy_env import ProxyEnv
8 |
9 |
10 | class DiscretizeEnv(ProxyEnv, Env):
11 | def __init__(self, wrapped_env, num_bins):
12 | super().__init__(wrapped_env)
13 | low = self.wrapped_env.action_space.low
14 | high = self.wrapped_env.action_space.high
15 | action_ranges = [
16 | np.linspace(low[i], high[i], num_bins)
17 | for i in range(len(low))
18 | ]
19 | self.idx_to_continuous_action = [
20 | np.array(x) for x in itertools.product(*action_ranges)
21 | ]
22 | self.action_space = Discrete(len(self.idx_to_continuous_action))
23 |
24 | def step(self, action):
25 | continuous_action = self.idx_to_continuous_action[action]
26 | return super().step(continuous_action)
27 |
28 |
29 |
--------------------------------------------------------------------------------
/maple/envs/wrappers/history_env.py:
--------------------------------------------------------------------------------
1 | from collections import deque
2 |
3 | import numpy as np
4 | from gym import Env
5 | from gym.spaces import Box
6 |
7 | from maple.envs.proxy_env import ProxyEnv
8 |
9 |
10 | class HistoryEnv(ProxyEnv, Env):
11 | def __init__(self, wrapped_env, history_len):
12 | super().__init__(wrapped_env)
13 | self.history_len = history_len
14 |
15 | high = np.inf * np.ones(
16 | self.history_len * self.observation_space.low.size)
17 | low = -high
18 | self.observation_space = Box(low=low,
19 | high=high,
20 | )
21 | self.history = deque(maxlen=self.history_len)
22 |
23 | def step(self, action):
24 | state, reward, done, info = super().step(action)
25 | self.history.append(state)
26 | flattened_history = self._get_history().flatten()
27 | return flattened_history, reward, done, info
28 |
29 | def reset(self, **kwargs):
30 | state = super().reset()
31 | self.history = deque(maxlen=self.history_len)
32 | self.history.append(state)
33 | flattened_history = self._get_history().flatten()
34 | return flattened_history
35 |
36 | def _get_history(self):
37 | observations = list(self.history)
38 |
39 | obs_count = len(observations)
40 | for _ in range(self.history_len - obs_count):
41 | dummy = np.zeros(self._wrapped_env.observation_space.low.size)
42 | observations.append(dummy)
43 | return np.c_[observations]
44 |
45 |
46 |
--------------------------------------------------------------------------------
/maple/envs/wrappers/image_mujoco_env.py:
--------------------------------------------------------------------------------
1 | from collections import deque
2 |
3 | import mujoco_py
4 | import numpy as np
5 | import torch
6 | from PIL import Image
7 | from gym import Env
8 | from gym.spaces import Box
9 |
10 | from maple.envs.proxy_env import ProxyEnv
11 |
12 |
13 | class ImageMujocoEnv(ProxyEnv, Env):
14 | def __init__(self,
15 | wrapped_env,
16 | imsize=32,
17 | keep_prev=0,
18 | init_camera=None,
19 | camera_name=None,
20 | transpose=False,
21 | grayscale=False,
22 | normalize=False,
23 | ):
24 | super().__init__(wrapped_env)
25 |
26 | self.imsize = imsize
27 | if grayscale:
28 | self.image_length = self.imsize * self.imsize
29 | else:
30 | self.image_length = 3 * self.imsize * self.imsize
31 | # This is torch format rather than PIL image
32 | self.image_shape = (self.imsize, self.imsize)
33 | # Flattened past image queue
34 | self.history_length = keep_prev + 1
35 | self.history = deque(maxlen=self.history_length)
36 | # init camera
37 | if init_camera is not None:
38 | sim = self._wrapped_env.sim
39 | viewer = mujoco_py.MjRenderContextOffscreen(sim, device_id=-1)
40 | init_camera(viewer.cam)
41 | sim.add_render_context(viewer)
42 | self.camera_name = camera_name # None means default camera
43 | self.transpose = transpose
44 | self.grayscale = grayscale
45 | self.normalize = normalize
46 | self._render_local = False
47 |
48 | self.observation_space = Box(low=0.0,
49 | high=1.0,
50 | shape=(
51 | self.image_length * self.history_length,))
52 |
53 | def step(self, action):
54 | # image observation get returned as a flattened 1D array
55 | true_state, reward, done, info = super().step(action)
56 |
57 | observation = self._image_observation()
58 | self.history.append(observation)
59 | history = self._get_history().flatten()
60 | full_obs = self._get_obs(history, true_state)
61 | return full_obs, reward, done, info
62 |
63 | def reset(self, **kwargs):
64 | true_state = super().reset(**kwargs)
65 | self.history = deque(maxlen=self.history_length)
66 |
67 | observation = self._image_observation()
68 | self.history.append(observation)
69 | history = self._get_history().flatten()
70 | full_obs = self._get_obs(history, true_state)
71 | return full_obs
72 |
73 | def get_image(self):
74 | """TODO: this should probably consider history"""
75 | return self._image_observation()
76 |
77 | def _get_obs(self, history_flat, true_state):
78 | # adds extra information from true_state into to the image observation.
79 | # Used in ImageWithObsEnv.
80 | return history_flat
81 |
82 | def _image_observation(self):
83 | # returns the image as a torch format np array
84 | image_obs = self._wrapped_env.sim.render(width=self.imsize,
85 | height=self.imsize,
86 | camera_name=self.camera_name)
87 | if self._render_local:
88 | cv2.imshow('env', image_obs)
89 | cv2.waitKey(1)
90 | if self.grayscale:
91 | image_obs = Image.fromarray(image_obs).convert('L')
92 | image_obs = np.array(image_obs)
93 | if self.normalize:
94 | image_obs = image_obs / 255.0
95 | if self.transpose:
96 | image_obs = image_obs.transpose()
97 | return image_obs
98 |
99 | def _get_history(self):
100 | observations = list(self.history)
101 |
102 | obs_count = len(observations)
103 | for _ in range(self.history_length - obs_count):
104 | dummy = np.zeros(self.image_shape)
105 | observations.append(dummy)
106 | return np.c_[observations]
107 |
108 | def retrieve_images(self):
109 | # returns images in unflattened PIL format
110 | images = []
111 | for image_obs in self.history:
112 | pil_image = self.torch_to_pil(torch.Tensor(image_obs))
113 | images.append(pil_image)
114 | return images
115 |
116 | def split_obs(self, obs):
117 | # splits observation into image input and true observation input
118 | imlength = self.image_length * self.history_length
119 | obs_length = self.observation_space.low.size
120 | obs = obs.view(-1, obs_length)
121 | image_obs = obs.narrow(start=0,
122 | length=imlength,
123 | dimension=1)
124 | if obs_length == imlength:
125 | return image_obs, None
126 |
127 | fc_obs = obs.narrow(start=imlength,
128 | length=obs.shape[1] - imlength,
129 | dimension=1)
130 | return image_obs, fc_obs
131 |
132 | def enable_render(self):
133 | self._render_local = True
134 |
135 |
136 |
--------------------------------------------------------------------------------
/maple/envs/wrappers/image_mujoco_env_with_obs.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from gym.spaces import Box
3 |
4 | from maple.envs.wrappers.image_mujoco_env import ImageMujocoEnv
5 |
6 |
7 | class ImageMujocoWithObsEnv(ImageMujocoEnv):
8 | def __init__(self, env, **kwargs):
9 | super().__init__(env, **kwargs)
10 | self.observation_space = Box(
11 | low=0.0,
12 | high=1.0,
13 | shape=(self.image_length * self.history_length
14 | + self.wrapped_env.obs_dim,))
15 |
16 | def _get_obs(self, history_flat, true_state):
17 | return np.concatenate([history_flat, true_state])
18 |
--------------------------------------------------------------------------------
/maple/envs/wrappers/normalized_box_env.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from gym.spaces import Box
3 |
4 | from maple.envs.proxy_env import ProxyEnv
5 |
6 |
7 | class NormalizedBoxEnv(ProxyEnv):
8 | """
9 | Normalize action to in [-1, 1].
10 |
11 | Optionally normalize observations and scale reward.
12 | """
13 |
14 | def __init__(
15 | self,
16 | env,
17 | reward_scale=1.,
18 | obs_mean=None,
19 | obs_std=None,
20 | ):
21 | ProxyEnv.__init__(self, env)
22 | self._should_normalize = not (obs_mean is None and obs_std is None)
23 | if self._should_normalize:
24 | if obs_mean is None:
25 | obs_mean = np.zeros_like(env.observation_space.low)
26 | else:
27 | obs_mean = np.array(obs_mean)
28 | if obs_std is None:
29 | obs_std = np.ones_like(env.observation_space.low)
30 | else:
31 | obs_std = np.array(obs_std)
32 | self._reward_scale = reward_scale
33 | self._obs_mean = obs_mean
34 | self._obs_std = obs_std
35 | ub = np.ones(self._wrapped_env.action_space.shape)
36 | self.action_space = Box(-1 * ub, ub)
37 |
38 | def estimate_obs_stats(self, obs_batch, override_values=False):
39 | if self._obs_mean is not None and not override_values:
40 | raise Exception("Observation mean and std already set. To "
41 | "override, set override_values to True.")
42 | self._obs_mean = np.mean(obs_batch, axis=0)
43 | self._obs_std = np.std(obs_batch, axis=0)
44 |
45 | def _apply_normalize_obs(self, obs):
46 | return (obs - self._obs_mean) / (self._obs_std + 1e-8)
47 |
48 | def step(self, action):
49 | lb = self._wrapped_env.action_space.low
50 | ub = self._wrapped_env.action_space.high
51 | scaled_action = lb + (action + 1.) * 0.5 * (ub - lb)
52 | scaled_action = np.clip(scaled_action, lb, ub)
53 |
54 | wrapped_step = self._wrapped_env.step(scaled_action)
55 | next_obs, reward, done, info = wrapped_step
56 | if self._should_normalize:
57 | next_obs = self._apply_normalize_obs(next_obs)
58 | return next_obs, reward * self._reward_scale, done, info
59 |
60 | def __str__(self):
61 | return "Normalized: %s" % self._wrapped_env
62 |
63 |
64 |
--------------------------------------------------------------------------------
/maple/envs/wrappers/reward_wrapper_env.py:
--------------------------------------------------------------------------------
1 | from maple.envs.proxy_env import ProxyEnv
2 |
3 |
4 | class RewardWrapperEnv(ProxyEnv):
5 | """Substitute a different reward function"""
6 |
7 | def __init__(
8 | self,
9 | env,
10 | compute_reward_fn,
11 | ):
12 | ProxyEnv.__init__(self, env)
13 | self.spec = env.spec # hack for hand envs
14 | self.compute_reward_fn = compute_reward_fn
15 |
16 | def step(self, action):
17 | next_obs, reward, done, info = self._wrapped_env.step(action)
18 | info["env_reward"] = reward
19 | reward = self.compute_reward_fn(next_obs, reward, done, info)
20 | return next_obs, reward, done, info
21 |
--------------------------------------------------------------------------------
/maple/envs/wrappers/stack_observation_env.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from gym.spaces import Box
3 |
4 | from maple.envs.proxy_env import ProxyEnv
5 |
6 |
7 | class StackObservationEnv(ProxyEnv):
8 | """
9 | Env wrapper for passing history of observations as the new observation
10 | """
11 |
12 | def __init__(
13 | self,
14 | env,
15 | stack_obs=1,
16 | ):
17 | ProxyEnv.__init__(self, env)
18 | self.stack_obs = stack_obs
19 | low = env.observation_space.low
20 | high = env.observation_space.high
21 | self.obs_dim = low.size
22 | self._last_obs = np.zeros((self.stack_obs, self.obs_dim))
23 | self.observation_space = Box(
24 | low=np.repeat(low, stack_obs),
25 | high=np.repeat(high, stack_obs),
26 | )
27 |
28 | def reset(self):
29 | self._last_obs = np.zeros((self.stack_obs, self.obs_dim))
30 | next_obs = self._wrapped_env.reset()
31 | self._last_obs[-1, :] = next_obs
32 | return self._last_obs.copy().flatten()
33 |
34 | def step(self, action):
35 | next_obs, reward, done, info = self._wrapped_env.step(action)
36 | self._last_obs = np.vstack((
37 | self._last_obs[1:, :],
38 | next_obs
39 | ))
40 | return self._last_obs.copy().flatten(), reward, done, info
41 |
42 |
43 |
--------------------------------------------------------------------------------
/maple/exploration_strategies/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/UT-Austin-RPL/maple/c13f584d4184e5796bd0adb822b8761ff4ebeee0/maple/exploration_strategies/__init__.py
--------------------------------------------------------------------------------
/maple/exploration_strategies/base.py:
--------------------------------------------------------------------------------
1 | import abc
2 |
3 | from maple.policies.base import ExplorationPolicy
4 |
5 |
6 | class ExplorationStrategy(object, metaclass=abc.ABCMeta):
7 | @abc.abstractmethod
8 | def get_action(self, t, observation, policy, **kwargs):
9 | pass
10 |
11 | def reset(self):
12 | pass
13 |
14 |
15 | class RawExplorationStrategy(ExplorationStrategy, metaclass=abc.ABCMeta):
16 | @abc.abstractmethod
17 | def get_action_from_raw_action(self, action, **kwargs):
18 | pass
19 |
20 | def get_action(self, t, policy, *args, **kwargs):
21 | action, agent_info = policy.get_action(*args, **kwargs)
22 | return self.get_action_from_raw_action(action, t=t), agent_info
23 |
24 | def reset(self):
25 | pass
26 |
27 |
28 | class PolicyWrappedWithExplorationStrategy(ExplorationPolicy):
29 | def __init__(
30 | self,
31 | exploration_strategy: ExplorationStrategy,
32 | policy,
33 | ):
34 | self.es = exploration_strategy
35 | self.policy = policy
36 | self.t = 0
37 |
38 | def set_num_steps_total(self, t):
39 | self.t = t
40 |
41 | def get_action(self, *args, **kwargs):
42 | return self.es.get_action(self.t, self.policy, *args, **kwargs)
43 |
44 | def reset(self):
45 | self.es.reset()
46 | self.policy.reset()
47 |
--------------------------------------------------------------------------------
/maple/exploration_strategies/epsilon_greedy.py:
--------------------------------------------------------------------------------
1 | import random
2 |
3 | from maple.exploration_strategies.base import RawExplorationStrategy
4 |
5 |
6 | class EpsilonGreedy(RawExplorationStrategy):
7 | """
8 | Take a random discrete action with some probability.
9 | """
10 | def __init__(self, action_space, prob_random_action=0.1):
11 | self.prob_random_action = prob_random_action
12 | self.action_space = action_space
13 |
14 | def get_action_from_raw_action(self, action, **kwargs):
15 | if random.random() <= self.prob_random_action:
16 | return self.action_space.sample()
17 | return action
18 |
--------------------------------------------------------------------------------
/maple/exploration_strategies/gaussian_and_epsilon_strategy.py:
--------------------------------------------------------------------------------
1 | import random
2 | from maple.exploration_strategies.base import RawExplorationStrategy
3 | import numpy as np
4 |
5 |
6 | class GaussianAndEpsilonStrategy(RawExplorationStrategy):
7 | """
8 | With probability epsilon, take a completely random action.
9 | with probability 1-epsilon, add Gaussian noise to the action taken by a
10 | deterministic policy.
11 | """
12 | def __init__(self, action_space, epsilon, max_sigma=1.0, min_sigma=None,
13 | decay_period=1000000):
14 | assert len(action_space.shape) == 1
15 | if min_sigma is None:
16 | min_sigma = max_sigma
17 | self._max_sigma = max_sigma
18 | self._epsilon = epsilon
19 | self._min_sigma = min_sigma
20 | self._decay_period = decay_period
21 | self._action_space = action_space
22 |
23 | def get_action_from_raw_action(self, action, t=None, **kwargs):
24 | if random.random() < self._epsilon:
25 | return self._action_space.sample()
26 | else:
27 | sigma = self._max_sigma - (self._max_sigma - self._min_sigma) * min(1.0, t * 1.0 / self._decay_period)
28 | return np.clip(
29 | action + np.random.normal(size=len(action)) * sigma,
30 | self._action_space.low,
31 | self._action_space.high,
32 | )
33 |
--------------------------------------------------------------------------------
/maple/exploration_strategies/gaussian_strategy.py:
--------------------------------------------------------------------------------
1 | from maple.exploration_strategies.base import RawExplorationStrategy
2 | import numpy as np
3 |
4 |
5 | class GaussianStrategy(RawExplorationStrategy):
6 | """
7 | This strategy adds Gaussian noise to the action taken by the deterministic policy.
8 |
9 | Based on the rllab implementation.
10 | """
11 | def __init__(self, action_space, max_sigma=1.0, min_sigma=None,
12 | decay_period=1000000):
13 | assert len(action_space.shape) == 1
14 | self._max_sigma = max_sigma
15 | if min_sigma is None:
16 | min_sigma = max_sigma
17 | self._min_sigma = min_sigma
18 | self._decay_period = decay_period
19 | self._action_space = action_space
20 |
21 | def get_action_from_raw_action(self, action, t=None, **kwargs):
22 | sigma = (
23 | self._max_sigma - (self._max_sigma - self._min_sigma) *
24 | min(1.0, t * 1.0 / self._decay_period)
25 | )
26 | return np.clip(
27 | action + np.random.normal(size=len(action)) * sigma,
28 | self._action_space.low,
29 | self._action_space.high,
30 | )
31 |
--------------------------------------------------------------------------------
/maple/exploration_strategies/ou_strategy.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import numpy.random as nr
3 |
4 | from maple.exploration_strategies.base import RawExplorationStrategy
5 |
6 |
7 | class OUStrategy(RawExplorationStrategy):
8 | """
9 | This strategy implements the Ornstein-Uhlenbeck process, which adds
10 | time-correlated noise to the actions taken by the deterministic policy.
11 | The OU process satisfies the following stochastic differential equation:
12 | dxt = theta*(mu - xt)*dt + sigma*dWt
13 | where Wt denotes the Wiener process
14 |
15 | Based on the rllab implementation.
16 | """
17 |
18 | def __init__(
19 | self,
20 | action_space,
21 | mu=0,
22 | theta=0.15,
23 | max_sigma=0.3,
24 | min_sigma=None,
25 | decay_period=100000,
26 | ):
27 | if min_sigma is None:
28 | min_sigma = max_sigma
29 | self.mu = mu
30 | self.theta = theta
31 | self.sigma = max_sigma
32 | self._max_sigma = max_sigma
33 | if min_sigma is None:
34 | min_sigma = max_sigma
35 | self._min_sigma = min_sigma
36 | self._decay_period = decay_period
37 | self.dim = np.prod(action_space.low.shape)
38 | self.low = action_space.low
39 | self.high = action_space.high
40 | self.state = np.ones(self.dim) * self.mu
41 | self.reset()
42 |
43 | def reset(self):
44 | self.state = np.ones(self.dim) * self.mu
45 |
46 | def evolve_state(self):
47 | x = self.state
48 | dx = self.theta * (self.mu - x) + self.sigma * nr.randn(len(x))
49 | self.state = x + dx
50 | return self.state
51 |
52 | def get_action_from_raw_action(self, action, t=0, **kwargs):
53 | ou_state = self.evolve_state()
54 | self.sigma = (
55 | self._max_sigma
56 | - (self._max_sigma - self._min_sigma)
57 | * min(1.0, t * 1.0 / self._decay_period)
58 | )
59 | return np.clip(action + ou_state, self.low, self.high)
60 |
--------------------------------------------------------------------------------
/maple/launchers/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | This module contains 'launchers', which are self-contained functions that take
3 | one dictionary and run a full experiment. The dictionary configures the
4 | experiment.
5 |
6 | It is important that the functions are completely self-contained (i.e. they
7 | import their own modules) so that they can be serialized.
8 | """
9 |
--------------------------------------------------------------------------------
/maple/launchers/conf.py:
--------------------------------------------------------------------------------
1 | """
2 | Copy this file to config.py and modify as needed.
3 | """
4 | import os
5 | from os.path import join
6 | import maple
7 | import robosuite
8 |
9 | """
10 | `doodad.mount.MountLocal` by default ignores directories called "data"
11 | If you're going to rename this directory and use EC2, then change
12 | `doodad.mount.MountLocal.filter_dir`
13 | """
14 | # The directory of the project, not source
15 | maple_project_dir = join(os.path.dirname(maple.__file__), os.pardir)
16 | robosuite_project_dir = join(os.path.dirname(robosuite.__file__), os.pardir)
17 | LOCAL_LOG_DIR = join(maple_project_dir, 'data')
18 |
19 | """
20 | ********************************************************************************
21 | ********************************************************************************
22 | ********************************************************************************
23 |
24 | You probably don't need to set all of the configurations below this line,
25 | unless you use AWS, GCP, Slurm, and/or Slurm on a remote server. I recommend
26 | ignoring most of these things and only using them on an as-needed basis.
27 |
28 | ********************************************************************************
29 | ********************************************************************************
30 | ********************************************************************************
31 | """
32 |
33 | """
34 | General doodad settings.
35 | """
36 | CODE_DIRS_TO_MOUNT = [
37 | maple_project_dir,
38 | robosuite_project_dir,
39 | # '/home/user/python/module/one', Add more paths as needed
40 | ]
41 |
42 | HOME = os.getenv('HOME') if os.getenv('HOME') is not None else os.getenv("USERPROFILE")
43 |
44 | DIR_AND_MOUNT_POINT_MAPPINGS = [
45 | dict(
46 | local_dir=join(HOME, '.mujoco/'),
47 | mount_point='/root/.mujoco',
48 | ),
49 | ]
50 | RUN_DOODAD_EXPERIMENT_SCRIPT_PATH = (
51 | join(maple_project_dir, 'scripts', 'run_experiment_from_doodad.py')
52 | # '/home/user/path/to/maple/scripts/run_experiment_from_doodad.py'
53 | )
54 | """
55 | AWS Settings
56 | """
57 | # If not set, default will be chosen by doodad
58 | # AWS_S3_PATH = 's3://bucket/directory
59 |
60 | # The docker image is looked up on dockerhub.com.
61 | DOODAD_DOCKER_IMAGE = "TODO"
62 | INSTANCE_TYPE = 'c4.large'
63 | SPOT_PRICE = 0.03
64 |
65 | GPU_DOODAD_DOCKER_IMAGE = "TODO"
66 | GPU_INSTANCE_TYPE = 'g2.2xlarge'
67 | GPU_SPOT_PRICE = 0.5
68 |
69 | # You can use AMI images with the docker images already installed.
70 | REGION_TO_GPU_AWS_IMAGE_ID = {
71 | 'us-west-1': "TODO",
72 | 'us-east-1': "TODO",
73 | }
74 |
75 | REGION_TO_GPU_AWS_AVAIL_ZONE = {
76 | 'us-east-1': "us-east-1b",
77 | }
78 |
79 | # This really shouldn't matter and in theory could be whatever
80 | OUTPUT_DIR_FOR_DOODAD_TARGET = '/tmp/doodad-output/'
81 |
82 |
83 | """
84 | Slurm Settings
85 | """
86 | SINGULARITY_IMAGE = '/home/PATH/TO/IMAGE.img'
87 | # This assumes you saved mujoco to $HOME/.mujoco
88 | SINGULARITY_PRE_CMDS = [
89 | 'export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$HOME/.mujoco/mjpro150/bin'
90 | ]
91 | SLURM_CPU_CONFIG = dict(
92 | account_name='TODO',
93 | partition='savio',
94 | nodes=1,
95 | n_tasks=1,
96 | n_gpus=1,
97 | )
98 | SLURM_GPU_CONFIG = dict(
99 | account_name='TODO',
100 | partition='savio2_1080ti',
101 | nodes=1,
102 | n_tasks=1,
103 | n_gpus=1,
104 | )
105 |
106 |
107 | """
108 | Slurm Script Settings
109 |
110 | These are basically the same settings as above, but for the remote machine
111 | where you will be running the generated script.
112 | """
113 | SSS_CODE_DIRS_TO_MOUNT = [
114 | ]
115 | SSS_DIR_AND_MOUNT_POINT_MAPPINGS = [
116 | dict(
117 | local_dir='/global/home/users/USERNAME/.mujoco',
118 | mount_point='/root/.mujoco',
119 | ),
120 | ]
121 | SSS_LOG_DIR = '/global/scratch/USERNAME/doodad-log'
122 |
123 | SSS_IMAGE = '/global/scratch/USERNAME/TODO.img'
124 | SSS_RUN_DOODAD_EXPERIMENT_SCRIPT_PATH = (
125 | '/global/home/users/USERNAME/path/to/maple/scripts'
126 | '/run_experiment_from_doodad.py'
127 | )
128 | SSS_PRE_CMDS = [
129 | 'export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/global/home/users/USERNAME'
130 | '/.mujoco/mjpro150/bin'
131 | ]
132 |
133 | """
134 | GCP Settings
135 | """
136 | GCP_IMAGE_NAME = 'TODO'
137 | GCP_GPU_IMAGE_NAME = 'TODO'
138 | GCP_BUCKET_NAME = 'TODO'
139 |
140 | GCP_DEFAULT_KWARGS = dict(
141 | zone='us-west2-c',
142 | instance_type='n1-standard-4',
143 | image_project='TODO',
144 | terminate=True,
145 | preemptible=True,
146 | gpu_kwargs=dict(
147 | gpu_model='nvidia-tesla-p4',
148 | num_gpu=1,
149 | )
150 | )
151 |
152 | try:
153 | from maple.launchers.conf_private import *
154 | except ImportError:
155 | print("No personal conf_private.py found.")
156 |
--------------------------------------------------------------------------------
/maple/policies/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/UT-Austin-RPL/maple/c13f584d4184e5796bd0adb822b8761ff4ebeee0/maple/policies/__init__.py
--------------------------------------------------------------------------------
/maple/policies/argmax.py:
--------------------------------------------------------------------------------
1 | """
2 | Torch argmax policy
3 | """
4 | import numpy as np
5 | from torch import nn
6 |
7 | import maple.torch.pytorch_util as ptu
8 | from maple.policies.base import Policy
9 |
10 |
11 | class ArgmaxDiscretePolicy(nn.Module, Policy):
12 | def __init__(self, qf):
13 | super().__init__()
14 | self.qf = qf
15 |
16 | def get_action(self, obs):
17 | obs = np.expand_dims(obs, axis=0)
18 | obs = ptu.from_numpy(obs).float()
19 | q_values = self.qf(obs).squeeze(0)
20 | q_values_np = ptu.get_numpy(q_values)
21 | return q_values_np.argmax(), {}
22 |
--------------------------------------------------------------------------------
/maple/policies/base.py:
--------------------------------------------------------------------------------
1 | import abc
2 |
3 |
4 | class Policy(object, metaclass=abc.ABCMeta):
5 | """
6 | General policy interface.
7 | """
8 | @abc.abstractmethod
9 | def get_action(self, observation):
10 | """
11 |
12 | :param observation:
13 | :return: action, debug_dictionary
14 | """
15 | pass
16 |
17 | def reset(self):
18 | pass
19 |
20 |
21 | class ExplorationPolicy(Policy, metaclass=abc.ABCMeta):
22 | def set_num_steps_total(self, t):
23 | pass
24 |
--------------------------------------------------------------------------------
/maple/policies/simple.py:
--------------------------------------------------------------------------------
1 | from maple.policies.base import Policy
2 |
3 |
4 | class RandomPolicy(Policy):
5 | """
6 | Policy that always outputs zero.
7 | """
8 |
9 | def __init__(self, action_space):
10 | self.action_space = action_space
11 |
12 | def get_action(self, obs):
13 | return self.action_space.sample(), {}
14 |
--------------------------------------------------------------------------------
/maple/samplers/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/UT-Austin-RPL/maple/c13f584d4184e5796bd0adb822b8761ff4ebeee0/maple/samplers/__init__.py
--------------------------------------------------------------------------------
/maple/samplers/data_collector/__init__.py:
--------------------------------------------------------------------------------
1 | from maple.samplers.data_collector.base import (
2 | DataCollector,
3 | PathCollector,
4 | StepCollector,
5 | )
6 | from maple.samplers.data_collector.path_collector import (
7 | MdpPathCollector,
8 | ObsDictPathCollector,
9 | GoalConditionedPathCollector,
10 | VAEWrappedEnvPathCollector,
11 | )
12 | from maple.samplers.data_collector.step_collector import (
13 | GoalConditionedStepCollector
14 | )
15 |
--------------------------------------------------------------------------------
/maple/samplers/data_collector/base.py:
--------------------------------------------------------------------------------
1 | import abc
2 |
3 |
4 | class DataCollector(object, metaclass=abc.ABCMeta):
5 | def end_epoch(self, epoch):
6 | pass
7 |
8 | def get_diagnostics(self):
9 | return {}
10 |
11 | def get_snapshot(self):
12 | return {}
13 |
14 | @abc.abstractmethod
15 | def get_epoch_paths(self):
16 | pass
17 |
18 |
19 | class PathCollector(DataCollector, metaclass=abc.ABCMeta):
20 | @abc.abstractmethod
21 | def collect_new_paths(
22 | self,
23 | max_path_length,
24 | num_steps,
25 | discard_incomplete_paths,
26 | ):
27 | pass
28 |
29 |
30 | class StepCollector(DataCollector, metaclass=abc.ABCMeta):
31 | @abc.abstractmethod
32 | def collect_new_steps(
33 | self,
34 | max_path_length,
35 | num_steps,
36 | discard_incomplete_paths,
37 | ):
38 | pass
39 |
--------------------------------------------------------------------------------
/maple/samplers/data_collector/contextual_path_collector.py:
--------------------------------------------------------------------------------
1 | from functools import partial
2 |
3 | from maple.envs.contextual import ContextualEnv
4 | from maple.policies.base import Policy
5 | from maple.samplers.data_collector import MdpPathCollector
6 | from maple.samplers.rollout_functions import contextual_rollout
7 |
8 |
9 | class ContextualPathCollector(MdpPathCollector):
10 | def __init__(
11 | self,
12 | env: ContextualEnv,
13 | policy: Policy,
14 | max_num_epoch_paths_saved=None,
15 | observation_key='observation',
16 | context_keys_for_policy='context',
17 | render=False,
18 | render_kwargs=None,
19 | **kwargs
20 | ):
21 | rollout_fn = partial(
22 | contextual_rollout,
23 | context_keys_for_policy=context_keys_for_policy,
24 | observation_key=observation_key,
25 | )
26 | super().__init__(
27 | env, policy, max_num_epoch_paths_saved, render, render_kwargs,
28 | rollout_fn=rollout_fn,
29 | **kwargs
30 | )
31 | self._observation_key = observation_key
32 | self._context_keys_for_policy = context_keys_for_policy
33 |
34 | def get_snapshot(self):
35 | snapshot = super().get_snapshot()
36 | snapshot.update(
37 | observation_key=self._observation_key,
38 | context_keys_for_policy=self._context_keys_for_policy,
39 | )
40 | return snapshot
41 |
--------------------------------------------------------------------------------
/maple/samplers/data_collector/joint_path_collector.py:
--------------------------------------------------------------------------------
1 | from collections import OrderedDict
2 | from typing import Dict
3 |
4 | from maple.core.logging import add_prefix
5 | from maple.samplers.data_collector import PathCollector
6 |
7 |
8 | class JointPathCollector(PathCollector):
9 | def __init__(self, path_collectors: Dict[str, PathCollector]):
10 | self.path_collectors = path_collectors
11 |
12 | def collect_new_paths(self, max_path_length, num_steps,
13 | discard_incomplete_paths):
14 | paths = []
15 | for collector in self.path_collectors.values():
16 | collector.collect_new_paths(
17 | max_path_length, num_steps, discard_incomplete_paths
18 | )
19 | return paths
20 |
21 | def end_epoch(self, epoch):
22 | for collector in self.path_collectors.values():
23 | collector.end_epoch(epoch)
24 |
25 | def get_diagnostics(self):
26 | diagnostics = OrderedDict()
27 | for name, collector in self.path_collectors.items():
28 | diagnostics.update(
29 | add_prefix(collector.get_diagnostics(), name, divider='/'),
30 | )
31 | return diagnostics
32 |
33 | def get_snapshot(self):
34 | snapshot = {}
35 | for name, collector in self.path_collectors.items():
36 | snapshot.update(
37 | add_prefix(collector.get_snapshot(), name, divider='/'),
38 | )
39 | return snapshot
40 |
41 | def get_epoch_paths(self):
42 | paths = {}
43 | for name, collector in self.path_collectors.items():
44 | paths[name] = collector.get_epoch_paths()
45 | return paths
46 |
47 |
--------------------------------------------------------------------------------
/maple/samplers/data_collector/path_collector.py:
--------------------------------------------------------------------------------
1 | from collections import deque, OrderedDict
2 | from functools import partial
3 |
4 | import numpy as np
5 |
6 | from maple.core.eval_util import create_stats_ordered_dict
7 | from maple.samplers.data_collector.base import PathCollector
8 | from maple.samplers.rollout_functions import rollout
9 |
10 |
11 | class MdpPathCollector(PathCollector):
12 | def __init__(
13 | self,
14 | env,
15 | policy,
16 | max_num_epoch_paths_saved=None,
17 | render=False,
18 | render_kwargs=None,
19 | rollout_fn=rollout,
20 | save_env_in_snapshot=True,
21 | rollout_fn_kwargs=None,
22 | ):
23 | if render_kwargs is None:
24 | render_kwargs = {}
25 | self._env = env
26 | self._policy = policy
27 | self._max_num_epoch_paths_saved = max_num_epoch_paths_saved
28 | self._epoch_paths = deque(maxlen=self._max_num_epoch_paths_saved)
29 | self._render = render
30 | self._render_kwargs = render_kwargs
31 | self._rollout_fn = rollout_fn
32 | if rollout_fn_kwargs is None:
33 | rollout_fn_kwargs = {}
34 | self._rollout_fn_kwargs = rollout_fn_kwargs
35 |
36 | self._num_steps_total = 0
37 | self._num_paths_total = 0
38 |
39 | self._num_actions_total = 0
40 |
41 | self._save_env_in_snapshot = save_env_in_snapshot
42 |
43 | def collect_new_paths(
44 | self,
45 | max_path_length,
46 | num_steps,
47 | discard_incomplete_paths,
48 | ):
49 | paths = []
50 | num_steps_collected = 0
51 | num_actions_collected = 0
52 | while num_steps_collected < num_steps:
53 | max_path_length_this_loop = min( # Do not go over num_steps
54 | max_path_length,
55 | num_steps - num_steps_collected,
56 | )
57 | if discard_incomplete_paths and (max_path_length_this_loop < max_path_length):
58 | break
59 | path = self._rollout_fn(
60 | self._env,
61 | self._policy,
62 | max_path_length=max_path_length_this_loop,
63 | render=self._render,
64 | render_kwargs=self._render_kwargs,
65 | **self._rollout_fn_kwargs
66 | )
67 | num_steps_collected += path['path_length']
68 | num_actions_collected += path['path_length_actions']
69 | paths.append(path)
70 | self._num_paths_total += len(paths)
71 | self._num_steps_total += num_steps_collected
72 | self._num_actions_total += num_actions_collected
73 | self._epoch_paths.extend(paths)
74 | return paths
75 |
76 | def get_epoch_paths(self):
77 | return self._epoch_paths
78 |
79 | def end_epoch(self, epoch):
80 | self._epoch_paths = deque(maxlen=self._max_num_epoch_paths_saved)
81 |
82 | def get_diagnostics(self):
83 | # path_lens = [len(path['actions']) for path in self._epoch_paths]
84 | path_lens = [path['path_length'] for path in self._epoch_paths]
85 | stats = OrderedDict([
86 | ('num steps total', self._num_steps_total),
87 | ('num paths total', self._num_paths_total),
88 | ('num actions total', self._num_actions_total),
89 | ])
90 | stats.update(create_stats_ordered_dict(
91 | "path length",
92 | path_lens,
93 | always_show_all_stats=True,
94 | ))
95 | return stats
96 |
97 | def get_snapshot(self):
98 | snapshot_dict = dict(
99 | policy=self._policy,
100 | )
101 | if self._save_env_in_snapshot:
102 | snapshot_dict['env'] = self._env
103 | return snapshot_dict
104 |
105 |
106 | class GoalConditionedPathCollector(MdpPathCollector):
107 | def __init__(
108 | self,
109 | *args,
110 | observation_key='observation',
111 | desired_goal_key='desired_goal',
112 | goal_sampling_mode=None,
113 | **kwargs
114 | ):
115 | def obs_processor(o):
116 | return np.hstack((o[observation_key], o[desired_goal_key]))
117 |
118 | rollout_fn = partial(
119 | rollout,
120 | preprocess_obs_for_policy_fn=obs_processor,
121 | )
122 | super().__init__(*args, rollout_fn=rollout_fn, **kwargs)
123 | self._observation_key = observation_key
124 | self._desired_goal_key = desired_goal_key
125 | self._goal_sampling_mode = goal_sampling_mode
126 |
127 | def collect_new_paths(self, *args, **kwargs):
128 | self._env.goal_sampling_mode = self._goal_sampling_mode
129 | return super().collect_new_paths(*args, **kwargs)
130 |
131 | def get_snapshot(self):
132 | snapshot = super().get_snapshot()
133 | snapshot.update(
134 | observation_key=self._observation_key,
135 | desired_goal_key=self._desired_goal_key,
136 | )
137 | return snapshot
138 |
139 |
140 | class ObsDictPathCollector(MdpPathCollector):
141 | def __init__(
142 | self,
143 | *args,
144 | observation_key='observation',
145 | **kwargs
146 | ):
147 | def obs_processor(obs):
148 | return obs[observation_key]
149 |
150 | rollout_fn = partial(
151 | rollout,
152 | preprocess_obs_for_policy_fn=obs_processor,
153 | )
154 | super().__init__(*args, rollout_fn=rollout_fn, **kwargs)
155 | self._observation_key = observation_key
156 |
157 | def get_snapshot(self):
158 | snapshot = super().get_snapshot()
159 | snapshot.update(
160 | observation_key=self._observation_key,
161 | )
162 | return snapshot
163 |
164 |
165 | class VAEWrappedEnvPathCollector(GoalConditionedPathCollector):
166 | def __init__(
167 | self,
168 | env,
169 | policy,
170 | decode_goals=False,
171 | **kwargs
172 | ):
173 | """Expects env is VAEWrappedEnv"""
174 | super().__init__(env, policy, **kwargs)
175 | self._decode_goals = decode_goals
176 |
177 | def collect_new_paths(self, *args, **kwargs):
178 | self._env.decode_goals = self._decode_goals
179 | return super().collect_new_paths(*args, **kwargs)
180 |
--------------------------------------------------------------------------------
/maple/samplers/data_collector/vae_env.py:
--------------------------------------------------------------------------------
1 | from maple.envs.vae_wrapper import VAEWrappedEnv
2 | from maple.samplers.data_collector import GoalConditionedPathCollector
3 |
4 |
5 | class VAEWrappedEnvPathCollector(GoalConditionedPathCollector):
6 | def __init__(
7 | self,
8 | goal_sampling_mode,
9 | env: VAEWrappedEnv,
10 | policy,
11 | decode_goals=False,
12 | **kwargs
13 | ):
14 | super().__init__(env, policy, **kwargs)
15 | self._goal_sampling_mode = goal_sampling_mode
16 | self._decode_goals = decode_goals
17 |
18 | def collect_new_paths(self, *args, **kwargs):
19 | self._env.goal_sampling_mode = self._goal_sampling_mode
20 | self._env.decode_goals = self._decode_goals
21 | return super().collect_new_paths(*args, **kwargs)
--------------------------------------------------------------------------------
/maple/samplers/in_place.py:
--------------------------------------------------------------------------------
1 | from maple.samplers.util import rollout
2 |
3 |
4 | class InPlacePathSampler(object):
5 | """
6 | A sampler that does not serialization for sampling. Instead, it just uses
7 | the current policy and environment as-is.
8 |
9 | WARNING: This will affect the environment! So
10 | ```
11 | sampler = InPlacePathSampler(env, ...)
12 | sampler.obtain_samples # this has side-effects: env will change!
13 | ```
14 | """
15 | def __init__(self, env, policy, max_samples, max_path_length, render=False):
16 | self.env = env
17 | self.policy = policy
18 | self.max_path_length = max_path_length
19 | self.max_samples = max_samples
20 | self.render = render
21 | assert max_samples >= max_path_length, "Need max_samples >= max_path_length"
22 |
23 | def start_worker(self):
24 | pass
25 |
26 | def shutdown_worker(self):
27 | pass
28 |
29 | def obtain_samples(self):
30 | paths = []
31 | n_steps_total = 0
32 | while n_steps_total + self.max_path_length <= self.max_samples:
33 | path = rollout(
34 | self.env, self.policy, max_path_length=self.max_path_length,
35 | animated=self.render
36 | )
37 | paths.append(path)
38 | n_steps_total += len(path['observations'])
39 | return paths
40 |
--------------------------------------------------------------------------------
/maple/samplers/rollout_functions.py:
--------------------------------------------------------------------------------
1 | from functools import partial
2 |
3 | import numpy as np
4 | import copy
5 |
6 | create_rollout_function = partial
7 |
8 |
9 | def multitask_rollout(
10 | env,
11 | agent,
12 | max_path_length=np.inf,
13 | render=False,
14 | render_kwargs=None,
15 | observation_key=None,
16 | desired_goal_key=None,
17 | get_action_kwargs=None,
18 | return_dict_obs=False,
19 | full_o_postprocess_func=None,
20 | ):
21 | if full_o_postprocess_func:
22 | def wrapped_fun(env, agent, o):
23 | full_o_postprocess_func(env, agent, observation_key, o)
24 | else:
25 | wrapped_fun = None
26 |
27 | def obs_processor(o):
28 | return np.hstack((o[observation_key], o[desired_goal_key]))
29 |
30 | paths = rollout(
31 | env,
32 | agent,
33 | max_path_length=max_path_length,
34 | render=render,
35 | render_kwargs=render_kwargs,
36 | get_action_kwargs=get_action_kwargs,
37 | preprocess_obs_for_policy_fn=obs_processor,
38 | full_o_postprocess_func=wrapped_fun,
39 | )
40 | if not return_dict_obs:
41 | paths['observations'] = paths['observations'][observation_key]
42 | return paths
43 |
44 |
45 | def contextual_rollout(
46 | env,
47 | agent,
48 | observation_key=None,
49 | context_keys_for_policy=None,
50 | obs_processor=None,
51 | **kwargs
52 | ):
53 | if context_keys_for_policy is None:
54 | context_keys_for_policy = ['context']
55 |
56 | if not obs_processor:
57 | def obs_processor(o):
58 | combined_obs = [o[observation_key]]
59 | for k in context_keys_for_policy:
60 | combined_obs.append(o[k])
61 | return np.concatenate(combined_obs, axis=0)
62 | paths = rollout(
63 | env,
64 | agent,
65 | preprocess_obs_for_policy_fn=obs_processor,
66 | **kwargs
67 | )
68 | return paths
69 |
70 |
71 | def rollout(
72 | env,
73 | agent,
74 | max_path_length=np.inf,
75 | render=False,
76 | render_kwargs=None,
77 | preprocess_obs_for_policy_fn=None,
78 | get_action_kwargs=None,
79 | return_dict_obs=False,
80 | full_o_postprocess_func=None,
81 | reset_callback=None,
82 | addl_info_func=None,
83 | image_obs_in_info=False,
84 | last_step_is_terminal=False,
85 | terminals_all_false=False,
86 | ):
87 | if render_kwargs is None:
88 | render_kwargs = {}
89 | if get_action_kwargs is None:
90 | get_action_kwargs = {}
91 | if preprocess_obs_for_policy_fn is None:
92 | preprocess_obs_for_policy_fn = lambda env, agent, o: o
93 | raw_obs = []
94 | raw_next_obs = []
95 | observations = []
96 | actions = []
97 | rewards = []
98 | terminals = []
99 | agent_infos = []
100 | env_infos = []
101 | next_observations = []
102 | addl_infos = []
103 | path_length = 0
104 | agent.reset()
105 | o = env.reset()
106 | if reset_callback:
107 | reset_callback(env, agent, o)
108 | if render:
109 | env.render(**render_kwargs)
110 | while path_length < max_path_length:
111 | raw_obs.append(o)
112 | o_for_agent = preprocess_obs_for_policy_fn(env, agent, o)
113 | a, agent_info = agent.get_action(o_for_agent, **get_action_kwargs)
114 |
115 | if full_o_postprocess_func:
116 | full_o_postprocess_func(env, agent, o)
117 |
118 | if addl_info_func:
119 | addl_infos.append(addl_info_func(env, agent, o, a))
120 |
121 | next_o, r, d, env_info = env.step(copy.deepcopy(a), image_obs_in_info=image_obs_in_info)
122 |
123 | new_path_length = path_length + env_info.get('num_ac_calls', 1)
124 |
125 | if new_path_length > max_path_length:
126 | break
127 | path_length = new_path_length
128 |
129 | if render:
130 | env.render(**render_kwargs)
131 | observations.append(o)
132 | rewards.append(r)
133 | if terminals_all_false:
134 | terminals.append(False)
135 | else:
136 | terminals.append(d)
137 | actions.append(a)
138 | next_observations.append(next_o)
139 | raw_next_obs.append(next_o)
140 | agent_infos.append(agent_info)
141 | env_infos.append(env_info)
142 | if d:
143 | break
144 | o = next_o
145 | actions = np.array(actions)
146 | if len(actions.shape) == 1:
147 | actions = np.expand_dims(actions, 1)
148 | observations = np.array(observations)
149 | next_observations = np.array(next_observations)
150 | if return_dict_obs:
151 | observations = raw_obs
152 | next_observations = raw_next_obs
153 | rewards = np.array(rewards)
154 | if len(rewards.shape) == 1:
155 | rewards = rewards.reshape(-1, 1)
156 |
157 | path_length_actions = np.sum(
158 | [info.get('num_ac_calls', 1) for info in env_infos]
159 | )
160 |
161 | reward_actions_sum = np.sum(
162 | [info.get('reward_actions', 0) for info in env_infos]
163 | )
164 |
165 | if last_step_is_terminal:
166 | terminals[-1] = True
167 |
168 | skill_names = []
169 | sc = env.env.skill_controller
170 | for i in range(len(actions)):
171 | ac = actions[i]
172 | skill_name = sc.get_skill_name_from_action(ac)
173 | skill_names.append(skill_name)
174 | success = env_infos[i].get('success', False)
175 | if success:
176 | break
177 |
178 | return dict(
179 | observations=observations,
180 | actions=actions,
181 | rewards=rewards,
182 | next_observations=next_observations,
183 | terminals=np.array(terminals).reshape(-1, 1),
184 | agent_infos=agent_infos,
185 | env_infos=env_infos,
186 | addl_infos=addl_infos,
187 | full_observations=raw_obs,
188 | full_next_observations=raw_obs,
189 | path_length=path_length,
190 | path_length_actions=path_length_actions,
191 | reward_actions_sum=reward_actions_sum,
192 | skill_names=skill_names,
193 | max_path_length=max_path_length,
194 | )
195 |
196 |
197 | def deprecated_rollout(
198 | env,
199 | agent,
200 | max_path_length=np.inf,
201 | render=False,
202 | render_kwargs=None,
203 | ):
204 | """
205 | The following value for the following keys will be a 2D array, with the
206 | first dimension corresponding to the time dimension.
207 | - observations
208 | - actions
209 | - rewards
210 | - next_observations
211 | - terminals
212 |
213 | The next two elements will be lists of dictionaries, with the index into
214 | the list being the index into the time
215 | - agent_infos
216 | - env_infos
217 | """
218 | if render_kwargs is None:
219 | render_kwargs = {}
220 | observations = []
221 | actions = []
222 | rewards = []
223 | terminals = []
224 | agent_infos = []
225 | env_infos = []
226 | o = env.reset()
227 | agent.reset()
228 | next_o = None
229 | path_length = 0
230 | if render:
231 | env.render(**render_kwargs)
232 | while path_length < max_path_length:
233 | a, agent_info = agent.get_action(o)
234 | next_o, r, d, env_info = env.step(a)
235 | observations.append(o)
236 | rewards.append(r)
237 | terminals.append(d)
238 | actions.append(a)
239 | agent_infos.append(agent_info)
240 | env_infos.append(env_info)
241 | path_length += 1
242 | if d:
243 | break
244 | o = next_o
245 | if render:
246 | env.render(**render_kwargs)
247 |
248 | actions = np.array(actions)
249 | if len(actions.shape) == 1:
250 | actions = np.expand_dims(actions, 1)
251 | observations = np.array(observations)
252 | if len(observations.shape) == 1:
253 | observations = np.expand_dims(observations, 1)
254 | next_o = np.array([next_o])
255 | next_observations = np.vstack(
256 | (
257 | observations[1:, :],
258 | np.expand_dims(next_o, 0)
259 | )
260 | )
261 | return dict(
262 | observations=observations,
263 | actions=actions,
264 | rewards=np.array(rewards).reshape(-1, 1),
265 | next_observations=next_observations,
266 | terminals=np.array(terminals).reshape(-1, 1),
267 | agent_infos=agent_infos,
268 | env_infos=env_infos,
269 | )
270 |
--------------------------------------------------------------------------------
/maple/samplers/util.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | def split_paths(paths):
4 | """
5 | Stack multiples obs/actions/etc. from different paths
6 | :param paths: List of paths, where one path is something returned from
7 | the rollout functino above.
8 | :return: Tuple. Every element will have shape batch_size X DIM, including
9 | the rewards and terminal flags.
10 | """
11 | rewards = [path["rewards"].reshape(-1, 1) for path in paths]
12 | terminals = [path["terminals"].reshape(-1, 1) for path in paths]
13 | actions = [path["actions"] for path in paths]
14 | obs = [path["observations"] for path in paths]
15 | next_obs = [path["next_observations"] for path in paths]
16 | rewards = np.vstack(rewards)
17 | terminals = np.vstack(terminals)
18 | obs = np.vstack(obs)
19 | actions = np.vstack(actions)
20 | next_obs = np.vstack(next_obs)
21 | assert len(rewards.shape) == 2
22 | assert len(terminals.shape) == 2
23 | assert len(obs.shape) == 2
24 | assert len(actions.shape) == 2
25 | assert len(next_obs.shape) == 2
26 | return rewards, terminals, obs, actions, next_obs
27 |
28 |
29 | def split_paths_to_dict(paths):
30 | rewards, terminals, obs, actions, next_obs = split_paths(paths)
31 | return dict(
32 | rewards=rewards,
33 | terminals=terminals,
34 | observations=obs,
35 | actions=actions,
36 | next_observations=next_obs,
37 | )
38 |
--------------------------------------------------------------------------------
/maple/torch/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/UT-Austin-RPL/maple/c13f584d4184e5796bd0adb822b8761ff4ebeee0/maple/torch/__init__.py
--------------------------------------------------------------------------------
/maple/torch/core.py:
--------------------------------------------------------------------------------
1 | import abc
2 |
3 | import numpy as np
4 | import torch
5 | from torch import nn as nn
6 |
7 | from maple.torch import pytorch_util as ptu
8 |
9 |
10 | class PyTorchModule(nn.Module, metaclass=abc.ABCMeta):
11 | """
12 | Keeping wrapper around to be a bit more future-proof.
13 | """
14 | pass
15 |
16 |
17 | def eval_np(module, *args, **kwargs):
18 | """
19 | Eval this module with a numpy interface
20 |
21 | Same as a call to __call__ except all Variable input/outputs are
22 | replaced with numpy equivalents.
23 |
24 | Assumes the output is either a single object or a tuple of objects.
25 | """
26 | torch_args = tuple(torch_ify(x) for x in args)
27 | torch_kwargs = {k: torch_ify(v) for k, v in kwargs.items()}
28 | outputs = module(*torch_args, **torch_kwargs)
29 | return elem_or_tuple_to_numpy(outputs)
30 |
31 |
32 | def torch_ify(np_array_or_other):
33 | if isinstance(np_array_or_other, np.ndarray):
34 | return ptu.from_numpy(np_array_or_other)
35 | else:
36 | return np_array_or_other
37 |
38 |
39 | def np_ify(tensor_or_other):
40 | if isinstance(tensor_or_other, torch.autograd.Variable):
41 | return ptu.get_numpy(tensor_or_other)
42 | else:
43 | return tensor_or_other
44 |
45 |
46 | def _elem_or_tuple_to_variable(elem_or_tuple):
47 | if isinstance(elem_or_tuple, tuple):
48 | return tuple(
49 | _elem_or_tuple_to_variable(e) for e in elem_or_tuple
50 | )
51 | return ptu.from_numpy(elem_or_tuple).float()
52 |
53 |
54 | def elem_or_tuple_to_numpy(elem_or_tuple):
55 | if isinstance(elem_or_tuple, tuple):
56 | return tuple(np_ify(x) for x in elem_or_tuple)
57 | else:
58 | return np_ify(elem_or_tuple)
59 |
60 |
61 | def _filter_batch(np_batch):
62 | for k, v in np_batch.items():
63 | if v.dtype == np.bool:
64 | yield k, v.astype(int)
65 | else:
66 | yield k, v
67 |
68 |
69 | def np_to_pytorch_batch(np_batch):
70 | if isinstance(np_batch, dict):
71 | return {
72 | k: _elem_or_tuple_to_variable(x)
73 | for k, x in _filter_batch(np_batch)
74 | if x.dtype != np.dtype('O') # ignore object (e.g. dictionaries)
75 | }
76 | else:
77 | _elem_or_tuple_to_variable(np_batch)
78 |
--------------------------------------------------------------------------------
/maple/torch/data.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | from torch.utils.data import Dataset, Sampler
4 |
5 | # TODO: move this to more reasonable place
6 | from maple.data_management.obs_dict_replay_buffer import normalize_image
7 |
8 |
9 | class ImageDataset(Dataset):
10 |
11 | def __init__(self, images, should_normalize=True):
12 | super().__init__()
13 | self.dataset = images
14 | self.dataset_len = len(self.dataset)
15 | assert should_normalize == (images.dtype == np.uint8)
16 | self.should_normalize = should_normalize
17 |
18 | def __len__(self):
19 | return self.dataset_len
20 |
21 | def __getitem__(self, idxs):
22 | samples = self.dataset[idxs, :]
23 | if self.should_normalize:
24 | samples = normalize_image(samples)
25 | return np.float32(samples)
26 |
27 |
28 | class InfiniteRandomSampler(Sampler):
29 |
30 | def __init__(self, data_source):
31 | self.data_source = data_source
32 | self.iter = iter(torch.randperm(len(self.data_source)).tolist())
33 |
34 | def __iter__(self):
35 | return self
36 |
37 | def __next__(self):
38 | try:
39 | idx = next(self.iter)
40 | except StopIteration:
41 | self.iter = iter(torch.randperm(len(self.data_source)).tolist())
42 | idx = next(self.iter)
43 | return idx
44 |
45 | def __len__(self):
46 | return 2 ** 62
47 |
48 |
49 | class InfiniteWeightedRandomSampler(Sampler):
50 |
51 | def __init__(self, data_source, weights):
52 | assert len(data_source) == len(weights)
53 | assert len(weights.shape) == 1
54 | self.data_source = data_source
55 | # Always use CPU
56 | self._weights = torch.from_numpy(weights)
57 | self.iter = self._create_iterator()
58 |
59 | def update_weights(self, weights):
60 | self._weights = weights
61 | self.iter = self._create_iterator()
62 |
63 | def _create_iterator(self):
64 | return iter(
65 | torch.multinomial(
66 | self._weights, len(self._weights), replacement=True
67 | ).tolist()
68 | )
69 |
70 | def __iter__(self):
71 | return self
72 |
73 | def __next__(self):
74 | try:
75 | idx = next(self.iter)
76 | except StopIteration:
77 | self.iter = self._create_iterator()
78 | idx = next(self.iter)
79 | return idx
80 |
81 | def __len__(self):
82 | return 2 ** 62
83 |
--------------------------------------------------------------------------------
/maple/torch/data_management/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/UT-Austin-RPL/maple/c13f584d4184e5796bd0adb822b8761ff4ebeee0/maple/torch/data_management/__init__.py
--------------------------------------------------------------------------------
/maple/torch/data_management/normalizer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import maple.torch.pytorch_util as ptu
3 | import numpy as np
4 |
5 | from maple.data_management.normalizer import Normalizer, FixedNormalizer
6 |
7 |
8 | class TorchNormalizer(Normalizer):
9 | """
10 | Update with np array, but de/normalize pytorch Tensors.
11 | """
12 | def normalize(self, v, clip_range=None):
13 | if not self.synchronized:
14 | self.synchronize()
15 | if clip_range is None:
16 | clip_range = self.default_clip_range
17 | mean = ptu.from_numpy(self.mean)
18 | std = ptu.from_numpy(self.std)
19 | if v.dim() == 2:
20 | # Unsqueeze along the batch use automatic broadcasting
21 | mean = mean.unsqueeze(0)
22 | std = std.unsqueeze(0)
23 | return torch.clamp((v - mean) / std, -clip_range, clip_range)
24 |
25 | def denormalize(self, v):
26 | if not self.synchronized:
27 | self.synchronize()
28 | mean = ptu.from_numpy(self.mean)
29 | std = ptu.from_numpy(self.std)
30 | if v.dim() == 2:
31 | mean = mean.unsqueeze(0)
32 | std = std.unsqueeze(0)
33 | return mean + v * std
34 |
35 |
36 | class TorchFixedNormalizer(FixedNormalizer):
37 | def normalize(self, v, clip_range=None):
38 | if clip_range is None:
39 | clip_range = self.default_clip_range
40 | mean = ptu.from_numpy(self.mean)
41 | std = ptu.from_numpy(self.std)
42 | if v.dim() == 2:
43 | # Unsqueeze along the batch use automatic broadcasting
44 | mean = mean.unsqueeze(0)
45 | std = std.unsqueeze(0)
46 | return torch.clamp((v - mean) / std, -clip_range, clip_range)
47 |
48 | def normalize_scale(self, v):
49 | """
50 | Only normalize the scale. Do not subtract the mean.
51 | """
52 | std = ptu.from_numpy(self.std)
53 | if v.dim() == 2:
54 | std = std.unsqueeze(0)
55 | return v / std
56 |
57 | def denormalize(self, v):
58 | mean = ptu.from_numpy(self.mean)
59 | std = ptu.from_numpy(self.std)
60 | if v.dim() == 2:
61 | mean = mean.unsqueeze(0)
62 | std = std.unsqueeze(0)
63 | return mean + v * std
64 |
65 | def denormalize_scale(self, v):
66 | """
67 | Only denormalize the scale. Do not add the mean.
68 | """
69 | std = ptu.from_numpy(self.std)
70 | if v.dim() == 2:
71 | std = std.unsqueeze(0)
72 | return v * std
73 |
--------------------------------------------------------------------------------
/maple/torch/lvm/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/UT-Austin-RPL/maple/c13f584d4184e5796bd0adb822b8761ff4ebeee0/maple/torch/lvm/__init__.py
--------------------------------------------------------------------------------
/maple/torch/lvm/bear_vae.py:
--------------------------------------------------------------------------------
1 | import abc
2 | import logging
3 |
4 | import numpy as np
5 | import torch
6 | import torch.nn.functional as F
7 | from torch import nn as nn
8 |
9 | import maple.torch.pytorch_util as ptu
10 | from maple.policies.base import ExplorationPolicy
11 | from maple.torch.core import torch_ify, elem_or_tuple_to_numpy
12 | from maple.torch.distributions import (
13 | Delta, TanhNormal, MultivariateDiagonalNormal, GaussianMixture, GaussianMixtureFull,
14 | )
15 | from maple.torch.networks import Mlp, CNN
16 | from maple.torch.networks.basic import MultiInputSequential
17 | from maple.torch.networks.stochastic.distribution_generator import (
18 | DistributionGenerator
19 | )
20 | from maple.torch.lvm.latent_variable_model import LatentVariableModel
21 |
22 |
23 | class VAEPolicy(LatentVariableModel):
24 | def __init__(
25 | self,
26 | obs_dim,
27 | action_dim,
28 | latent_dim,
29 | ):
30 | encoder = Encoder(obs_dim, latent_dim, action_dim)
31 | decoder = Decoder(obs_dim, latent_dim, action_dim)
32 | super().__init__(encoder, decoder)
33 |
34 | self.latent_dim = latent_dim
35 |
36 | def forward(self, state, action):
37 | z = F.relu(self.encoder.e1(torch.cat([state, action], 1)))
38 | z = F.relu(self.encoder.e2(z))
39 |
40 | mean = self.encoder.mean(z)
41 | # Clamped for numerical stability
42 | log_std = self.encoder.log_std(z).clamp(-4, 15)
43 | std = torch.exp(log_std)
44 | z = mean + std * ptu.from_numpy(
45 | np.random.normal(0, 1, size=(std.size())))
46 |
47 | u = self.decode(state, z)
48 |
49 | return u, mean, std
50 |
51 | def decode(self, state, z=None):
52 | if z is None:
53 | z = ptu.from_numpy(np.random.normal(0, 1, size=(
54 | state.size(0), self.latent_dim))).clamp(-0.5, 0.5)
55 |
56 | a = F.relu(self.decoder.d1(torch.cat([state, z], 1)))
57 | a = F.relu(self.decoder.d2(a))
58 | return torch.tanh(self.decoder.d3(a))
59 |
60 | def decode_multiple(self, state, z=None, num_decode=10):
61 | if z is None:
62 | z = ptu.from_numpy(np.random.normal(0, 1, size=(
63 | state.size(0), num_decode, self.latent_dim))).clamp(-0.5, 0.5)
64 |
65 | a = F.relu(self.decoder.d1(torch.cat(
66 | [state.unsqueeze(0).repeat(num_decode, 1, 1).permute(1, 0, 2), z],
67 | 2)))
68 | a = F.relu(self.decoder.d2(a))
69 | return torch.tanh(self.decoder.d3(a)), self.decoder.d3(a)
70 |
71 |
72 | class Encoder(nn.Module):
73 | def __init__(self, obs_dim, latent_dim, action_dim):
74 | super().__init__()
75 | self.latent_dim = latent_dim
76 |
77 | self.e1 = torch.nn.Linear(obs_dim + action_dim, 750)
78 | self.e2 = torch.nn.Linear(750, 750)
79 |
80 | self.mean = torch.nn.Linear(750, self.latent_dim)
81 | self.log_std = torch.nn.Linear(750, self.latent_dim)
82 |
83 | def forward(self, state, action):
84 | z = F.relu(self.e1(torch.cat([state, action], 1)))
85 | z = F.relu(self.e2(z))
86 |
87 | mean = self.mean(z)
88 | # Clamped for numerical stability
89 | log_std = self.log_std(z).clamp(-4, 15)
90 | std = torch.exp(log_std)
91 | return MultivariateDiagonalNormal(mean, std)
92 |
93 |
94 | class Decoder(nn.Module):
95 | def __init__(self, obs_dim, latent_dim, action_dim):
96 | super().__init__()
97 | self.latent_dim = latent_dim
98 |
99 | self.d1 = torch.nn.Linear(obs_dim + self.latent_dim, 750)
100 | self.d2 = torch.nn.Linear(750, 750)
101 | self.d3 = torch.nn.Linear(750, action_dim)
102 |
103 | def forward(self, state, z):
104 | a = F.relu(self.d1(torch.cat([state, z], 1)))
105 | a = F.relu(self.d2(a))
106 | return Delta(torch.tanh(self.d3(a)))
107 |
--------------------------------------------------------------------------------
/maple/torch/lvm/latent_variable_model.py:
--------------------------------------------------------------------------------
1 | import abc
2 | import logging
3 |
4 | import numpy as np
5 | import torch
6 | import torch.nn.functional as F
7 | from torch import nn as nn
8 |
9 | import maple.torch.pytorch_util as ptu
10 | from maple.policies.base import ExplorationPolicy
11 | from maple.torch.core import torch_ify, elem_or_tuple_to_numpy
12 | from maple.torch.distributions import (
13 | Delta, TanhNormal, MultivariateDiagonalNormal, GaussianMixture, GaussianMixtureFull,
14 | )
15 | from maple.torch.networks import Mlp, CNN
16 | from maple.torch.networks.basic import MultiInputSequential
17 | from maple.torch.networks.stochastic.distribution_generator import (
18 | DistributionGenerator
19 | )
20 | from maple.torch.sac.policies.base import (
21 | TorchStochasticPolicy,
22 | PolicyFromDistributionGenerator,
23 | MakeDeterministic,
24 | )
25 |
26 |
27 | class LatentVariableModel(nn.Module):
28 | def __init__(
29 | self,
30 | encoder,
31 | decoder,
32 | **kwargs
33 | ):
34 | super().__init__()
35 | self.encoder = encoder
36 | self.decoder = decoder
37 |
--------------------------------------------------------------------------------
/maple/torch/modules.py:
--------------------------------------------------------------------------------
1 | """
2 | Contain some self-contained modules.
3 | """
4 | import torch
5 | import torch.nn as nn
6 |
7 |
8 | class HuberLoss(nn.Module):
9 | def __init__(self, delta=1):
10 | super().__init__()
11 | self.huber_loss_delta1 = nn.SmoothL1Loss()
12 | self.delta = delta
13 |
14 | def forward(self, x, x_hat):
15 | loss = self.huber_loss_delta1(x / self.delta, x_hat / self.delta)
16 | return loss * self.delta * self.delta
17 |
18 |
19 | class LayerNorm(nn.Module):
20 | """
21 | Simple 1D LayerNorm.
22 | """
23 |
24 | def __init__(self, features, center=True, scale=False, eps=1e-6):
25 | super().__init__()
26 | self.center = center
27 | self.scale = scale
28 | self.eps = eps
29 | if self.scale:
30 | self.scale_param = nn.Parameter(torch.ones(features))
31 | else:
32 | self.scale_param = None
33 | if self.center:
34 | self.center_param = nn.Parameter(torch.zeros(features))
35 | else:
36 | self.center_param = None
37 |
38 | def forward(self, x):
39 | mean = x.mean(-1, keepdim=True)
40 | std = x.std(-1, keepdim=True)
41 | output = (x - mean) / (std + self.eps)
42 | if self.scale:
43 | output = output * self.scale_param
44 | if self.center:
45 | output = output + self.center_param
46 | return output
47 |
--------------------------------------------------------------------------------
/maple/torch/networks/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | General networks for pytorch.
3 |
4 | Algorithm-specific networks should go else-where.
5 | """
6 | from maple.torch.networks.basic import (
7 | Clamp, ConcatTuple, Detach, Flatten, FlattenEach, Split, Reshape,
8 | )
9 | from maple.torch.networks.cnn import BasicCNN, CNN, MergedCNN, CNNPolicy
10 | from maple.torch.networks.dcnn import DCNN, TwoHeadDCNN
11 | from maple.torch.networks.feat_point_mlp import FeatPointMlp
12 | from maple.torch.networks.image_state import ImageStatePolicy, ImageStateQ
13 | from maple.torch.networks.linear_transform import LinearTransform
14 | from maple.torch.networks.normalization import LayerNorm
15 | from maple.torch.networks.mlp import (
16 | Mlp, ConcatMlp, MlpPolicy, TanhMlpPolicy,
17 | MlpQf,
18 | MlpQfWithObsProcessor,
19 | ConcatMultiHeadedMlp,
20 | )
21 | from maple.torch.networks.pretrained_cnn import PretrainedCNN
22 | from maple.torch.networks.two_headed_mlp import TwoHeadMlp
23 |
24 | __all__ = [
25 | 'Clamp',
26 | 'ConcatMlp',
27 | 'ConcatMultiHeadedMlp',
28 | 'ConcatTuple',
29 | 'BasicCNN',
30 | 'CNN',
31 | 'CNNPolicy',
32 | 'DCNN',
33 | 'Detach',
34 | 'FeatPointMlp',
35 | 'Flatten',
36 | 'FlattenEach',
37 | 'LayerNorm',
38 | 'LinearTransform',
39 | 'ImageStatePolicy',
40 | 'ImageStateQ',
41 | 'MergedCNN',
42 | 'Mlp',
43 | 'PretrainedCNN',
44 | 'Reshape',
45 | 'Split',
46 | 'TwoHeadDCNN',
47 | 'TwoHeadMlp',
48 | ]
49 |
50 |
--------------------------------------------------------------------------------
/maple/torch/networks/basic.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 |
4 |
5 | class Clamp(nn.Module):
6 | def __init__(self, **kwargs):
7 | super().__init__()
8 | self.kwargs = kwargs
9 | self.__name__ = "Clamp"
10 |
11 | def forward(self, x):
12 | return torch.clamp(x, **self.kwargs)
13 |
14 |
15 | class Split(nn.Module):
16 | """
17 | Split input and process each chunk with a separate module.
18 | """
19 | def __init__(self, module1, module2, split_idx):
20 | super().__init__()
21 | self.module1 = module1
22 | self.module2 = module2
23 | self.split_idx = split_idx
24 |
25 | def forward(self, x):
26 | in1 = x[:, :self.split_idx]
27 | out1 = self.module1(in1)
28 |
29 | in2 = x[:, self.split_idx:]
30 | out2 = self.module2(in2)
31 |
32 | return out1, out2
33 |
34 |
35 | class FlattenEach(nn.Module):
36 | def forward(self, inputs):
37 | return tuple(x.view(x.size(0), -1) for x in inputs)
38 |
39 |
40 | class FlattenEachParallel(nn.Module):
41 | def forward(self, *inputs):
42 | return tuple(x.view(x.size(0), -1) for x in inputs)
43 |
44 |
45 | class Flatten(nn.Module):
46 | def forward(self, inputs):
47 | return inputs.view(inputs.size(0), -1)
48 |
49 |
50 | class Map(nn.Module):
51 | """Apply a module to each input."""
52 | def __init__(self, module):
53 | super().__init__()
54 | self.module = module
55 |
56 | def forward(self, inputs):
57 | return tuple(self.module(x) for x in inputs)
58 |
59 |
60 | class ApplyMany(nn.Module):
61 | """Apply many modules to one input."""
62 | def __init__(self, *modules):
63 | super().__init__()
64 | self.modules_to_apply = nn.ModuleList(modules)
65 |
66 | def forward(self, inputs):
67 | return tuple(m(inputs) for m in self.modules_to_apply)
68 |
69 |
70 | class LearnedPositiveConstant(nn.Module):
71 | def __init__(self, init_value):
72 | super().__init__()
73 | self._constant = nn.Parameter(init_value)
74 |
75 | def forward(self, _):
76 | return self._constant
77 |
78 |
79 | class Reshape(nn.Module):
80 | def __init__(self, *output_shape):
81 | super().__init__()
82 | self._output_shape_with_batch_size = (-1, *output_shape)
83 |
84 | def forward(self, inputs):
85 | return inputs.view(self._output_shape_with_batch_size)
86 |
87 |
88 | class ConcatTuple(nn.Module):
89 | def __init__(self, dim=1):
90 | super().__init__()
91 | self.dim = dim
92 |
93 | def forward(self, inputs):
94 | return torch.cat(inputs, dim=self.dim)
95 |
96 |
97 | class Concat(nn.Module):
98 | def __init__(self, dim=1):
99 | super().__init__()
100 | self.dim = dim
101 |
102 | def forward(self, *inputs):
103 | return torch.cat(inputs, dim=self.dim)
104 |
105 |
106 | class MultiInputSequential(nn.Sequential):
107 | def forward(self, *input):
108 | for module in self._modules.values():
109 | if isinstance(input, tuple):
110 | input = module(*input)
111 | else:
112 | input = module(input)
113 | return input
114 |
115 |
116 | class Detach(nn.Module):
117 | def __init__(self, wrapped_mlp):
118 | super().__init__()
119 | self.wrapped_mlp = wrapped_mlp
120 |
121 | def forward(self, inputs):
122 | return self.wrapped_mlp.forward(inputs).detach()
123 |
124 | def __getattr__(self, attr_name):
125 | try:
126 | return super().__getattr__(attr_name)
127 | except AttributeError:
128 | return getattr(self.wrapped_mlp, attr_name)
129 |
--------------------------------------------------------------------------------
/maple/torch/networks/custom.py:
--------------------------------------------------------------------------------
1 | """
2 | Random networks
3 | """
4 |
--------------------------------------------------------------------------------
/maple/torch/networks/feat_point_mlp.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | from torch import nn as nn
4 | from torch.nn import functional as F
5 |
6 | from maple.pythonplusplus import identity
7 | from maple.torch import pytorch_util as ptu
8 | from maple.torch.core import PyTorchModule
9 |
10 |
11 | class FeatPointMlp(PyTorchModule):
12 | def __init__(
13 | self,
14 | downsample_size,
15 | input_channels,
16 | num_feat_points,
17 | temperature=1.0,
18 | init_w=1e-3,
19 | input_size=32,
20 | hidden_init=ptu.fanin_init,
21 | output_activation=identity,
22 | ):
23 | super().__init__()
24 |
25 | self.downsample_size = downsample_size
26 | self.temperature = temperature
27 | self.num_feat_points = num_feat_points
28 | self.hidden_init = hidden_init
29 | self.output_activation = output_activation
30 | self.input_channels = input_channels
31 | self.input_size = input_size
32 |
33 | # self.bn1 = nn.BatchNorm2d(1)
34 | self.conv1 = nn.Conv2d(input_channels, 48, kernel_size=5, stride=2)
35 | # self.bn1 = nn.BatchNorm2d(16)
36 | self.conv2 = nn.Conv2d(48, 48, kernel_size=5, stride=1)
37 | self.conv3 = nn.Conv2d(48, self.num_feat_points, kernel_size=5, stride=1)
38 |
39 | test_mat = ptu.zeros(1, self.input_channels, self.input_size, self.input_size)
40 | test_mat = self.conv1(test_mat)
41 | test_mat = self.conv2(test_mat)
42 | test_mat = self.conv3(test_mat)
43 | self.out_size = int(np.prod(test_mat.shape))
44 | self.fc1 = nn.Linear(2 * self.num_feat_points, 400)
45 | self.fc2 = nn.Linear(400, 300)
46 | self.last_fc = nn.Linear(300, self.input_channels * self.downsample_size * self.downsample_size)
47 |
48 | self.init_weights(init_w)
49 | self.i = 0
50 |
51 | def init_weights(self, init_w):
52 | self.hidden_init(self.conv1.weight)
53 | self.conv1.bias.data.fill_(0)
54 | self.hidden_init(self.conv2.weight)
55 | self.conv2.bias.data.fill_(0)
56 |
57 | def forward(self, input):
58 | h = self.encoder(input)
59 | out = self.decoder(h)
60 | return out
61 |
62 | def encoder(self, input):
63 | x = input.contiguous().view(-1, self.input_channels, self.input_size, self.input_size)
64 | x = F.relu(self.conv1(x))
65 | x = F.relu(self.conv2(x))
66 | x = self.conv3(x)
67 | d = int((self.out_size // self.num_feat_points) ** (1 / 2))
68 | x = x.view(-1, self.num_feat_points, d * d)
69 | x = F.softmax(x / self.temperature, 2)
70 | x = x.view(-1, self.num_feat_points, d, d)
71 |
72 | maps_x = torch.sum(x, 2)
73 | maps_y = torch.sum(x, 3)
74 |
75 | weights = ptu.from_numpy(np.arange(d) / (d + 1))
76 |
77 | fp_x = torch.sum(maps_x * weights, 2)
78 | fp_y = torch.sum(maps_y * weights, 2)
79 |
80 | x = torch.cat([fp_x, fp_y], 1)
81 | # h = x.view(-1, 2, self.num_feat_points).transpose(1, 2).contiguous().view(-1, self.num_feat_points * 2)
82 | h = x.view(-1, self.num_feat_points * 2)
83 | return h
84 |
85 | def decoder(self, input):
86 | h = input
87 | h = F.relu(self.fc1(h))
88 | h = F.relu(self.fc2(h))
89 | h = self.last_fc(h)
90 | return h
91 |
92 | def history_encoder(self, input, history_length):
93 | input = input.contiguous().view(-1,
94 | self.input_channels,
95 | self.input_size,
96 | self.input_size)
97 | latent = self.encoder(input)
98 |
99 | assert latent.shape[0] % history_length == 0
100 | n_samples = latent.shape[0] // history_length
101 | latent = latent.view(n_samples, -1)
102 | return latent
103 |
104 |
--------------------------------------------------------------------------------
/maple/torch/networks/image_state.py:
--------------------------------------------------------------------------------
1 | from maple.policies.base import Policy
2 | from maple.torch.core import PyTorchModule, eval_np
3 |
4 |
5 | class ImageStatePolicy(PyTorchModule, Policy):
6 | """Switches between image or state inputs"""
7 |
8 | def __init__(
9 | self,
10 | image_conv_net,
11 | state_fc_net,
12 | ):
13 | super().__init__()
14 |
15 | assert image_conv_net is None or state_fc_net is None
16 | self.image_conv_net = image_conv_net
17 | self.state_fc_net = state_fc_net
18 |
19 | def forward(self, input, return_preactivations=False):
20 | if self.image_conv_net is not None:
21 | image = input[:, :21168]
22 | return self.image_conv_net(image)
23 | if self.state_fc_net is not None:
24 | state = input[:, 21168:]
25 | return self.state_fc_net(state)
26 |
27 | def get_action(self, obs_np):
28 | actions = self.get_actions(obs_np[None])
29 | return actions[0, :], {}
30 |
31 | def get_actions(self, obs):
32 | return eval_np(self, obs)
33 |
34 |
35 | class ImageStateQ(PyTorchModule):
36 | """Switches between image or state inputs"""
37 |
38 | def __init__(
39 | self,
40 | # obs_dim,
41 | # action_dim,
42 | # goal_dim,
43 | image_conv_net, # assumed to be a MergedCNN
44 | state_fc_net,
45 | ):
46 | super().__init__()
47 |
48 | assert image_conv_net is None or state_fc_net is None
49 | # self.obs_dim = obs_dim
50 | # self.action_dim = action_dim
51 | # self.goal_dim = goal_dim
52 | self.image_conv_net = image_conv_net
53 | self.state_fc_net = state_fc_net
54 |
55 | def forward(self, input, action, return_preactivations=False):
56 | if self.image_conv_net is not None:
57 | image = input[:, :21168]
58 | return self.image_conv_net(image, action)
59 | if self.state_fc_net is not None:
60 | state = input[:, 21168:] # action + state
61 | return self.state_fc_net(state, action)
62 |
63 |
64 |
--------------------------------------------------------------------------------
/maple/torch/networks/linear_transform.py:
--------------------------------------------------------------------------------
1 | from maple.torch.core import PyTorchModule
2 |
3 |
4 | class LinearTransform(PyTorchModule):
5 | def __init__(self, m, b):
6 | super().__init__()
7 | self.m = m
8 | self.b = b
9 |
10 | def __call__(self, t):
11 | return self.m * t + self.b
12 |
--------------------------------------------------------------------------------
/maple/torch/networks/normalization.py:
--------------------------------------------------------------------------------
1 | """
2 | Contain some self-contained modules. Maybe depend on pytorch_util.
3 | """
4 | import torch
5 | import torch.nn as nn
6 | from maple.torch import pytorch_util as ptu
7 |
8 |
9 | class LayerNorm(nn.Module):
10 | """
11 | Simple 1D LayerNorm.
12 | """
13 | def __init__(self, features, center=True, scale=False, eps=1e-6):
14 | super().__init__()
15 | self.center = center
16 | self.scale = scale
17 | self.eps = eps
18 | if self.scale:
19 | self.scale_param = nn.Parameter(torch.ones(features))
20 | else:
21 | self.scale_param = None
22 | if self.center:
23 | self.center_param = nn.Parameter(torch.zeros(features))
24 | else:
25 | self.center_param = None
26 |
27 | def forward(self, x):
28 | mean = x.mean(-1, keepdim=True)
29 | std = x.std(-1, keepdim=True)
30 | output = (x - mean) / (std + self.eps)
31 | if self.scale:
32 | output = output * self.scale_param
33 | if self.center:
34 | output = output + self.center_param
35 | return output
36 |
--------------------------------------------------------------------------------
/maple/torch/networks/pretrained_cnn.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torchvision.models as models
4 | from torch import nn as nn
5 |
6 | from maple.pythonplusplus import identity
7 | from maple.torch.core import PyTorchModule
8 |
9 |
10 | class PretrainedCNN(PyTorchModule):
11 | # Uses a pretrained CNN architecture from torchvision
12 | def __init__(
13 | self,
14 | input_width,
15 | input_height,
16 | input_channels,
17 | output_size,
18 | hidden_sizes=None,
19 | added_fc_input_size=0,
20 | batch_norm_fc=False,
21 | init_w=1e-4,
22 | hidden_init=nn.init.xavier_uniform_,
23 | hidden_activation=nn.ReLU(),
24 | output_activation=identity,
25 | output_conv_channels=False,
26 | model_architecture=models.resnet18,
27 | model_pretrained=True,
28 | model_freeze=False,
29 | ):
30 | if hidden_sizes is None:
31 | hidden_sizes = []
32 | super().__init__()
33 |
34 | self.hidden_sizes = hidden_sizes
35 | self.input_width = input_width
36 | self.input_height = input_height
37 | self.input_channels = input_channels
38 | self.output_size = output_size
39 | self.output_activation = output_activation
40 | self.hidden_activation = hidden_activation
41 | self.batch_norm_fc = batch_norm_fc
42 | self.added_fc_input_size = added_fc_input_size
43 | self.conv_input_length = self.input_width * self.input_height * self.input_channels
44 | self.output_conv_channels = output_conv_channels
45 |
46 | self.pretrained_model = nn.Sequential(*list(model_architecture(
47 | pretrained=model_pretrained).children())[:-1])
48 | if model_freeze:
49 | for child in self.pretrained_model.children():
50 | for param in child.parameters():
51 | param.requires_grad = False
52 | self.fc_layers = nn.ModuleList()
53 | self.fc_norm_layers = nn.ModuleList()
54 |
55 | # use torch rather than ptu because initially the model is on CPU
56 | test_mat = torch.zeros(
57 | 1,
58 | self.input_channels,
59 | self.input_width,
60 | self.input_height,
61 | )
62 | # find output dim of conv_layers by trial and add norm conv layers
63 | test_mat = self.pretrained_model(test_mat)
64 |
65 | self.conv_output_flat_size = int(np.prod(test_mat.shape))
66 | if self.output_conv_channels:
67 | self.last_fc = None
68 | else:
69 | fc_input_size = self.conv_output_flat_size
70 | # used only for injecting input directly into fc layers
71 | fc_input_size += added_fc_input_size
72 | for idx, hidden_size in enumerate(hidden_sizes):
73 | fc_layer = nn.Linear(fc_input_size, hidden_size)
74 | fc_input_size = hidden_size
75 |
76 | fc_layer.weight.data.uniform_(-init_w, init_w)
77 | fc_layer.bias.data.uniform_(-init_w, init_w)
78 |
79 | self.fc_layers.append(fc_layer)
80 |
81 | if self.batch_norm_fc:
82 | norm_layer = nn.BatchNorm1d(hidden_size)
83 | self.fc_norm_layers.append(norm_layer)
84 |
85 | self.last_fc = nn.Linear(fc_input_size, output_size)
86 | self.last_fc.weight.data.uniform_(-init_w, init_w)
87 | self.last_fc.bias.data.uniform_(-init_w, init_w)
88 |
89 | def forward(self, input, return_last_activations=False):
90 | conv_input = input.narrow(start=0,
91 | length=self.conv_input_length,
92 | dim=1).contiguous()
93 | # reshape from batch of flattened images into (channels, w, h)
94 | h = conv_input.view(conv_input.shape[0],
95 | self.input_channels,
96 | self.input_height,
97 | self.input_width)
98 |
99 | h = self.apply_forward_conv(h)
100 |
101 | if self.output_conv_channels:
102 | return h
103 |
104 | # flatten channels for fc layers
105 | h = h.view(h.size(0), -1)
106 | if self.added_fc_input_size != 0:
107 | extra_fc_input = input.narrow(
108 | start=self.conv_input_length,
109 | length=self.added_fc_input_size,
110 | dim=1,
111 | )
112 | h = torch.cat((h, extra_fc_input), dim=1)
113 | h = self.apply_forward_fc(h)
114 |
115 | if return_last_activations:
116 | return h
117 | return self.output_activation(self.last_fc(h))
118 |
119 | def apply_forward_conv(self, h):
120 | return self.pretrained_model(h)
121 |
122 | def apply_forward_fc(self, h):
123 | for i, layer in enumerate(self.fc_layers):
124 | h = layer(h)
125 | if self.batch_norm_fc:
126 | h = self.fc_norm_layers[i](h)
127 | h = self.hidden_activation(h)
128 | return h
129 |
130 |
131 |
--------------------------------------------------------------------------------
/maple/torch/networks/stochastic/distribution_generator.py:
--------------------------------------------------------------------------------
1 | import abc
2 |
3 | from torch import nn
4 |
5 | from maple.torch.distributions import (
6 | Bernoulli,
7 | Beta,
8 | Distribution,
9 | Independent,
10 | GaussianMixture as GaussianMixtureDistribution,
11 | GaussianMixtureFull as GaussianMixtureFullDistribution,
12 | MultivariateDiagonalNormal,
13 | TanhNormal,
14 | )
15 | from maple.torch.networks.basic import MultiInputSequential
16 |
17 |
18 | class DistributionGenerator(nn.Module, metaclass=abc.ABCMeta):
19 | def forward(self, *input, **kwarg) -> Distribution:
20 | raise NotImplementedError
21 |
22 |
23 | class ModuleToDistributionGenerator(
24 | MultiInputSequential,
25 | DistributionGenerator,
26 | metaclass=abc.ABCMeta
27 | ):
28 | pass
29 |
30 |
31 | class Beta(ModuleToDistributionGenerator):
32 | def forward(self, *input):
33 | alpha, beta = super().forward(*input)
34 | return Beta(alpha, beta)
35 |
36 |
37 | class Gaussian(ModuleToDistributionGenerator):
38 | def __init__(self, module, std=None, reinterpreted_batch_ndims=1):
39 | super().__init__(module)
40 | self.std = std
41 | self.reinterpreted_batch_ndims = reinterpreted_batch_ndims
42 |
43 | def forward(self, *input):
44 | if self.std:
45 | mean = super().forward(*input)
46 | std = self.std
47 | else:
48 | mean, log_std = super().forward(*input)
49 | std = log_std.exp()
50 | return MultivariateDiagonalNormal(
51 | mean, std, reinterpreted_batch_ndims=self.reinterpreted_batch_ndims)
52 |
53 |
54 | class BernoulliGenerator(ModuleToDistributionGenerator):
55 | def forward(self, *input):
56 | probs = super().forward(*input)
57 | return Bernoulli(probs)
58 |
59 |
60 | class IndependentGenerator(ModuleToDistributionGenerator):
61 | def __init__(self, *args, reinterpreted_batch_ndims=1):
62 | super().__init__(*args)
63 | self.reinterpreted_batch_ndims = reinterpreted_batch_ndims
64 |
65 | def forward(self, *input):
66 | distribution = super().forward(*input)
67 | return Independent(
68 | distribution,
69 | reinterpreted_batch_ndims=self.reinterpreted_batch_ndims,
70 | )
71 |
72 |
73 | class GaussianMixture(ModuleToDistributionGenerator):
74 | def forward(self, *input):
75 | mixture_means, mixture_stds, weights = super().forward(*input)
76 | return GaussianMixtureDistribution(mixture_means, mixture_stds, weights)
77 |
78 |
79 | class GaussianMixtureFull(ModuleToDistributionGenerator):
80 | def forward(self, *input):
81 | mixture_means, mixture_stds, weights = super().forward(*input)
82 | return GaussianMixtureFullDistribution(mixture_means, mixture_stds, weights)
83 |
84 |
85 | class TanhGaussian(ModuleToDistributionGenerator):
86 | def forward(self, *input):
87 | mean, log_std = super().forward(*input)
88 | std = log_std.exp()
89 | return TanhNormal(mean, std)
90 |
--------------------------------------------------------------------------------
/maple/torch/networks/two_headed_mlp.py:
--------------------------------------------------------------------------------
1 | from torch import nn as nn
2 | from torch.nn import functional as F
3 |
4 | from maple.pythonplusplus import identity
5 | from maple.torch import pytorch_util as ptu
6 | from maple.torch.core import PyTorchModule
7 | from maple.torch.networks import LayerNorm
8 |
9 |
10 | class TwoHeadMlp(PyTorchModule):
11 | def __init__(
12 | self,
13 | hidden_sizes,
14 | first_head_size,
15 | second_head_size,
16 | input_size,
17 | init_w=3e-3,
18 | hidden_activation=F.relu,
19 | output_activation=identity,
20 | hidden_init=ptu.fanin_init,
21 | b_init_value=0.,
22 | layer_norm=False,
23 | layer_norm_kwargs=None,
24 | ):
25 | super().__init__()
26 |
27 | if layer_norm_kwargs is None:
28 | layer_norm_kwargs = dict()
29 |
30 | self.input_size = input_size
31 | self.first_head_size = first_head_size
32 | self.second_head_size = second_head_size
33 | self.hidden_activation = hidden_activation
34 | self.output_activation = output_activation
35 | self.layer_norm = layer_norm
36 | self.fcs = []
37 | self.layer_norms = []
38 | in_size = input_size
39 |
40 | for i, next_size in enumerate(hidden_sizes):
41 | fc = nn.Linear(in_size, next_size)
42 | in_size = next_size
43 | hidden_init(fc.weight)
44 | fc.bias.data.fill_(b_init_value)
45 | self.__setattr__("fc{}".format(i), fc)
46 | self.fcs.append(fc)
47 |
48 | if self.layer_norm:
49 | ln = LayerNorm(next_size)
50 | self.__setattr__("layer_norm{}".format(i), ln)
51 | self.layer_norms.append(ln)
52 |
53 | self.first_head = nn.Linear(in_size, self.first_head_size)
54 | self.first_head.weight.data.uniform_(-init_w, init_w)
55 | self.first_head.bias.data.fill_(0)
56 |
57 | self.second_head = nn.Linear(in_size, self.second_head_size)
58 | self.second_head.weight.data.uniform_(-init_w, init_w)
59 | self.second_head.bias.data.fill_(0)
60 |
61 | def forward(self, input, return_preactivations=False):
62 | h = input
63 | for i, fc in enumerate(self.fcs):
64 | h = fc(h)
65 | if self.layer_norm and i < len(self.fcs) - 1:
66 | h = self.layer_norms[i](h)
67 | h = self.hidden_activation(h)
68 | preactivation = self.first_head(h)
69 | first_output = self.output_activation(preactivation)
70 | preactivation = self.second_head(h)
71 | second_output = self.output_activation(preactivation)
72 |
73 | return first_output, second_output
74 |
--------------------------------------------------------------------------------
/maple/torch/sac/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/UT-Austin-RPL/maple/c13f584d4184e5796bd0adb822b8761ff4ebeee0/maple/torch/sac/__init__.py
--------------------------------------------------------------------------------
/maple/torch/sac/policies/__init__.py:
--------------------------------------------------------------------------------
1 | from maple.torch.sac.policies.base import (
2 | TorchStochasticPolicy,
3 | PolicyFromDistributionGenerator,
4 | MakeDeterministic,
5 | )
6 | from maple.torch.sac.policies.gaussian_policy import (
7 | TanhGaussianPolicyAdapter,
8 | TanhGaussianPolicy,
9 | GaussianPolicy,
10 | GaussianCNNPolicy,
11 | GaussianMixturePolicy,
12 | BinnedGMMPolicy,
13 | TanhGaussianObsProcessorPolicy,
14 | TanhCNNGaussianPolicy,
15 | )
16 | from maple.torch.sac.policies.lvm_policy import LVMPolicy
17 | from maple.torch.sac.policies.policy_from_q import PolicyFromQ
18 | from maple.torch.sac.policies.pamdp_policy import PAMDPPolicy
19 |
20 | __all__ = [
21 | 'TorchStochasticPolicy',
22 | 'PolicyFromDistributionGenerator',
23 | 'MakeDeterministic',
24 | 'TanhGaussianPolicyAdapter',
25 | 'TanhGaussianPolicy',
26 | 'GaussianPolicy',
27 | 'GaussianCNNPolicy',
28 | 'GaussianMixturePolicy',
29 | 'BinnedGMMPolicy',
30 | 'TanhGaussianObsProcessorPolicy',
31 | 'TanhCNNGaussianPolicy',
32 | 'LVMPolicy',
33 | 'PolicyFromQ',
34 | 'PAMDPPolicy',
35 | ]
36 |
--------------------------------------------------------------------------------
/maple/torch/sac/policies/base.py:
--------------------------------------------------------------------------------
1 | import abc
2 | import logging
3 |
4 | import numpy as np
5 | import torch
6 | import torch.nn.functional as F
7 | from torch import nn as nn
8 |
9 | import maple.torch.pytorch_util as ptu
10 | from maple.policies.base import ExplorationPolicy
11 | from maple.torch.core import torch_ify, elem_or_tuple_to_numpy
12 | from maple.torch.distributions import (
13 | Delta, TanhNormal, MultivariateDiagonalNormal, GaussianMixture, GaussianMixtureFull,
14 | )
15 | from maple.torch.networks import Mlp, CNN
16 | from maple.torch.networks.basic import MultiInputSequential
17 | from maple.torch.networks.stochastic.distribution_generator import (
18 | DistributionGenerator
19 | )
20 |
21 |
22 | class TorchStochasticPolicy(
23 | DistributionGenerator,
24 | ExplorationPolicy, metaclass=abc.ABCMeta
25 | ):
26 | def get_action(self, obs_np, return_dist=False):
27 | info = {}
28 | if return_dist:
29 | actions, dist = self.get_actions(obs_np[None], return_dist=return_dist)
30 | info['dist'] = dist
31 | else:
32 | actions = self.get_actions(obs_np[None], return_dist=return_dist)
33 | return actions[0, :], info
34 |
35 | def get_actions(self, obs_np, return_dist=False):
36 | dist = self._get_dist_from_np(obs_np)
37 | actions = dist.sample()
38 | if return_dist:
39 | return elem_or_tuple_to_numpy(actions), dist
40 | else:
41 | return elem_or_tuple_to_numpy(actions)
42 |
43 | def _get_dist_from_np(self, *args, **kwargs):
44 | torch_args = tuple(torch_ify(x) for x in args)
45 | torch_kwargs = {k: torch_ify(v) for k, v in kwargs.items()}
46 | dist = self(*torch_args, **torch_kwargs)
47 | return dist
48 |
49 |
50 | class PolicyFromDistributionGenerator(
51 | MultiInputSequential,
52 | TorchStochasticPolicy,
53 | ):
54 | """
55 | Usage:
56 | ```
57 | distribution_generator = FancyGenerativeModel()
58 | policy = PolicyFromBatchDistributionModule(distribution_generator)
59 | ```
60 | """
61 | pass
62 |
63 |
64 | class MakeDeterministic(TorchStochasticPolicy):
65 | def __init__(
66 | self,
67 | action_distribution_generator: DistributionGenerator,
68 | ):
69 | super().__init__()
70 | self._action_distribution_generator = action_distribution_generator
71 |
72 | def forward(self, *args, **kwargs):
73 | dist = self._action_distribution_generator.forward(*args, **kwargs)
74 | return Delta(dist.mle_estimate())
75 |
--------------------------------------------------------------------------------
/maple/torch/sac/policies/lvm_policy.py:
--------------------------------------------------------------------------------
1 | from maple.torch.networks.stochastic.distribution_generator import (
2 | DistributionGenerator
3 | )
4 | from maple.torch.sac.policies.base import (
5 | TorchStochasticPolicy,
6 | PolicyFromDistributionGenerator,
7 | MakeDeterministic,
8 | )
9 |
10 | from maple.torch.lvm.latent_variable_model import LatentVariableModel
11 |
12 |
13 | class LVMPolicy(LatentVariableModel, TorchStochasticPolicy):
14 | """Expects encoder p(z|s) and decoder p(u|s,z)"""
15 |
16 | def forward(self, obs):
17 | z_dist = self.encoder(obs)
18 | z = z_dist.sample()
19 | return self.decoder(obs, z)
20 |
--------------------------------------------------------------------------------
/maple/torch/sac/policies/pamdp_policy.py:
--------------------------------------------------------------------------------
1 | from maple.torch.sac.policies.base import TorchStochasticPolicy
2 | from maple.torch.core import PyTorchModule
3 | from maple.torch.networks import Mlp
4 |
5 | from maple.torch.distributions import (
6 | Softmax, TanhNormal,
7 | ConcatDistribution,
8 | HierarchicalDistribution,
9 | )
10 |
11 | import torch
12 | from torch import nn as nn
13 |
14 | LOGITS_SCALE = 10
15 | LOG_SIG_MAX = 2
16 | LOG_SIG_MIN = -20
17 |
18 | class PAMDPPolicy(PyTorchModule, TorchStochasticPolicy):
19 | def __init__(
20 | self,
21 | hidden_sizes,
22 | obs_dim,
23 | action_dim_s,
24 | action_dim_p,
25 | one_hot_s,
26 | ):
27 | super().__init__()
28 |
29 | task_policy = CategoricalPolicy(
30 | obs_dim=obs_dim,
31 | hidden_sizes=hidden_sizes,
32 | action_dim=action_dim_s,
33 | one_hot=one_hot_s,
34 | prefix='task',
35 | )
36 |
37 | param_policy = ParallelHybridPolicy(
38 | obs_dim=obs_dim,
39 | hidden_sizes=hidden_sizes,
40 | num_networks=action_dim_s,
41 | action_dim_c=action_dim_p,
42 | prefix_c='param',
43 | )
44 |
45 | self.policy = HierarchicalPolicy(
46 | task_policy, param_policy,
47 | policy1_obs_dim=obs_dim
48 | )
49 |
50 | self.obs_dim = obs_dim
51 | self.action_dim_s = action_dim_s
52 | self.action_dim_p = action_dim_p
53 |
54 | self.one_hot_s = one_hot_s
55 |
56 | def forward(self, obs):
57 | return self.policy(obs)
58 |
59 | class ConcatPolicy(PyTorchModule, TorchStochasticPolicy):
60 | def __init__(
61 | self,
62 | policy1,
63 | policy2,
64 | policy1_obs_dim
65 | ):
66 | super().__init__()
67 |
68 | self.policy1 = policy1
69 | self.policy2 = policy2
70 |
71 | self.policy1_obs_dim = policy1_obs_dim
72 |
73 | def forward(self, obs):
74 | return ConcatDistribution(
75 | distr1=self.policy1(obs[:,-self.policy1_obs_dim:]),
76 | distr2=self.policy2(obs),
77 | )
78 |
79 | class HierarchicalPolicy(PyTorchModule, TorchStochasticPolicy):
80 | def __init__(self, policy1, policy2, policy1_obs_dim):
81 | super().__init__()
82 |
83 | self.policy1 = policy1
84 | self.policy2 = policy2
85 |
86 | self.policy1_obs_dim = policy1_obs_dim
87 |
88 | def forward(self, obs):
89 | assert obs.dim() == 2
90 |
91 | def distr2_cond_fn(inputs):
92 | obs_for_p = obs
93 | if inputs.dim() == 3:
94 | tile_dim = inputs.shape[1]
95 | obs_for_p = obs.unsqueeze(1).repeat((1, tile_dim, 1))
96 | if isinstance(self.policy2, ParallelHybridPolicy):
97 | id = torch.argmax(inputs, dim=-1)
98 | return self.policy2(obs_for_p, id)
99 | else:
100 | return self.policy2(torch.cat([obs_for_p, inputs], dim=-1))
101 |
102 | return HierarchicalDistribution(
103 | distr1=self.policy1(obs[:,-self.policy1_obs_dim:]),
104 | distr2_cond_fn=distr2_cond_fn,
105 | )
106 |
107 | class CategoricalPolicy(Mlp, TorchStochasticPolicy):
108 | def __init__(
109 | self,
110 | hidden_sizes,
111 | obs_dim,
112 | action_dim,
113 | prefix='',
114 | init_w=1e-3,
115 | one_hot=False,
116 | **kwargs
117 | ):
118 | super().__init__(
119 | hidden_sizes,
120 | input_size=obs_dim,
121 | output_size=action_dim,
122 | init_w=init_w,
123 | **kwargs
124 | )
125 |
126 | self.prefix = prefix
127 | self.one_hot = one_hot
128 |
129 | def forward(self, obs):
130 | h = obs
131 | for i, fc in enumerate(self.fcs):
132 | h = self.hidden_activation(fc(h))
133 | logits = self.last_fc(h)
134 | logits = torch.clamp(logits, -LOGITS_SCALE, LOGITS_SCALE)
135 | return Softmax(logits, one_hot=self.one_hot, prefix=self.prefix)
136 |
137 |
138 | class ParallelHybridPolicy(PyTorchModule, TorchStochasticPolicy):
139 | """
140 | Usage:
141 |
142 | ```
143 | policy = ParallelHybridPolicy(...)
144 | """
145 |
146 | def __init__(
147 | self,
148 | hidden_sizes,
149 | obs_dim,
150 | num_networks,
151 | action_dim_c,
152 | prefix_c='',
153 | init_w=1e-3,
154 | ):
155 | super().__init__()
156 |
157 | self.log_std = None
158 | self.num_networks = num_networks
159 |
160 | mlp_list = []
161 | output_size = 2*action_dim_c
162 | for i in range(num_networks):
163 | mlp = Mlp(
164 | hidden_sizes,
165 | input_size=obs_dim,
166 | output_size=output_size,
167 | init_w=init_w,
168 | )
169 | mlp_list.append(mlp)
170 | self.mlp_list = nn.ModuleList(mlp_list)
171 |
172 | self.action_dim_c = action_dim_c
173 | self.prefix_c = prefix_c
174 |
175 | def forward(self, obs, id):
176 | if torch.numel(id) == 1:
177 | mean, std = self.get_mean_std(obs, id)
178 | else:
179 | input_dims = id.shape
180 | obs = obs.reshape((-1, obs.shape[-1]))
181 | id = id.reshape(-1)
182 |
183 | means = []
184 | stds = []
185 | for i in range(self.num_networks):
186 | mean, std = self.get_mean_std(obs, i)
187 | means.append(mean)
188 | stds.append(std)
189 |
190 | means = torch.stack(means, dim=1)
191 | stds = torch.stack(stds, dim=1)
192 |
193 | mean = means[torch.arange(obs.size(0)), id]
194 | std = stds[torch.arange(obs.size(0)), id]
195 |
196 | ### reshape to original input dims
197 | mean = mean.reshape((*input_dims, -1))
198 | std = std.reshape((*input_dims, -1))
199 |
200 | distr_c = TanhNormal(mean, std, prefix=self.prefix_c)
201 | return distr_c
202 |
203 | def get_mean_std(self, obs, id):
204 | c_dim = self.action_dim_c
205 | nn_output = self.mlp_list[id](obs)
206 | mean = nn_output[..., -2 * c_dim:-c_dim]
207 | log_std = nn_output[..., -c_dim:]
208 |
209 | log_std = torch.clamp(log_std, LOG_SIG_MIN, LOG_SIG_MAX)
210 | std = torch.exp(log_std)
211 |
212 | return mean, std
--------------------------------------------------------------------------------
/maple/torch/sac/policies/policy_from_q.py:
--------------------------------------------------------------------------------
1 | import abc
2 | import logging
3 |
4 | import numpy as np
5 | import torch
6 | import torch.nn.functional as F
7 | from torch import nn as nn
8 |
9 | import maple.torch.pytorch_util as ptu
10 | from maple.policies.base import ExplorationPolicy
11 | from maple.torch.core import torch_ify, elem_or_tuple_to_numpy
12 | from maple.torch.distributions import (
13 | Delta, TanhNormal, MultivariateDiagonalNormal, GaussianMixture, GaussianMixtureFull,
14 | )
15 | from maple.torch.networks import Mlp, CNN
16 | from maple.torch.networks.basic import MultiInputSequential
17 | from maple.torch.networks.stochastic.distribution_generator import (
18 | DistributionGenerator
19 | )
20 | from maple.torch.sac.policies.base import (
21 | TorchStochasticPolicy,
22 | PolicyFromDistributionGenerator,
23 | MakeDeterministic,
24 | )
25 |
26 |
27 | class PolicyFromQ(TorchStochasticPolicy):
28 | def __init__(
29 | self,
30 | qf,
31 | policy,
32 | num_samples=10,
33 | **kwargs
34 | ):
35 | super().__init__()
36 | self.qf = qf
37 | self.policy = policy
38 | self.num_samples = num_samples
39 |
40 | def forward(self, obs):
41 | with torch.no_grad():
42 | state = obs.repeat(self.num_samples, 1)
43 | action = self.policy(state).sample()
44 | q_values = self.qf(state, action)
45 | ind = q_values.max(0)[1]
46 | return Delta(action[ind])
47 |
--------------------------------------------------------------------------------
/maple/torch/torch_rl_algorithm.py:
--------------------------------------------------------------------------------
1 | import abc
2 | from collections import OrderedDict
3 |
4 | from typing import Iterable
5 | from torch import nn as nn
6 |
7 | from maple.core.batch_rl_algorithm import BatchRLAlgorithm
8 | from maple.core.online_rl_algorithm import OnlineRLAlgorithm
9 | from maple.core.trainer import Trainer
10 | from maple.torch.core import np_to_pytorch_batch
11 |
12 |
13 | class TorchOnlineRLAlgorithm(OnlineRLAlgorithm):
14 | def to(self, device):
15 | for net in self.trainer.networks:
16 | net.to(device)
17 |
18 | def training_mode(self, mode):
19 | for net in self.trainer.networks:
20 | net.train(mode)
21 |
22 |
23 | class TorchBatchRLAlgorithm(BatchRLAlgorithm):
24 | def to(self, device):
25 | for net in self.trainer.networks:
26 | net.to(device)
27 |
28 | def training_mode(self, mode):
29 | for net in self.trainer.networks:
30 | net.train(mode)
31 |
32 |
33 | class TorchTrainer(Trainer, metaclass=abc.ABCMeta):
34 | def __init__(self):
35 | self._num_train_steps = 0
36 |
37 | def train(self, np_batch):
38 | self._num_train_steps += 1
39 | batch = np_to_pytorch_batch(np_batch)
40 | self.train_from_torch(batch)
41 |
42 | def get_diagnostics(self):
43 | return OrderedDict([
44 | ('num train calls', self._num_train_steps),
45 | ])
46 |
47 | @abc.abstractmethod
48 | def train_from_torch(self, batch):
49 | pass
50 |
51 | @property
52 | @abc.abstractmethod
53 | def networks(self) -> Iterable[nn.Module]:
54 | pass
55 |
--------------------------------------------------------------------------------
/maple/util/data_processing.py:
--------------------------------------------------------------------------------
1 | """
2 | Utility functions for writing and loading data.
3 | """
4 | import json
5 | import numpy as np
6 | import os
7 | import os.path as osp
8 | from collections import defaultdict, namedtuple
9 |
10 | from maple.pythonplusplus import nested_dict_to_dot_map_dict
11 |
12 |
13 | Trial = namedtuple("Trial", ["data", "variant", "directory"])
14 |
15 |
16 | def matches_dict(criteria_dict, test_dict):
17 | for k, v in criteria_dict.items():
18 | if k not in test_dict:
19 | return False
20 | else:
21 | if test_dict[k] != v:
22 | return False
23 | return True
24 |
25 |
26 | class Experiment(object):
27 | """
28 | Represents an experiment, which consists of many Trials.
29 | """
30 | def __init__(self, base_dir, criteria=None):
31 | """
32 | :param base_dir: A path. Directory structure should be something like:
33 | ```
34 | base_dir/
35 | foo/
36 | bar/
37 | arbtrarily_deep/
38 | trial_one/
39 | variant.json
40 | progress.csv
41 | trial_two/
42 | variant.json
43 | progress.csv
44 | trial_three/
45 | variant.json
46 | progress.csv
47 | ...
48 | variant.json # <-- base_dir/foo/bar has its own Trial
49 | progress.csv
50 | variant.json # <-- base_dir/foo has its own Trial
51 | progress.csv
52 | variant.json # <-- base_dir has its own Trial
53 | progress.csv
54 | ```
55 |
56 | The important thing is that `variant.json` and `progress.csv` are
57 | in the same sub-directory for each Trial.
58 | :param criteria: A dictionary of allowable values for the given keys.
59 | """
60 | if criteria is None:
61 | criteria = {}
62 | self.trials = get_trials(base_dir, criteria=criteria)
63 | assert len(self.trials) > 0, "Nothing loaded."
64 | self.label = 'AverageReturn'
65 |
66 | def get_trials(self, criteria=None):
67 | """
68 | Return a list of Trials that match a criteria.
69 | :param criteria: A dictionary from key to value that must be matches
70 | in the trial's variant. e.g.
71 | ```
72 | >>> print(exp.trials)
73 | [
74 | (X, {'a': True, ...})
75 | (Y, {'a': False, ...})
76 | (Z, {'a': True, ...})
77 | ]
78 | >>> print(exp.get_trials({'a': True}))
79 | [
80 | (X, {'a': True, ...})
81 | (Z, {'a': True, ...})
82 | ]
83 | ```
84 | If a trial does not have the key, the trial is filtered out.
85 | :return:
86 | """
87 | if criteria is None:
88 | criteria = {}
89 | return [trial for trial in self.trials
90 | if matches_dict(criteria, trial.variant)]
91 |
92 |
93 | def get_dirs(root):
94 | """
95 | Get a list of all the directories under this directory.
96 | """
97 | yield root
98 | for root, directories, filenames in os.walk(root):
99 | for directory in directories:
100 | yield os.path.join(root, directory)
101 |
102 |
103 | def get_trials(base_dir, verbose=False, criteria=None, excluded_seeds=()):
104 | """
105 | Get a list of (data, variant, directory) tuples, loaded from
106 | - process.csv
107 | - variant.json
108 | files under this directory.
109 | :param base_dir: root directory
110 | :param criteria: dictionary of keys and values. Only load experiemnts
111 | that match this criteria.
112 | :return: List of tuples. Each tuple has:
113 | 1. Progress data (nd.array)
114 | 2. Variant dictionary
115 | """
116 | if criteria is None:
117 | criteria = {}
118 |
119 | trials = []
120 | # delimiter = ','
121 | delimiter = ','
122 | for dir_name in get_dirs(base_dir):
123 | variant_file_name = osp.join(dir_name, 'variant.json')
124 | if not os.path.exists(variant_file_name):
125 | continue
126 |
127 | with open(variant_file_name) as variant_file:
128 | variant = json.load(variant_file)
129 | variant = nested_dict_to_dot_map_dict(variant)
130 |
131 | if 'seed' in variant and int(variant['seed']) in excluded_seeds:
132 | continue
133 |
134 | if not matches_dict(criteria, variant):
135 | continue
136 |
137 | data_file_name = osp.join(dir_name, 'progress.csv')
138 | # Hack for iclr 2018 deadline
139 | if not os.path.exists(data_file_name) or os.stat(
140 | data_file_name).st_size == 0:
141 | data_file_name = osp.join(dir_name, 'log.txt')
142 | if not os.path.exists(data_file_name):
143 | continue
144 | delimiter = '\t'
145 | if verbose:
146 | print("Reading {}".format(data_file_name))
147 | num_lines = sum(1 for _ in open(data_file_name))
148 | if num_lines < 2:
149 | continue
150 | # print(delimiter)
151 | data = np.genfromtxt(
152 | data_file_name,
153 | delimiter=delimiter,
154 | dtype=None,
155 | names=True,
156 | )
157 | trials.append(Trial(data, variant, dir_name))
158 | return trials
159 |
160 |
161 | def get_all_csv(base_dir, verbose=False):
162 | """
163 | Get a list of all csv data under a directory.
164 | :param base_dir: root directory
165 | """
166 | data = []
167 | delimiter = ','
168 | for dir_name in get_dirs(base_dir):
169 | for data_file_name in os.listdir(dir_name):
170 | if data_file_name.endswith(".csv"):
171 | full_path = os.path.join(dir_name, data_file_name)
172 | if verbose:
173 | print("Reading {}".format(full_path))
174 | data.append(np.genfromtxt(
175 | full_path, delimiter=delimiter, dtype=None, names=True
176 | ))
177 | return data
178 |
179 |
180 | def get_unique_param_to_values(all_variants):
181 | variant_key_to_values = defaultdict(set)
182 | for variant in all_variants:
183 | for k, v in variant.items():
184 | if type(v) == list:
185 | v = str(v)
186 | variant_key_to_values[k].add(v)
187 | unique_key_to_values = {
188 | k: variant_key_to_values[k]
189 | for k in variant_key_to_values
190 | if len(variant_key_to_values[k]) > 1
191 | }
192 | return unique_key_to_values
--------------------------------------------------------------------------------
/maple/util/hyperparameter.py:
--------------------------------------------------------------------------------
1 | """
2 | Custom hyperparameter functions.
3 | """
4 | import abc
5 | import copy
6 | import math
7 | import random
8 | import itertools
9 | from typing import List
10 |
11 | import maple.pythonplusplus as ppp
12 |
13 |
14 | class Hyperparameter(metaclass=abc.ABCMeta):
15 | def __init__(self, name):
16 | self._name = name
17 |
18 | @property
19 | def name(self):
20 | return self._name
21 |
22 |
23 | class RandomHyperparameter(Hyperparameter):
24 | def __init__(self, name):
25 | super().__init__(name)
26 | self._last_value = None
27 |
28 | @abc.abstractmethod
29 | def generate_next_value(self):
30 | """Return a value for the hyperparameter"""
31 | return
32 |
33 | def generate(self):
34 | self._last_value = self.generate_next_value()
35 | return self._last_value
36 |
37 |
38 | class EnumParam(RandomHyperparameter):
39 | def __init__(self, name, possible_values):
40 | super().__init__(name)
41 | self.possible_values = possible_values
42 |
43 | def generate_next_value(self):
44 | return random.choice(self.possible_values)
45 |
46 |
47 | class LogFloatParam(RandomHyperparameter):
48 | """
49 | Return something ranging from [min_value + offset, max_value + offset],
50 | distributed with a log.
51 | """
52 | def __init__(self, name, min_value, max_value, *, offset=0):
53 | super(LogFloatParam, self).__init__(name)
54 | self._linear_float_param = LinearFloatParam("log_" + name,
55 | math.log(min_value),
56 | math.log(max_value))
57 | self.offset = offset
58 |
59 | def generate_next_value(self):
60 | return math.e ** (self._linear_float_param.generate()) + self.offset
61 |
62 |
63 | class LinearFloatParam(RandomHyperparameter):
64 | def __init__(self, name, min_value, max_value):
65 | super(LinearFloatParam, self).__init__(name)
66 | self._min = min_value
67 | self._delta = max_value - min_value
68 |
69 | def generate_next_value(self):
70 | return random.random() * self._delta + self._min
71 |
72 |
73 | class LogIntParam(RandomHyperparameter):
74 | def __init__(self, name, min_value, max_value, *, offset=0):
75 | super().__init__(name)
76 | self._linear_float_param = LinearFloatParam("log_" + name,
77 | math.log(min_value),
78 | math.log(max_value))
79 | self.offset = offset
80 |
81 | def generate_next_value(self):
82 | return int(
83 | math.e ** (self._linear_float_param.generate()) + self.offset
84 | )
85 |
86 |
87 | class LinearIntParam(RandomHyperparameter):
88 | def __init__(self, name, min_value, max_value):
89 | super(LinearIntParam, self).__init__(name)
90 | self._min = min_value
91 | self._max = max_value
92 |
93 | def generate_next_value(self):
94 | return random.randint(self._min, self._max)
95 |
96 |
97 | class FixedParam(RandomHyperparameter):
98 | def __init__(self, name, value):
99 | super().__init__(name)
100 | self._value = value
101 |
102 | def generate_next_value(self):
103 | return self._value
104 |
105 |
106 | class Sweeper(object):
107 | pass
108 |
109 |
110 | class RandomHyperparameterSweeper(Sweeper):
111 | def __init__(self, hyperparameters=None, default_kwargs=None):
112 | if default_kwargs is None:
113 | default_kwargs = {}
114 | self._hyperparameters = hyperparameters or []
115 | self._validate_hyperparameters()
116 | self._default_kwargs = default_kwargs
117 |
118 | def _validate_hyperparameters(self):
119 | names = set()
120 | for hp in self._hyperparameters:
121 | name = hp.name
122 | if name in names:
123 | raise Exception("Hyperparameter '{0}' already added.".format(
124 | name))
125 | names.add(name)
126 |
127 | def set_default_parameters(self, default_kwargs):
128 | self._default_kwargs = default_kwargs
129 |
130 | def generate_random_hyperparameters(self):
131 | hyperparameters = {}
132 | for hp in self._hyperparameters:
133 | hyperparameters[hp.name] = hp.generate()
134 | hyperparameters = ppp.dot_map_dict_to_nested_dict(hyperparameters)
135 | return ppp.merge_recursive_dicts(
136 | hyperparameters,
137 | copy.deepcopy(self._default_kwargs),
138 | ignore_duplicate_keys_in_second_dict=True,
139 | )
140 |
141 | def sweep_hyperparameters(self, function, num_configs):
142 | returned_value_and_params = []
143 | for _ in range(num_configs):
144 | kwargs = self.generate_random_hyperparameters()
145 | score = function(**kwargs)
146 | returned_value_and_params.append((score, kwargs))
147 |
148 | return returned_value_and_params
149 |
150 |
151 | class DeterministicHyperparameterSweeper(Sweeper):
152 | """
153 | Do a grid search over hyperparameters based on a predefined set of
154 | hyperparameters.
155 | """
156 | def __init__(self, hyperparameters, default_parameters=None):
157 | """
158 |
159 | :param hyperparameters: A dictionary of the form
160 | ```
161 | {
162 | 'hp_1': [value1, value2, value3],
163 | 'hp_2': [value1, value2, value3],
164 | ...
165 | }
166 | ```
167 | This format is like the param_grid in SciKit-Learn:
168 | http://scikit-learn.org/stable/modules/grid_search.html#exhaustive-grid-search
169 | :param default_parameters: Default key-value pairs to add to the
170 | dictionary.
171 | """
172 | self._hyperparameters = hyperparameters
173 | self._default_kwargs = default_parameters or {}
174 | named_hyperparameters = []
175 | for name, values in self._hyperparameters.items():
176 | named_hyperparameters.append(
177 | [(name, v) for v in values]
178 | )
179 | self._hyperparameters_dicts = [
180 | ppp.dot_map_dict_to_nested_dict(dict(tuple_list))
181 | for tuple_list in itertools.product(*named_hyperparameters)
182 | ]
183 |
184 | def iterate_hyperparameters(self):
185 | """
186 | Iterate over the hyperparameters in a grid-manner.
187 |
188 | :return: List of dictionaries. Each dictionary is a map from name to
189 | hyperpameter.
190 | """
191 | return [
192 | ppp.merge_recursive_dicts(
193 | hyperparameters,
194 | copy.deepcopy(self._default_kwargs),
195 | ignore_duplicate_keys_in_second_dict=True,
196 | )
197 | for hyperparameters in self._hyperparameters_dicts
198 | ]
199 |
200 |
201 | # TODO(vpong): Test this
202 | class DeterministicSweeperCombiner(object):
203 | """
204 | A simple wrapper to combiner multiple DeterministicHyperParameterSweeper's
205 | """
206 | def __init__(self, sweepers: List[DeterministicHyperparameterSweeper]):
207 | self._sweepers = sweepers
208 |
209 | def iterate_list_of_hyperparameters(self):
210 | """
211 | Usage:
212 |
213 | ```
214 | sweeper1 = DeterministicHyperparameterSweeper(...)
215 | sweeper2 = DeterministicHyperparameterSweeper(...)
216 | combiner = DeterministicSweeperCombiner([sweeper1, sweeper2])
217 |
218 | for params_1, params_2 in combiner.iterate_list_of_hyperparameters():
219 | # param_1 = {...}
220 | # param_2 = {...}
221 | ```
222 | :return: Generator of hyperparameters, in the same order as provided
223 | sweepers.
224 | """
225 | return itertools.product(
226 | sweeper.iterate_hyperparameters()
227 | for sweeper in self._sweepers
228 | )
--------------------------------------------------------------------------------
/maple/util/io.py:
--------------------------------------------------------------------------------
1 | import joblib
2 | import numpy as np
3 | import pickle
4 |
5 | import boto3
6 |
7 | from maple.launchers.conf import LOCAL_LOG_DIR, AWS_S3_PATH
8 | import os
9 |
10 | PICKLE = 'pickle'
11 | NUMPY = 'numpy'
12 | JOBLIB = 'joblib'
13 |
14 |
15 | def local_path_from_s3_or_local_path(filename):
16 | relative_filename = os.path.join(LOCAL_LOG_DIR, filename)
17 | if os.path.isfile(filename):
18 | return filename
19 | elif os.path.isfile(relative_filename):
20 | return relative_filename
21 | else:
22 | return sync_down(filename)
23 |
24 |
25 | def sync_down(path, check_exists=True):
26 | is_docker = os.path.isfile("/.dockerenv")
27 | if is_docker:
28 | local_path = "/tmp/%s" % (path)
29 | else:
30 | local_path = "%s/%s" % (LOCAL_LOG_DIR, path)
31 |
32 | if check_exists and os.path.isfile(local_path):
33 | return local_path
34 |
35 | local_dir = os.path.dirname(local_path)
36 | os.makedirs(local_dir, exist_ok=True)
37 |
38 | if is_docker:
39 | from doodad.ec2.autoconfig import AUTOCONFIG
40 | os.environ["AWS_ACCESS_KEY_ID"] = AUTOCONFIG.aws_access_key()
41 | os.environ["AWS_SECRET_ACCESS_KEY"] = AUTOCONFIG.aws_access_secret()
42 |
43 | full_s3_path = os.path.join(AWS_S3_PATH, path)
44 | bucket_name, bucket_relative_path = split_s3_full_path(full_s3_path)
45 | try:
46 | bucket = boto3.resource('s3').Bucket(bucket_name)
47 | bucket.download_file(bucket_relative_path, local_path)
48 | except Exception as e:
49 | local_path = None
50 | print("Failed to sync! path: ", path)
51 | print("Exception: ", e)
52 | return local_path
53 |
54 |
55 | def sync_down_folder(path):
56 | is_docker = os.path.isfile("/.dockerenv")
57 | if is_docker:
58 | local_path = "/tmp/%s" % (path)
59 | else:
60 | local_path = "%s/%s" % (LOCAL_LOG_DIR, path)
61 |
62 | local_dir = os.path.dirname(local_path)
63 | os.makedirs(local_dir, exist_ok=True)
64 |
65 | if is_docker:
66 | from doodad.ec2.autoconfig import AUTOCONFIG
67 | os.environ["AWS_ACCESS_KEY_ID"] = AUTOCONFIG.aws_access_key()
68 | os.environ["AWS_SECRET_ACCESS_KEY"] = AUTOCONFIG.aws_access_secret()
69 |
70 | full_s3_path = os.path.join(AWS_S3_PATH, path)
71 | bucket_name, bucket_relative_path = split_s3_full_path(full_s3_path)
72 | command = "aws s3 sync s3://%s/%s %s" % (bucket_name, bucket_relative_path, local_path)
73 | print(command)
74 | stream = os.popen(command)
75 | output = stream.read()
76 | print(output)
77 | return local_path
78 |
79 |
80 | def split_s3_full_path(s3_path):
81 | """
82 | Split "s3://foo/bar/baz" into "foo" and "bar/baz"
83 | """
84 | bucket_name_and_directories = s3_path.split('//')[1]
85 | bucket_name, *directories = bucket_name_and_directories.split('/')
86 | directory_path = '/'.join(directories)
87 | return bucket_name, directory_path
88 |
89 |
90 | class CPU_Unpickler(pickle.Unpickler):
91 | """Utility for loading a pickled model on CPU machine saved from a GPU"""
92 | def find_class(self, module, name):
93 | if module == 'torch.storage' and name == '_load_from_bytes':
94 | return lambda b: torch.load(io.BytesIO(b), map_location='cpu')
95 | else: return super().find_class(module, name)
96 |
97 |
98 | def load_local_or_remote_file(filepath, file_type=None):
99 | local_path = local_path_from_s3_or_local_path(filepath)
100 | if file_type is None:
101 | extension = local_path.split('.')[-1]
102 | if extension == 'npy':
103 | file_type = NUMPY
104 | else:
105 | file_type = PICKLE
106 | else:
107 | file_type = PICKLE
108 | if file_type == NUMPY:
109 | object = np.load(open(local_path, "rb"), allow_pickle=True)
110 | elif file_type == JOBLIB:
111 | object = joblib.load(local_path)
112 | else:
113 | #f = open(local_path, 'rb')
114 | #object = CPU_Unpickler(f).load()
115 | object = pickle.load(open(local_path, "rb"))
116 | print("loaded", local_path)
117 | return object
118 |
119 |
120 | def get_absolute_path(path):
121 | if path[0] == "/":
122 | return path
123 | else:
124 | is_docker = os.path.isfile("/.dockerenv")
125 | if is_docker:
126 | local_path = "/tmp/%s" % (path)
127 | else:
128 | local_path = "%s/%s" % (LOCAL_LOG_DIR, path)
129 | return local_path
130 |
131 |
132 | if __name__ == "__main__":
133 | p = sync_down("ashvin/vae/new-point2d/run0/id1/params.pkl")
134 | print("got", p)
135 |
--------------------------------------------------------------------------------
/maple/util/ml_util.py:
--------------------------------------------------------------------------------
1 | """
2 | General utility functions for machine learning.
3 | """
4 | import abc
5 | import math
6 | import numpy as np
7 |
8 |
9 | class ScalarSchedule(object, metaclass=abc.ABCMeta):
10 | @abc.abstractmethod
11 | def get_value(self, t):
12 | pass
13 |
14 |
15 | class ConstantSchedule(ScalarSchedule):
16 | def __init__(self, value):
17 | self._value = value
18 |
19 | def get_value(self, t):
20 | return self._value
21 |
22 |
23 | class LinearSchedule(ScalarSchedule):
24 | """
25 | Linearly interpolate and then stop at a final value.
26 | """
27 | def __init__(
28 | self,
29 | init_value,
30 | final_value,
31 | ramp_duration,
32 | ):
33 | self._init_value = init_value
34 | self._final_value = final_value
35 | self._ramp_duration = ramp_duration
36 |
37 | def get_value(self, t):
38 | return (
39 | self._init_value
40 | + (self._final_value - self._init_value)
41 | * min(1.0, t * 1.0 / self._ramp_duration)
42 | )
43 |
44 |
45 | class IntLinearSchedule(LinearSchedule):
46 | """
47 | Same as RampUpSchedule but round output to an int
48 | """
49 | def get_value(self, t):
50 | return int(super().get_value(t))
51 |
52 |
53 | class PiecewiseLinearSchedule(ScalarSchedule):
54 | """
55 | Given a list of (x, t) value-time pairs, return value x at time t,
56 | and linearly interpolate between the two
57 | """
58 | def __init__(
59 | self,
60 | x_values,
61 | y_values,
62 | ):
63 | self._x_values = x_values
64 | self._y_values = y_values
65 |
66 | def get_value(self, t):
67 | return np.interp(t, self._x_values, self._y_values)
68 |
69 |
70 | class IntPiecewiseLinearSchedule(PiecewiseLinearSchedule):
71 | def get_value(self, t):
72 | return int(super().get_value(t))
73 |
74 |
75 | def none_to_infty(bounds):
76 | if bounds is None:
77 | bounds = -math.inf, math.inf
78 | lb, ub = bounds
79 | if lb is None:
80 | lb = -math.inf
81 | if ub is None:
82 | ub = math.inf
83 | return lb, ub
84 |
--------------------------------------------------------------------------------
/maple/util/slurm_util.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 | import pathlib
4 |
5 | def create_sbatch_script(args, use_variants=True):
6 | # Create a new directory path if it doesn't exist and create a new filename that we will write to
7 | exp_dir = args.exp_dir
8 | sbatch_dir = os.path.join(exp_dir, "sbatch")
9 | new_sbatch_fpath = os.path.join(sbatch_dir, "{}.sbatch".format(args.job_name))
10 | if not os.path.isdir(sbatch_dir):
11 | os.mkdir(sbatch_dir)
12 |
13 | if use_variants:
14 | base_variant = os.path.join(exp_dir, "variants", args.env, "base.json")
15 | variant_update = os.path.join(exp_dir, "variants", args.env, "{}.json".format(args.config))
16 | assert os.path.exists(base_variant) and os.path.exists(variant_update)
17 |
18 | command = ""
19 | for i in range(args.num_seeds):
20 | # Compose main command to be executed in script
21 | python_script = args.python_script
22 | line_command = "{{\nsleep {}\n".format(30*i)
23 | line_command += "python {python_script} --env {env} --label {label}".format(
24 | python_script=python_script,
25 | env=args.env,
26 | label=args.label,
27 | )
28 | if use_variants:
29 | line_command = "{line_command} --base_variant {base_variant} --variant_update {variant_update}".format(
30 | line_command=line_command,
31 | base_variant=base_variant,
32 | variant_update=variant_update,
33 | )
34 | if args.no_gpu:
35 | line_command += " --no_gpu"
36 |
37 | command += "{line_command}\n}} & \n".format(line_command=line_command)
38 | command += "wait"
39 |
40 | if args.partition in ["titans", "dgx"]:
41 | log_dir = "/scratch/cluster/soroush/logs"
42 | elif args.partition in ["svl", "tibet", "napoli-gpu"]:
43 | log_dir = "/cvgl2/u/soroush/logs"
44 | else:
45 | raise ValueError
46 |
47 | if args.exclude is None:
48 | if args.partition == "titans":
49 | args.exclude = "titan-5,titan-12"
50 | else:
51 | args.exclude = ""
52 |
53 | # Define a dict to map expected fill-ins with replacement values
54 | fill_ins = {
55 | "{{PARTITION}}": args.partition,
56 | "{{EXCLUDE}}": args.exclude,
57 | "{{NUM_GPU}}": 0 if args.no_gpu else 1,
58 | "{{NUM_CPU}}": args.num_seeds,
59 | "{{MEM}}": args.mem * args.num_seeds,
60 | "{{JOB_NAME}}": args.job_name,
61 | "{{LOG_DIR}}": log_dir,
62 | "{{HOURS}}": args.max_hours,
63 | "{{CMD}}": command,
64 | "{{CONDA_ENV}}": args.conda_env,
65 | }
66 |
67 | # Open the template file
68 | with open(args.slurm_template) as template:
69 | # Open the new sbatch file
70 | print(new_sbatch_fpath)
71 | with open(new_sbatch_fpath, 'w+') as new_file:
72 | # Loop through template and write to this new file
73 | for line in template:
74 | wrote = False
75 | # Check for various cases
76 | for k, v in fill_ins.items():
77 | # If the key is found in the line, replace it with its value and pop it from the dict
78 | if k in line:
79 | new_file.write(line.replace(k, str(v)))
80 | wrote = True
81 | break
82 | # Otherwise, we just write the line from the template directly
83 | if not wrote:
84 | new_file.write(line)
85 |
86 | # Execute this file!
87 | # TODO: Fix! (Permission denied error)
88 | #os.system(new_sbatch_fpath)
89 |
90 | if __name__ == "__main__":
91 | # noinspection PyTypeChecker
92 | parser = argparse.ArgumentParser()
93 |
94 | parser.add_argument('--env', type=str, default='can')
95 | parser.add_argument('--config', type=str, default='base')
96 | parser.add_argument('--label', type=str, default=None)
97 | parser.add_argument('--job_name', type=str, default=None)
98 |
99 | parser.add_argument('--num_seeds', type=int, default=4)
100 | parser.add_argument('--no_video', action='store_true')
101 | parser.add_argument('--no_gpu', action='store_true')
102 |
103 | parser.add_argument('--mem', type=int, default=9)
104 | parser.add_argument('--max_hours', type=int, default=504) #168
105 | parser.add_argument('--partition', type=str, default="titans")
106 | parser.add_argument('--exclude', type=str, default=None)
107 |
108 | args = parser.parse_args()
109 |
110 | if args.label is None:
111 | args.label = args.config
112 |
113 | if args.job_name is None:
114 | args.job_name = "rl_{}_{}".format(args.env, args.label)
115 |
116 | if args.exclude is None:
117 | if args.partition == "titans":
118 | args.exclude = "titan-5,titan-12"
119 | else:
120 | args.exclude = ""
121 |
122 | create_sbatch_script(args)
123 |
--------------------------------------------------------------------------------
/scripts/eval.py:
--------------------------------------------------------------------------------
1 | from maple.launchers.launcher_util import run_experiment
2 | from maple.launchers.robosuite_launcher import experiment
3 | import maple.util.hyperparameter as hyp
4 | import os.path as osp
5 | import argparse
6 | import json
7 | import collections
8 | import copy
9 |
10 | from maple.launchers.conf import LOCAL_LOG_DIR
11 |
12 | base_variant = dict(
13 | algorithm_kwargs=dict(
14 | eval_only=True,
15 | num_epochs=5000,
16 | eval_epoch_freq=100,
17 | ),
18 | replay_buffer_size=int(1E2),
19 | vis_expl=False,
20 | dump_video_kwargs=dict(
21 | rows=1,
22 | columns=6,
23 | pad_length=5,
24 | pad_color=0,
25 | ),
26 | num_eval_rollouts=50,
27 |
28 | # ckpt_epoch=100, #### uncomment if you want to evaluate a specific epoch ckeckpoint only ###
29 | )
30 |
31 | env_params = dict(
32 | lift={
33 | 'ckpt_path': [
34 | ### Add paths here ###
35 | ],
36 | },
37 | door={
38 | 'ckpt_path': [
39 | ### Add paths here ###
40 | ],
41 | },
42 | pnp={
43 | 'ckpt_path': [
44 | ### Add paths here ###
45 | ],
46 | },
47 | wipe={
48 | 'ckpt_path': [
49 | ### Add paths here ###
50 | ],
51 | },
52 | stack={
53 | 'ckpt_path': [
54 | ### Add paths here ###
55 | ],
56 | },
57 | nut_round={
58 | 'ckpt_path': [
59 | ### Add paths here ###
60 | ],
61 | },
62 | cleanup={
63 | 'ckpt_path': [
64 | ### Add paths here ###
65 | ],
66 | },
67 | peg_ins={
68 | 'ckpt_path': [
69 | ### Add paths here ###
70 | ],
71 | },
72 | )
73 |
74 | def process_variant(eval_variant):
75 | ckpt_path = eval_variant['ckpt_path']
76 | json_path = osp.join(LOCAL_LOG_DIR, ckpt_path, 'variant.json')
77 | with open(json_path) as f:
78 | ckpt_variant = json.load(f)
79 | deep_update(ckpt_variant, eval_variant)
80 | variant = copy.deepcopy(ckpt_variant)
81 |
82 | if args.debug:
83 | mpl = variant['algorithm_kwargs']['max_path_length']
84 | variant['algorithm_kwargs']['num_eval_steps_per_epoch'] = mpl * 3
85 | variant['dump_video_kwargs']['rows'] = 1
86 | variant['dump_video_kwargs']['columns'] = 2
87 | else:
88 | mpl = variant['algorithm_kwargs']['max_path_length']
89 | variant['algorithm_kwargs']['num_eval_steps_per_epoch'] = mpl * variant['num_eval_rollouts']
90 |
91 | variant['save_video_period'] = variant['algorithm_kwargs']['eval_epoch_freq']
92 |
93 | if args.no_video:
94 | variant['save_video'] = False
95 |
96 | variant['exp_label'] = args.label
97 | return variant
98 |
99 | def deep_update(source, overrides):
100 | '''
101 | Update a nested dictionary or similar mapping.
102 | Modify ``source`` in place.
103 | Copied from: https://stackoverflow.com/questions/3232943/update-value-of-a-nested-dictionary-of-varying-depth
104 | '''
105 | for key, value in overrides.items():
106 | if isinstance(value, collections.Mapping) and value:
107 | returned = deep_update(source.get(key, {}), value)
108 | source[key] = returned
109 | else:
110 | source[key] = overrides[key]
111 | return source
112 |
113 | if __name__ == "__main__":
114 | # noinspection PyTypeChecker
115 | parser = argparse.ArgumentParser()
116 | parser.add_argument('--env', type=str)
117 | parser.add_argument('--label', type=str, default='test')
118 | parser.add_argument('--num_seeds', type=int, default=1)
119 | parser.add_argument('--no_video', action='store_true')
120 | parser.add_argument('--no_gpu', action='store_true')
121 | parser.add_argument('--gpu_id', type=int, default=0)
122 | parser.add_argument('--debug', action='store_true')
123 | parser.add_argument('--first_variant', action='store_true')
124 | args = parser.parse_args()
125 |
126 | search_space = env_params[args.env]
127 | sweeper = hyp.DeterministicHyperparameterSweeper(
128 | search_space, default_parameters=base_variant,
129 | )
130 | for exp_id, eval_variant in enumerate(sweeper.iterate_hyperparameters()):
131 | variant = process_variant(eval_variant)
132 |
133 | run_experiment(
134 | experiment,
135 | exp_folder=args.env,
136 | exp_prefix=args.label,
137 | variant=variant,
138 | snapshot_mode='gap_and_last',
139 | snapshot_gap=200,
140 | exp_id=exp_id,
141 | use_gpu=(not args.no_gpu),
142 | gpu_id=args.gpu_id,
143 | )
144 |
145 | if args.first_variant:
146 | exit()
--------------------------------------------------------------------------------
/scripts/run_experiment_from_doodad.py:
--------------------------------------------------------------------------------
1 | import doodad as dd
2 | # import torch.multiprocessing as mp
3 |
4 | from maple.launchers.launcher_util import run_experiment_here
5 |
6 | if __name__ == "__main__":
7 | import matplotlib
8 | matplotlib.use('agg')
9 |
10 | # mp.set_start_method('forkserver')
11 | args_dict = dd.get_args()
12 | method_call = args_dict['method_call']
13 | run_experiment_kwargs = args_dict['run_experiment_kwargs']
14 | output_dir = args_dict['output_dir']
15 | run_mode = args_dict.get('mode', None)
16 | if run_mode and run_mode in ['slurm_singularity', 'sss']:
17 | import os
18 | run_experiment_kwargs['variant']['slurm-job-id'] = os.environ.get(
19 | 'SLURM_JOB_ID', None
20 | )
21 | if run_mode and (run_mode == 'ec2' or run_mode == 'gcp'):
22 | if run_mode == 'ec2':
23 | try:
24 | import urllib.request
25 | instance_id = urllib.request.urlopen(
26 | 'http://169.254.169.254/latest/meta-data/instance-id'
27 | ).read().decode()
28 | run_experiment_kwargs['variant']['EC2_instance_id'] = instance_id
29 | except Exception as e:
30 | print("Could not get AWS instance ID. Error was...")
31 | print(e)
32 | if run_mode == 'gcp':
33 | try:
34 | import urllib.request
35 | request = urllib.request.Request(
36 | "http://metadata/computeMetadata/v1/instance/name",
37 | )
38 | # See this URL for why we need this header:
39 | # https://cloud.google.com/compute/docs/storing-retrieving-metadata
40 | request.add_header("Metadata-Flavor", "Google")
41 | instance_name = urllib.request.urlopen(request).read().decode()
42 | run_experiment_kwargs['variant']['GCP_instance_name'] = (
43 | instance_name
44 | )
45 | except Exception as e:
46 | print("Could not get GCP instance name. Error was...")
47 | print(e)
48 | # Do this in case base_log_dir was already set
49 | run_experiment_kwargs['base_log_dir'] = output_dir
50 | run_experiment_here(
51 | method_call,
52 | include_exp_prefix_sub_dir=False,
53 | **run_experiment_kwargs
54 | )
55 | else:
56 | run_experiment_here(
57 | method_call,
58 | log_dir=output_dir,
59 | **run_experiment_kwargs
60 | )
61 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from distutils.core import setup
2 | from setuptools import find_packages
3 |
4 | setup(
5 | name='maple',
6 | version='0.2.1dev',
7 | packages=find_packages(),
8 | license='MIT License',
9 | long_description=open('README.md').read(),
10 | )
11 |
--------------------------------------------------------------------------------