├── .gitignore ├── LICENSE ├── README.md ├── maple.yml ├── maple ├── __init__.py ├── core │ ├── __init__.py │ ├── batch_rl_algorithm.py │ ├── eval_util.py │ ├── logging.py │ ├── loss.py │ ├── online_rl_algorithm.py │ ├── rl_algorithm.py │ ├── serializable.py │ ├── tabulate.py │ └── trainer.py ├── data_management │ ├── __init__.py │ ├── env_replay_buffer.py │ ├── normalizer.py │ ├── obs_dict_replay_buffer.py │ ├── online_vae_replay_buffer.py │ ├── path_builder.py │ ├── replay_buffer.py │ ├── shared_obs_dict_replay_buffer.py │ ├── simple_replay_buffer.py │ └── split_buffer.py ├── envs │ ├── __init__.py │ ├── env_utils.py │ ├── make_env.py │ ├── mujoco_env.py │ ├── mujoco_image_env.py │ ├── proxy_env.py │ ├── vae_wrapper.py │ ├── wrappers.py │ └── wrappers │ │ ├── __init__.py │ │ ├── discretize_env.py │ │ ├── history_env.py │ │ ├── image_mujoco_env.py │ │ ├── image_mujoco_env_with_obs.py │ │ ├── normalized_box_env.py │ │ ├── reward_wrapper_env.py │ │ └── stack_observation_env.py ├── exploration_strategies │ ├── __init__.py │ ├── base.py │ ├── epsilon_greedy.py │ ├── gaussian_and_epsilon_strategy.py │ ├── gaussian_strategy.py │ └── ou_strategy.py ├── launchers │ ├── __init__.py │ ├── conf.py │ ├── launcher_util.py │ ├── robosuite_launcher.py │ └── visualization.py ├── policies │ ├── __init__.py │ ├── argmax.py │ ├── base.py │ └── simple.py ├── pythonplusplus.py ├── samplers │ ├── __init__.py │ ├── data_collector │ │ ├── __init__.py │ │ ├── base.py │ │ ├── contextual_path_collector.py │ │ ├── joint_path_collector.py │ │ ├── path_collector.py │ │ ├── step_collector.py │ │ └── vae_env.py │ ├── in_place.py │ ├── rollout_functions.py │ └── util.py ├── torch │ ├── __init__.py │ ├── conv_networks.py │ ├── core.py │ ├── data.py │ ├── data_management │ │ ├── __init__.py │ │ └── normalizer.py │ ├── distributions.py │ ├── lvm │ │ ├── __init__.py │ │ ├── bear_vae.py │ │ └── latent_variable_model.py │ ├── modules.py │ ├── networks │ │ ├── __init__.py │ │ ├── basic.py │ │ ├── cnn.py │ │ ├── custom.py │ │ ├── dcnn.py │ │ ├── feat_point_mlp.py │ │ ├── image_state.py │ │ ├── linear_transform.py │ │ ├── mlp.py │ │ ├── normalization.py │ │ ├── pretrained_cnn.py │ │ ├── stochastic │ │ │ └── distribution_generator.py │ │ └── two_headed_mlp.py │ ├── pytorch_util.py │ ├── sac │ │ ├── __init__.py │ │ ├── policies │ │ │ ├── __init__.py │ │ │ ├── base.py │ │ │ ├── gaussian_policy.py │ │ │ ├── lvm_policy.py │ │ │ ├── pamdp_policy.py │ │ │ └── policy_from_q.py │ │ ├── sac.py │ │ └── sac_hybrid.py │ └── torch_rl_algorithm.py └── util │ ├── data_processing.py │ ├── hyperparameter.py │ ├── io.py │ ├── ml_util.py │ └── slurm_util.py ├── scripts ├── eval.py ├── run_experiment_from_doodad.py └── train.py └── setup.py /.gitignore: -------------------------------------------------------------------------------- 1 | mjkey.txt 2 | **/.DS_STORE 3 | **/*.pyc 4 | **/*.swp 5 | **/*.pdf 6 | maple/launchers/conf_private.py 7 | MANIFEST 8 | *.egg-info 9 | \.idea/ 10 | data 11 | MUJOCO_LOG.TXT 12 | **/sbatch/* 13 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 UT Robot Perception and Learning Lab 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # MAPLE: Augmenting Reinforcement Learning with Behavior Primitives for Diverse Manipulation Tasks 2 | 3 | This is the official codebase for **Ma**nipulation **P**rimitive-augmented reinforcement **Le**arning (MAPLE), from the following paper: 4 | 5 | **Augmenting Reinforcement Learning with Behavior Primitives for Diverse Manipulation Tasks** 6 |
[Soroush Nasiriany](http://snasiriany.me/), [Huihan Liu](https://huihanl.github.io/), [Yuke Zhu](https://www.cs.utexas.edu/~yukez/) 7 |
[UT Austin Robot Perception and Learning Lab](https://rpl.cs.utexas.edu/) 8 |
IEEE International Conference on Robotics and Automation (ICRA), 2022 9 |
**[[Paper]](https://arxiv.org/abs/2110.03655)** **[[Project Website]](https://ut-austin-rpl.github.io/maple/)** 10 | 11 | 12 | 13 | 14 | This guide contains information about (1) [Installation](#installation), (2) [Running Experiments](#running-experiments), (3) [Setting Up Your Own Environments](#setting-up-your-own-environments), (4) [Acknowledgement](#acknowledgement), and (5) [Citation](#citation). 15 | 16 | ## Installation 17 | ### Download code 18 | - Current codebase: ```git clone https://github.com/UT-Austin-RPL/maple``` 19 | - (for environments) the `maple` branch in robosuite: ```git clone -b maple https://github.com/ARISE-Initiative/robosuite``` 20 | 21 | ### Setup robosuite 22 | 1. Download MuJoCo 2.0 (Linux and Mac OS X) and unzip its contents into `~/.mujoco/mujoco200`, and copy your MuJoCo license key `~/.mujoco/mjkey.txt`. You can obtain a license key from [here](https://www.roboti.us/license.html). 23 | 2. (linux) Setup additional dependencies: ```sudo apt install libgl1-mesa-dev libgl1-mesa-glx libglew-dev libosmesa6-dev software-properties-common net-tools xpra xserver-xorg-dev libglfw3-dev patchelf``` 24 | 3. Add MuJoCo to library paths: `export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$HOME/.mujoco/mujoco200/bin` 25 | 26 | ### Setup conda environment 27 | 1. Create the conda environment: `conda env create --name maple --file=maple.yml` 28 | 2. (if above fails) edit `maple.yml` to modify dependencies and then resume setup: `conda env update --name maple --file=maple.yml` 29 | 3. Activate the conda environment: `conda activate maple` 30 | 4. Finish maple setup: (in your maple repo path do) `pip install -e .` 31 | 5. Finish robosuite setup: (in your robosuite repo path do) `pip install -e .` 32 | 33 | ## Running Experiments 34 | Scripts for training policies and re-playing policy checkpoints are located in `scripts/train.py` and `scripts/eval.py`, respectively. 35 | 36 | These experiment scripts use the following structure: 37 | ``` 38 | base_variant = dict( 39 | # default hyperparam settings for all envs 40 | ) 41 | 42 | env_params = { 43 | '' : { 44 | # add/override default hyperparam settings for specific env 45 | # each setting is specified as a dictionary address (key), 46 | # followed by list of possible options (value). 47 | # Example in following line: 48 | # 'env_variant.controller_type': ['OSC_POSITION'], 49 | }, 50 | '' : { 51 | ... 52 | }, 53 | } 54 | ``` 55 | 56 | ### Command Line Options 57 | See `parser` in `scripts/train.py` for a complete list of options. Some notable options: 58 | - `env`: the env to run (eg. `stack`) 59 | - `label`: name for experiment 60 | - `debug`: run with lite options for debugging 61 | 62 | ### Plotting Experiment Results 63 | During training, the results will be saved to a file called under `LOCAL_LOG_DIR///`. 64 | Inside this folder, the experiment results are stored in `progress.csv`. We recommend using [viskit](https://github.com/vitchyr/viskit) to plot the results. 65 | 66 | ## Setting Up Your Own Environments 67 | Note that this codebase is designed to work with robosuite environments only. For setting up your own environments, please follow [these examples](https://github.com/ARISE-Initiative/robosuite/tree/maple/robosuite/environments/manipulation) for reference. Notably, you will need to add the `skill_config` variable to the constructor, and define the keypoints for the affordance score by implementing the `_get_skill_info` function. 68 | 69 | If you would like to know the inner workings of the primitives, refer to [`skill_controller.py`](https://github.com/ARISE-Initiative/robosuite/tree/maple/robosuite/controllers/skill_controller.py) and [`skills.py`](https://github.com/ARISE-Initiative/robosuite/tree/maple/robosuite/controllers/skills.py). Note that we use the term "skill" to refer to behavior primitives in the code. 70 | 71 | ## Acknowledgement 72 | Much of this codebase is directly based on [RLkit](https://github.com/vitchyr/rlkit), which itself is based on [rllab](https://github.com/rll/rllab). 73 | In addition, the environments were developed as a forked branch of [robosuite](https://github.com/ARISE-Initiative/robosuite) `v1.1.0`. 74 | 75 | ## Citation 76 | ```bibtex 77 | @inproceedings{nasiriany2022maple, 78 | title={Augmenting Reinforcement Learning with Behavior Primitives for Diverse Manipulation Tasks}, 79 | author={Soroush Nasiriany and Huihan Liu and Yuke Zhu}, 80 | booktitle={IEEE International Conference on Robotics and Automation (ICRA)}, 81 | year={2022} 82 | } 83 | ``` 84 | -------------------------------------------------------------------------------- /maple.yml: -------------------------------------------------------------------------------- 1 | name: maple 2 | channels: 3 | - defaults 4 | dependencies: 5 | - _libgcc_mutex=0.1=main 6 | - argon2-cffi=20.1.0=py38h27cfd23_1 7 | - async_generator=1.10=pyhd3eb1b0_0 8 | - backcall=0.2.0=pyhd3eb1b0_0 9 | - bleach=3.3.0=pyhd3eb1b0_0 10 | - ca-certificates=2021.1.19=h06a4308_1 11 | - certifi=2020.12.5=py38h06a4308_0 12 | - dbus=1.13.18=hb2f20db_0 13 | - decorator=4.4.2=pyhd3eb1b0_0 14 | - defusedxml=0.7.1=pyhd3eb1b0_0 15 | - entrypoints=0.3=py38_0 16 | - expat=2.3.0=h2531618_2 17 | - fontconfig=2.13.1=h6c09931_0 18 | - freetype=2.10.4=h5ab3b9f_0 19 | - glib=2.68.0=h36276a3_0 20 | - gst-plugins-base=1.14.0=h8213a91_2 21 | - gstreamer=1.14.0=h28cd5cc_2 22 | - icu=58.2=he6710b0_3 23 | - importlib-metadata=3.7.3=py38h06a4308_1 24 | - importlib_metadata=3.7.3=hd3eb1b0_1 25 | - ipykernel=5.3.4=py38h5ca1d4c_0 26 | - ipython=7.22.0=py38hb070fc8_0 27 | - ipython_genutils=0.2.0=pyhd3eb1b0_1 28 | - ipywidgets=7.6.3=pyhd3eb1b0_1 29 | - jedi=0.17.0=py38_0 30 | - jpeg=9b=h024ee3a_2 31 | - jsonschema=3.2.0=py_2 32 | - jupyter=1.0.0=py38_7 33 | - jupyter_client=6.1.12=pyhd3eb1b0_0 34 | - jupyter_console=6.4.0=pyhd3eb1b0_0 35 | - jupyter_core=4.7.1=py38h06a4308_0 36 | - jupyterlab_pygments=0.1.2=py_0 37 | - jupyterlab_widgets=1.0.0=pyhd3eb1b0_1 38 | - ld_impl_linux-64=2.33.1=h53a641e_7 39 | - libedit=3.1.20191231=h14c3975_1 40 | - libffi=3.3=he6710b0_2 41 | - libgcc-ng=9.1.0=hdf63c60_0 42 | - libpng=1.6.37=hbc83047_0 43 | - libsodium=1.0.18=h7b6447c_0 44 | - libstdcxx-ng=9.1.0=hdf63c60_0 45 | - libuuid=1.0.3=h1bed415_2 46 | - libxcb=1.14=h7b6447c_0 47 | - libxml2=2.9.10=hb55368b_3 48 | - markupsafe=1.1.1=py38h7b6447c_0 49 | - mistune=0.8.4=py38h7b6447c_1000 50 | - nbclient=0.5.3=pyhd3eb1b0_0 51 | - nbconvert=6.0.7=py38_0 52 | - ncurses=6.2=he6710b0_1 53 | - nest-asyncio=1.5.1=pyhd3eb1b0_0 54 | - notebook=6.3.0=py38h06a4308_0 55 | - openssl=1.1.1k=h27cfd23_0 56 | - packaging=20.9=pyhd3eb1b0_0 57 | - pandoc=2.12=h06a4308_0 58 | - pandocfilters=1.4.3=py38h06a4308_1 59 | - parso=0.8.1=pyhd3eb1b0_0 60 | - pcre=8.44=he6710b0_0 61 | - pexpect=4.8.0=pyhd3eb1b0_3 62 | - pickleshare=0.7.5=pyhd3eb1b0_1003 63 | - pip=20.2.2=py38_0 64 | - prometheus_client=0.9.0=pyhd3eb1b0_0 65 | - prompt-toolkit=3.0.17=pyh06a4308_0 66 | - prompt_toolkit=3.0.17=hd3eb1b0_0 67 | - ptyprocess=0.7.0=pyhd3eb1b0_2 68 | - pycparser=2.20=py_2 69 | - pygments=2.8.1=pyhd3eb1b0_0 70 | - pyparsing=2.4.7=pyhd3eb1b0_0 71 | - pyqt=5.9.2=py38h05f1152_4 72 | - pyrsistent=0.17.3=py38h7b6447c_0 73 | - python=3.8.5=h7579374_1 74 | - python-dateutil=2.8.1=pyhd3eb1b0_0 75 | - pyzmq=20.0.0=py38h2531618_1 76 | - qt=5.9.7=h5867ecd_1 77 | - qtconsole=5.0.3=pyhd3eb1b0_0 78 | - qtpy=1.9.0=py_0 79 | - readline=8.0=h7b6447c_0 80 | - send2trash=1.5.0=pyhd3eb1b0_1 81 | - setuptools=49.6.0=py38_0 82 | - sip=4.19.13=py38he6710b0_0 83 | - six=1.15.0=py38h06a4308_0 84 | - sqlite=3.33.0=h62c20be_0 85 | - terminado=0.9.4=py38h06a4308_0 86 | - testpath=0.4.4=pyhd3eb1b0_0 87 | - tk=8.6.10=hbc83047_0 88 | - tornado=6.1=py38h27cfd23_0 89 | - traitlets=5.0.5=pyhd3eb1b0_0 90 | - wcwidth=0.2.5=py_0 91 | - webencodings=0.5.1=py38_1 92 | - widgetsnbextension=3.5.1=py38_0 93 | - xz=5.2.5=h7b6447c_0 94 | - zeromq=4.3.4=h2531618_0 95 | - zipp=3.4.1=pyhd3eb1b0_0 96 | - zlib=1.2.11=h7b6447c_3 97 | - pip: 98 | - absl-py==0.12.0 99 | - addict==2.4.0 100 | - anyio==3.3.1 101 | - astunparse==1.6.3 102 | - attrs==20.2.0 103 | - babel==2.9.1 104 | - cachetools==4.2.1 105 | - cffi==1.14.2 106 | - chardet==3.0.4 107 | - clang==5.0 108 | - click==7.1.2 109 | - cloudpickle==1.3.0 110 | - cycler==0.10.0 111 | - cython==0.29.21 112 | - deprecation==2.1.0 113 | - dill==0.3.4 114 | - fasteners==0.15 115 | - flask==1.0.2 116 | - flatbuffers==1.12 117 | - future==0.18.2 118 | - gast==0.4.0 119 | - gitdb==4.0.2 120 | - gitdb2==2.0.6 121 | - gitpython==2.1.7 122 | - glfw==1.12.0 123 | - google-api-core==1.26.0 124 | - google-api-python-client==1.12.8 125 | - google-auth==1.26.1 126 | - google-auth-httplib2==0.0.4 127 | - google-auth-oauthlib==0.4.4 128 | - google-cloud-core==1.6.0 129 | - google-cloud-storage==1.36.0 130 | - google-crc32c==1.1.2 131 | - google-pasta==0.2.0 132 | - google-resumable-media==1.2.0 133 | - googleapis-common-protos==1.52.0 134 | - grpcio==1.39.0 135 | - gtimer==1.0.0b5 136 | - gym==0.17.2 137 | - h5py==3.1.0 138 | - httplib2==0.19.0 139 | - idna==2.10 140 | - imageio==2.9.0 141 | - imageio-ffmpeg==0.4.4 142 | - itsdangerous==1.1.0 143 | - jinja2==2.11.2 144 | - joblib==1.0.1 145 | - json5==0.9.6 146 | - jupyter-core==4.6.3 147 | - jupyter-packaging==0.10.4 148 | - jupyter-server==1.11.0 149 | - jupyterlab==3.1.12 150 | - jupyterlab-server==2.8.1 151 | - keras==2.6.0 152 | - keras-preprocessing==1.1.2 153 | - kiwisolver==1.3.2 154 | - llvmlite==0.32.1 155 | - lockfile==0.12.2 156 | - markdown==3.3.4 157 | - matplotlib==3.4.3 158 | - monotonic==1.5 159 | - moviepy==1.0.3 160 | - mujoco-py==2.0.2.9 161 | - nbclassic==0.3.1 162 | - nbformat==5.0.8 163 | - networkx==2.6.2 164 | - numba==0.49.1 165 | - numpy==1.19.5 166 | - nvisii==1.1.72 167 | - oauthlib==3.1.0 168 | - open3d==0.13.0 169 | - opencv-python==4.4.0.44 170 | - opt-einsum==3.3.0 171 | - pandas==1.3.3 172 | - pillow==8.3.2 173 | - plotly==3.4.2 174 | - proglog==0.1.9 175 | - progressbar2==3.53.1 176 | - protobuf==3.14.0 177 | - psutil==5.8.0 178 | - pyasn1==0.4.8 179 | - pyasn1-modules==0.2.8 180 | - pybullet==2.6.9 181 | - pyglet==1.5.0 182 | - python-utils==2.5.6 183 | - pytz==2020.4 184 | - pywavelets==1.1.1 185 | - pyyaml==5.4.1 186 | - requests==2.24.0 187 | - requests-oauthlib==1.3.0 188 | - requests-unixsocket==0.2.0 189 | - retrying==1.3.3 190 | - rsa==4.7 191 | - scikit-image==0.18.2 192 | - scikit-learn==0.24.2 193 | - scikit-video==1.1.11 194 | - scipy==1.5.2 195 | - smmap==3.0.5 196 | - smmap2==3.0.1 197 | - sniffio==1.2.0 198 | - tensorboard==2.6.0 199 | - tensorboard-data-server==0.6.1 200 | - tensorboard-plugin-wit==1.8.0 201 | - tensorboardx==2.4 202 | - tensorflow-estimator==2.6.0 203 | - tensorflow-gpu==2.6.0 204 | - termcolor==1.1.0 205 | - threadpoolctl==2.2.0 206 | - tifffile==2021.8.8 207 | - tomlkit==0.7.2 208 | - torch==1.6.0 209 | - torchvision==0.7.0 210 | - tqdm==4.61.1 211 | - typing-extensions==3.7.4.3 212 | - uritemplate==3.0.1 213 | - urllib3==1.25.10 214 | - websocket-client==1.2.1 215 | - werkzeug==1.0.1 216 | - wheel==0.37.0 217 | - wrapt==1.12.1 218 | 219 | -------------------------------------------------------------------------------- /maple/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UT-Austin-RPL/maple/c13f584d4184e5796bd0adb822b8761ff4ebeee0/maple/__init__.py -------------------------------------------------------------------------------- /maple/core/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | General classes, functions, utilities that are used throughout maple. 3 | """ 4 | from maple.core.logging import logger 5 | 6 | __all__ = ['logger'] 7 | 8 | -------------------------------------------------------------------------------- /maple/core/batch_rl_algorithm.py: -------------------------------------------------------------------------------- 1 | import abc 2 | 3 | import gtimer as gt 4 | from maple.core.rl_algorithm import BaseRLAlgorithm 5 | from maple.data_management.replay_buffer import ReplayBuffer 6 | from maple.samplers.data_collector import PathCollector 7 | 8 | 9 | class BatchRLAlgorithm(BaseRLAlgorithm, metaclass=abc.ABCMeta): 10 | def __init__( 11 | self, 12 | trainer, 13 | exploration_env, 14 | evaluation_env, 15 | exploration_data_collector: PathCollector, 16 | evaluation_data_collector: PathCollector, 17 | replay_buffer: ReplayBuffer, 18 | batch_size, 19 | max_path_length, 20 | num_epochs, 21 | num_eval_steps_per_epoch, 22 | num_expl_steps_per_train_loop, 23 | num_trains_per_train_loop, 24 | num_train_loops_per_epoch=1, 25 | min_num_steps_before_training=0, 26 | eval_epoch_freq=1, 27 | expl_epoch_freq=1, 28 | eval_only=False, 29 | no_training=False, 30 | ): 31 | super().__init__( 32 | trainer, 33 | exploration_env, 34 | evaluation_env, 35 | exploration_data_collector, 36 | evaluation_data_collector, 37 | replay_buffer, 38 | eval_epoch_freq=eval_epoch_freq, 39 | expl_epoch_freq=expl_epoch_freq, 40 | eval_only=eval_only, 41 | no_training=no_training, 42 | ) 43 | self.batch_size = batch_size 44 | self.max_path_length = max_path_length 45 | self.num_epochs = num_epochs 46 | self.num_eval_steps_per_epoch = num_eval_steps_per_epoch 47 | self.num_trains_per_train_loop = num_trains_per_train_loop 48 | self.num_train_loops_per_epoch = num_train_loops_per_epoch 49 | self.num_expl_steps_per_train_loop = num_expl_steps_per_train_loop 50 | self.min_num_steps_before_training = min_num_steps_before_training 51 | gt.reset_root() 52 | 53 | def _train(self): 54 | if self.min_num_steps_before_training > 0 and not self._eval_only: 55 | init_expl_paths = self.expl_data_collector.collect_new_paths( 56 | self.max_path_length, 57 | self.min_num_steps_before_training, 58 | discard_incomplete_paths=True, #False, 59 | ) 60 | self.replay_buffer.add_paths(init_expl_paths) 61 | self.expl_data_collector.end_epoch(-1) 62 | 63 | for epoch in gt.timed_for( 64 | range(self._start_epoch, self.num_epochs + 1), 65 | save_itrs=True, 66 | ): 67 | for pre_epoch_func in self.pre_epoch_funcs: 68 | pre_epoch_func(self, epoch) 69 | 70 | if epoch % self._eval_epoch_freq == 0: 71 | self.eval_data_collector.collect_new_paths( 72 | self.max_path_length, 73 | self.num_eval_steps_per_epoch, 74 | discard_incomplete_paths=True, 75 | ) 76 | gt.stamp('evaluation sampling') 77 | 78 | if not self._eval_only: 79 | for _ in range(self.num_train_loops_per_epoch): 80 | if epoch % self._expl_epoch_freq == 0: 81 | new_expl_paths = self.expl_data_collector.collect_new_paths( 82 | self.max_path_length, 83 | self.num_expl_steps_per_train_loop, 84 | discard_incomplete_paths=True, #False, 85 | ) 86 | gt.stamp('exploration sampling', unique=False) 87 | 88 | self.replay_buffer.add_paths(new_expl_paths) 89 | gt.stamp('data storing', unique=False) 90 | 91 | if not self._no_training: 92 | self.training_mode(True) 93 | for _ in range(self.num_trains_per_train_loop): 94 | train_data = self.replay_buffer.random_batch( 95 | self.batch_size) 96 | self.trainer.train(train_data) 97 | gt.stamp('training', unique=False) 98 | self.training_mode(False) 99 | 100 | self._end_epoch(epoch) 101 | -------------------------------------------------------------------------------- /maple/core/eval_util.py: -------------------------------------------------------------------------------- 1 | """ 2 | Common evaluation utilities. 3 | """ 4 | 5 | from collections import OrderedDict 6 | from numbers import Number 7 | 8 | import numpy as np 9 | import itertools 10 | 11 | import maple.pythonplusplus as ppp 12 | 13 | 14 | def get_generic_path_information(paths, stat_prefix=''): 15 | """ 16 | Get an OrderedDict with a bunch of statistic names and values. 17 | """ 18 | statistics = OrderedDict() 19 | returns = [sum(path["rewards"]) for path in paths] 20 | 21 | rewards = np.vstack([path["rewards"] for path in paths]) 22 | statistics.update(create_stats_ordered_dict('Rewards', rewards, 23 | stat_prefix=stat_prefix)) 24 | statistics.update(create_stats_ordered_dict('Returns', returns, 25 | stat_prefix=stat_prefix)) 26 | actions = [path["actions"] for path in paths] 27 | if len(actions[0].shape) == 1: 28 | actions = np.hstack([path["actions"] for path in paths]) 29 | else: 30 | actions = np.vstack([path["actions"] for path in paths]) 31 | statistics.update(create_stats_ordered_dict( 32 | 'Actions', actions, stat_prefix=stat_prefix 33 | )) 34 | statistics['Num Paths'] = len(paths) 35 | statistics[stat_prefix + 'Average Returns'] = get_average_returns(paths) 36 | 37 | statistics[stat_prefix + 'Task Returns Sum'] = np.mean([path['reward_actions_sum'] for path in paths]) 38 | statistics[stat_prefix + 'Task Returns Avg'] = np.mean([path['reward_actions_sum'] / path['path_length_actions'] for path in paths]) 39 | 40 | statistics[stat_prefix + 'Num Rollout Success'] = get_num_rollout_success(paths) 41 | 42 | for info_key in ['env_infos', 'agent_infos']: 43 | if info_key in paths[0]: 44 | all_env_infos = [ 45 | ppp.list_of_dicts__to__dict_of_lists(p[info_key]) 46 | for p in paths 47 | ] 48 | for k in all_env_infos[0].keys(): 49 | final_ks = np.array([info[k][-1] for info in all_env_infos]) 50 | first_ks = np.array([info[k][0] for info in all_env_infos]) 51 | all_ks = np.concatenate([info[k] for info in all_env_infos]) 52 | statistics.update(create_stats_ordered_dict( 53 | stat_prefix + k, 54 | final_ks, 55 | stat_prefix='{}/final/'.format(info_key), 56 | )) 57 | statistics.update(create_stats_ordered_dict( 58 | stat_prefix + k, 59 | first_ks, 60 | stat_prefix='{}/initial/'.format(info_key), 61 | )) 62 | statistics.update(create_stats_ordered_dict( 63 | stat_prefix + k, 64 | all_ks, 65 | stat_prefix='{}/'.format(info_key), 66 | )) 67 | 68 | return statistics 69 | 70 | 71 | def get_average_returns(paths): 72 | returns = [sum(path["rewards"]) for path in paths] 73 | return np.mean(returns) 74 | 75 | def get_num_rollout_success(paths): 76 | num_success = 0 77 | for path in paths: 78 | if any([info.get('success', False) for info in path['env_infos']]): 79 | num_success += 1 80 | return num_success 81 | 82 | 83 | def create_stats_ordered_dict( 84 | name, 85 | data, 86 | stat_prefix=None, 87 | always_show_all_stats=True, 88 | exclude_max_min=True, 89 | ): 90 | if stat_prefix is not None: 91 | name = "{}{}".format(stat_prefix, name) 92 | if isinstance(data, Number): 93 | return OrderedDict({name: data}) 94 | 95 | if len(data) == 0: 96 | return OrderedDict() 97 | 98 | if isinstance(data, tuple): 99 | ordered_dict = OrderedDict() 100 | for number, d in enumerate(data): 101 | sub_dict = create_stats_ordered_dict( 102 | "{0}_{1}".format(name, number), 103 | d, 104 | ) 105 | ordered_dict.update(sub_dict) 106 | return ordered_dict 107 | 108 | if isinstance(data, list): 109 | try: 110 | iter(data[0]) 111 | except TypeError: 112 | pass 113 | else: 114 | data = np.concatenate(data) 115 | 116 | if (isinstance(data, np.ndarray) and data.size == 1 117 | and not always_show_all_stats): 118 | return OrderedDict({name: float(data)}) 119 | 120 | stats = OrderedDict([ 121 | (name + ' Mean', np.mean(data)), 122 | (name + ' Std', np.std(data)), 123 | ]) 124 | if not exclude_max_min: 125 | stats[name + ' Max'] = np.max(data) 126 | stats[name + ' Min'] = np.min(data) 127 | return stats 128 | -------------------------------------------------------------------------------- /maple/core/loss.py: -------------------------------------------------------------------------------- 1 | import abc 2 | from collections import OrderedDict 3 | 4 | 5 | LossStatistics = OrderedDict 6 | 7 | 8 | class LossFunction(object, metaclass=abc.ABCMeta): 9 | @abc.abstractmethod 10 | def compute_loss(self, batch, skip_statistics=False, **kwargs): 11 | """Returns loss and statistics given a batch of data. 12 | batch : Data to compute loss of 13 | skip_statistics: Whether statistics should be calculated. If True, then 14 | an empty dict is returned for the statistics. 15 | 16 | Returns: (loss, stats) tuple. 17 | """ 18 | pass 19 | -------------------------------------------------------------------------------- /maple/core/online_rl_algorithm.py: -------------------------------------------------------------------------------- 1 | import abc 2 | 3 | import gtimer as gt 4 | from maple.core.rl_algorithm import BaseRLAlgorithm 5 | from maple.data_management.replay_buffer import ReplayBuffer 6 | from maple.samplers.data_collector import ( 7 | PathCollector, 8 | StepCollector, 9 | ) 10 | 11 | 12 | class OnlineRLAlgorithm(BaseRLAlgorithm, metaclass=abc.ABCMeta): 13 | def __init__( 14 | self, 15 | trainer, 16 | exploration_env, 17 | evaluation_env, 18 | exploration_data_collector: StepCollector, 19 | evaluation_data_collector: PathCollector, 20 | replay_buffer: ReplayBuffer, 21 | batch_size, 22 | max_path_length, 23 | num_epochs, 24 | num_eval_steps_per_epoch, 25 | num_expl_steps_per_train_loop, 26 | num_trains_per_train_loop, 27 | num_train_loops_per_epoch=1, 28 | min_num_steps_before_training=0, 29 | ): 30 | super().__init__( 31 | trainer, 32 | exploration_env, 33 | evaluation_env, 34 | exploration_data_collector, 35 | evaluation_data_collector, 36 | replay_buffer, 37 | ) 38 | self.batch_size = batch_size 39 | self.max_path_length = max_path_length 40 | self.num_epochs = num_epochs 41 | self.num_eval_steps_per_epoch = num_eval_steps_per_epoch 42 | self.num_trains_per_train_loop = num_trains_per_train_loop 43 | self.num_train_loops_per_epoch = num_train_loops_per_epoch 44 | self.num_expl_steps_per_train_loop = num_expl_steps_per_train_loop 45 | self.min_num_steps_before_training = min_num_steps_before_training 46 | 47 | assert self.num_trains_per_train_loop >= self.num_expl_steps_per_train_loop, \ 48 | 'Online training presumes num_trains_per_train_loop >= num_expl_steps_per_train_loop' 49 | 50 | def _train(self): 51 | self.training_mode(False) 52 | if self.min_num_steps_before_training > 0: 53 | self.expl_data_collector.collect_new_steps( 54 | self.max_path_length, 55 | self.min_num_steps_before_training, 56 | discard_incomplete_paths=False, 57 | ) 58 | init_expl_paths = self.expl_data_collector.get_epoch_paths() 59 | self.replay_buffer.add_paths(init_expl_paths) 60 | self.expl_data_collector.end_epoch(-1) 61 | 62 | gt.stamp('initial exploration', unique=True) 63 | 64 | num_trains_per_expl_step = self.num_trains_per_train_loop // self.num_expl_steps_per_train_loop 65 | for epoch in gt.timed_for( 66 | range(self._start_epoch, self.num_epochs), 67 | save_itrs=True, 68 | ): 69 | self.eval_data_collector.collect_new_paths( 70 | self.max_path_length, 71 | self.num_eval_steps_per_epoch, 72 | discard_incomplete_paths=True, 73 | ) 74 | gt.stamp('evaluation sampling') 75 | 76 | for _ in range(self.num_train_loops_per_epoch): 77 | for _ in range(self.num_expl_steps_per_train_loop): 78 | self.expl_data_collector.collect_new_steps( 79 | self.max_path_length, 80 | 1, # num steps 81 | discard_incomplete_paths=False, 82 | ) 83 | gt.stamp('exploration sampling', unique=False) 84 | 85 | self.training_mode(True) 86 | for _ in range(num_trains_per_expl_step): 87 | train_data = self.replay_buffer.random_batch( 88 | self.batch_size) 89 | self.trainer.train(train_data) 90 | gt.stamp('training', unique=False) 91 | self.training_mode(False) 92 | 93 | new_expl_paths = self.expl_data_collector.get_epoch_paths() 94 | self.replay_buffer.add_paths(new_expl_paths) 95 | gt.stamp('data storing', unique=False) 96 | 97 | self._end_epoch(epoch) 98 | -------------------------------------------------------------------------------- /maple/core/rl_algorithm.py: -------------------------------------------------------------------------------- 1 | import abc 2 | from collections import OrderedDict 3 | 4 | import gtimer as gt 5 | 6 | from maple.core import logger, eval_util 7 | from maple.core.logging import append_log 8 | from maple.data_management.replay_buffer import ReplayBuffer 9 | from maple.samplers.data_collector import DataCollector 10 | 11 | 12 | def _get_epoch_timings(): 13 | times_itrs = gt.get_times().stamps.itrs 14 | times = OrderedDict() 15 | epoch_time = 0 16 | for key in sorted(times_itrs): 17 | time = times_itrs[key][-1] 18 | epoch_time += time 19 | times['time/{} (s)'.format(key)] = time 20 | times['time/epoch (s)'] = epoch_time 21 | times['time/total (s)'] = gt.get_times().total 22 | return times 23 | 24 | 25 | class BaseRLAlgorithm(object, metaclass=abc.ABCMeta): 26 | def __init__( 27 | self, 28 | trainer, 29 | exploration_env, 30 | evaluation_env, 31 | exploration_data_collector: DataCollector, 32 | evaluation_data_collector: DataCollector, 33 | replay_buffer: ReplayBuffer, 34 | eval_epoch_freq=1, 35 | expl_epoch_freq=1, 36 | eval_only=False, 37 | no_training=False, 38 | ): 39 | self.trainer = trainer 40 | self.expl_env = exploration_env 41 | self.eval_env = evaluation_env 42 | self.expl_data_collector = exploration_data_collector 43 | self.eval_data_collector = evaluation_data_collector 44 | self.replay_buffer = replay_buffer 45 | self._start_epoch = 0 46 | 47 | self.pre_epoch_funcs = [] 48 | self.post_epoch_funcs = [] 49 | 50 | self._eval_epoch_freq = eval_epoch_freq 51 | self._expl_epoch_freq = expl_epoch_freq 52 | self._eval_only = eval_only 53 | 54 | self._no_training = no_training 55 | 56 | def train(self, start_epoch=0): 57 | self._start_epoch = start_epoch 58 | self._train() 59 | 60 | def _train(self): 61 | """ 62 | Train model. 63 | """ 64 | raise NotImplementedError('_train must implemented by inherited class') 65 | 66 | def _end_epoch(self, epoch): 67 | snapshot = self._get_snapshot() 68 | logger.save_itr_params(epoch, snapshot) 69 | gt.stamp('saving') 70 | self._log_stats(epoch) 71 | 72 | self.expl_data_collector.end_epoch(epoch) 73 | self.eval_data_collector.end_epoch(epoch) 74 | self.replay_buffer.end_epoch(epoch) 75 | self.trainer.end_epoch(epoch) 76 | 77 | for post_epoch_func in self.post_epoch_funcs: 78 | post_epoch_func(self, epoch) 79 | 80 | def _get_snapshot(self): 81 | snapshot = {} 82 | for k, v in self.trainer.get_snapshot().items(): 83 | snapshot['trainer/' + k] = v 84 | for k, v in self.expl_data_collector.get_snapshot().items(): 85 | snapshot['exploration/' + k] = v 86 | for k, v in self.eval_data_collector.get_snapshot().items(): 87 | snapshot['evaluation/' + k] = v 88 | for k, v in self.replay_buffer.get_snapshot().items(): 89 | snapshot['replay_buffer/' + k] = v 90 | return snapshot 91 | 92 | def _log_stats(self, epoch): 93 | dump_logs = False 94 | if not self._eval_only and epoch % self._expl_epoch_freq == 0: 95 | dump_logs = True 96 | if epoch % self._eval_epoch_freq == 0: 97 | dump_logs = True 98 | if not dump_logs: 99 | return 100 | 101 | logger.log("Epoch {} finished".format(epoch), with_timestamp=True) 102 | 103 | """ 104 | Replay Buffer 105 | """ 106 | if not self._eval_only: 107 | logger.record_dict( 108 | self.replay_buffer.get_diagnostics(), 109 | prefix='replay_buffer/' 110 | ) 111 | 112 | """ 113 | Trainer 114 | """ 115 | if not self._eval_only: 116 | logger.record_dict(self.trainer.get_diagnostics(), prefix='trainer/') 117 | 118 | """ 119 | Exploration 120 | """ 121 | if not self._eval_only: 122 | if epoch % self._expl_epoch_freq == 0: 123 | self._cur_expl_log = OrderedDict() 124 | self._cur_expl_log.update( 125 | self.expl_data_collector.get_diagnostics() 126 | ) 127 | expl_paths = self.expl_data_collector.get_epoch_paths() 128 | if hasattr(self.expl_env, 'get_diagnostics'): 129 | self._cur_expl_log.update( 130 | self.expl_env.get_diagnostics(expl_paths), 131 | ) 132 | self._cur_expl_log.update( 133 | eval_util.get_generic_path_information(expl_paths), 134 | ) 135 | logger.record_dict(self._cur_expl_log, prefix='expl/') 136 | 137 | """ 138 | Evaluation 139 | """ 140 | if epoch % self._eval_epoch_freq == 0: 141 | self._cur_eval_log = OrderedDict() 142 | self._cur_eval_log.update( 143 | self.eval_data_collector.get_diagnostics() 144 | ) 145 | eval_paths = self.eval_data_collector.get_epoch_paths() 146 | if hasattr(self.eval_env, 'get_diagnostics'): 147 | self._cur_eval_log.update( 148 | self.eval_env.get_diagnostics(eval_paths), 149 | ) 150 | self._cur_eval_log.update( 151 | eval_util.get_generic_path_information(eval_paths), 152 | ) 153 | logger.record_dict(self._cur_eval_log, prefix='eval/') 154 | 155 | """ 156 | Misc 157 | """ 158 | try: 159 | import os 160 | import psutil 161 | process = psutil.Process(os.getpid()) 162 | k = 'process/RAM Usage (Mb)' 163 | v = int(process.memory_info().rss / 1000000) 164 | logger.record_tabular(k, v) 165 | logger.record_tabular('process/Num Threads', process.num_threads()) 166 | except ImportError: 167 | pass 168 | 169 | gt.stamp('logging') 170 | logger.record_dict(_get_epoch_timings()) 171 | logger.record_tabular('Dummy', 0) 172 | logger.record_tabular('Epoch', epoch) 173 | logger.dump_tabular(with_prefix=False, with_timestamp=False) 174 | 175 | @abc.abstractmethod 176 | def training_mode(self, mode): 177 | """ 178 | Set training mode to `mode`. 179 | :param mode: If True, training will happen (e.g. set the dropout 180 | probabilities to not all ones). 181 | """ 182 | pass 183 | -------------------------------------------------------------------------------- /maple/core/serializable.py: -------------------------------------------------------------------------------- 1 | """ 2 | Based on rllab's serializable.py file 3 | 4 | https://github.com/rll/rllab 5 | """ 6 | 7 | import inspect 8 | import sys 9 | 10 | 11 | class Serializable(object): 12 | 13 | def __init__(self, *args, **kwargs): 14 | self.__args = args 15 | self.__kwargs = kwargs 16 | 17 | def quick_init(self, locals_): 18 | if getattr(self, "_serializable_initialized", False): 19 | return 20 | if sys.version_info >= (3, 0): 21 | spec = inspect.getfullargspec(self.__init__) 22 | # Exclude the first "self" parameter 23 | if spec.varkw: 24 | kwargs = locals_[spec.varkw].copy() 25 | else: 26 | kwargs = dict() 27 | if spec.kwonlyargs: 28 | for key in spec.kwonlyargs: 29 | kwargs[key] = locals_[key] 30 | else: 31 | spec = inspect.getargspec(self.__init__) 32 | if spec.keywords: 33 | kwargs = locals_[spec.keywords] 34 | else: 35 | kwargs = dict() 36 | if spec.varargs: 37 | varargs = locals_[spec.varargs] 38 | else: 39 | varargs = tuple() 40 | in_order_args = [locals_[arg] for arg in spec.args][1:] 41 | self.__args = tuple(in_order_args) + varargs 42 | self.__kwargs = kwargs 43 | setattr(self, "_serializable_initialized", True) 44 | 45 | def __getstate__(self): 46 | return {"__args": self.__args, "__kwargs": self.__kwargs} 47 | 48 | def __setstate__(self, d): 49 | # convert all __args to keyword-based arguments 50 | if sys.version_info >= (3, 0): 51 | spec = inspect.getfullargspec(self.__init__) 52 | else: 53 | spec = inspect.getargspec(self.__init__) 54 | in_order_args = spec.args[1:] 55 | out = type(self)(**dict(zip(in_order_args, d["__args"]), **d["__kwargs"])) 56 | self.__dict__.update(out.__dict__) 57 | 58 | @classmethod 59 | def clone(cls, obj, **kwargs): 60 | assert isinstance(obj, Serializable) 61 | d = obj.__getstate__() 62 | d["__kwargs"] = dict(d["__kwargs"], **kwargs) 63 | out = type(obj).__new__(type(obj)) 64 | out.__setstate__(d) 65 | return out 66 | -------------------------------------------------------------------------------- /maple/core/trainer.py: -------------------------------------------------------------------------------- 1 | import abc 2 | 3 | 4 | class Trainer(object, metaclass=abc.ABCMeta): 5 | @abc.abstractmethod 6 | def train(self, data): 7 | pass 8 | 9 | def end_epoch(self, epoch): 10 | pass 11 | 12 | def get_snapshot(self): 13 | return {} 14 | 15 | def get_diagnostics(self): 16 | return {} 17 | -------------------------------------------------------------------------------- /maple/data_management/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UT-Austin-RPL/maple/c13f584d4184e5796bd0adb822b8761ff4ebeee0/maple/data_management/__init__.py -------------------------------------------------------------------------------- /maple/data_management/env_replay_buffer.py: -------------------------------------------------------------------------------- 1 | from gym.spaces import Discrete 2 | 3 | from maple.data_management.simple_replay_buffer import SimpleReplayBuffer 4 | from maple.envs.env_utils import get_dim 5 | import numpy as np 6 | 7 | 8 | class EnvReplayBuffer(SimpleReplayBuffer): 9 | def __init__( 10 | self, 11 | max_replay_buffer_size, 12 | env, 13 | env_info_sizes=None 14 | ): 15 | """ 16 | :param max_replay_buffer_size: 17 | :param env: 18 | """ 19 | self.env = env 20 | self._ob_space = env.observation_space 21 | self._action_space = env.action_space 22 | 23 | if env_info_sizes is None: 24 | if hasattr(env, 'info_sizes'): 25 | env_info_sizes = env.info_sizes 26 | else: 27 | env_info_sizes = dict() 28 | 29 | super().__init__( 30 | max_replay_buffer_size=max_replay_buffer_size, 31 | observation_dim=get_dim(self._ob_space), 32 | action_dim=get_dim(self._action_space), 33 | env_info_sizes=env_info_sizes 34 | ) 35 | 36 | def add_sample(self, observation, action, reward, terminal, 37 | next_observation, **kwargs): 38 | if isinstance(self._action_space, Discrete): 39 | new_action = np.zeros(self._action_dim) 40 | new_action[action] = 1 41 | else: 42 | new_action = action 43 | return super().add_sample( 44 | observation=observation, 45 | action=new_action, 46 | reward=reward, 47 | next_observation=next_observation, 48 | terminal=terminal, 49 | **kwargs 50 | ) 51 | -------------------------------------------------------------------------------- /maple/data_management/normalizer.py: -------------------------------------------------------------------------------- 1 | """ 2 | Based on code from Marcin Andrychowicz 3 | """ 4 | import numpy as np 5 | 6 | 7 | class Normalizer(object): 8 | def __init__( 9 | self, 10 | size, 11 | eps=1e-8, 12 | default_clip_range=np.inf, 13 | mean=0, 14 | std=1, 15 | ): 16 | self.size = size 17 | self.eps = eps 18 | self.default_clip_range = default_clip_range 19 | self.sum = np.zeros(self.size, np.float32) 20 | self.sumsq = np.zeros(self.size, np.float32) 21 | self.count = np.ones(1, np.float32) 22 | self.mean = mean + np.zeros(self.size, np.float32) 23 | self.std = std * np.ones(self.size, np.float32) 24 | self.synchronized = True 25 | 26 | def update(self, v): 27 | if v.ndim == 1: 28 | v = np.expand_dims(v, 0) 29 | assert v.ndim == 2 30 | assert v.shape[1] == self.size 31 | self.sum += v.sum(axis=0) 32 | self.sumsq += (np.square(v)).sum(axis=0) 33 | self.count[0] += v.shape[0] 34 | self.synchronized = False 35 | 36 | def normalize(self, v, clip_range=None): 37 | if not self.synchronized: 38 | self.synchronize() 39 | if clip_range is None: 40 | clip_range = self.default_clip_range 41 | mean, std = self.mean, self.std 42 | if v.ndim == 2: 43 | mean = mean.reshape(1, -1) 44 | std = std.reshape(1, -1) 45 | return np.clip((v - mean) / std, -clip_range, clip_range) 46 | 47 | def denormalize(self, v): 48 | if not self.synchronized: 49 | self.synchronize() 50 | mean, std = self.mean, self.std 51 | if v.ndim == 2: 52 | mean = mean.reshape(1, -1) 53 | std = std.reshape(1, -1) 54 | return mean + v * std 55 | 56 | def synchronize(self): 57 | self.mean[...] = self.sum / self.count[0] 58 | self.std[...] = np.sqrt( 59 | np.maximum( 60 | np.square(self.eps), 61 | self.sumsq / self.count[0] - np.square(self.mean) 62 | ) 63 | ) 64 | self.synchronized = True 65 | 66 | 67 | class IdentityNormalizer(object): 68 | def __init__(self, *args, **kwargs): 69 | pass 70 | 71 | def update(self, v): 72 | pass 73 | 74 | def normalize(self, v, clip_range=None): 75 | return v 76 | 77 | def denormalize(self, v): 78 | return v 79 | 80 | 81 | class FixedNormalizer(object): 82 | def __init__( 83 | self, 84 | size, 85 | default_clip_range=np.inf, 86 | mean=0, 87 | std=1, 88 | eps=1e-8, 89 | ): 90 | assert std > 0 91 | std = std + eps 92 | self.size = size 93 | self.default_clip_range = default_clip_range 94 | self.mean = mean + np.zeros(self.size, np.float32) 95 | self.std = std + np.zeros(self.size, np.float32) 96 | self.eps = eps 97 | 98 | def set_mean(self, mean): 99 | self.mean = mean + np.zeros(self.size, np.float32) 100 | 101 | def set_std(self, std): 102 | std = std + self.eps 103 | self.std = std + np.zeros(self.size, np.float32) 104 | 105 | def normalize(self, v, clip_range=None): 106 | if clip_range is None: 107 | clip_range = self.default_clip_range 108 | mean, std = self.mean, self.std 109 | if v.ndim == 2: 110 | mean = mean.reshape(1, -1) 111 | std = std.reshape(1, -1) 112 | return np.clip((v - mean) / std, -clip_range, clip_range) 113 | 114 | def denormalize(self, v): 115 | mean, std = self.mean, self.std 116 | if v.ndim == 2: 117 | mean = mean.reshape(1, -1) 118 | std = std.reshape(1, -1) 119 | return mean + v * std 120 | 121 | def copy_stats(self, other): 122 | self.set_mean(other.mean) 123 | self.set_std(other.std) 124 | -------------------------------------------------------------------------------- /maple/data_management/path_builder.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | class PathBuilder(dict): 5 | """ 6 | Usage: 7 | ``` 8 | path_builder = PathBuilder() 9 | path.add_sample( 10 | observations=1, 11 | actions=2, 12 | next_observations=3, 13 | ... 14 | ) 15 | path.add_sample( 16 | observations=4, 17 | actions=5, 18 | next_observations=6, 19 | ... 20 | ) 21 | 22 | path = path_builder.get_all_stacked() 23 | 24 | path['observations'] 25 | # output: [1, 4] 26 | path['actions'] 27 | # output: [2, 5] 28 | ``` 29 | 30 | Note that the key should be "actions" and not "action" since the 31 | resulting dictionary will have those keys. 32 | """ 33 | 34 | def __init__(self): 35 | super().__init__() 36 | self._path_length = 0 37 | 38 | def add_all(self, **key_to_value): 39 | for k, v in key_to_value.items(): 40 | if k not in self: 41 | self[k] = [v] 42 | else: 43 | self[k].append(v) 44 | self._path_length += 1 45 | 46 | def get_all_stacked(self): 47 | output_dict = dict() 48 | for k, v in self.items(): 49 | output_dict[k] = stack_list(v) 50 | return output_dict 51 | 52 | def __len__(self): 53 | return self._path_length 54 | 55 | 56 | def stack_list(lst): 57 | if isinstance(lst[0], dict): 58 | return lst 59 | else: 60 | return np.array(lst) 61 | -------------------------------------------------------------------------------- /maple/data_management/replay_buffer.py: -------------------------------------------------------------------------------- 1 | import abc 2 | 3 | 4 | class ReplayBuffer(object, metaclass=abc.ABCMeta): 5 | """ 6 | A class used to save and replay data. 7 | """ 8 | 9 | @abc.abstractmethod 10 | def add_sample(self, observation, action, reward, next_observation, 11 | terminal, **kwargs): 12 | """ 13 | Add a transition tuple. 14 | """ 15 | pass 16 | 17 | @abc.abstractmethod 18 | def terminate_episode(self): 19 | """ 20 | Let the replay buffer know that the episode has terminated in case some 21 | special book-keeping has to happen. 22 | :return: 23 | """ 24 | pass 25 | 26 | @abc.abstractmethod 27 | def num_steps_can_sample(self, **kwargs): 28 | """ 29 | :return: # of unique items that can be sampled. 30 | """ 31 | pass 32 | 33 | def add_path(self, path): 34 | """ 35 | Add a path to the replay buffer. 36 | 37 | This default implementation naively goes through every step, but you 38 | may want to optimize this. 39 | 40 | NOTE: You should NOT call "terminate_episode" after calling add_path. 41 | It's assumed that this function handles the episode termination. 42 | 43 | :param path: Dict like one outputted by maple.samplers.util.rollout 44 | """ 45 | for i, ( 46 | obs, 47 | action, 48 | reward, 49 | next_obs, 50 | terminal, 51 | agent_info, 52 | env_info 53 | ) in enumerate(zip( 54 | path["observations"], 55 | path["actions"], 56 | path["rewards"], 57 | path["next_observations"], 58 | path["terminals"], 59 | path["agent_infos"], 60 | path["env_infos"], 61 | )): 62 | self.add_sample( 63 | observation=obs, 64 | action=action, 65 | reward=reward, 66 | next_observation=next_obs, 67 | terminal=terminal, 68 | agent_info=agent_info, 69 | env_info=env_info, 70 | ) 71 | self.terminate_episode() 72 | 73 | def add_paths(self, paths): 74 | for path in paths: 75 | self.add_path(path) 76 | 77 | @abc.abstractmethod 78 | def random_batch(self, batch_size): 79 | """ 80 | Return a batch of size `batch_size`. 81 | :param batch_size: 82 | :return: 83 | """ 84 | pass 85 | 86 | def get_diagnostics(self): 87 | return {} 88 | 89 | def get_snapshot(self): 90 | return {} 91 | 92 | def end_epoch(self, epoch): 93 | return 94 | 95 | -------------------------------------------------------------------------------- /maple/data_management/shared_obs_dict_replay_buffer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from maple.data_management.obs_dict_replay_buffer import ObsDictRelabelingBuffer 4 | 5 | import torch.multiprocessing as mp 6 | import ctypes 7 | 8 | 9 | class SharedObsDictRelabelingBuffer(ObsDictRelabelingBuffer): 10 | """ 11 | Same as an ObsDictRelabelingBuffer but the obs and next_obs are backed 12 | by multiprocessing arrays. The replay buffer size is also shared. The 13 | intended use case is for if one wants obs/next_obs to be shared between 14 | processes. Accesses are synchronized internally by locks (mp takes care 15 | of that). Technically, putting such large arrays in shared memory/requiring 16 | synchronized access can be extremely slow, but it seems ok empirically. 17 | 18 | This code also breaks a lot of functionality for the subprocess. For example, 19 | random_batch is incorrect as actions and _idx_to_future_obs_idx are not 20 | shared. If the subprocess needs all of the functionality, a mp.Array 21 | must be used for all numpy arrays in the replay buffer. 22 | 23 | """ 24 | 25 | def __init__( 26 | self, 27 | *args, 28 | **kwargs 29 | ): 30 | self._shared_size = mp.Value(ctypes.c_long, 0) 31 | ObsDictRelabelingBuffer.__init__(self, *args, **kwargs) 32 | 33 | self._mp_array_info = {} 34 | self._shared_obs_info = {} 35 | self._shared_next_obs_info = {} 36 | 37 | for obs_key, obs_arr in self._obs.items(): 38 | ctype = ctypes.c_double 39 | if obs_arr.dtype == np.uint8: 40 | ctype = ctypes.c_uint8 41 | 42 | self._shared_obs_info[obs_key] = ( 43 | mp.Array(ctype, obs_arr.size), 44 | obs_arr.dtype, 45 | obs_arr.shape, 46 | ) 47 | self._shared_next_obs_info[obs_key] = ( 48 | mp.Array(ctype, obs_arr.size), 49 | obs_arr.dtype, 50 | obs_arr.shape, 51 | ) 52 | 53 | self._obs[obs_key] = to_np(*self._shared_obs_info[obs_key]) 54 | self._next_obs[obs_key] = to_np( 55 | *self._shared_next_obs_info[obs_key]) 56 | self._register_mp_array("_actions") 57 | self._register_mp_array("_terminals") 58 | 59 | def _register_mp_array(self, arr_instance_var_name): 60 | """ 61 | Use this function to register an array to be shared. This will wipe arr. 62 | """ 63 | assert hasattr(self, arr_instance_var_name), arr_instance_var_name 64 | arr = getattr(self, arr_instance_var_name) 65 | 66 | ctype = ctypes.c_double 67 | if arr.dtype == np.uint8: 68 | ctype = ctypes.c_uint8 69 | 70 | self._mp_array_info[arr_instance_var_name] = ( 71 | mp.Array(ctype, arr.size), arr.dtype, arr.shape, 72 | ) 73 | setattr( 74 | self, 75 | arr_instance_var_name, 76 | to_np(*self._mp_array_info[arr_instance_var_name]) 77 | ) 78 | 79 | def init_from_mp_info( 80 | self, 81 | mp_info, 82 | ): 83 | """ 84 | The intended use is to have a subprocess serialize/copy a 85 | SharedObsDictRelabelingBuffer instance and call init_from on the 86 | instance's shared variables. This can't be done during serialization 87 | since multiprocessing shared objects can't be serialized and must be 88 | passed directly to the subprocess as an argument to the fork call. 89 | """ 90 | shared_obs_info, shared_next_obs_info, mp_array_info, shared_size = mp_info 91 | 92 | self._shared_obs_info = shared_obs_info 93 | self._shared_next_obs_info = shared_next_obs_info 94 | self._mp_array_info = mp_array_info 95 | for obs_key in self._shared_obs_info.keys(): 96 | self._obs[obs_key] = to_np(*self._shared_obs_info[obs_key]) 97 | self._next_obs[obs_key] = to_np( 98 | *self._shared_next_obs_info[obs_key]) 99 | 100 | for arr_instance_var_name in self._mp_array_info.keys(): 101 | setattr( 102 | self, 103 | arr_instance_var_name, 104 | to_np(*self._mp_array_info[arr_instance_var_name]) 105 | ) 106 | self._shared_size = shared_size 107 | 108 | def get_mp_info(self): 109 | return ( 110 | self._shared_obs_info, 111 | self._shared_next_obs_info, 112 | self._mp_array_info, 113 | self._shared_size, 114 | ) 115 | 116 | @property 117 | def _size(self): 118 | return self._shared_size.value 119 | 120 | @_size.setter 121 | def _size(self, size): 122 | self._shared_size.value = size 123 | 124 | 125 | def to_np(shared_arr, np_dtype, shape): 126 | return np.frombuffer(shared_arr.get_obj(), dtype=np_dtype).reshape(shape) 127 | -------------------------------------------------------------------------------- /maple/data_management/simple_replay_buffer.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | 3 | import numpy as np 4 | import warnings 5 | 6 | from maple.data_management.replay_buffer import ReplayBuffer 7 | 8 | 9 | class SimpleReplayBuffer(ReplayBuffer): 10 | 11 | def __init__( 12 | self, 13 | max_replay_buffer_size, 14 | observation_dim, 15 | action_dim, 16 | env_info_sizes, 17 | replace = True, 18 | ): 19 | self._observation_dim = observation_dim 20 | self._action_dim = action_dim 21 | self._max_replay_buffer_size = max_replay_buffer_size 22 | self._observations = np.zeros((max_replay_buffer_size, observation_dim)) 23 | # It's a bit memory inefficient to save the observations twice, 24 | # but it makes the code *much* easier since you no longer have to 25 | # worry about termination conditions. 26 | self._next_obs = np.zeros((max_replay_buffer_size, observation_dim)) 27 | self._actions = np.zeros((max_replay_buffer_size, action_dim)) 28 | # Make everything a 2D np array to make it easier for other code to 29 | # reason about the shape of the data 30 | self._rewards = np.zeros((max_replay_buffer_size, 1)) 31 | # self._terminals[i] = a terminal was received at time i 32 | self._terminals = np.zeros((max_replay_buffer_size, 1), dtype='uint8') 33 | # Define self._env_infos[key][i] to be the return value of env_info[key] 34 | # at time i 35 | self._env_infos = {} 36 | for key, size in env_info_sizes.items(): 37 | self._env_infos[key] = np.zeros((max_replay_buffer_size, size)) 38 | self._env_info_keys = env_info_sizes.keys() 39 | 40 | self._replace = replace 41 | 42 | self._top = 0 43 | self._size = 0 44 | 45 | def add_sample(self, observation, action, reward, next_observation, 46 | terminal, env_info, **kwargs): 47 | self._observations[self._top] = observation 48 | self._actions[self._top] = action 49 | self._rewards[self._top] = reward 50 | self._terminals[self._top] = terminal 51 | self._next_obs[self._top] = next_observation 52 | 53 | for key in self._env_info_keys: 54 | self._env_infos[key][self._top] = env_info[key] 55 | self._advance() 56 | 57 | def terminate_episode(self): 58 | pass 59 | 60 | def _advance(self): 61 | self._top = (self._top + 1) % self._max_replay_buffer_size 62 | if self._size < self._max_replay_buffer_size: 63 | self._size += 1 64 | 65 | def random_batch(self, batch_size): 66 | indices = np.random.choice(self._size, size=batch_size, replace=self._replace or self._size < batch_size) 67 | if not self._replace and self._size < batch_size: 68 | warnings.warn('Replace was set to false, but is temporarily set to true because batch size is larger than current size of replay.') 69 | batch = dict( 70 | observations=self._observations[indices], 71 | actions=self._actions[indices], 72 | rewards=self._rewards[indices], 73 | terminals=self._terminals[indices], 74 | next_observations=self._next_obs[indices], 75 | ) 76 | for key in self._env_info_keys: 77 | assert key not in batch.keys() 78 | batch[key] = self._env_infos[key][indices] 79 | return batch 80 | 81 | def rebuild_env_info_dict(self, idx): 82 | return { 83 | key: self._env_infos[key][idx] 84 | for key in self._env_info_keys 85 | } 86 | 87 | def batch_env_info_dict(self, indices): 88 | return { 89 | key: self._env_infos[key][indices] 90 | for key in self._env_info_keys 91 | } 92 | 93 | def num_steps_can_sample(self): 94 | return self._size 95 | 96 | def get_diagnostics(self): 97 | return OrderedDict([ 98 | ('size', self._size) 99 | ]) 100 | -------------------------------------------------------------------------------- /maple/data_management/split_buffer.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | from maple.data_management.replay_buffer import ReplayBuffer 4 | 5 | 6 | class SplitReplayBuffer(ReplayBuffer): 7 | """ 8 | Split the data into a training and validation set. 9 | """ 10 | def __init__( 11 | self, 12 | train_replay_buffer: ReplayBuffer, 13 | validation_replay_buffer: ReplayBuffer, 14 | fraction_paths_in_train, 15 | ): 16 | self.train_replay_buffer = train_replay_buffer 17 | self.validation_replay_buffer = validation_replay_buffer 18 | self.fraction_paths_in_train = fraction_paths_in_train 19 | self.replay_buffer = self.train_replay_buffer 20 | 21 | def add_sample(self, *args, **kwargs): 22 | self.replay_buffer.add_sample(*args, **kwargs) 23 | 24 | def add_path(self, path): 25 | self.replay_buffer.add_path(path) 26 | self._randomly_set_replay_buffer() 27 | 28 | def num_steps_can_sample(self): 29 | return min( 30 | self.train_replay_buffer.num_steps_can_sample(), 31 | self.validation_replay_buffer.num_steps_can_sample(), 32 | ) 33 | 34 | def terminate_episode(self, *args, **kwargs): 35 | self.replay_buffer.terminate_episode(*args, **kwargs) 36 | self._randomly_set_replay_buffer() 37 | 38 | def _randomly_set_replay_buffer(self): 39 | if random.random() <= self.fraction_paths_in_train: 40 | self.replay_buffer = self.train_replay_buffer 41 | else: 42 | self.replay_buffer = self.validation_replay_buffer 43 | 44 | def get_replay_buffer(self, training=True): 45 | if training: 46 | return self.train_replay_buffer 47 | else: 48 | return self.validation_replay_buffer 49 | 50 | def random_batch(self, batch_size): 51 | return self.train_replay_buffer.random_batch(batch_size) 52 | 53 | def __getattr__(self, attrname): 54 | return getattr(self.replay_buffer, attrname) 55 | 56 | def __getstate__(self): 57 | # Do not save self.replay_buffer since it's a duplicate and seems to 58 | # cause joblib recursion issues. 59 | return dict( 60 | train_replay_buffer=self.train_replay_buffer, 61 | validation_replay_buffer=self.validation_replay_buffer, 62 | fraction_paths_in_train=self.fraction_paths_in_train, 63 | ) 64 | 65 | def __setstate__(self, d): 66 | self.train_replay_buffer = d['train_replay_buffer'] 67 | self.validation_replay_buffer = d['validation_replay_buffer'] 68 | self.fraction_paths_in_train = d['fraction_paths_in_train'] 69 | self.replay_buffer = self.train_replay_buffer 70 | -------------------------------------------------------------------------------- /maple/envs/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UT-Austin-RPL/maple/c13f584d4184e5796bd0adb822b8761ff4ebeee0/maple/envs/__init__.py -------------------------------------------------------------------------------- /maple/envs/env_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from gym.spaces import Box, Discrete, Tuple 4 | 5 | ENV_ASSET_DIR = os.path.join(os.path.dirname(__file__), 'assets') 6 | 7 | 8 | def get_asset_full_path(file_name): 9 | return os.path.join(ENV_ASSET_DIR, file_name) 10 | 11 | 12 | def get_dim(space): 13 | if isinstance(space, Box): 14 | return space.low.size 15 | elif isinstance(space, Discrete): 16 | return space.n 17 | elif isinstance(space, Tuple): 18 | return sum(get_dim(subspace) for subspace in space.spaces) 19 | elif hasattr(space, 'flat_dim'): 20 | return space.flat_dim 21 | else: 22 | raise TypeError("Unknown space: {}".format(space)) 23 | 24 | 25 | def mode(env, mode_type): 26 | try: 27 | getattr(env, mode_type)() 28 | except AttributeError: 29 | pass 30 | -------------------------------------------------------------------------------- /maple/envs/make_env.py: -------------------------------------------------------------------------------- 1 | """ 2 | This file provides a more uniform interface to gym.make(env_id) that handles 3 | imports and normalization 4 | """ 5 | 6 | import gym 7 | 8 | from maple.envs.wrappers import NormalizedBoxEnv 9 | 10 | DAPG_ENVS = [ 11 | 'pen-v0', 'pen-sparse-v0', 'pen-notermination-v0', 'pen-binary-v0', 'pen-binary-old-v0', 12 | 'door-v0', 'door-sparse-v0', 'door-binary-v0', 'door-binary-old-v0', 13 | 'relocate-v0', 'relocate-sparse-v0', 'relocate-binary-v0', 'relocate-binary-old-v0', 14 | 'hammer-v0', 'hammer-sparse-v0', 'hammer-binary-v0', 15 | ] 16 | 17 | D4RL_ENVS = [ 18 | "maze2d-open-v0", "maze2d-umaze-v0", "maze2d-medium-v0", "maze2d-large-v0", 19 | "maze2d-open-dense-v0", "maze2d-umaze-dense-v0", "maze2d-medium-dense-v0", "maze2d-large-dense-v0", 20 | "antmaze-umaze-v0", "antmaze-umaze-diverse-v0", "antmaze-medium-diverse-v0", 21 | "antmaze-medium-play-v0", "antmaze-large-diverse-v0", "antmaze-large-play-v0", 22 | "pen-human-v0", "pen-cloned-v0", "pen-expert-v0", "hammer-human-v0", "hammer-cloned-v0", "hammer-expert-v0", 23 | "door-human-v0", "door-cloned-v0", "door-expert-v0", "relocate-human-v0", "relocate-cloned-v0", "relocate-expert-v0", 24 | "halfcheetah-random-v0", "halfcheetah-medium-v0", "halfcheetah-expert-v0", "halfcheetah-mixed-v0", "halfcheetah-medium-expert-v0", 25 | "walker2d-random-v0", "walker2d-medium-v0", "walker2d-expert-v0", "walker2d-mixed-v0", "walker2d-medium-expert-v0", 26 | "hopper-random-v0", "hopper-medium-v0", "hopper-expert-v0", "hopper-mixed-v0", "hopper-medium-expert-v0" 27 | ] 28 | 29 | def make(env_id=None, env_class=None, env_kwargs=None, normalize_env=True): 30 | assert env_id or env_class 31 | if env_class: 32 | env = env_class(**env_kwargs) 33 | elif env_id in DAPG_ENVS: 34 | import mj_envs 35 | assert normalize_env == False 36 | env = gym.make(env_id) 37 | elif env_id in D4RL_ENVS: 38 | import d4rl 39 | assert normalize_env == False 40 | env = gym.make(env_id) 41 | elif env_id: 42 | env = gym.make(env_id) 43 | env = env.env # unwrap TimeLimit 44 | 45 | if normalize_env: 46 | env = NormalizedBoxEnv(env) 47 | 48 | return env 49 | -------------------------------------------------------------------------------- /maple/envs/mujoco_env.py: -------------------------------------------------------------------------------- 1 | import os 2 | from os import path 3 | 4 | import mujoco_py 5 | import numpy as np 6 | from gym.envs.mujoco import mujoco_env 7 | 8 | from maple.core.serializable import Serializable 9 | 10 | ENV_ASSET_DIR = os.path.join(os.path.dirname(__file__), 'assets') 11 | 12 | 13 | class MujocoEnv(mujoco_env.MujocoEnv, Serializable): 14 | """ 15 | My own wrapper around MujocoEnv. 16 | 17 | The caller needs to declare 18 | """ 19 | def __init__( 20 | self, 21 | model_path, 22 | frame_skip=1, 23 | model_path_is_local=True, 24 | automatically_set_obs_and_action_space=False, 25 | ): 26 | if model_path_is_local: 27 | model_path = get_asset_xml(model_path) 28 | if automatically_set_obs_and_action_space: 29 | mujoco_env.MujocoEnv.__init__(self, model_path, frame_skip) 30 | else: 31 | """ 32 | Code below is copy/pasted from MujocoEnv's __init__ function. 33 | """ 34 | if model_path.startswith("/"): 35 | fullpath = model_path 36 | else: 37 | fullpath = os.path.join(os.path.dirname(__file__), "assets", model_path) 38 | if not path.exists(fullpath): 39 | raise IOError("File %s does not exist" % fullpath) 40 | self.frame_skip = frame_skip 41 | self.model = mujoco_py.MjModel(fullpath) 42 | self.data = self.model.data 43 | self.viewer = None 44 | 45 | self.metadata = { 46 | 'render.modes': ['human', 'rgb_array'], 47 | 'video.frames_per_second': int(np.round(1.0 / self.dt)) 48 | } 49 | 50 | self.init_qpos = self.model.data.qpos.ravel().copy() 51 | self.init_qvel = self.model.data.qvel.ravel().copy() 52 | self._seed() 53 | 54 | def init_serialization(self, locals): 55 | Serializable.quick_init(self, locals) 56 | 57 | def log_diagnostics(self, paths): 58 | pass 59 | 60 | 61 | def get_asset_xml(xml_name): 62 | return os.path.join(ENV_ASSET_DIR, xml_name) 63 | -------------------------------------------------------------------------------- /maple/envs/mujoco_image_env.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | import torch 4 | from PIL import Image 5 | from collections.__init__ import deque 6 | 7 | from gym import Env 8 | from gym.spaces import Box 9 | 10 | from maple.envs.wrappers import ProxyEnv 11 | 12 | 13 | class ImageMujocoEnv(ProxyEnv, Env): 14 | def __init__(self, 15 | wrapped_env, 16 | imsize=32, 17 | keep_prev=0, 18 | init_camera=None, 19 | camera_name=None, 20 | transpose=False, 21 | grayscale=False, 22 | normalize=False, 23 | ): 24 | import mujoco_py 25 | super().__init__(wrapped_env) 26 | 27 | self.imsize = imsize 28 | if grayscale: 29 | self.image_length = self.imsize * self.imsize 30 | else: 31 | self.image_length = 3 * self.imsize * self.imsize 32 | # This is torch format rather than PIL image 33 | self.image_shape = (self.imsize, self.imsize) 34 | # Flattened past image queue 35 | self.history_length = keep_prev + 1 36 | self.history = deque(maxlen=self.history_length) 37 | # init camera 38 | if init_camera is not None: 39 | sim = self._wrapped_env.sim 40 | viewer = mujoco_py.MjRenderContextOffscreen(sim, device_id=-1) 41 | init_camera(viewer.cam) 42 | sim.add_render_context(viewer) 43 | self.camera_name = camera_name # None means default camera 44 | self.transpose = transpose 45 | self.grayscale = grayscale 46 | self.normalize = normalize 47 | self._render_local = False 48 | 49 | self.observation_space = Box(low=0.0, 50 | high=1.0, 51 | shape=( 52 | self.image_length * self.history_length,)) 53 | 54 | def step(self, action): 55 | # image observation get returned as a flattened 1D array 56 | true_state, reward, done, info = super().step(action) 57 | 58 | observation = self._image_observation() 59 | self.history.append(observation) 60 | history = self._get_history().flatten() 61 | full_obs = self._get_obs(history, true_state) 62 | return full_obs, reward, done, info 63 | 64 | def reset(self, **kwargs): 65 | true_state = super().reset(**kwargs) 66 | self.history = deque(maxlen=self.history_length) 67 | 68 | observation = self._image_observation() 69 | self.history.append(observation) 70 | history = self._get_history().flatten() 71 | full_obs = self._get_obs(history, true_state) 72 | return full_obs 73 | 74 | def get_image(self): 75 | """TODO: this should probably consider history""" 76 | return self._image_observation() 77 | 78 | def _get_obs(self, history_flat, true_state): 79 | # adds extra information from true_state into to the image observation. 80 | # Used in ImageWithObsEnv. 81 | return history_flat 82 | 83 | def _image_observation(self): 84 | # returns the image as a torch format np array 85 | image_obs = self._wrapped_env.sim.render(width=self.imsize, 86 | height=self.imsize, 87 | camera_name=self.camera_name) 88 | if self._render_local: 89 | cv2.imshow('env', image_obs) 90 | cv2.waitKey(1) 91 | if self.grayscale: 92 | image_obs = Image.fromarray(image_obs).convert('L') 93 | image_obs = np.array(image_obs) 94 | if self.normalize: 95 | image_obs = image_obs / 255.0 96 | if self.transpose: 97 | image_obs = image_obs.transpose() 98 | return image_obs 99 | 100 | def _get_history(self): 101 | observations = list(self.history) 102 | 103 | obs_count = len(observations) 104 | for _ in range(self.history_length - obs_count): 105 | dummy = np.zeros(self.image_shape) 106 | observations.append(dummy) 107 | return np.c_[observations] 108 | 109 | def retrieve_images(self): 110 | # returns images in unflattened PIL format 111 | images = [] 112 | for image_obs in self.history: 113 | pil_image = self.torch_to_pil(torch.Tensor(image_obs)) 114 | images.append(pil_image) 115 | return images 116 | 117 | def split_obs(self, obs): 118 | # splits observation into image input and true observation input 119 | imlength = self.image_length * self.history_length 120 | obs_length = self.observation_space.low.size 121 | obs = obs.view(-1, obs_length) 122 | image_obs = obs.narrow(start=0, 123 | length=imlength, 124 | dimension=1) 125 | if obs_length == imlength: 126 | return image_obs, None 127 | 128 | fc_obs = obs.narrow(start=imlength, 129 | length=obs.shape[1] - imlength, 130 | dimension=1) 131 | return image_obs, fc_obs 132 | 133 | def enable_render(self): 134 | self._render_local = True 135 | 136 | 137 | class ImageMujocoWithObsEnv(ImageMujocoEnv): 138 | def __init__(self, env, **kwargs): 139 | super().__init__(env, **kwargs) 140 | self.observation_space = Box(low=0.0, 141 | high=1.0, 142 | shape=( 143 | self.image_length * self.history_length + 144 | self.wrapped_env.obs_dim,)) 145 | 146 | def _get_obs(self, history_flat, true_state): 147 | return np.concatenate([history_flat, 148 | true_state]) -------------------------------------------------------------------------------- /maple/envs/proxy_env.py: -------------------------------------------------------------------------------- 1 | from gym import Env 2 | 3 | 4 | class ProxyEnv(Env): 5 | def __init__(self, wrapped_env): 6 | self._wrapped_env = wrapped_env 7 | self.action_space = self._wrapped_env.action_space 8 | self.observation_space = self._wrapped_env.observation_space 9 | 10 | @property 11 | def wrapped_env(self): 12 | return self._wrapped_env 13 | 14 | def reset(self, **kwargs): 15 | return self._wrapped_env.reset(**kwargs) 16 | 17 | def step(self, action): 18 | return self._wrapped_env.step(action) 19 | 20 | def render(self, *args, **kwargs): 21 | return self._wrapped_env.render(*args, **kwargs) 22 | 23 | @property 24 | def horizon(self): 25 | return self._wrapped_env.horizon 26 | 27 | def terminate(self): 28 | if hasattr(self.wrapped_env, "terminate"): 29 | self.wrapped_env.terminate() 30 | 31 | def __getattr__(self, attr): 32 | if attr == '_wrapped_env': 33 | raise AttributeError() 34 | return getattr(self._wrapped_env, attr) 35 | 36 | def __getstate__(self): 37 | """ 38 | This is useful to override in case the wrapped env has some funky 39 | __getstate__ that doesn't play well with overriding __getattr__. 40 | 41 | The main problematic case is/was gym's EzPickle serialization scheme. 42 | :return: 43 | """ 44 | return self.__dict__ 45 | 46 | def __setstate__(self, state): 47 | self.__dict__.update(state) 48 | 49 | def __str__(self): 50 | return '{}({})'.format(type(self).__name__, self.wrapped_env) -------------------------------------------------------------------------------- /maple/envs/wrappers.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import itertools 3 | from gym import Env 4 | from gym.spaces import Box 5 | from gym.spaces import Discrete 6 | 7 | from collections import deque 8 | 9 | 10 | class ProxyEnv(Env): 11 | def __init__(self, wrapped_env): 12 | self._wrapped_env = wrapped_env 13 | self.action_space = self._wrapped_env.action_space 14 | self.observation_space = self._wrapped_env.observation_space 15 | 16 | @property 17 | def wrapped_env(self): 18 | return self._wrapped_env 19 | 20 | def reset(self, **kwargs): 21 | return self._wrapped_env.reset(**kwargs) 22 | 23 | def step(self, action): 24 | return self._wrapped_env.step(action) 25 | 26 | def render(self, *args, **kwargs): 27 | return self._wrapped_env.render(*args, **kwargs) 28 | 29 | @property 30 | def horizon(self): 31 | return self._wrapped_env.horizon 32 | 33 | def terminate(self): 34 | if hasattr(self.wrapped_env, "terminate"): 35 | self.wrapped_env.terminate() 36 | 37 | def __getattr__(self, attr): 38 | if attr == '_wrapped_env': 39 | raise AttributeError() 40 | return getattr(self._wrapped_env, attr) 41 | 42 | def __getstate__(self): 43 | """ 44 | This is useful to override in case the wrapped env has some funky 45 | __getstate__ that doesn't play well with overriding __getattr__. 46 | 47 | The main problematic case is/was gym's EzPickle serialization scheme. 48 | :return: 49 | """ 50 | return self.__dict__ 51 | 52 | def __setstate__(self, state): 53 | self.__dict__.update(state) 54 | 55 | def __str__(self): 56 | return '{}({})'.format(type(self).__name__, self.wrapped_env) 57 | 58 | 59 | class HistoryEnv(ProxyEnv, Env): 60 | def __init__(self, wrapped_env, history_len): 61 | super().__init__(wrapped_env) 62 | self.history_len = history_len 63 | 64 | high = np.inf * np.ones( 65 | self.history_len * self.observation_space.low.size) 66 | low = -high 67 | self.observation_space = Box(low=low, 68 | high=high, 69 | ) 70 | self.history = deque(maxlen=self.history_len) 71 | 72 | def step(self, action): 73 | state, reward, done, info = super().step(action) 74 | self.history.append(state) 75 | flattened_history = self._get_history().flatten() 76 | return flattened_history, reward, done, info 77 | 78 | def reset(self, **kwargs): 79 | state = super().reset() 80 | self.history = deque(maxlen=self.history_len) 81 | self.history.append(state) 82 | flattened_history = self._get_history().flatten() 83 | return flattened_history 84 | 85 | def _get_history(self): 86 | observations = list(self.history) 87 | 88 | obs_count = len(observations) 89 | for _ in range(self.history_len - obs_count): 90 | dummy = np.zeros(self._wrapped_env.observation_space.low.size) 91 | observations.append(dummy) 92 | return np.c_[observations] 93 | 94 | 95 | class DiscretizeEnv(ProxyEnv, Env): 96 | def __init__(self, wrapped_env, num_bins): 97 | super().__init__(wrapped_env) 98 | low = self.wrapped_env.action_space.low 99 | high = self.wrapped_env.action_space.high 100 | action_ranges = [ 101 | np.linspace(low[i], high[i], num_bins) 102 | for i in range(len(low)) 103 | ] 104 | self.idx_to_continuous_action = [ 105 | np.array(x) for x in itertools.product(*action_ranges) 106 | ] 107 | self.action_space = Discrete(len(self.idx_to_continuous_action)) 108 | 109 | def step(self, action): 110 | continuous_action = self.idx_to_continuous_action[action] 111 | return super().step(continuous_action) 112 | 113 | 114 | class NormalizedBoxEnv(ProxyEnv): 115 | """ 116 | Normalize action to in [-1, 1]. 117 | 118 | Optionally normalize observations and scale reward. 119 | """ 120 | 121 | def __init__( 122 | self, 123 | env, 124 | reward_scale=1., 125 | obs_mean=None, 126 | obs_std=None, 127 | ): 128 | ProxyEnv.__init__(self, env) 129 | self._should_normalize = not (obs_mean is None and obs_std is None) 130 | if self._should_normalize: 131 | if obs_mean is None: 132 | obs_mean = np.zeros_like(env.observation_space.low) 133 | else: 134 | obs_mean = np.array(obs_mean) 135 | if obs_std is None: 136 | obs_std = np.ones_like(env.observation_space.low) 137 | else: 138 | obs_std = np.array(obs_std) 139 | self._reward_scale = reward_scale 140 | self._obs_mean = obs_mean 141 | self._obs_std = obs_std 142 | ub = np.ones(self._wrapped_env.action_space.shape) 143 | self.action_space = Box(-1 * ub, ub) 144 | 145 | def estimate_obs_stats(self, obs_batch, override_values=False): 146 | if self._obs_mean is not None and not override_values: 147 | raise Exception("Observation mean and std already set. To " 148 | "override, set override_values to True.") 149 | self._obs_mean = np.mean(obs_batch, axis=0) 150 | self._obs_std = np.std(obs_batch, axis=0) 151 | 152 | def _apply_normalize_obs(self, obs): 153 | return (obs - self._obs_mean) / (self._obs_std + 1e-8) 154 | 155 | def step(self, action): 156 | lb = self._wrapped_env.action_space.low 157 | ub = self._wrapped_env.action_space.high 158 | scaled_action = lb + (action + 1.) * 0.5 * (ub - lb) 159 | scaled_action = np.clip(scaled_action, lb, ub) 160 | 161 | wrapped_step = self._wrapped_env.step(scaled_action) 162 | next_obs, reward, done, info = wrapped_step 163 | if self._should_normalize: 164 | next_obs = self._apply_normalize_obs(next_obs) 165 | return next_obs, reward * self._reward_scale, done, info 166 | 167 | def __str__(self): 168 | return "Normalized: %s" % self._wrapped_env 169 | 170 | -------------------------------------------------------------------------------- /maple/envs/wrappers/__init__.py: -------------------------------------------------------------------------------- 1 | from maple.envs.wrappers.discretize_env import DiscretizeEnv 2 | from maple.envs.wrappers.history_env import HistoryEnv 3 | from maple.envs.wrappers.image_mujoco_env import ImageMujocoEnv 4 | from maple.envs.wrappers.image_mujoco_env_with_obs import ImageMujocoWithObsEnv 5 | from maple.envs.wrappers.normalized_box_env import NormalizedBoxEnv 6 | from maple.envs.proxy_env import ProxyEnv 7 | from maple.envs.wrappers.reward_wrapper_env import RewardWrapperEnv 8 | from maple.envs.wrappers.stack_observation_env import StackObservationEnv 9 | 10 | 11 | __all__ = [ 12 | 'DiscretizeEnv', 13 | 'HistoryEnv', 14 | 'ImageMujocoEnv', 15 | 'ImageMujocoWithObsEnv', 16 | 'NormalizedBoxEnv', 17 | 'ProxyEnv', 18 | 'RewardWrapperEnv', 19 | 'StackObservationEnv', 20 | ] -------------------------------------------------------------------------------- /maple/envs/wrappers/discretize_env.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | 3 | import numpy as np 4 | from gym import Env 5 | from gym.spaces import Discrete 6 | 7 | from maple.envs.proxy_env import ProxyEnv 8 | 9 | 10 | class DiscretizeEnv(ProxyEnv, Env): 11 | def __init__(self, wrapped_env, num_bins): 12 | super().__init__(wrapped_env) 13 | low = self.wrapped_env.action_space.low 14 | high = self.wrapped_env.action_space.high 15 | action_ranges = [ 16 | np.linspace(low[i], high[i], num_bins) 17 | for i in range(len(low)) 18 | ] 19 | self.idx_to_continuous_action = [ 20 | np.array(x) for x in itertools.product(*action_ranges) 21 | ] 22 | self.action_space = Discrete(len(self.idx_to_continuous_action)) 23 | 24 | def step(self, action): 25 | continuous_action = self.idx_to_continuous_action[action] 26 | return super().step(continuous_action) 27 | 28 | 29 | -------------------------------------------------------------------------------- /maple/envs/wrappers/history_env.py: -------------------------------------------------------------------------------- 1 | from collections import deque 2 | 3 | import numpy as np 4 | from gym import Env 5 | from gym.spaces import Box 6 | 7 | from maple.envs.proxy_env import ProxyEnv 8 | 9 | 10 | class HistoryEnv(ProxyEnv, Env): 11 | def __init__(self, wrapped_env, history_len): 12 | super().__init__(wrapped_env) 13 | self.history_len = history_len 14 | 15 | high = np.inf * np.ones( 16 | self.history_len * self.observation_space.low.size) 17 | low = -high 18 | self.observation_space = Box(low=low, 19 | high=high, 20 | ) 21 | self.history = deque(maxlen=self.history_len) 22 | 23 | def step(self, action): 24 | state, reward, done, info = super().step(action) 25 | self.history.append(state) 26 | flattened_history = self._get_history().flatten() 27 | return flattened_history, reward, done, info 28 | 29 | def reset(self, **kwargs): 30 | state = super().reset() 31 | self.history = deque(maxlen=self.history_len) 32 | self.history.append(state) 33 | flattened_history = self._get_history().flatten() 34 | return flattened_history 35 | 36 | def _get_history(self): 37 | observations = list(self.history) 38 | 39 | obs_count = len(observations) 40 | for _ in range(self.history_len - obs_count): 41 | dummy = np.zeros(self._wrapped_env.observation_space.low.size) 42 | observations.append(dummy) 43 | return np.c_[observations] 44 | 45 | 46 | -------------------------------------------------------------------------------- /maple/envs/wrappers/image_mujoco_env.py: -------------------------------------------------------------------------------- 1 | from collections import deque 2 | 3 | import mujoco_py 4 | import numpy as np 5 | import torch 6 | from PIL import Image 7 | from gym import Env 8 | from gym.spaces import Box 9 | 10 | from maple.envs.proxy_env import ProxyEnv 11 | 12 | 13 | class ImageMujocoEnv(ProxyEnv, Env): 14 | def __init__(self, 15 | wrapped_env, 16 | imsize=32, 17 | keep_prev=0, 18 | init_camera=None, 19 | camera_name=None, 20 | transpose=False, 21 | grayscale=False, 22 | normalize=False, 23 | ): 24 | super().__init__(wrapped_env) 25 | 26 | self.imsize = imsize 27 | if grayscale: 28 | self.image_length = self.imsize * self.imsize 29 | else: 30 | self.image_length = 3 * self.imsize * self.imsize 31 | # This is torch format rather than PIL image 32 | self.image_shape = (self.imsize, self.imsize) 33 | # Flattened past image queue 34 | self.history_length = keep_prev + 1 35 | self.history = deque(maxlen=self.history_length) 36 | # init camera 37 | if init_camera is not None: 38 | sim = self._wrapped_env.sim 39 | viewer = mujoco_py.MjRenderContextOffscreen(sim, device_id=-1) 40 | init_camera(viewer.cam) 41 | sim.add_render_context(viewer) 42 | self.camera_name = camera_name # None means default camera 43 | self.transpose = transpose 44 | self.grayscale = grayscale 45 | self.normalize = normalize 46 | self._render_local = False 47 | 48 | self.observation_space = Box(low=0.0, 49 | high=1.0, 50 | shape=( 51 | self.image_length * self.history_length,)) 52 | 53 | def step(self, action): 54 | # image observation get returned as a flattened 1D array 55 | true_state, reward, done, info = super().step(action) 56 | 57 | observation = self._image_observation() 58 | self.history.append(observation) 59 | history = self._get_history().flatten() 60 | full_obs = self._get_obs(history, true_state) 61 | return full_obs, reward, done, info 62 | 63 | def reset(self, **kwargs): 64 | true_state = super().reset(**kwargs) 65 | self.history = deque(maxlen=self.history_length) 66 | 67 | observation = self._image_observation() 68 | self.history.append(observation) 69 | history = self._get_history().flatten() 70 | full_obs = self._get_obs(history, true_state) 71 | return full_obs 72 | 73 | def get_image(self): 74 | """TODO: this should probably consider history""" 75 | return self._image_observation() 76 | 77 | def _get_obs(self, history_flat, true_state): 78 | # adds extra information from true_state into to the image observation. 79 | # Used in ImageWithObsEnv. 80 | return history_flat 81 | 82 | def _image_observation(self): 83 | # returns the image as a torch format np array 84 | image_obs = self._wrapped_env.sim.render(width=self.imsize, 85 | height=self.imsize, 86 | camera_name=self.camera_name) 87 | if self._render_local: 88 | cv2.imshow('env', image_obs) 89 | cv2.waitKey(1) 90 | if self.grayscale: 91 | image_obs = Image.fromarray(image_obs).convert('L') 92 | image_obs = np.array(image_obs) 93 | if self.normalize: 94 | image_obs = image_obs / 255.0 95 | if self.transpose: 96 | image_obs = image_obs.transpose() 97 | return image_obs 98 | 99 | def _get_history(self): 100 | observations = list(self.history) 101 | 102 | obs_count = len(observations) 103 | for _ in range(self.history_length - obs_count): 104 | dummy = np.zeros(self.image_shape) 105 | observations.append(dummy) 106 | return np.c_[observations] 107 | 108 | def retrieve_images(self): 109 | # returns images in unflattened PIL format 110 | images = [] 111 | for image_obs in self.history: 112 | pil_image = self.torch_to_pil(torch.Tensor(image_obs)) 113 | images.append(pil_image) 114 | return images 115 | 116 | def split_obs(self, obs): 117 | # splits observation into image input and true observation input 118 | imlength = self.image_length * self.history_length 119 | obs_length = self.observation_space.low.size 120 | obs = obs.view(-1, obs_length) 121 | image_obs = obs.narrow(start=0, 122 | length=imlength, 123 | dimension=1) 124 | if obs_length == imlength: 125 | return image_obs, None 126 | 127 | fc_obs = obs.narrow(start=imlength, 128 | length=obs.shape[1] - imlength, 129 | dimension=1) 130 | return image_obs, fc_obs 131 | 132 | def enable_render(self): 133 | self._render_local = True 134 | 135 | 136 | -------------------------------------------------------------------------------- /maple/envs/wrappers/image_mujoco_env_with_obs.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from gym.spaces import Box 3 | 4 | from maple.envs.wrappers.image_mujoco_env import ImageMujocoEnv 5 | 6 | 7 | class ImageMujocoWithObsEnv(ImageMujocoEnv): 8 | def __init__(self, env, **kwargs): 9 | super().__init__(env, **kwargs) 10 | self.observation_space = Box( 11 | low=0.0, 12 | high=1.0, 13 | shape=(self.image_length * self.history_length 14 | + self.wrapped_env.obs_dim,)) 15 | 16 | def _get_obs(self, history_flat, true_state): 17 | return np.concatenate([history_flat, true_state]) 18 | -------------------------------------------------------------------------------- /maple/envs/wrappers/normalized_box_env.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from gym.spaces import Box 3 | 4 | from maple.envs.proxy_env import ProxyEnv 5 | 6 | 7 | class NormalizedBoxEnv(ProxyEnv): 8 | """ 9 | Normalize action to in [-1, 1]. 10 | 11 | Optionally normalize observations and scale reward. 12 | """ 13 | 14 | def __init__( 15 | self, 16 | env, 17 | reward_scale=1., 18 | obs_mean=None, 19 | obs_std=None, 20 | ): 21 | ProxyEnv.__init__(self, env) 22 | self._should_normalize = not (obs_mean is None and obs_std is None) 23 | if self._should_normalize: 24 | if obs_mean is None: 25 | obs_mean = np.zeros_like(env.observation_space.low) 26 | else: 27 | obs_mean = np.array(obs_mean) 28 | if obs_std is None: 29 | obs_std = np.ones_like(env.observation_space.low) 30 | else: 31 | obs_std = np.array(obs_std) 32 | self._reward_scale = reward_scale 33 | self._obs_mean = obs_mean 34 | self._obs_std = obs_std 35 | ub = np.ones(self._wrapped_env.action_space.shape) 36 | self.action_space = Box(-1 * ub, ub) 37 | 38 | def estimate_obs_stats(self, obs_batch, override_values=False): 39 | if self._obs_mean is not None and not override_values: 40 | raise Exception("Observation mean and std already set. To " 41 | "override, set override_values to True.") 42 | self._obs_mean = np.mean(obs_batch, axis=0) 43 | self._obs_std = np.std(obs_batch, axis=0) 44 | 45 | def _apply_normalize_obs(self, obs): 46 | return (obs - self._obs_mean) / (self._obs_std + 1e-8) 47 | 48 | def step(self, action): 49 | lb = self._wrapped_env.action_space.low 50 | ub = self._wrapped_env.action_space.high 51 | scaled_action = lb + (action + 1.) * 0.5 * (ub - lb) 52 | scaled_action = np.clip(scaled_action, lb, ub) 53 | 54 | wrapped_step = self._wrapped_env.step(scaled_action) 55 | next_obs, reward, done, info = wrapped_step 56 | if self._should_normalize: 57 | next_obs = self._apply_normalize_obs(next_obs) 58 | return next_obs, reward * self._reward_scale, done, info 59 | 60 | def __str__(self): 61 | return "Normalized: %s" % self._wrapped_env 62 | 63 | 64 | -------------------------------------------------------------------------------- /maple/envs/wrappers/reward_wrapper_env.py: -------------------------------------------------------------------------------- 1 | from maple.envs.proxy_env import ProxyEnv 2 | 3 | 4 | class RewardWrapperEnv(ProxyEnv): 5 | """Substitute a different reward function""" 6 | 7 | def __init__( 8 | self, 9 | env, 10 | compute_reward_fn, 11 | ): 12 | ProxyEnv.__init__(self, env) 13 | self.spec = env.spec # hack for hand envs 14 | self.compute_reward_fn = compute_reward_fn 15 | 16 | def step(self, action): 17 | next_obs, reward, done, info = self._wrapped_env.step(action) 18 | info["env_reward"] = reward 19 | reward = self.compute_reward_fn(next_obs, reward, done, info) 20 | return next_obs, reward, done, info 21 | -------------------------------------------------------------------------------- /maple/envs/wrappers/stack_observation_env.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from gym.spaces import Box 3 | 4 | from maple.envs.proxy_env import ProxyEnv 5 | 6 | 7 | class StackObservationEnv(ProxyEnv): 8 | """ 9 | Env wrapper for passing history of observations as the new observation 10 | """ 11 | 12 | def __init__( 13 | self, 14 | env, 15 | stack_obs=1, 16 | ): 17 | ProxyEnv.__init__(self, env) 18 | self.stack_obs = stack_obs 19 | low = env.observation_space.low 20 | high = env.observation_space.high 21 | self.obs_dim = low.size 22 | self._last_obs = np.zeros((self.stack_obs, self.obs_dim)) 23 | self.observation_space = Box( 24 | low=np.repeat(low, stack_obs), 25 | high=np.repeat(high, stack_obs), 26 | ) 27 | 28 | def reset(self): 29 | self._last_obs = np.zeros((self.stack_obs, self.obs_dim)) 30 | next_obs = self._wrapped_env.reset() 31 | self._last_obs[-1, :] = next_obs 32 | return self._last_obs.copy().flatten() 33 | 34 | def step(self, action): 35 | next_obs, reward, done, info = self._wrapped_env.step(action) 36 | self._last_obs = np.vstack(( 37 | self._last_obs[1:, :], 38 | next_obs 39 | )) 40 | return self._last_obs.copy().flatten(), reward, done, info 41 | 42 | 43 | -------------------------------------------------------------------------------- /maple/exploration_strategies/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UT-Austin-RPL/maple/c13f584d4184e5796bd0adb822b8761ff4ebeee0/maple/exploration_strategies/__init__.py -------------------------------------------------------------------------------- /maple/exploration_strategies/base.py: -------------------------------------------------------------------------------- 1 | import abc 2 | 3 | from maple.policies.base import ExplorationPolicy 4 | 5 | 6 | class ExplorationStrategy(object, metaclass=abc.ABCMeta): 7 | @abc.abstractmethod 8 | def get_action(self, t, observation, policy, **kwargs): 9 | pass 10 | 11 | def reset(self): 12 | pass 13 | 14 | 15 | class RawExplorationStrategy(ExplorationStrategy, metaclass=abc.ABCMeta): 16 | @abc.abstractmethod 17 | def get_action_from_raw_action(self, action, **kwargs): 18 | pass 19 | 20 | def get_action(self, t, policy, *args, **kwargs): 21 | action, agent_info = policy.get_action(*args, **kwargs) 22 | return self.get_action_from_raw_action(action, t=t), agent_info 23 | 24 | def reset(self): 25 | pass 26 | 27 | 28 | class PolicyWrappedWithExplorationStrategy(ExplorationPolicy): 29 | def __init__( 30 | self, 31 | exploration_strategy: ExplorationStrategy, 32 | policy, 33 | ): 34 | self.es = exploration_strategy 35 | self.policy = policy 36 | self.t = 0 37 | 38 | def set_num_steps_total(self, t): 39 | self.t = t 40 | 41 | def get_action(self, *args, **kwargs): 42 | return self.es.get_action(self.t, self.policy, *args, **kwargs) 43 | 44 | def reset(self): 45 | self.es.reset() 46 | self.policy.reset() 47 | -------------------------------------------------------------------------------- /maple/exploration_strategies/epsilon_greedy.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | from maple.exploration_strategies.base import RawExplorationStrategy 4 | 5 | 6 | class EpsilonGreedy(RawExplorationStrategy): 7 | """ 8 | Take a random discrete action with some probability. 9 | """ 10 | def __init__(self, action_space, prob_random_action=0.1): 11 | self.prob_random_action = prob_random_action 12 | self.action_space = action_space 13 | 14 | def get_action_from_raw_action(self, action, **kwargs): 15 | if random.random() <= self.prob_random_action: 16 | return self.action_space.sample() 17 | return action 18 | -------------------------------------------------------------------------------- /maple/exploration_strategies/gaussian_and_epsilon_strategy.py: -------------------------------------------------------------------------------- 1 | import random 2 | from maple.exploration_strategies.base import RawExplorationStrategy 3 | import numpy as np 4 | 5 | 6 | class GaussianAndEpsilonStrategy(RawExplorationStrategy): 7 | """ 8 | With probability epsilon, take a completely random action. 9 | with probability 1-epsilon, add Gaussian noise to the action taken by a 10 | deterministic policy. 11 | """ 12 | def __init__(self, action_space, epsilon, max_sigma=1.0, min_sigma=None, 13 | decay_period=1000000): 14 | assert len(action_space.shape) == 1 15 | if min_sigma is None: 16 | min_sigma = max_sigma 17 | self._max_sigma = max_sigma 18 | self._epsilon = epsilon 19 | self._min_sigma = min_sigma 20 | self._decay_period = decay_period 21 | self._action_space = action_space 22 | 23 | def get_action_from_raw_action(self, action, t=None, **kwargs): 24 | if random.random() < self._epsilon: 25 | return self._action_space.sample() 26 | else: 27 | sigma = self._max_sigma - (self._max_sigma - self._min_sigma) * min(1.0, t * 1.0 / self._decay_period) 28 | return np.clip( 29 | action + np.random.normal(size=len(action)) * sigma, 30 | self._action_space.low, 31 | self._action_space.high, 32 | ) 33 | -------------------------------------------------------------------------------- /maple/exploration_strategies/gaussian_strategy.py: -------------------------------------------------------------------------------- 1 | from maple.exploration_strategies.base import RawExplorationStrategy 2 | import numpy as np 3 | 4 | 5 | class GaussianStrategy(RawExplorationStrategy): 6 | """ 7 | This strategy adds Gaussian noise to the action taken by the deterministic policy. 8 | 9 | Based on the rllab implementation. 10 | """ 11 | def __init__(self, action_space, max_sigma=1.0, min_sigma=None, 12 | decay_period=1000000): 13 | assert len(action_space.shape) == 1 14 | self._max_sigma = max_sigma 15 | if min_sigma is None: 16 | min_sigma = max_sigma 17 | self._min_sigma = min_sigma 18 | self._decay_period = decay_period 19 | self._action_space = action_space 20 | 21 | def get_action_from_raw_action(self, action, t=None, **kwargs): 22 | sigma = ( 23 | self._max_sigma - (self._max_sigma - self._min_sigma) * 24 | min(1.0, t * 1.0 / self._decay_period) 25 | ) 26 | return np.clip( 27 | action + np.random.normal(size=len(action)) * sigma, 28 | self._action_space.low, 29 | self._action_space.high, 30 | ) 31 | -------------------------------------------------------------------------------- /maple/exploration_strategies/ou_strategy.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import numpy.random as nr 3 | 4 | from maple.exploration_strategies.base import RawExplorationStrategy 5 | 6 | 7 | class OUStrategy(RawExplorationStrategy): 8 | """ 9 | This strategy implements the Ornstein-Uhlenbeck process, which adds 10 | time-correlated noise to the actions taken by the deterministic policy. 11 | The OU process satisfies the following stochastic differential equation: 12 | dxt = theta*(mu - xt)*dt + sigma*dWt 13 | where Wt denotes the Wiener process 14 | 15 | Based on the rllab implementation. 16 | """ 17 | 18 | def __init__( 19 | self, 20 | action_space, 21 | mu=0, 22 | theta=0.15, 23 | max_sigma=0.3, 24 | min_sigma=None, 25 | decay_period=100000, 26 | ): 27 | if min_sigma is None: 28 | min_sigma = max_sigma 29 | self.mu = mu 30 | self.theta = theta 31 | self.sigma = max_sigma 32 | self._max_sigma = max_sigma 33 | if min_sigma is None: 34 | min_sigma = max_sigma 35 | self._min_sigma = min_sigma 36 | self._decay_period = decay_period 37 | self.dim = np.prod(action_space.low.shape) 38 | self.low = action_space.low 39 | self.high = action_space.high 40 | self.state = np.ones(self.dim) * self.mu 41 | self.reset() 42 | 43 | def reset(self): 44 | self.state = np.ones(self.dim) * self.mu 45 | 46 | def evolve_state(self): 47 | x = self.state 48 | dx = self.theta * (self.mu - x) + self.sigma * nr.randn(len(x)) 49 | self.state = x + dx 50 | return self.state 51 | 52 | def get_action_from_raw_action(self, action, t=0, **kwargs): 53 | ou_state = self.evolve_state() 54 | self.sigma = ( 55 | self._max_sigma 56 | - (self._max_sigma - self._min_sigma) 57 | * min(1.0, t * 1.0 / self._decay_period) 58 | ) 59 | return np.clip(action + ou_state, self.low, self.high) 60 | -------------------------------------------------------------------------------- /maple/launchers/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | This module contains 'launchers', which are self-contained functions that take 3 | one dictionary and run a full experiment. The dictionary configures the 4 | experiment. 5 | 6 | It is important that the functions are completely self-contained (i.e. they 7 | import their own modules) so that they can be serialized. 8 | """ 9 | -------------------------------------------------------------------------------- /maple/launchers/conf.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copy this file to config.py and modify as needed. 3 | """ 4 | import os 5 | from os.path import join 6 | import maple 7 | import robosuite 8 | 9 | """ 10 | `doodad.mount.MountLocal` by default ignores directories called "data" 11 | If you're going to rename this directory and use EC2, then change 12 | `doodad.mount.MountLocal.filter_dir` 13 | """ 14 | # The directory of the project, not source 15 | maple_project_dir = join(os.path.dirname(maple.__file__), os.pardir) 16 | robosuite_project_dir = join(os.path.dirname(robosuite.__file__), os.pardir) 17 | LOCAL_LOG_DIR = join(maple_project_dir, 'data') 18 | 19 | """ 20 | ******************************************************************************** 21 | ******************************************************************************** 22 | ******************************************************************************** 23 | 24 | You probably don't need to set all of the configurations below this line, 25 | unless you use AWS, GCP, Slurm, and/or Slurm on a remote server. I recommend 26 | ignoring most of these things and only using them on an as-needed basis. 27 | 28 | ******************************************************************************** 29 | ******************************************************************************** 30 | ******************************************************************************** 31 | """ 32 | 33 | """ 34 | General doodad settings. 35 | """ 36 | CODE_DIRS_TO_MOUNT = [ 37 | maple_project_dir, 38 | robosuite_project_dir, 39 | # '/home/user/python/module/one', Add more paths as needed 40 | ] 41 | 42 | HOME = os.getenv('HOME') if os.getenv('HOME') is not None else os.getenv("USERPROFILE") 43 | 44 | DIR_AND_MOUNT_POINT_MAPPINGS = [ 45 | dict( 46 | local_dir=join(HOME, '.mujoco/'), 47 | mount_point='/root/.mujoco', 48 | ), 49 | ] 50 | RUN_DOODAD_EXPERIMENT_SCRIPT_PATH = ( 51 | join(maple_project_dir, 'scripts', 'run_experiment_from_doodad.py') 52 | # '/home/user/path/to/maple/scripts/run_experiment_from_doodad.py' 53 | ) 54 | """ 55 | AWS Settings 56 | """ 57 | # If not set, default will be chosen by doodad 58 | # AWS_S3_PATH = 's3://bucket/directory 59 | 60 | # The docker image is looked up on dockerhub.com. 61 | DOODAD_DOCKER_IMAGE = "TODO" 62 | INSTANCE_TYPE = 'c4.large' 63 | SPOT_PRICE = 0.03 64 | 65 | GPU_DOODAD_DOCKER_IMAGE = "TODO" 66 | GPU_INSTANCE_TYPE = 'g2.2xlarge' 67 | GPU_SPOT_PRICE = 0.5 68 | 69 | # You can use AMI images with the docker images already installed. 70 | REGION_TO_GPU_AWS_IMAGE_ID = { 71 | 'us-west-1': "TODO", 72 | 'us-east-1': "TODO", 73 | } 74 | 75 | REGION_TO_GPU_AWS_AVAIL_ZONE = { 76 | 'us-east-1': "us-east-1b", 77 | } 78 | 79 | # This really shouldn't matter and in theory could be whatever 80 | OUTPUT_DIR_FOR_DOODAD_TARGET = '/tmp/doodad-output/' 81 | 82 | 83 | """ 84 | Slurm Settings 85 | """ 86 | SINGULARITY_IMAGE = '/home/PATH/TO/IMAGE.img' 87 | # This assumes you saved mujoco to $HOME/.mujoco 88 | SINGULARITY_PRE_CMDS = [ 89 | 'export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$HOME/.mujoco/mjpro150/bin' 90 | ] 91 | SLURM_CPU_CONFIG = dict( 92 | account_name='TODO', 93 | partition='savio', 94 | nodes=1, 95 | n_tasks=1, 96 | n_gpus=1, 97 | ) 98 | SLURM_GPU_CONFIG = dict( 99 | account_name='TODO', 100 | partition='savio2_1080ti', 101 | nodes=1, 102 | n_tasks=1, 103 | n_gpus=1, 104 | ) 105 | 106 | 107 | """ 108 | Slurm Script Settings 109 | 110 | These are basically the same settings as above, but for the remote machine 111 | where you will be running the generated script. 112 | """ 113 | SSS_CODE_DIRS_TO_MOUNT = [ 114 | ] 115 | SSS_DIR_AND_MOUNT_POINT_MAPPINGS = [ 116 | dict( 117 | local_dir='/global/home/users/USERNAME/.mujoco', 118 | mount_point='/root/.mujoco', 119 | ), 120 | ] 121 | SSS_LOG_DIR = '/global/scratch/USERNAME/doodad-log' 122 | 123 | SSS_IMAGE = '/global/scratch/USERNAME/TODO.img' 124 | SSS_RUN_DOODAD_EXPERIMENT_SCRIPT_PATH = ( 125 | '/global/home/users/USERNAME/path/to/maple/scripts' 126 | '/run_experiment_from_doodad.py' 127 | ) 128 | SSS_PRE_CMDS = [ 129 | 'export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/global/home/users/USERNAME' 130 | '/.mujoco/mjpro150/bin' 131 | ] 132 | 133 | """ 134 | GCP Settings 135 | """ 136 | GCP_IMAGE_NAME = 'TODO' 137 | GCP_GPU_IMAGE_NAME = 'TODO' 138 | GCP_BUCKET_NAME = 'TODO' 139 | 140 | GCP_DEFAULT_KWARGS = dict( 141 | zone='us-west2-c', 142 | instance_type='n1-standard-4', 143 | image_project='TODO', 144 | terminate=True, 145 | preemptible=True, 146 | gpu_kwargs=dict( 147 | gpu_model='nvidia-tesla-p4', 148 | num_gpu=1, 149 | ) 150 | ) 151 | 152 | try: 153 | from maple.launchers.conf_private import * 154 | except ImportError: 155 | print("No personal conf_private.py found.") 156 | -------------------------------------------------------------------------------- /maple/policies/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UT-Austin-RPL/maple/c13f584d4184e5796bd0adb822b8761ff4ebeee0/maple/policies/__init__.py -------------------------------------------------------------------------------- /maple/policies/argmax.py: -------------------------------------------------------------------------------- 1 | """ 2 | Torch argmax policy 3 | """ 4 | import numpy as np 5 | from torch import nn 6 | 7 | import maple.torch.pytorch_util as ptu 8 | from maple.policies.base import Policy 9 | 10 | 11 | class ArgmaxDiscretePolicy(nn.Module, Policy): 12 | def __init__(self, qf): 13 | super().__init__() 14 | self.qf = qf 15 | 16 | def get_action(self, obs): 17 | obs = np.expand_dims(obs, axis=0) 18 | obs = ptu.from_numpy(obs).float() 19 | q_values = self.qf(obs).squeeze(0) 20 | q_values_np = ptu.get_numpy(q_values) 21 | return q_values_np.argmax(), {} 22 | -------------------------------------------------------------------------------- /maple/policies/base.py: -------------------------------------------------------------------------------- 1 | import abc 2 | 3 | 4 | class Policy(object, metaclass=abc.ABCMeta): 5 | """ 6 | General policy interface. 7 | """ 8 | @abc.abstractmethod 9 | def get_action(self, observation): 10 | """ 11 | 12 | :param observation: 13 | :return: action, debug_dictionary 14 | """ 15 | pass 16 | 17 | def reset(self): 18 | pass 19 | 20 | 21 | class ExplorationPolicy(Policy, metaclass=abc.ABCMeta): 22 | def set_num_steps_total(self, t): 23 | pass 24 | -------------------------------------------------------------------------------- /maple/policies/simple.py: -------------------------------------------------------------------------------- 1 | from maple.policies.base import Policy 2 | 3 | 4 | class RandomPolicy(Policy): 5 | """ 6 | Policy that always outputs zero. 7 | """ 8 | 9 | def __init__(self, action_space): 10 | self.action_space = action_space 11 | 12 | def get_action(self, obs): 13 | return self.action_space.sample(), {} 14 | -------------------------------------------------------------------------------- /maple/samplers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UT-Austin-RPL/maple/c13f584d4184e5796bd0adb822b8761ff4ebeee0/maple/samplers/__init__.py -------------------------------------------------------------------------------- /maple/samplers/data_collector/__init__.py: -------------------------------------------------------------------------------- 1 | from maple.samplers.data_collector.base import ( 2 | DataCollector, 3 | PathCollector, 4 | StepCollector, 5 | ) 6 | from maple.samplers.data_collector.path_collector import ( 7 | MdpPathCollector, 8 | ObsDictPathCollector, 9 | GoalConditionedPathCollector, 10 | VAEWrappedEnvPathCollector, 11 | ) 12 | from maple.samplers.data_collector.step_collector import ( 13 | GoalConditionedStepCollector 14 | ) 15 | -------------------------------------------------------------------------------- /maple/samplers/data_collector/base.py: -------------------------------------------------------------------------------- 1 | import abc 2 | 3 | 4 | class DataCollector(object, metaclass=abc.ABCMeta): 5 | def end_epoch(self, epoch): 6 | pass 7 | 8 | def get_diagnostics(self): 9 | return {} 10 | 11 | def get_snapshot(self): 12 | return {} 13 | 14 | @abc.abstractmethod 15 | def get_epoch_paths(self): 16 | pass 17 | 18 | 19 | class PathCollector(DataCollector, metaclass=abc.ABCMeta): 20 | @abc.abstractmethod 21 | def collect_new_paths( 22 | self, 23 | max_path_length, 24 | num_steps, 25 | discard_incomplete_paths, 26 | ): 27 | pass 28 | 29 | 30 | class StepCollector(DataCollector, metaclass=abc.ABCMeta): 31 | @abc.abstractmethod 32 | def collect_new_steps( 33 | self, 34 | max_path_length, 35 | num_steps, 36 | discard_incomplete_paths, 37 | ): 38 | pass 39 | -------------------------------------------------------------------------------- /maple/samplers/data_collector/contextual_path_collector.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | 3 | from maple.envs.contextual import ContextualEnv 4 | from maple.policies.base import Policy 5 | from maple.samplers.data_collector import MdpPathCollector 6 | from maple.samplers.rollout_functions import contextual_rollout 7 | 8 | 9 | class ContextualPathCollector(MdpPathCollector): 10 | def __init__( 11 | self, 12 | env: ContextualEnv, 13 | policy: Policy, 14 | max_num_epoch_paths_saved=None, 15 | observation_key='observation', 16 | context_keys_for_policy='context', 17 | render=False, 18 | render_kwargs=None, 19 | **kwargs 20 | ): 21 | rollout_fn = partial( 22 | contextual_rollout, 23 | context_keys_for_policy=context_keys_for_policy, 24 | observation_key=observation_key, 25 | ) 26 | super().__init__( 27 | env, policy, max_num_epoch_paths_saved, render, render_kwargs, 28 | rollout_fn=rollout_fn, 29 | **kwargs 30 | ) 31 | self._observation_key = observation_key 32 | self._context_keys_for_policy = context_keys_for_policy 33 | 34 | def get_snapshot(self): 35 | snapshot = super().get_snapshot() 36 | snapshot.update( 37 | observation_key=self._observation_key, 38 | context_keys_for_policy=self._context_keys_for_policy, 39 | ) 40 | return snapshot 41 | -------------------------------------------------------------------------------- /maple/samplers/data_collector/joint_path_collector.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | from typing import Dict 3 | 4 | from maple.core.logging import add_prefix 5 | from maple.samplers.data_collector import PathCollector 6 | 7 | 8 | class JointPathCollector(PathCollector): 9 | def __init__(self, path_collectors: Dict[str, PathCollector]): 10 | self.path_collectors = path_collectors 11 | 12 | def collect_new_paths(self, max_path_length, num_steps, 13 | discard_incomplete_paths): 14 | paths = [] 15 | for collector in self.path_collectors.values(): 16 | collector.collect_new_paths( 17 | max_path_length, num_steps, discard_incomplete_paths 18 | ) 19 | return paths 20 | 21 | def end_epoch(self, epoch): 22 | for collector in self.path_collectors.values(): 23 | collector.end_epoch(epoch) 24 | 25 | def get_diagnostics(self): 26 | diagnostics = OrderedDict() 27 | for name, collector in self.path_collectors.items(): 28 | diagnostics.update( 29 | add_prefix(collector.get_diagnostics(), name, divider='/'), 30 | ) 31 | return diagnostics 32 | 33 | def get_snapshot(self): 34 | snapshot = {} 35 | for name, collector in self.path_collectors.items(): 36 | snapshot.update( 37 | add_prefix(collector.get_snapshot(), name, divider='/'), 38 | ) 39 | return snapshot 40 | 41 | def get_epoch_paths(self): 42 | paths = {} 43 | for name, collector in self.path_collectors.items(): 44 | paths[name] = collector.get_epoch_paths() 45 | return paths 46 | 47 | -------------------------------------------------------------------------------- /maple/samplers/data_collector/path_collector.py: -------------------------------------------------------------------------------- 1 | from collections import deque, OrderedDict 2 | from functools import partial 3 | 4 | import numpy as np 5 | 6 | from maple.core.eval_util import create_stats_ordered_dict 7 | from maple.samplers.data_collector.base import PathCollector 8 | from maple.samplers.rollout_functions import rollout 9 | 10 | 11 | class MdpPathCollector(PathCollector): 12 | def __init__( 13 | self, 14 | env, 15 | policy, 16 | max_num_epoch_paths_saved=None, 17 | render=False, 18 | render_kwargs=None, 19 | rollout_fn=rollout, 20 | save_env_in_snapshot=True, 21 | rollout_fn_kwargs=None, 22 | ): 23 | if render_kwargs is None: 24 | render_kwargs = {} 25 | self._env = env 26 | self._policy = policy 27 | self._max_num_epoch_paths_saved = max_num_epoch_paths_saved 28 | self._epoch_paths = deque(maxlen=self._max_num_epoch_paths_saved) 29 | self._render = render 30 | self._render_kwargs = render_kwargs 31 | self._rollout_fn = rollout_fn 32 | if rollout_fn_kwargs is None: 33 | rollout_fn_kwargs = {} 34 | self._rollout_fn_kwargs = rollout_fn_kwargs 35 | 36 | self._num_steps_total = 0 37 | self._num_paths_total = 0 38 | 39 | self._num_actions_total = 0 40 | 41 | self._save_env_in_snapshot = save_env_in_snapshot 42 | 43 | def collect_new_paths( 44 | self, 45 | max_path_length, 46 | num_steps, 47 | discard_incomplete_paths, 48 | ): 49 | paths = [] 50 | num_steps_collected = 0 51 | num_actions_collected = 0 52 | while num_steps_collected < num_steps: 53 | max_path_length_this_loop = min( # Do not go over num_steps 54 | max_path_length, 55 | num_steps - num_steps_collected, 56 | ) 57 | if discard_incomplete_paths and (max_path_length_this_loop < max_path_length): 58 | break 59 | path = self._rollout_fn( 60 | self._env, 61 | self._policy, 62 | max_path_length=max_path_length_this_loop, 63 | render=self._render, 64 | render_kwargs=self._render_kwargs, 65 | **self._rollout_fn_kwargs 66 | ) 67 | num_steps_collected += path['path_length'] 68 | num_actions_collected += path['path_length_actions'] 69 | paths.append(path) 70 | self._num_paths_total += len(paths) 71 | self._num_steps_total += num_steps_collected 72 | self._num_actions_total += num_actions_collected 73 | self._epoch_paths.extend(paths) 74 | return paths 75 | 76 | def get_epoch_paths(self): 77 | return self._epoch_paths 78 | 79 | def end_epoch(self, epoch): 80 | self._epoch_paths = deque(maxlen=self._max_num_epoch_paths_saved) 81 | 82 | def get_diagnostics(self): 83 | # path_lens = [len(path['actions']) for path in self._epoch_paths] 84 | path_lens = [path['path_length'] for path in self._epoch_paths] 85 | stats = OrderedDict([ 86 | ('num steps total', self._num_steps_total), 87 | ('num paths total', self._num_paths_total), 88 | ('num actions total', self._num_actions_total), 89 | ]) 90 | stats.update(create_stats_ordered_dict( 91 | "path length", 92 | path_lens, 93 | always_show_all_stats=True, 94 | )) 95 | return stats 96 | 97 | def get_snapshot(self): 98 | snapshot_dict = dict( 99 | policy=self._policy, 100 | ) 101 | if self._save_env_in_snapshot: 102 | snapshot_dict['env'] = self._env 103 | return snapshot_dict 104 | 105 | 106 | class GoalConditionedPathCollector(MdpPathCollector): 107 | def __init__( 108 | self, 109 | *args, 110 | observation_key='observation', 111 | desired_goal_key='desired_goal', 112 | goal_sampling_mode=None, 113 | **kwargs 114 | ): 115 | def obs_processor(o): 116 | return np.hstack((o[observation_key], o[desired_goal_key])) 117 | 118 | rollout_fn = partial( 119 | rollout, 120 | preprocess_obs_for_policy_fn=obs_processor, 121 | ) 122 | super().__init__(*args, rollout_fn=rollout_fn, **kwargs) 123 | self._observation_key = observation_key 124 | self._desired_goal_key = desired_goal_key 125 | self._goal_sampling_mode = goal_sampling_mode 126 | 127 | def collect_new_paths(self, *args, **kwargs): 128 | self._env.goal_sampling_mode = self._goal_sampling_mode 129 | return super().collect_new_paths(*args, **kwargs) 130 | 131 | def get_snapshot(self): 132 | snapshot = super().get_snapshot() 133 | snapshot.update( 134 | observation_key=self._observation_key, 135 | desired_goal_key=self._desired_goal_key, 136 | ) 137 | return snapshot 138 | 139 | 140 | class ObsDictPathCollector(MdpPathCollector): 141 | def __init__( 142 | self, 143 | *args, 144 | observation_key='observation', 145 | **kwargs 146 | ): 147 | def obs_processor(obs): 148 | return obs[observation_key] 149 | 150 | rollout_fn = partial( 151 | rollout, 152 | preprocess_obs_for_policy_fn=obs_processor, 153 | ) 154 | super().__init__(*args, rollout_fn=rollout_fn, **kwargs) 155 | self._observation_key = observation_key 156 | 157 | def get_snapshot(self): 158 | snapshot = super().get_snapshot() 159 | snapshot.update( 160 | observation_key=self._observation_key, 161 | ) 162 | return snapshot 163 | 164 | 165 | class VAEWrappedEnvPathCollector(GoalConditionedPathCollector): 166 | def __init__( 167 | self, 168 | env, 169 | policy, 170 | decode_goals=False, 171 | **kwargs 172 | ): 173 | """Expects env is VAEWrappedEnv""" 174 | super().__init__(env, policy, **kwargs) 175 | self._decode_goals = decode_goals 176 | 177 | def collect_new_paths(self, *args, **kwargs): 178 | self._env.decode_goals = self._decode_goals 179 | return super().collect_new_paths(*args, **kwargs) 180 | -------------------------------------------------------------------------------- /maple/samplers/data_collector/vae_env.py: -------------------------------------------------------------------------------- 1 | from maple.envs.vae_wrapper import VAEWrappedEnv 2 | from maple.samplers.data_collector import GoalConditionedPathCollector 3 | 4 | 5 | class VAEWrappedEnvPathCollector(GoalConditionedPathCollector): 6 | def __init__( 7 | self, 8 | goal_sampling_mode, 9 | env: VAEWrappedEnv, 10 | policy, 11 | decode_goals=False, 12 | **kwargs 13 | ): 14 | super().__init__(env, policy, **kwargs) 15 | self._goal_sampling_mode = goal_sampling_mode 16 | self._decode_goals = decode_goals 17 | 18 | def collect_new_paths(self, *args, **kwargs): 19 | self._env.goal_sampling_mode = self._goal_sampling_mode 20 | self._env.decode_goals = self._decode_goals 21 | return super().collect_new_paths(*args, **kwargs) -------------------------------------------------------------------------------- /maple/samplers/in_place.py: -------------------------------------------------------------------------------- 1 | from maple.samplers.util import rollout 2 | 3 | 4 | class InPlacePathSampler(object): 5 | """ 6 | A sampler that does not serialization for sampling. Instead, it just uses 7 | the current policy and environment as-is. 8 | 9 | WARNING: This will affect the environment! So 10 | ``` 11 | sampler = InPlacePathSampler(env, ...) 12 | sampler.obtain_samples # this has side-effects: env will change! 13 | ``` 14 | """ 15 | def __init__(self, env, policy, max_samples, max_path_length, render=False): 16 | self.env = env 17 | self.policy = policy 18 | self.max_path_length = max_path_length 19 | self.max_samples = max_samples 20 | self.render = render 21 | assert max_samples >= max_path_length, "Need max_samples >= max_path_length" 22 | 23 | def start_worker(self): 24 | pass 25 | 26 | def shutdown_worker(self): 27 | pass 28 | 29 | def obtain_samples(self): 30 | paths = [] 31 | n_steps_total = 0 32 | while n_steps_total + self.max_path_length <= self.max_samples: 33 | path = rollout( 34 | self.env, self.policy, max_path_length=self.max_path_length, 35 | animated=self.render 36 | ) 37 | paths.append(path) 38 | n_steps_total += len(path['observations']) 39 | return paths 40 | -------------------------------------------------------------------------------- /maple/samplers/rollout_functions.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | 3 | import numpy as np 4 | import copy 5 | 6 | create_rollout_function = partial 7 | 8 | 9 | def multitask_rollout( 10 | env, 11 | agent, 12 | max_path_length=np.inf, 13 | render=False, 14 | render_kwargs=None, 15 | observation_key=None, 16 | desired_goal_key=None, 17 | get_action_kwargs=None, 18 | return_dict_obs=False, 19 | full_o_postprocess_func=None, 20 | ): 21 | if full_o_postprocess_func: 22 | def wrapped_fun(env, agent, o): 23 | full_o_postprocess_func(env, agent, observation_key, o) 24 | else: 25 | wrapped_fun = None 26 | 27 | def obs_processor(o): 28 | return np.hstack((o[observation_key], o[desired_goal_key])) 29 | 30 | paths = rollout( 31 | env, 32 | agent, 33 | max_path_length=max_path_length, 34 | render=render, 35 | render_kwargs=render_kwargs, 36 | get_action_kwargs=get_action_kwargs, 37 | preprocess_obs_for_policy_fn=obs_processor, 38 | full_o_postprocess_func=wrapped_fun, 39 | ) 40 | if not return_dict_obs: 41 | paths['observations'] = paths['observations'][observation_key] 42 | return paths 43 | 44 | 45 | def contextual_rollout( 46 | env, 47 | agent, 48 | observation_key=None, 49 | context_keys_for_policy=None, 50 | obs_processor=None, 51 | **kwargs 52 | ): 53 | if context_keys_for_policy is None: 54 | context_keys_for_policy = ['context'] 55 | 56 | if not obs_processor: 57 | def obs_processor(o): 58 | combined_obs = [o[observation_key]] 59 | for k in context_keys_for_policy: 60 | combined_obs.append(o[k]) 61 | return np.concatenate(combined_obs, axis=0) 62 | paths = rollout( 63 | env, 64 | agent, 65 | preprocess_obs_for_policy_fn=obs_processor, 66 | **kwargs 67 | ) 68 | return paths 69 | 70 | 71 | def rollout( 72 | env, 73 | agent, 74 | max_path_length=np.inf, 75 | render=False, 76 | render_kwargs=None, 77 | preprocess_obs_for_policy_fn=None, 78 | get_action_kwargs=None, 79 | return_dict_obs=False, 80 | full_o_postprocess_func=None, 81 | reset_callback=None, 82 | addl_info_func=None, 83 | image_obs_in_info=False, 84 | last_step_is_terminal=False, 85 | terminals_all_false=False, 86 | ): 87 | if render_kwargs is None: 88 | render_kwargs = {} 89 | if get_action_kwargs is None: 90 | get_action_kwargs = {} 91 | if preprocess_obs_for_policy_fn is None: 92 | preprocess_obs_for_policy_fn = lambda env, agent, o: o 93 | raw_obs = [] 94 | raw_next_obs = [] 95 | observations = [] 96 | actions = [] 97 | rewards = [] 98 | terminals = [] 99 | agent_infos = [] 100 | env_infos = [] 101 | next_observations = [] 102 | addl_infos = [] 103 | path_length = 0 104 | agent.reset() 105 | o = env.reset() 106 | if reset_callback: 107 | reset_callback(env, agent, o) 108 | if render: 109 | env.render(**render_kwargs) 110 | while path_length < max_path_length: 111 | raw_obs.append(o) 112 | o_for_agent = preprocess_obs_for_policy_fn(env, agent, o) 113 | a, agent_info = agent.get_action(o_for_agent, **get_action_kwargs) 114 | 115 | if full_o_postprocess_func: 116 | full_o_postprocess_func(env, agent, o) 117 | 118 | if addl_info_func: 119 | addl_infos.append(addl_info_func(env, agent, o, a)) 120 | 121 | next_o, r, d, env_info = env.step(copy.deepcopy(a), image_obs_in_info=image_obs_in_info) 122 | 123 | new_path_length = path_length + env_info.get('num_ac_calls', 1) 124 | 125 | if new_path_length > max_path_length: 126 | break 127 | path_length = new_path_length 128 | 129 | if render: 130 | env.render(**render_kwargs) 131 | observations.append(o) 132 | rewards.append(r) 133 | if terminals_all_false: 134 | terminals.append(False) 135 | else: 136 | terminals.append(d) 137 | actions.append(a) 138 | next_observations.append(next_o) 139 | raw_next_obs.append(next_o) 140 | agent_infos.append(agent_info) 141 | env_infos.append(env_info) 142 | if d: 143 | break 144 | o = next_o 145 | actions = np.array(actions) 146 | if len(actions.shape) == 1: 147 | actions = np.expand_dims(actions, 1) 148 | observations = np.array(observations) 149 | next_observations = np.array(next_observations) 150 | if return_dict_obs: 151 | observations = raw_obs 152 | next_observations = raw_next_obs 153 | rewards = np.array(rewards) 154 | if len(rewards.shape) == 1: 155 | rewards = rewards.reshape(-1, 1) 156 | 157 | path_length_actions = np.sum( 158 | [info.get('num_ac_calls', 1) for info in env_infos] 159 | ) 160 | 161 | reward_actions_sum = np.sum( 162 | [info.get('reward_actions', 0) for info in env_infos] 163 | ) 164 | 165 | if last_step_is_terminal: 166 | terminals[-1] = True 167 | 168 | skill_names = [] 169 | sc = env.env.skill_controller 170 | for i in range(len(actions)): 171 | ac = actions[i] 172 | skill_name = sc.get_skill_name_from_action(ac) 173 | skill_names.append(skill_name) 174 | success = env_infos[i].get('success', False) 175 | if success: 176 | break 177 | 178 | return dict( 179 | observations=observations, 180 | actions=actions, 181 | rewards=rewards, 182 | next_observations=next_observations, 183 | terminals=np.array(terminals).reshape(-1, 1), 184 | agent_infos=agent_infos, 185 | env_infos=env_infos, 186 | addl_infos=addl_infos, 187 | full_observations=raw_obs, 188 | full_next_observations=raw_obs, 189 | path_length=path_length, 190 | path_length_actions=path_length_actions, 191 | reward_actions_sum=reward_actions_sum, 192 | skill_names=skill_names, 193 | max_path_length=max_path_length, 194 | ) 195 | 196 | 197 | def deprecated_rollout( 198 | env, 199 | agent, 200 | max_path_length=np.inf, 201 | render=False, 202 | render_kwargs=None, 203 | ): 204 | """ 205 | The following value for the following keys will be a 2D array, with the 206 | first dimension corresponding to the time dimension. 207 | - observations 208 | - actions 209 | - rewards 210 | - next_observations 211 | - terminals 212 | 213 | The next two elements will be lists of dictionaries, with the index into 214 | the list being the index into the time 215 | - agent_infos 216 | - env_infos 217 | """ 218 | if render_kwargs is None: 219 | render_kwargs = {} 220 | observations = [] 221 | actions = [] 222 | rewards = [] 223 | terminals = [] 224 | agent_infos = [] 225 | env_infos = [] 226 | o = env.reset() 227 | agent.reset() 228 | next_o = None 229 | path_length = 0 230 | if render: 231 | env.render(**render_kwargs) 232 | while path_length < max_path_length: 233 | a, agent_info = agent.get_action(o) 234 | next_o, r, d, env_info = env.step(a) 235 | observations.append(o) 236 | rewards.append(r) 237 | terminals.append(d) 238 | actions.append(a) 239 | agent_infos.append(agent_info) 240 | env_infos.append(env_info) 241 | path_length += 1 242 | if d: 243 | break 244 | o = next_o 245 | if render: 246 | env.render(**render_kwargs) 247 | 248 | actions = np.array(actions) 249 | if len(actions.shape) == 1: 250 | actions = np.expand_dims(actions, 1) 251 | observations = np.array(observations) 252 | if len(observations.shape) == 1: 253 | observations = np.expand_dims(observations, 1) 254 | next_o = np.array([next_o]) 255 | next_observations = np.vstack( 256 | ( 257 | observations[1:, :], 258 | np.expand_dims(next_o, 0) 259 | ) 260 | ) 261 | return dict( 262 | observations=observations, 263 | actions=actions, 264 | rewards=np.array(rewards).reshape(-1, 1), 265 | next_observations=next_observations, 266 | terminals=np.array(terminals).reshape(-1, 1), 267 | agent_infos=agent_infos, 268 | env_infos=env_infos, 269 | ) 270 | -------------------------------------------------------------------------------- /maple/samplers/util.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | def split_paths(paths): 4 | """ 5 | Stack multiples obs/actions/etc. from different paths 6 | :param paths: List of paths, where one path is something returned from 7 | the rollout functino above. 8 | :return: Tuple. Every element will have shape batch_size X DIM, including 9 | the rewards and terminal flags. 10 | """ 11 | rewards = [path["rewards"].reshape(-1, 1) for path in paths] 12 | terminals = [path["terminals"].reshape(-1, 1) for path in paths] 13 | actions = [path["actions"] for path in paths] 14 | obs = [path["observations"] for path in paths] 15 | next_obs = [path["next_observations"] for path in paths] 16 | rewards = np.vstack(rewards) 17 | terminals = np.vstack(terminals) 18 | obs = np.vstack(obs) 19 | actions = np.vstack(actions) 20 | next_obs = np.vstack(next_obs) 21 | assert len(rewards.shape) == 2 22 | assert len(terminals.shape) == 2 23 | assert len(obs.shape) == 2 24 | assert len(actions.shape) == 2 25 | assert len(next_obs.shape) == 2 26 | return rewards, terminals, obs, actions, next_obs 27 | 28 | 29 | def split_paths_to_dict(paths): 30 | rewards, terminals, obs, actions, next_obs = split_paths(paths) 31 | return dict( 32 | rewards=rewards, 33 | terminals=terminals, 34 | observations=obs, 35 | actions=actions, 36 | next_observations=next_obs, 37 | ) 38 | -------------------------------------------------------------------------------- /maple/torch/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UT-Austin-RPL/maple/c13f584d4184e5796bd0adb822b8761ff4ebeee0/maple/torch/__init__.py -------------------------------------------------------------------------------- /maple/torch/core.py: -------------------------------------------------------------------------------- 1 | import abc 2 | 3 | import numpy as np 4 | import torch 5 | from torch import nn as nn 6 | 7 | from maple.torch import pytorch_util as ptu 8 | 9 | 10 | class PyTorchModule(nn.Module, metaclass=abc.ABCMeta): 11 | """ 12 | Keeping wrapper around to be a bit more future-proof. 13 | """ 14 | pass 15 | 16 | 17 | def eval_np(module, *args, **kwargs): 18 | """ 19 | Eval this module with a numpy interface 20 | 21 | Same as a call to __call__ except all Variable input/outputs are 22 | replaced with numpy equivalents. 23 | 24 | Assumes the output is either a single object or a tuple of objects. 25 | """ 26 | torch_args = tuple(torch_ify(x) for x in args) 27 | torch_kwargs = {k: torch_ify(v) for k, v in kwargs.items()} 28 | outputs = module(*torch_args, **torch_kwargs) 29 | return elem_or_tuple_to_numpy(outputs) 30 | 31 | 32 | def torch_ify(np_array_or_other): 33 | if isinstance(np_array_or_other, np.ndarray): 34 | return ptu.from_numpy(np_array_or_other) 35 | else: 36 | return np_array_or_other 37 | 38 | 39 | def np_ify(tensor_or_other): 40 | if isinstance(tensor_or_other, torch.autograd.Variable): 41 | return ptu.get_numpy(tensor_or_other) 42 | else: 43 | return tensor_or_other 44 | 45 | 46 | def _elem_or_tuple_to_variable(elem_or_tuple): 47 | if isinstance(elem_or_tuple, tuple): 48 | return tuple( 49 | _elem_or_tuple_to_variable(e) for e in elem_or_tuple 50 | ) 51 | return ptu.from_numpy(elem_or_tuple).float() 52 | 53 | 54 | def elem_or_tuple_to_numpy(elem_or_tuple): 55 | if isinstance(elem_or_tuple, tuple): 56 | return tuple(np_ify(x) for x in elem_or_tuple) 57 | else: 58 | return np_ify(elem_or_tuple) 59 | 60 | 61 | def _filter_batch(np_batch): 62 | for k, v in np_batch.items(): 63 | if v.dtype == np.bool: 64 | yield k, v.astype(int) 65 | else: 66 | yield k, v 67 | 68 | 69 | def np_to_pytorch_batch(np_batch): 70 | if isinstance(np_batch, dict): 71 | return { 72 | k: _elem_or_tuple_to_variable(x) 73 | for k, x in _filter_batch(np_batch) 74 | if x.dtype != np.dtype('O') # ignore object (e.g. dictionaries) 75 | } 76 | else: 77 | _elem_or_tuple_to_variable(np_batch) 78 | -------------------------------------------------------------------------------- /maple/torch/data.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch.utils.data import Dataset, Sampler 4 | 5 | # TODO: move this to more reasonable place 6 | from maple.data_management.obs_dict_replay_buffer import normalize_image 7 | 8 | 9 | class ImageDataset(Dataset): 10 | 11 | def __init__(self, images, should_normalize=True): 12 | super().__init__() 13 | self.dataset = images 14 | self.dataset_len = len(self.dataset) 15 | assert should_normalize == (images.dtype == np.uint8) 16 | self.should_normalize = should_normalize 17 | 18 | def __len__(self): 19 | return self.dataset_len 20 | 21 | def __getitem__(self, idxs): 22 | samples = self.dataset[idxs, :] 23 | if self.should_normalize: 24 | samples = normalize_image(samples) 25 | return np.float32(samples) 26 | 27 | 28 | class InfiniteRandomSampler(Sampler): 29 | 30 | def __init__(self, data_source): 31 | self.data_source = data_source 32 | self.iter = iter(torch.randperm(len(self.data_source)).tolist()) 33 | 34 | def __iter__(self): 35 | return self 36 | 37 | def __next__(self): 38 | try: 39 | idx = next(self.iter) 40 | except StopIteration: 41 | self.iter = iter(torch.randperm(len(self.data_source)).tolist()) 42 | idx = next(self.iter) 43 | return idx 44 | 45 | def __len__(self): 46 | return 2 ** 62 47 | 48 | 49 | class InfiniteWeightedRandomSampler(Sampler): 50 | 51 | def __init__(self, data_source, weights): 52 | assert len(data_source) == len(weights) 53 | assert len(weights.shape) == 1 54 | self.data_source = data_source 55 | # Always use CPU 56 | self._weights = torch.from_numpy(weights) 57 | self.iter = self._create_iterator() 58 | 59 | def update_weights(self, weights): 60 | self._weights = weights 61 | self.iter = self._create_iterator() 62 | 63 | def _create_iterator(self): 64 | return iter( 65 | torch.multinomial( 66 | self._weights, len(self._weights), replacement=True 67 | ).tolist() 68 | ) 69 | 70 | def __iter__(self): 71 | return self 72 | 73 | def __next__(self): 74 | try: 75 | idx = next(self.iter) 76 | except StopIteration: 77 | self.iter = self._create_iterator() 78 | idx = next(self.iter) 79 | return idx 80 | 81 | def __len__(self): 82 | return 2 ** 62 83 | -------------------------------------------------------------------------------- /maple/torch/data_management/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UT-Austin-RPL/maple/c13f584d4184e5796bd0adb822b8761ff4ebeee0/maple/torch/data_management/__init__.py -------------------------------------------------------------------------------- /maple/torch/data_management/normalizer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import maple.torch.pytorch_util as ptu 3 | import numpy as np 4 | 5 | from maple.data_management.normalizer import Normalizer, FixedNormalizer 6 | 7 | 8 | class TorchNormalizer(Normalizer): 9 | """ 10 | Update with np array, but de/normalize pytorch Tensors. 11 | """ 12 | def normalize(self, v, clip_range=None): 13 | if not self.synchronized: 14 | self.synchronize() 15 | if clip_range is None: 16 | clip_range = self.default_clip_range 17 | mean = ptu.from_numpy(self.mean) 18 | std = ptu.from_numpy(self.std) 19 | if v.dim() == 2: 20 | # Unsqueeze along the batch use automatic broadcasting 21 | mean = mean.unsqueeze(0) 22 | std = std.unsqueeze(0) 23 | return torch.clamp((v - mean) / std, -clip_range, clip_range) 24 | 25 | def denormalize(self, v): 26 | if not self.synchronized: 27 | self.synchronize() 28 | mean = ptu.from_numpy(self.mean) 29 | std = ptu.from_numpy(self.std) 30 | if v.dim() == 2: 31 | mean = mean.unsqueeze(0) 32 | std = std.unsqueeze(0) 33 | return mean + v * std 34 | 35 | 36 | class TorchFixedNormalizer(FixedNormalizer): 37 | def normalize(self, v, clip_range=None): 38 | if clip_range is None: 39 | clip_range = self.default_clip_range 40 | mean = ptu.from_numpy(self.mean) 41 | std = ptu.from_numpy(self.std) 42 | if v.dim() == 2: 43 | # Unsqueeze along the batch use automatic broadcasting 44 | mean = mean.unsqueeze(0) 45 | std = std.unsqueeze(0) 46 | return torch.clamp((v - mean) / std, -clip_range, clip_range) 47 | 48 | def normalize_scale(self, v): 49 | """ 50 | Only normalize the scale. Do not subtract the mean. 51 | """ 52 | std = ptu.from_numpy(self.std) 53 | if v.dim() == 2: 54 | std = std.unsqueeze(0) 55 | return v / std 56 | 57 | def denormalize(self, v): 58 | mean = ptu.from_numpy(self.mean) 59 | std = ptu.from_numpy(self.std) 60 | if v.dim() == 2: 61 | mean = mean.unsqueeze(0) 62 | std = std.unsqueeze(0) 63 | return mean + v * std 64 | 65 | def denormalize_scale(self, v): 66 | """ 67 | Only denormalize the scale. Do not add the mean. 68 | """ 69 | std = ptu.from_numpy(self.std) 70 | if v.dim() == 2: 71 | std = std.unsqueeze(0) 72 | return v * std 73 | -------------------------------------------------------------------------------- /maple/torch/lvm/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UT-Austin-RPL/maple/c13f584d4184e5796bd0adb822b8761ff4ebeee0/maple/torch/lvm/__init__.py -------------------------------------------------------------------------------- /maple/torch/lvm/bear_vae.py: -------------------------------------------------------------------------------- 1 | import abc 2 | import logging 3 | 4 | import numpy as np 5 | import torch 6 | import torch.nn.functional as F 7 | from torch import nn as nn 8 | 9 | import maple.torch.pytorch_util as ptu 10 | from maple.policies.base import ExplorationPolicy 11 | from maple.torch.core import torch_ify, elem_or_tuple_to_numpy 12 | from maple.torch.distributions import ( 13 | Delta, TanhNormal, MultivariateDiagonalNormal, GaussianMixture, GaussianMixtureFull, 14 | ) 15 | from maple.torch.networks import Mlp, CNN 16 | from maple.torch.networks.basic import MultiInputSequential 17 | from maple.torch.networks.stochastic.distribution_generator import ( 18 | DistributionGenerator 19 | ) 20 | from maple.torch.lvm.latent_variable_model import LatentVariableModel 21 | 22 | 23 | class VAEPolicy(LatentVariableModel): 24 | def __init__( 25 | self, 26 | obs_dim, 27 | action_dim, 28 | latent_dim, 29 | ): 30 | encoder = Encoder(obs_dim, latent_dim, action_dim) 31 | decoder = Decoder(obs_dim, latent_dim, action_dim) 32 | super().__init__(encoder, decoder) 33 | 34 | self.latent_dim = latent_dim 35 | 36 | def forward(self, state, action): 37 | z = F.relu(self.encoder.e1(torch.cat([state, action], 1))) 38 | z = F.relu(self.encoder.e2(z)) 39 | 40 | mean = self.encoder.mean(z) 41 | # Clamped for numerical stability 42 | log_std = self.encoder.log_std(z).clamp(-4, 15) 43 | std = torch.exp(log_std) 44 | z = mean + std * ptu.from_numpy( 45 | np.random.normal(0, 1, size=(std.size()))) 46 | 47 | u = self.decode(state, z) 48 | 49 | return u, mean, std 50 | 51 | def decode(self, state, z=None): 52 | if z is None: 53 | z = ptu.from_numpy(np.random.normal(0, 1, size=( 54 | state.size(0), self.latent_dim))).clamp(-0.5, 0.5) 55 | 56 | a = F.relu(self.decoder.d1(torch.cat([state, z], 1))) 57 | a = F.relu(self.decoder.d2(a)) 58 | return torch.tanh(self.decoder.d3(a)) 59 | 60 | def decode_multiple(self, state, z=None, num_decode=10): 61 | if z is None: 62 | z = ptu.from_numpy(np.random.normal(0, 1, size=( 63 | state.size(0), num_decode, self.latent_dim))).clamp(-0.5, 0.5) 64 | 65 | a = F.relu(self.decoder.d1(torch.cat( 66 | [state.unsqueeze(0).repeat(num_decode, 1, 1).permute(1, 0, 2), z], 67 | 2))) 68 | a = F.relu(self.decoder.d2(a)) 69 | return torch.tanh(self.decoder.d3(a)), self.decoder.d3(a) 70 | 71 | 72 | class Encoder(nn.Module): 73 | def __init__(self, obs_dim, latent_dim, action_dim): 74 | super().__init__() 75 | self.latent_dim = latent_dim 76 | 77 | self.e1 = torch.nn.Linear(obs_dim + action_dim, 750) 78 | self.e2 = torch.nn.Linear(750, 750) 79 | 80 | self.mean = torch.nn.Linear(750, self.latent_dim) 81 | self.log_std = torch.nn.Linear(750, self.latent_dim) 82 | 83 | def forward(self, state, action): 84 | z = F.relu(self.e1(torch.cat([state, action], 1))) 85 | z = F.relu(self.e2(z)) 86 | 87 | mean = self.mean(z) 88 | # Clamped for numerical stability 89 | log_std = self.log_std(z).clamp(-4, 15) 90 | std = torch.exp(log_std) 91 | return MultivariateDiagonalNormal(mean, std) 92 | 93 | 94 | class Decoder(nn.Module): 95 | def __init__(self, obs_dim, latent_dim, action_dim): 96 | super().__init__() 97 | self.latent_dim = latent_dim 98 | 99 | self.d1 = torch.nn.Linear(obs_dim + self.latent_dim, 750) 100 | self.d2 = torch.nn.Linear(750, 750) 101 | self.d3 = torch.nn.Linear(750, action_dim) 102 | 103 | def forward(self, state, z): 104 | a = F.relu(self.d1(torch.cat([state, z], 1))) 105 | a = F.relu(self.d2(a)) 106 | return Delta(torch.tanh(self.d3(a))) 107 | -------------------------------------------------------------------------------- /maple/torch/lvm/latent_variable_model.py: -------------------------------------------------------------------------------- 1 | import abc 2 | import logging 3 | 4 | import numpy as np 5 | import torch 6 | import torch.nn.functional as F 7 | from torch import nn as nn 8 | 9 | import maple.torch.pytorch_util as ptu 10 | from maple.policies.base import ExplorationPolicy 11 | from maple.torch.core import torch_ify, elem_or_tuple_to_numpy 12 | from maple.torch.distributions import ( 13 | Delta, TanhNormal, MultivariateDiagonalNormal, GaussianMixture, GaussianMixtureFull, 14 | ) 15 | from maple.torch.networks import Mlp, CNN 16 | from maple.torch.networks.basic import MultiInputSequential 17 | from maple.torch.networks.stochastic.distribution_generator import ( 18 | DistributionGenerator 19 | ) 20 | from maple.torch.sac.policies.base import ( 21 | TorchStochasticPolicy, 22 | PolicyFromDistributionGenerator, 23 | MakeDeterministic, 24 | ) 25 | 26 | 27 | class LatentVariableModel(nn.Module): 28 | def __init__( 29 | self, 30 | encoder, 31 | decoder, 32 | **kwargs 33 | ): 34 | super().__init__() 35 | self.encoder = encoder 36 | self.decoder = decoder 37 | -------------------------------------------------------------------------------- /maple/torch/modules.py: -------------------------------------------------------------------------------- 1 | """ 2 | Contain some self-contained modules. 3 | """ 4 | import torch 5 | import torch.nn as nn 6 | 7 | 8 | class HuberLoss(nn.Module): 9 | def __init__(self, delta=1): 10 | super().__init__() 11 | self.huber_loss_delta1 = nn.SmoothL1Loss() 12 | self.delta = delta 13 | 14 | def forward(self, x, x_hat): 15 | loss = self.huber_loss_delta1(x / self.delta, x_hat / self.delta) 16 | return loss * self.delta * self.delta 17 | 18 | 19 | class LayerNorm(nn.Module): 20 | """ 21 | Simple 1D LayerNorm. 22 | """ 23 | 24 | def __init__(self, features, center=True, scale=False, eps=1e-6): 25 | super().__init__() 26 | self.center = center 27 | self.scale = scale 28 | self.eps = eps 29 | if self.scale: 30 | self.scale_param = nn.Parameter(torch.ones(features)) 31 | else: 32 | self.scale_param = None 33 | if self.center: 34 | self.center_param = nn.Parameter(torch.zeros(features)) 35 | else: 36 | self.center_param = None 37 | 38 | def forward(self, x): 39 | mean = x.mean(-1, keepdim=True) 40 | std = x.std(-1, keepdim=True) 41 | output = (x - mean) / (std + self.eps) 42 | if self.scale: 43 | output = output * self.scale_param 44 | if self.center: 45 | output = output + self.center_param 46 | return output 47 | -------------------------------------------------------------------------------- /maple/torch/networks/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | General networks for pytorch. 3 | 4 | Algorithm-specific networks should go else-where. 5 | """ 6 | from maple.torch.networks.basic import ( 7 | Clamp, ConcatTuple, Detach, Flatten, FlattenEach, Split, Reshape, 8 | ) 9 | from maple.torch.networks.cnn import BasicCNN, CNN, MergedCNN, CNNPolicy 10 | from maple.torch.networks.dcnn import DCNN, TwoHeadDCNN 11 | from maple.torch.networks.feat_point_mlp import FeatPointMlp 12 | from maple.torch.networks.image_state import ImageStatePolicy, ImageStateQ 13 | from maple.torch.networks.linear_transform import LinearTransform 14 | from maple.torch.networks.normalization import LayerNorm 15 | from maple.torch.networks.mlp import ( 16 | Mlp, ConcatMlp, MlpPolicy, TanhMlpPolicy, 17 | MlpQf, 18 | MlpQfWithObsProcessor, 19 | ConcatMultiHeadedMlp, 20 | ) 21 | from maple.torch.networks.pretrained_cnn import PretrainedCNN 22 | from maple.torch.networks.two_headed_mlp import TwoHeadMlp 23 | 24 | __all__ = [ 25 | 'Clamp', 26 | 'ConcatMlp', 27 | 'ConcatMultiHeadedMlp', 28 | 'ConcatTuple', 29 | 'BasicCNN', 30 | 'CNN', 31 | 'CNNPolicy', 32 | 'DCNN', 33 | 'Detach', 34 | 'FeatPointMlp', 35 | 'Flatten', 36 | 'FlattenEach', 37 | 'LayerNorm', 38 | 'LinearTransform', 39 | 'ImageStatePolicy', 40 | 'ImageStateQ', 41 | 'MergedCNN', 42 | 'Mlp', 43 | 'PretrainedCNN', 44 | 'Reshape', 45 | 'Split', 46 | 'TwoHeadDCNN', 47 | 'TwoHeadMlp', 48 | ] 49 | 50 | -------------------------------------------------------------------------------- /maple/torch/networks/basic.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class Clamp(nn.Module): 6 | def __init__(self, **kwargs): 7 | super().__init__() 8 | self.kwargs = kwargs 9 | self.__name__ = "Clamp" 10 | 11 | def forward(self, x): 12 | return torch.clamp(x, **self.kwargs) 13 | 14 | 15 | class Split(nn.Module): 16 | """ 17 | Split input and process each chunk with a separate module. 18 | """ 19 | def __init__(self, module1, module2, split_idx): 20 | super().__init__() 21 | self.module1 = module1 22 | self.module2 = module2 23 | self.split_idx = split_idx 24 | 25 | def forward(self, x): 26 | in1 = x[:, :self.split_idx] 27 | out1 = self.module1(in1) 28 | 29 | in2 = x[:, self.split_idx:] 30 | out2 = self.module2(in2) 31 | 32 | return out1, out2 33 | 34 | 35 | class FlattenEach(nn.Module): 36 | def forward(self, inputs): 37 | return tuple(x.view(x.size(0), -1) for x in inputs) 38 | 39 | 40 | class FlattenEachParallel(nn.Module): 41 | def forward(self, *inputs): 42 | return tuple(x.view(x.size(0), -1) for x in inputs) 43 | 44 | 45 | class Flatten(nn.Module): 46 | def forward(self, inputs): 47 | return inputs.view(inputs.size(0), -1) 48 | 49 | 50 | class Map(nn.Module): 51 | """Apply a module to each input.""" 52 | def __init__(self, module): 53 | super().__init__() 54 | self.module = module 55 | 56 | def forward(self, inputs): 57 | return tuple(self.module(x) for x in inputs) 58 | 59 | 60 | class ApplyMany(nn.Module): 61 | """Apply many modules to one input.""" 62 | def __init__(self, *modules): 63 | super().__init__() 64 | self.modules_to_apply = nn.ModuleList(modules) 65 | 66 | def forward(self, inputs): 67 | return tuple(m(inputs) for m in self.modules_to_apply) 68 | 69 | 70 | class LearnedPositiveConstant(nn.Module): 71 | def __init__(self, init_value): 72 | super().__init__() 73 | self._constant = nn.Parameter(init_value) 74 | 75 | def forward(self, _): 76 | return self._constant 77 | 78 | 79 | class Reshape(nn.Module): 80 | def __init__(self, *output_shape): 81 | super().__init__() 82 | self._output_shape_with_batch_size = (-1, *output_shape) 83 | 84 | def forward(self, inputs): 85 | return inputs.view(self._output_shape_with_batch_size) 86 | 87 | 88 | class ConcatTuple(nn.Module): 89 | def __init__(self, dim=1): 90 | super().__init__() 91 | self.dim = dim 92 | 93 | def forward(self, inputs): 94 | return torch.cat(inputs, dim=self.dim) 95 | 96 | 97 | class Concat(nn.Module): 98 | def __init__(self, dim=1): 99 | super().__init__() 100 | self.dim = dim 101 | 102 | def forward(self, *inputs): 103 | return torch.cat(inputs, dim=self.dim) 104 | 105 | 106 | class MultiInputSequential(nn.Sequential): 107 | def forward(self, *input): 108 | for module in self._modules.values(): 109 | if isinstance(input, tuple): 110 | input = module(*input) 111 | else: 112 | input = module(input) 113 | return input 114 | 115 | 116 | class Detach(nn.Module): 117 | def __init__(self, wrapped_mlp): 118 | super().__init__() 119 | self.wrapped_mlp = wrapped_mlp 120 | 121 | def forward(self, inputs): 122 | return self.wrapped_mlp.forward(inputs).detach() 123 | 124 | def __getattr__(self, attr_name): 125 | try: 126 | return super().__getattr__(attr_name) 127 | except AttributeError: 128 | return getattr(self.wrapped_mlp, attr_name) 129 | -------------------------------------------------------------------------------- /maple/torch/networks/custom.py: -------------------------------------------------------------------------------- 1 | """ 2 | Random networks 3 | """ 4 | -------------------------------------------------------------------------------- /maple/torch/networks/feat_point_mlp.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch import nn as nn 4 | from torch.nn import functional as F 5 | 6 | from maple.pythonplusplus import identity 7 | from maple.torch import pytorch_util as ptu 8 | from maple.torch.core import PyTorchModule 9 | 10 | 11 | class FeatPointMlp(PyTorchModule): 12 | def __init__( 13 | self, 14 | downsample_size, 15 | input_channels, 16 | num_feat_points, 17 | temperature=1.0, 18 | init_w=1e-3, 19 | input_size=32, 20 | hidden_init=ptu.fanin_init, 21 | output_activation=identity, 22 | ): 23 | super().__init__() 24 | 25 | self.downsample_size = downsample_size 26 | self.temperature = temperature 27 | self.num_feat_points = num_feat_points 28 | self.hidden_init = hidden_init 29 | self.output_activation = output_activation 30 | self.input_channels = input_channels 31 | self.input_size = input_size 32 | 33 | # self.bn1 = nn.BatchNorm2d(1) 34 | self.conv1 = nn.Conv2d(input_channels, 48, kernel_size=5, stride=2) 35 | # self.bn1 = nn.BatchNorm2d(16) 36 | self.conv2 = nn.Conv2d(48, 48, kernel_size=5, stride=1) 37 | self.conv3 = nn.Conv2d(48, self.num_feat_points, kernel_size=5, stride=1) 38 | 39 | test_mat = ptu.zeros(1, self.input_channels, self.input_size, self.input_size) 40 | test_mat = self.conv1(test_mat) 41 | test_mat = self.conv2(test_mat) 42 | test_mat = self.conv3(test_mat) 43 | self.out_size = int(np.prod(test_mat.shape)) 44 | self.fc1 = nn.Linear(2 * self.num_feat_points, 400) 45 | self.fc2 = nn.Linear(400, 300) 46 | self.last_fc = nn.Linear(300, self.input_channels * self.downsample_size * self.downsample_size) 47 | 48 | self.init_weights(init_w) 49 | self.i = 0 50 | 51 | def init_weights(self, init_w): 52 | self.hidden_init(self.conv1.weight) 53 | self.conv1.bias.data.fill_(0) 54 | self.hidden_init(self.conv2.weight) 55 | self.conv2.bias.data.fill_(0) 56 | 57 | def forward(self, input): 58 | h = self.encoder(input) 59 | out = self.decoder(h) 60 | return out 61 | 62 | def encoder(self, input): 63 | x = input.contiguous().view(-1, self.input_channels, self.input_size, self.input_size) 64 | x = F.relu(self.conv1(x)) 65 | x = F.relu(self.conv2(x)) 66 | x = self.conv3(x) 67 | d = int((self.out_size // self.num_feat_points) ** (1 / 2)) 68 | x = x.view(-1, self.num_feat_points, d * d) 69 | x = F.softmax(x / self.temperature, 2) 70 | x = x.view(-1, self.num_feat_points, d, d) 71 | 72 | maps_x = torch.sum(x, 2) 73 | maps_y = torch.sum(x, 3) 74 | 75 | weights = ptu.from_numpy(np.arange(d) / (d + 1)) 76 | 77 | fp_x = torch.sum(maps_x * weights, 2) 78 | fp_y = torch.sum(maps_y * weights, 2) 79 | 80 | x = torch.cat([fp_x, fp_y], 1) 81 | # h = x.view(-1, 2, self.num_feat_points).transpose(1, 2).contiguous().view(-1, self.num_feat_points * 2) 82 | h = x.view(-1, self.num_feat_points * 2) 83 | return h 84 | 85 | def decoder(self, input): 86 | h = input 87 | h = F.relu(self.fc1(h)) 88 | h = F.relu(self.fc2(h)) 89 | h = self.last_fc(h) 90 | return h 91 | 92 | def history_encoder(self, input, history_length): 93 | input = input.contiguous().view(-1, 94 | self.input_channels, 95 | self.input_size, 96 | self.input_size) 97 | latent = self.encoder(input) 98 | 99 | assert latent.shape[0] % history_length == 0 100 | n_samples = latent.shape[0] // history_length 101 | latent = latent.view(n_samples, -1) 102 | return latent 103 | 104 | -------------------------------------------------------------------------------- /maple/torch/networks/image_state.py: -------------------------------------------------------------------------------- 1 | from maple.policies.base import Policy 2 | from maple.torch.core import PyTorchModule, eval_np 3 | 4 | 5 | class ImageStatePolicy(PyTorchModule, Policy): 6 | """Switches between image or state inputs""" 7 | 8 | def __init__( 9 | self, 10 | image_conv_net, 11 | state_fc_net, 12 | ): 13 | super().__init__() 14 | 15 | assert image_conv_net is None or state_fc_net is None 16 | self.image_conv_net = image_conv_net 17 | self.state_fc_net = state_fc_net 18 | 19 | def forward(self, input, return_preactivations=False): 20 | if self.image_conv_net is not None: 21 | image = input[:, :21168] 22 | return self.image_conv_net(image) 23 | if self.state_fc_net is not None: 24 | state = input[:, 21168:] 25 | return self.state_fc_net(state) 26 | 27 | def get_action(self, obs_np): 28 | actions = self.get_actions(obs_np[None]) 29 | return actions[0, :], {} 30 | 31 | def get_actions(self, obs): 32 | return eval_np(self, obs) 33 | 34 | 35 | class ImageStateQ(PyTorchModule): 36 | """Switches between image or state inputs""" 37 | 38 | def __init__( 39 | self, 40 | # obs_dim, 41 | # action_dim, 42 | # goal_dim, 43 | image_conv_net, # assumed to be a MergedCNN 44 | state_fc_net, 45 | ): 46 | super().__init__() 47 | 48 | assert image_conv_net is None or state_fc_net is None 49 | # self.obs_dim = obs_dim 50 | # self.action_dim = action_dim 51 | # self.goal_dim = goal_dim 52 | self.image_conv_net = image_conv_net 53 | self.state_fc_net = state_fc_net 54 | 55 | def forward(self, input, action, return_preactivations=False): 56 | if self.image_conv_net is not None: 57 | image = input[:, :21168] 58 | return self.image_conv_net(image, action) 59 | if self.state_fc_net is not None: 60 | state = input[:, 21168:] # action + state 61 | return self.state_fc_net(state, action) 62 | 63 | 64 | -------------------------------------------------------------------------------- /maple/torch/networks/linear_transform.py: -------------------------------------------------------------------------------- 1 | from maple.torch.core import PyTorchModule 2 | 3 | 4 | class LinearTransform(PyTorchModule): 5 | def __init__(self, m, b): 6 | super().__init__() 7 | self.m = m 8 | self.b = b 9 | 10 | def __call__(self, t): 11 | return self.m * t + self.b 12 | -------------------------------------------------------------------------------- /maple/torch/networks/normalization.py: -------------------------------------------------------------------------------- 1 | """ 2 | Contain some self-contained modules. Maybe depend on pytorch_util. 3 | """ 4 | import torch 5 | import torch.nn as nn 6 | from maple.torch import pytorch_util as ptu 7 | 8 | 9 | class LayerNorm(nn.Module): 10 | """ 11 | Simple 1D LayerNorm. 12 | """ 13 | def __init__(self, features, center=True, scale=False, eps=1e-6): 14 | super().__init__() 15 | self.center = center 16 | self.scale = scale 17 | self.eps = eps 18 | if self.scale: 19 | self.scale_param = nn.Parameter(torch.ones(features)) 20 | else: 21 | self.scale_param = None 22 | if self.center: 23 | self.center_param = nn.Parameter(torch.zeros(features)) 24 | else: 25 | self.center_param = None 26 | 27 | def forward(self, x): 28 | mean = x.mean(-1, keepdim=True) 29 | std = x.std(-1, keepdim=True) 30 | output = (x - mean) / (std + self.eps) 31 | if self.scale: 32 | output = output * self.scale_param 33 | if self.center: 34 | output = output + self.center_param 35 | return output 36 | -------------------------------------------------------------------------------- /maple/torch/networks/pretrained_cnn.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torchvision.models as models 4 | from torch import nn as nn 5 | 6 | from maple.pythonplusplus import identity 7 | from maple.torch.core import PyTorchModule 8 | 9 | 10 | class PretrainedCNN(PyTorchModule): 11 | # Uses a pretrained CNN architecture from torchvision 12 | def __init__( 13 | self, 14 | input_width, 15 | input_height, 16 | input_channels, 17 | output_size, 18 | hidden_sizes=None, 19 | added_fc_input_size=0, 20 | batch_norm_fc=False, 21 | init_w=1e-4, 22 | hidden_init=nn.init.xavier_uniform_, 23 | hidden_activation=nn.ReLU(), 24 | output_activation=identity, 25 | output_conv_channels=False, 26 | model_architecture=models.resnet18, 27 | model_pretrained=True, 28 | model_freeze=False, 29 | ): 30 | if hidden_sizes is None: 31 | hidden_sizes = [] 32 | super().__init__() 33 | 34 | self.hidden_sizes = hidden_sizes 35 | self.input_width = input_width 36 | self.input_height = input_height 37 | self.input_channels = input_channels 38 | self.output_size = output_size 39 | self.output_activation = output_activation 40 | self.hidden_activation = hidden_activation 41 | self.batch_norm_fc = batch_norm_fc 42 | self.added_fc_input_size = added_fc_input_size 43 | self.conv_input_length = self.input_width * self.input_height * self.input_channels 44 | self.output_conv_channels = output_conv_channels 45 | 46 | self.pretrained_model = nn.Sequential(*list(model_architecture( 47 | pretrained=model_pretrained).children())[:-1]) 48 | if model_freeze: 49 | for child in self.pretrained_model.children(): 50 | for param in child.parameters(): 51 | param.requires_grad = False 52 | self.fc_layers = nn.ModuleList() 53 | self.fc_norm_layers = nn.ModuleList() 54 | 55 | # use torch rather than ptu because initially the model is on CPU 56 | test_mat = torch.zeros( 57 | 1, 58 | self.input_channels, 59 | self.input_width, 60 | self.input_height, 61 | ) 62 | # find output dim of conv_layers by trial and add norm conv layers 63 | test_mat = self.pretrained_model(test_mat) 64 | 65 | self.conv_output_flat_size = int(np.prod(test_mat.shape)) 66 | if self.output_conv_channels: 67 | self.last_fc = None 68 | else: 69 | fc_input_size = self.conv_output_flat_size 70 | # used only for injecting input directly into fc layers 71 | fc_input_size += added_fc_input_size 72 | for idx, hidden_size in enumerate(hidden_sizes): 73 | fc_layer = nn.Linear(fc_input_size, hidden_size) 74 | fc_input_size = hidden_size 75 | 76 | fc_layer.weight.data.uniform_(-init_w, init_w) 77 | fc_layer.bias.data.uniform_(-init_w, init_w) 78 | 79 | self.fc_layers.append(fc_layer) 80 | 81 | if self.batch_norm_fc: 82 | norm_layer = nn.BatchNorm1d(hidden_size) 83 | self.fc_norm_layers.append(norm_layer) 84 | 85 | self.last_fc = nn.Linear(fc_input_size, output_size) 86 | self.last_fc.weight.data.uniform_(-init_w, init_w) 87 | self.last_fc.bias.data.uniform_(-init_w, init_w) 88 | 89 | def forward(self, input, return_last_activations=False): 90 | conv_input = input.narrow(start=0, 91 | length=self.conv_input_length, 92 | dim=1).contiguous() 93 | # reshape from batch of flattened images into (channels, w, h) 94 | h = conv_input.view(conv_input.shape[0], 95 | self.input_channels, 96 | self.input_height, 97 | self.input_width) 98 | 99 | h = self.apply_forward_conv(h) 100 | 101 | if self.output_conv_channels: 102 | return h 103 | 104 | # flatten channels for fc layers 105 | h = h.view(h.size(0), -1) 106 | if self.added_fc_input_size != 0: 107 | extra_fc_input = input.narrow( 108 | start=self.conv_input_length, 109 | length=self.added_fc_input_size, 110 | dim=1, 111 | ) 112 | h = torch.cat((h, extra_fc_input), dim=1) 113 | h = self.apply_forward_fc(h) 114 | 115 | if return_last_activations: 116 | return h 117 | return self.output_activation(self.last_fc(h)) 118 | 119 | def apply_forward_conv(self, h): 120 | return self.pretrained_model(h) 121 | 122 | def apply_forward_fc(self, h): 123 | for i, layer in enumerate(self.fc_layers): 124 | h = layer(h) 125 | if self.batch_norm_fc: 126 | h = self.fc_norm_layers[i](h) 127 | h = self.hidden_activation(h) 128 | return h 129 | 130 | 131 | -------------------------------------------------------------------------------- /maple/torch/networks/stochastic/distribution_generator.py: -------------------------------------------------------------------------------- 1 | import abc 2 | 3 | from torch import nn 4 | 5 | from maple.torch.distributions import ( 6 | Bernoulli, 7 | Beta, 8 | Distribution, 9 | Independent, 10 | GaussianMixture as GaussianMixtureDistribution, 11 | GaussianMixtureFull as GaussianMixtureFullDistribution, 12 | MultivariateDiagonalNormal, 13 | TanhNormal, 14 | ) 15 | from maple.torch.networks.basic import MultiInputSequential 16 | 17 | 18 | class DistributionGenerator(nn.Module, metaclass=abc.ABCMeta): 19 | def forward(self, *input, **kwarg) -> Distribution: 20 | raise NotImplementedError 21 | 22 | 23 | class ModuleToDistributionGenerator( 24 | MultiInputSequential, 25 | DistributionGenerator, 26 | metaclass=abc.ABCMeta 27 | ): 28 | pass 29 | 30 | 31 | class Beta(ModuleToDistributionGenerator): 32 | def forward(self, *input): 33 | alpha, beta = super().forward(*input) 34 | return Beta(alpha, beta) 35 | 36 | 37 | class Gaussian(ModuleToDistributionGenerator): 38 | def __init__(self, module, std=None, reinterpreted_batch_ndims=1): 39 | super().__init__(module) 40 | self.std = std 41 | self.reinterpreted_batch_ndims = reinterpreted_batch_ndims 42 | 43 | def forward(self, *input): 44 | if self.std: 45 | mean = super().forward(*input) 46 | std = self.std 47 | else: 48 | mean, log_std = super().forward(*input) 49 | std = log_std.exp() 50 | return MultivariateDiagonalNormal( 51 | mean, std, reinterpreted_batch_ndims=self.reinterpreted_batch_ndims) 52 | 53 | 54 | class BernoulliGenerator(ModuleToDistributionGenerator): 55 | def forward(self, *input): 56 | probs = super().forward(*input) 57 | return Bernoulli(probs) 58 | 59 | 60 | class IndependentGenerator(ModuleToDistributionGenerator): 61 | def __init__(self, *args, reinterpreted_batch_ndims=1): 62 | super().__init__(*args) 63 | self.reinterpreted_batch_ndims = reinterpreted_batch_ndims 64 | 65 | def forward(self, *input): 66 | distribution = super().forward(*input) 67 | return Independent( 68 | distribution, 69 | reinterpreted_batch_ndims=self.reinterpreted_batch_ndims, 70 | ) 71 | 72 | 73 | class GaussianMixture(ModuleToDistributionGenerator): 74 | def forward(self, *input): 75 | mixture_means, mixture_stds, weights = super().forward(*input) 76 | return GaussianMixtureDistribution(mixture_means, mixture_stds, weights) 77 | 78 | 79 | class GaussianMixtureFull(ModuleToDistributionGenerator): 80 | def forward(self, *input): 81 | mixture_means, mixture_stds, weights = super().forward(*input) 82 | return GaussianMixtureFullDistribution(mixture_means, mixture_stds, weights) 83 | 84 | 85 | class TanhGaussian(ModuleToDistributionGenerator): 86 | def forward(self, *input): 87 | mean, log_std = super().forward(*input) 88 | std = log_std.exp() 89 | return TanhNormal(mean, std) 90 | -------------------------------------------------------------------------------- /maple/torch/networks/two_headed_mlp.py: -------------------------------------------------------------------------------- 1 | from torch import nn as nn 2 | from torch.nn import functional as F 3 | 4 | from maple.pythonplusplus import identity 5 | from maple.torch import pytorch_util as ptu 6 | from maple.torch.core import PyTorchModule 7 | from maple.torch.networks import LayerNorm 8 | 9 | 10 | class TwoHeadMlp(PyTorchModule): 11 | def __init__( 12 | self, 13 | hidden_sizes, 14 | first_head_size, 15 | second_head_size, 16 | input_size, 17 | init_w=3e-3, 18 | hidden_activation=F.relu, 19 | output_activation=identity, 20 | hidden_init=ptu.fanin_init, 21 | b_init_value=0., 22 | layer_norm=False, 23 | layer_norm_kwargs=None, 24 | ): 25 | super().__init__() 26 | 27 | if layer_norm_kwargs is None: 28 | layer_norm_kwargs = dict() 29 | 30 | self.input_size = input_size 31 | self.first_head_size = first_head_size 32 | self.second_head_size = second_head_size 33 | self.hidden_activation = hidden_activation 34 | self.output_activation = output_activation 35 | self.layer_norm = layer_norm 36 | self.fcs = [] 37 | self.layer_norms = [] 38 | in_size = input_size 39 | 40 | for i, next_size in enumerate(hidden_sizes): 41 | fc = nn.Linear(in_size, next_size) 42 | in_size = next_size 43 | hidden_init(fc.weight) 44 | fc.bias.data.fill_(b_init_value) 45 | self.__setattr__("fc{}".format(i), fc) 46 | self.fcs.append(fc) 47 | 48 | if self.layer_norm: 49 | ln = LayerNorm(next_size) 50 | self.__setattr__("layer_norm{}".format(i), ln) 51 | self.layer_norms.append(ln) 52 | 53 | self.first_head = nn.Linear(in_size, self.first_head_size) 54 | self.first_head.weight.data.uniform_(-init_w, init_w) 55 | self.first_head.bias.data.fill_(0) 56 | 57 | self.second_head = nn.Linear(in_size, self.second_head_size) 58 | self.second_head.weight.data.uniform_(-init_w, init_w) 59 | self.second_head.bias.data.fill_(0) 60 | 61 | def forward(self, input, return_preactivations=False): 62 | h = input 63 | for i, fc in enumerate(self.fcs): 64 | h = fc(h) 65 | if self.layer_norm and i < len(self.fcs) - 1: 66 | h = self.layer_norms[i](h) 67 | h = self.hidden_activation(h) 68 | preactivation = self.first_head(h) 69 | first_output = self.output_activation(preactivation) 70 | preactivation = self.second_head(h) 71 | second_output = self.output_activation(preactivation) 72 | 73 | return first_output, second_output 74 | -------------------------------------------------------------------------------- /maple/torch/sac/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UT-Austin-RPL/maple/c13f584d4184e5796bd0adb822b8761ff4ebeee0/maple/torch/sac/__init__.py -------------------------------------------------------------------------------- /maple/torch/sac/policies/__init__.py: -------------------------------------------------------------------------------- 1 | from maple.torch.sac.policies.base import ( 2 | TorchStochasticPolicy, 3 | PolicyFromDistributionGenerator, 4 | MakeDeterministic, 5 | ) 6 | from maple.torch.sac.policies.gaussian_policy import ( 7 | TanhGaussianPolicyAdapter, 8 | TanhGaussianPolicy, 9 | GaussianPolicy, 10 | GaussianCNNPolicy, 11 | GaussianMixturePolicy, 12 | BinnedGMMPolicy, 13 | TanhGaussianObsProcessorPolicy, 14 | TanhCNNGaussianPolicy, 15 | ) 16 | from maple.torch.sac.policies.lvm_policy import LVMPolicy 17 | from maple.torch.sac.policies.policy_from_q import PolicyFromQ 18 | from maple.torch.sac.policies.pamdp_policy import PAMDPPolicy 19 | 20 | __all__ = [ 21 | 'TorchStochasticPolicy', 22 | 'PolicyFromDistributionGenerator', 23 | 'MakeDeterministic', 24 | 'TanhGaussianPolicyAdapter', 25 | 'TanhGaussianPolicy', 26 | 'GaussianPolicy', 27 | 'GaussianCNNPolicy', 28 | 'GaussianMixturePolicy', 29 | 'BinnedGMMPolicy', 30 | 'TanhGaussianObsProcessorPolicy', 31 | 'TanhCNNGaussianPolicy', 32 | 'LVMPolicy', 33 | 'PolicyFromQ', 34 | 'PAMDPPolicy', 35 | ] 36 | -------------------------------------------------------------------------------- /maple/torch/sac/policies/base.py: -------------------------------------------------------------------------------- 1 | import abc 2 | import logging 3 | 4 | import numpy as np 5 | import torch 6 | import torch.nn.functional as F 7 | from torch import nn as nn 8 | 9 | import maple.torch.pytorch_util as ptu 10 | from maple.policies.base import ExplorationPolicy 11 | from maple.torch.core import torch_ify, elem_or_tuple_to_numpy 12 | from maple.torch.distributions import ( 13 | Delta, TanhNormal, MultivariateDiagonalNormal, GaussianMixture, GaussianMixtureFull, 14 | ) 15 | from maple.torch.networks import Mlp, CNN 16 | from maple.torch.networks.basic import MultiInputSequential 17 | from maple.torch.networks.stochastic.distribution_generator import ( 18 | DistributionGenerator 19 | ) 20 | 21 | 22 | class TorchStochasticPolicy( 23 | DistributionGenerator, 24 | ExplorationPolicy, metaclass=abc.ABCMeta 25 | ): 26 | def get_action(self, obs_np, return_dist=False): 27 | info = {} 28 | if return_dist: 29 | actions, dist = self.get_actions(obs_np[None], return_dist=return_dist) 30 | info['dist'] = dist 31 | else: 32 | actions = self.get_actions(obs_np[None], return_dist=return_dist) 33 | return actions[0, :], info 34 | 35 | def get_actions(self, obs_np, return_dist=False): 36 | dist = self._get_dist_from_np(obs_np) 37 | actions = dist.sample() 38 | if return_dist: 39 | return elem_or_tuple_to_numpy(actions), dist 40 | else: 41 | return elem_or_tuple_to_numpy(actions) 42 | 43 | def _get_dist_from_np(self, *args, **kwargs): 44 | torch_args = tuple(torch_ify(x) for x in args) 45 | torch_kwargs = {k: torch_ify(v) for k, v in kwargs.items()} 46 | dist = self(*torch_args, **torch_kwargs) 47 | return dist 48 | 49 | 50 | class PolicyFromDistributionGenerator( 51 | MultiInputSequential, 52 | TorchStochasticPolicy, 53 | ): 54 | """ 55 | Usage: 56 | ``` 57 | distribution_generator = FancyGenerativeModel() 58 | policy = PolicyFromBatchDistributionModule(distribution_generator) 59 | ``` 60 | """ 61 | pass 62 | 63 | 64 | class MakeDeterministic(TorchStochasticPolicy): 65 | def __init__( 66 | self, 67 | action_distribution_generator: DistributionGenerator, 68 | ): 69 | super().__init__() 70 | self._action_distribution_generator = action_distribution_generator 71 | 72 | def forward(self, *args, **kwargs): 73 | dist = self._action_distribution_generator.forward(*args, **kwargs) 74 | return Delta(dist.mle_estimate()) 75 | -------------------------------------------------------------------------------- /maple/torch/sac/policies/lvm_policy.py: -------------------------------------------------------------------------------- 1 | from maple.torch.networks.stochastic.distribution_generator import ( 2 | DistributionGenerator 3 | ) 4 | from maple.torch.sac.policies.base import ( 5 | TorchStochasticPolicy, 6 | PolicyFromDistributionGenerator, 7 | MakeDeterministic, 8 | ) 9 | 10 | from maple.torch.lvm.latent_variable_model import LatentVariableModel 11 | 12 | 13 | class LVMPolicy(LatentVariableModel, TorchStochasticPolicy): 14 | """Expects encoder p(z|s) and decoder p(u|s,z)""" 15 | 16 | def forward(self, obs): 17 | z_dist = self.encoder(obs) 18 | z = z_dist.sample() 19 | return self.decoder(obs, z) 20 | -------------------------------------------------------------------------------- /maple/torch/sac/policies/pamdp_policy.py: -------------------------------------------------------------------------------- 1 | from maple.torch.sac.policies.base import TorchStochasticPolicy 2 | from maple.torch.core import PyTorchModule 3 | from maple.torch.networks import Mlp 4 | 5 | from maple.torch.distributions import ( 6 | Softmax, TanhNormal, 7 | ConcatDistribution, 8 | HierarchicalDistribution, 9 | ) 10 | 11 | import torch 12 | from torch import nn as nn 13 | 14 | LOGITS_SCALE = 10 15 | LOG_SIG_MAX = 2 16 | LOG_SIG_MIN = -20 17 | 18 | class PAMDPPolicy(PyTorchModule, TorchStochasticPolicy): 19 | def __init__( 20 | self, 21 | hidden_sizes, 22 | obs_dim, 23 | action_dim_s, 24 | action_dim_p, 25 | one_hot_s, 26 | ): 27 | super().__init__() 28 | 29 | task_policy = CategoricalPolicy( 30 | obs_dim=obs_dim, 31 | hidden_sizes=hidden_sizes, 32 | action_dim=action_dim_s, 33 | one_hot=one_hot_s, 34 | prefix='task', 35 | ) 36 | 37 | param_policy = ParallelHybridPolicy( 38 | obs_dim=obs_dim, 39 | hidden_sizes=hidden_sizes, 40 | num_networks=action_dim_s, 41 | action_dim_c=action_dim_p, 42 | prefix_c='param', 43 | ) 44 | 45 | self.policy = HierarchicalPolicy( 46 | task_policy, param_policy, 47 | policy1_obs_dim=obs_dim 48 | ) 49 | 50 | self.obs_dim = obs_dim 51 | self.action_dim_s = action_dim_s 52 | self.action_dim_p = action_dim_p 53 | 54 | self.one_hot_s = one_hot_s 55 | 56 | def forward(self, obs): 57 | return self.policy(obs) 58 | 59 | class ConcatPolicy(PyTorchModule, TorchStochasticPolicy): 60 | def __init__( 61 | self, 62 | policy1, 63 | policy2, 64 | policy1_obs_dim 65 | ): 66 | super().__init__() 67 | 68 | self.policy1 = policy1 69 | self.policy2 = policy2 70 | 71 | self.policy1_obs_dim = policy1_obs_dim 72 | 73 | def forward(self, obs): 74 | return ConcatDistribution( 75 | distr1=self.policy1(obs[:,-self.policy1_obs_dim:]), 76 | distr2=self.policy2(obs), 77 | ) 78 | 79 | class HierarchicalPolicy(PyTorchModule, TorchStochasticPolicy): 80 | def __init__(self, policy1, policy2, policy1_obs_dim): 81 | super().__init__() 82 | 83 | self.policy1 = policy1 84 | self.policy2 = policy2 85 | 86 | self.policy1_obs_dim = policy1_obs_dim 87 | 88 | def forward(self, obs): 89 | assert obs.dim() == 2 90 | 91 | def distr2_cond_fn(inputs): 92 | obs_for_p = obs 93 | if inputs.dim() == 3: 94 | tile_dim = inputs.shape[1] 95 | obs_for_p = obs.unsqueeze(1).repeat((1, tile_dim, 1)) 96 | if isinstance(self.policy2, ParallelHybridPolicy): 97 | id = torch.argmax(inputs, dim=-1) 98 | return self.policy2(obs_for_p, id) 99 | else: 100 | return self.policy2(torch.cat([obs_for_p, inputs], dim=-1)) 101 | 102 | return HierarchicalDistribution( 103 | distr1=self.policy1(obs[:,-self.policy1_obs_dim:]), 104 | distr2_cond_fn=distr2_cond_fn, 105 | ) 106 | 107 | class CategoricalPolicy(Mlp, TorchStochasticPolicy): 108 | def __init__( 109 | self, 110 | hidden_sizes, 111 | obs_dim, 112 | action_dim, 113 | prefix='', 114 | init_w=1e-3, 115 | one_hot=False, 116 | **kwargs 117 | ): 118 | super().__init__( 119 | hidden_sizes, 120 | input_size=obs_dim, 121 | output_size=action_dim, 122 | init_w=init_w, 123 | **kwargs 124 | ) 125 | 126 | self.prefix = prefix 127 | self.one_hot = one_hot 128 | 129 | def forward(self, obs): 130 | h = obs 131 | for i, fc in enumerate(self.fcs): 132 | h = self.hidden_activation(fc(h)) 133 | logits = self.last_fc(h) 134 | logits = torch.clamp(logits, -LOGITS_SCALE, LOGITS_SCALE) 135 | return Softmax(logits, one_hot=self.one_hot, prefix=self.prefix) 136 | 137 | 138 | class ParallelHybridPolicy(PyTorchModule, TorchStochasticPolicy): 139 | """ 140 | Usage: 141 | 142 | ``` 143 | policy = ParallelHybridPolicy(...) 144 | """ 145 | 146 | def __init__( 147 | self, 148 | hidden_sizes, 149 | obs_dim, 150 | num_networks, 151 | action_dim_c, 152 | prefix_c='', 153 | init_w=1e-3, 154 | ): 155 | super().__init__() 156 | 157 | self.log_std = None 158 | self.num_networks = num_networks 159 | 160 | mlp_list = [] 161 | output_size = 2*action_dim_c 162 | for i in range(num_networks): 163 | mlp = Mlp( 164 | hidden_sizes, 165 | input_size=obs_dim, 166 | output_size=output_size, 167 | init_w=init_w, 168 | ) 169 | mlp_list.append(mlp) 170 | self.mlp_list = nn.ModuleList(mlp_list) 171 | 172 | self.action_dim_c = action_dim_c 173 | self.prefix_c = prefix_c 174 | 175 | def forward(self, obs, id): 176 | if torch.numel(id) == 1: 177 | mean, std = self.get_mean_std(obs, id) 178 | else: 179 | input_dims = id.shape 180 | obs = obs.reshape((-1, obs.shape[-1])) 181 | id = id.reshape(-1) 182 | 183 | means = [] 184 | stds = [] 185 | for i in range(self.num_networks): 186 | mean, std = self.get_mean_std(obs, i) 187 | means.append(mean) 188 | stds.append(std) 189 | 190 | means = torch.stack(means, dim=1) 191 | stds = torch.stack(stds, dim=1) 192 | 193 | mean = means[torch.arange(obs.size(0)), id] 194 | std = stds[torch.arange(obs.size(0)), id] 195 | 196 | ### reshape to original input dims 197 | mean = mean.reshape((*input_dims, -1)) 198 | std = std.reshape((*input_dims, -1)) 199 | 200 | distr_c = TanhNormal(mean, std, prefix=self.prefix_c) 201 | return distr_c 202 | 203 | def get_mean_std(self, obs, id): 204 | c_dim = self.action_dim_c 205 | nn_output = self.mlp_list[id](obs) 206 | mean = nn_output[..., -2 * c_dim:-c_dim] 207 | log_std = nn_output[..., -c_dim:] 208 | 209 | log_std = torch.clamp(log_std, LOG_SIG_MIN, LOG_SIG_MAX) 210 | std = torch.exp(log_std) 211 | 212 | return mean, std -------------------------------------------------------------------------------- /maple/torch/sac/policies/policy_from_q.py: -------------------------------------------------------------------------------- 1 | import abc 2 | import logging 3 | 4 | import numpy as np 5 | import torch 6 | import torch.nn.functional as F 7 | from torch import nn as nn 8 | 9 | import maple.torch.pytorch_util as ptu 10 | from maple.policies.base import ExplorationPolicy 11 | from maple.torch.core import torch_ify, elem_or_tuple_to_numpy 12 | from maple.torch.distributions import ( 13 | Delta, TanhNormal, MultivariateDiagonalNormal, GaussianMixture, GaussianMixtureFull, 14 | ) 15 | from maple.torch.networks import Mlp, CNN 16 | from maple.torch.networks.basic import MultiInputSequential 17 | from maple.torch.networks.stochastic.distribution_generator import ( 18 | DistributionGenerator 19 | ) 20 | from maple.torch.sac.policies.base import ( 21 | TorchStochasticPolicy, 22 | PolicyFromDistributionGenerator, 23 | MakeDeterministic, 24 | ) 25 | 26 | 27 | class PolicyFromQ(TorchStochasticPolicy): 28 | def __init__( 29 | self, 30 | qf, 31 | policy, 32 | num_samples=10, 33 | **kwargs 34 | ): 35 | super().__init__() 36 | self.qf = qf 37 | self.policy = policy 38 | self.num_samples = num_samples 39 | 40 | def forward(self, obs): 41 | with torch.no_grad(): 42 | state = obs.repeat(self.num_samples, 1) 43 | action = self.policy(state).sample() 44 | q_values = self.qf(state, action) 45 | ind = q_values.max(0)[1] 46 | return Delta(action[ind]) 47 | -------------------------------------------------------------------------------- /maple/torch/torch_rl_algorithm.py: -------------------------------------------------------------------------------- 1 | import abc 2 | from collections import OrderedDict 3 | 4 | from typing import Iterable 5 | from torch import nn as nn 6 | 7 | from maple.core.batch_rl_algorithm import BatchRLAlgorithm 8 | from maple.core.online_rl_algorithm import OnlineRLAlgorithm 9 | from maple.core.trainer import Trainer 10 | from maple.torch.core import np_to_pytorch_batch 11 | 12 | 13 | class TorchOnlineRLAlgorithm(OnlineRLAlgorithm): 14 | def to(self, device): 15 | for net in self.trainer.networks: 16 | net.to(device) 17 | 18 | def training_mode(self, mode): 19 | for net in self.trainer.networks: 20 | net.train(mode) 21 | 22 | 23 | class TorchBatchRLAlgorithm(BatchRLAlgorithm): 24 | def to(self, device): 25 | for net in self.trainer.networks: 26 | net.to(device) 27 | 28 | def training_mode(self, mode): 29 | for net in self.trainer.networks: 30 | net.train(mode) 31 | 32 | 33 | class TorchTrainer(Trainer, metaclass=abc.ABCMeta): 34 | def __init__(self): 35 | self._num_train_steps = 0 36 | 37 | def train(self, np_batch): 38 | self._num_train_steps += 1 39 | batch = np_to_pytorch_batch(np_batch) 40 | self.train_from_torch(batch) 41 | 42 | def get_diagnostics(self): 43 | return OrderedDict([ 44 | ('num train calls', self._num_train_steps), 45 | ]) 46 | 47 | @abc.abstractmethod 48 | def train_from_torch(self, batch): 49 | pass 50 | 51 | @property 52 | @abc.abstractmethod 53 | def networks(self) -> Iterable[nn.Module]: 54 | pass 55 | -------------------------------------------------------------------------------- /maple/util/data_processing.py: -------------------------------------------------------------------------------- 1 | """ 2 | Utility functions for writing and loading data. 3 | """ 4 | import json 5 | import numpy as np 6 | import os 7 | import os.path as osp 8 | from collections import defaultdict, namedtuple 9 | 10 | from maple.pythonplusplus import nested_dict_to_dot_map_dict 11 | 12 | 13 | Trial = namedtuple("Trial", ["data", "variant", "directory"]) 14 | 15 | 16 | def matches_dict(criteria_dict, test_dict): 17 | for k, v in criteria_dict.items(): 18 | if k not in test_dict: 19 | return False 20 | else: 21 | if test_dict[k] != v: 22 | return False 23 | return True 24 | 25 | 26 | class Experiment(object): 27 | """ 28 | Represents an experiment, which consists of many Trials. 29 | """ 30 | def __init__(self, base_dir, criteria=None): 31 | """ 32 | :param base_dir: A path. Directory structure should be something like: 33 | ``` 34 | base_dir/ 35 | foo/ 36 | bar/ 37 | arbtrarily_deep/ 38 | trial_one/ 39 | variant.json 40 | progress.csv 41 | trial_two/ 42 | variant.json 43 | progress.csv 44 | trial_three/ 45 | variant.json 46 | progress.csv 47 | ... 48 | variant.json # <-- base_dir/foo/bar has its own Trial 49 | progress.csv 50 | variant.json # <-- base_dir/foo has its own Trial 51 | progress.csv 52 | variant.json # <-- base_dir has its own Trial 53 | progress.csv 54 | ``` 55 | 56 | The important thing is that `variant.json` and `progress.csv` are 57 | in the same sub-directory for each Trial. 58 | :param criteria: A dictionary of allowable values for the given keys. 59 | """ 60 | if criteria is None: 61 | criteria = {} 62 | self.trials = get_trials(base_dir, criteria=criteria) 63 | assert len(self.trials) > 0, "Nothing loaded." 64 | self.label = 'AverageReturn' 65 | 66 | def get_trials(self, criteria=None): 67 | """ 68 | Return a list of Trials that match a criteria. 69 | :param criteria: A dictionary from key to value that must be matches 70 | in the trial's variant. e.g. 71 | ``` 72 | >>> print(exp.trials) 73 | [ 74 | (X, {'a': True, ...}) 75 | (Y, {'a': False, ...}) 76 | (Z, {'a': True, ...}) 77 | ] 78 | >>> print(exp.get_trials({'a': True})) 79 | [ 80 | (X, {'a': True, ...}) 81 | (Z, {'a': True, ...}) 82 | ] 83 | ``` 84 | If a trial does not have the key, the trial is filtered out. 85 | :return: 86 | """ 87 | if criteria is None: 88 | criteria = {} 89 | return [trial for trial in self.trials 90 | if matches_dict(criteria, trial.variant)] 91 | 92 | 93 | def get_dirs(root): 94 | """ 95 | Get a list of all the directories under this directory. 96 | """ 97 | yield root 98 | for root, directories, filenames in os.walk(root): 99 | for directory in directories: 100 | yield os.path.join(root, directory) 101 | 102 | 103 | def get_trials(base_dir, verbose=False, criteria=None, excluded_seeds=()): 104 | """ 105 | Get a list of (data, variant, directory) tuples, loaded from 106 | - process.csv 107 | - variant.json 108 | files under this directory. 109 | :param base_dir: root directory 110 | :param criteria: dictionary of keys and values. Only load experiemnts 111 | that match this criteria. 112 | :return: List of tuples. Each tuple has: 113 | 1. Progress data (nd.array) 114 | 2. Variant dictionary 115 | """ 116 | if criteria is None: 117 | criteria = {} 118 | 119 | trials = [] 120 | # delimiter = ',' 121 | delimiter = ',' 122 | for dir_name in get_dirs(base_dir): 123 | variant_file_name = osp.join(dir_name, 'variant.json') 124 | if not os.path.exists(variant_file_name): 125 | continue 126 | 127 | with open(variant_file_name) as variant_file: 128 | variant = json.load(variant_file) 129 | variant = nested_dict_to_dot_map_dict(variant) 130 | 131 | if 'seed' in variant and int(variant['seed']) in excluded_seeds: 132 | continue 133 | 134 | if not matches_dict(criteria, variant): 135 | continue 136 | 137 | data_file_name = osp.join(dir_name, 'progress.csv') 138 | # Hack for iclr 2018 deadline 139 | if not os.path.exists(data_file_name) or os.stat( 140 | data_file_name).st_size == 0: 141 | data_file_name = osp.join(dir_name, 'log.txt') 142 | if not os.path.exists(data_file_name): 143 | continue 144 | delimiter = '\t' 145 | if verbose: 146 | print("Reading {}".format(data_file_name)) 147 | num_lines = sum(1 for _ in open(data_file_name)) 148 | if num_lines < 2: 149 | continue 150 | # print(delimiter) 151 | data = np.genfromtxt( 152 | data_file_name, 153 | delimiter=delimiter, 154 | dtype=None, 155 | names=True, 156 | ) 157 | trials.append(Trial(data, variant, dir_name)) 158 | return trials 159 | 160 | 161 | def get_all_csv(base_dir, verbose=False): 162 | """ 163 | Get a list of all csv data under a directory. 164 | :param base_dir: root directory 165 | """ 166 | data = [] 167 | delimiter = ',' 168 | for dir_name in get_dirs(base_dir): 169 | for data_file_name in os.listdir(dir_name): 170 | if data_file_name.endswith(".csv"): 171 | full_path = os.path.join(dir_name, data_file_name) 172 | if verbose: 173 | print("Reading {}".format(full_path)) 174 | data.append(np.genfromtxt( 175 | full_path, delimiter=delimiter, dtype=None, names=True 176 | )) 177 | return data 178 | 179 | 180 | def get_unique_param_to_values(all_variants): 181 | variant_key_to_values = defaultdict(set) 182 | for variant in all_variants: 183 | for k, v in variant.items(): 184 | if type(v) == list: 185 | v = str(v) 186 | variant_key_to_values[k].add(v) 187 | unique_key_to_values = { 188 | k: variant_key_to_values[k] 189 | for k in variant_key_to_values 190 | if len(variant_key_to_values[k]) > 1 191 | } 192 | return unique_key_to_values -------------------------------------------------------------------------------- /maple/util/hyperparameter.py: -------------------------------------------------------------------------------- 1 | """ 2 | Custom hyperparameter functions. 3 | """ 4 | import abc 5 | import copy 6 | import math 7 | import random 8 | import itertools 9 | from typing import List 10 | 11 | import maple.pythonplusplus as ppp 12 | 13 | 14 | class Hyperparameter(metaclass=abc.ABCMeta): 15 | def __init__(self, name): 16 | self._name = name 17 | 18 | @property 19 | def name(self): 20 | return self._name 21 | 22 | 23 | class RandomHyperparameter(Hyperparameter): 24 | def __init__(self, name): 25 | super().__init__(name) 26 | self._last_value = None 27 | 28 | @abc.abstractmethod 29 | def generate_next_value(self): 30 | """Return a value for the hyperparameter""" 31 | return 32 | 33 | def generate(self): 34 | self._last_value = self.generate_next_value() 35 | return self._last_value 36 | 37 | 38 | class EnumParam(RandomHyperparameter): 39 | def __init__(self, name, possible_values): 40 | super().__init__(name) 41 | self.possible_values = possible_values 42 | 43 | def generate_next_value(self): 44 | return random.choice(self.possible_values) 45 | 46 | 47 | class LogFloatParam(RandomHyperparameter): 48 | """ 49 | Return something ranging from [min_value + offset, max_value + offset], 50 | distributed with a log. 51 | """ 52 | def __init__(self, name, min_value, max_value, *, offset=0): 53 | super(LogFloatParam, self).__init__(name) 54 | self._linear_float_param = LinearFloatParam("log_" + name, 55 | math.log(min_value), 56 | math.log(max_value)) 57 | self.offset = offset 58 | 59 | def generate_next_value(self): 60 | return math.e ** (self._linear_float_param.generate()) + self.offset 61 | 62 | 63 | class LinearFloatParam(RandomHyperparameter): 64 | def __init__(self, name, min_value, max_value): 65 | super(LinearFloatParam, self).__init__(name) 66 | self._min = min_value 67 | self._delta = max_value - min_value 68 | 69 | def generate_next_value(self): 70 | return random.random() * self._delta + self._min 71 | 72 | 73 | class LogIntParam(RandomHyperparameter): 74 | def __init__(self, name, min_value, max_value, *, offset=0): 75 | super().__init__(name) 76 | self._linear_float_param = LinearFloatParam("log_" + name, 77 | math.log(min_value), 78 | math.log(max_value)) 79 | self.offset = offset 80 | 81 | def generate_next_value(self): 82 | return int( 83 | math.e ** (self._linear_float_param.generate()) + self.offset 84 | ) 85 | 86 | 87 | class LinearIntParam(RandomHyperparameter): 88 | def __init__(self, name, min_value, max_value): 89 | super(LinearIntParam, self).__init__(name) 90 | self._min = min_value 91 | self._max = max_value 92 | 93 | def generate_next_value(self): 94 | return random.randint(self._min, self._max) 95 | 96 | 97 | class FixedParam(RandomHyperparameter): 98 | def __init__(self, name, value): 99 | super().__init__(name) 100 | self._value = value 101 | 102 | def generate_next_value(self): 103 | return self._value 104 | 105 | 106 | class Sweeper(object): 107 | pass 108 | 109 | 110 | class RandomHyperparameterSweeper(Sweeper): 111 | def __init__(self, hyperparameters=None, default_kwargs=None): 112 | if default_kwargs is None: 113 | default_kwargs = {} 114 | self._hyperparameters = hyperparameters or [] 115 | self._validate_hyperparameters() 116 | self._default_kwargs = default_kwargs 117 | 118 | def _validate_hyperparameters(self): 119 | names = set() 120 | for hp in self._hyperparameters: 121 | name = hp.name 122 | if name in names: 123 | raise Exception("Hyperparameter '{0}' already added.".format( 124 | name)) 125 | names.add(name) 126 | 127 | def set_default_parameters(self, default_kwargs): 128 | self._default_kwargs = default_kwargs 129 | 130 | def generate_random_hyperparameters(self): 131 | hyperparameters = {} 132 | for hp in self._hyperparameters: 133 | hyperparameters[hp.name] = hp.generate() 134 | hyperparameters = ppp.dot_map_dict_to_nested_dict(hyperparameters) 135 | return ppp.merge_recursive_dicts( 136 | hyperparameters, 137 | copy.deepcopy(self._default_kwargs), 138 | ignore_duplicate_keys_in_second_dict=True, 139 | ) 140 | 141 | def sweep_hyperparameters(self, function, num_configs): 142 | returned_value_and_params = [] 143 | for _ in range(num_configs): 144 | kwargs = self.generate_random_hyperparameters() 145 | score = function(**kwargs) 146 | returned_value_and_params.append((score, kwargs)) 147 | 148 | return returned_value_and_params 149 | 150 | 151 | class DeterministicHyperparameterSweeper(Sweeper): 152 | """ 153 | Do a grid search over hyperparameters based on a predefined set of 154 | hyperparameters. 155 | """ 156 | def __init__(self, hyperparameters, default_parameters=None): 157 | """ 158 | 159 | :param hyperparameters: A dictionary of the form 160 | ``` 161 | { 162 | 'hp_1': [value1, value2, value3], 163 | 'hp_2': [value1, value2, value3], 164 | ... 165 | } 166 | ``` 167 | This format is like the param_grid in SciKit-Learn: 168 | http://scikit-learn.org/stable/modules/grid_search.html#exhaustive-grid-search 169 | :param default_parameters: Default key-value pairs to add to the 170 | dictionary. 171 | """ 172 | self._hyperparameters = hyperparameters 173 | self._default_kwargs = default_parameters or {} 174 | named_hyperparameters = [] 175 | for name, values in self._hyperparameters.items(): 176 | named_hyperparameters.append( 177 | [(name, v) for v in values] 178 | ) 179 | self._hyperparameters_dicts = [ 180 | ppp.dot_map_dict_to_nested_dict(dict(tuple_list)) 181 | for tuple_list in itertools.product(*named_hyperparameters) 182 | ] 183 | 184 | def iterate_hyperparameters(self): 185 | """ 186 | Iterate over the hyperparameters in a grid-manner. 187 | 188 | :return: List of dictionaries. Each dictionary is a map from name to 189 | hyperpameter. 190 | """ 191 | return [ 192 | ppp.merge_recursive_dicts( 193 | hyperparameters, 194 | copy.deepcopy(self._default_kwargs), 195 | ignore_duplicate_keys_in_second_dict=True, 196 | ) 197 | for hyperparameters in self._hyperparameters_dicts 198 | ] 199 | 200 | 201 | # TODO(vpong): Test this 202 | class DeterministicSweeperCombiner(object): 203 | """ 204 | A simple wrapper to combiner multiple DeterministicHyperParameterSweeper's 205 | """ 206 | def __init__(self, sweepers: List[DeterministicHyperparameterSweeper]): 207 | self._sweepers = sweepers 208 | 209 | def iterate_list_of_hyperparameters(self): 210 | """ 211 | Usage: 212 | 213 | ``` 214 | sweeper1 = DeterministicHyperparameterSweeper(...) 215 | sweeper2 = DeterministicHyperparameterSweeper(...) 216 | combiner = DeterministicSweeperCombiner([sweeper1, sweeper2]) 217 | 218 | for params_1, params_2 in combiner.iterate_list_of_hyperparameters(): 219 | # param_1 = {...} 220 | # param_2 = {...} 221 | ``` 222 | :return: Generator of hyperparameters, in the same order as provided 223 | sweepers. 224 | """ 225 | return itertools.product( 226 | sweeper.iterate_hyperparameters() 227 | for sweeper in self._sweepers 228 | ) -------------------------------------------------------------------------------- /maple/util/io.py: -------------------------------------------------------------------------------- 1 | import joblib 2 | import numpy as np 3 | import pickle 4 | 5 | import boto3 6 | 7 | from maple.launchers.conf import LOCAL_LOG_DIR, AWS_S3_PATH 8 | import os 9 | 10 | PICKLE = 'pickle' 11 | NUMPY = 'numpy' 12 | JOBLIB = 'joblib' 13 | 14 | 15 | def local_path_from_s3_or_local_path(filename): 16 | relative_filename = os.path.join(LOCAL_LOG_DIR, filename) 17 | if os.path.isfile(filename): 18 | return filename 19 | elif os.path.isfile(relative_filename): 20 | return relative_filename 21 | else: 22 | return sync_down(filename) 23 | 24 | 25 | def sync_down(path, check_exists=True): 26 | is_docker = os.path.isfile("/.dockerenv") 27 | if is_docker: 28 | local_path = "/tmp/%s" % (path) 29 | else: 30 | local_path = "%s/%s" % (LOCAL_LOG_DIR, path) 31 | 32 | if check_exists and os.path.isfile(local_path): 33 | return local_path 34 | 35 | local_dir = os.path.dirname(local_path) 36 | os.makedirs(local_dir, exist_ok=True) 37 | 38 | if is_docker: 39 | from doodad.ec2.autoconfig import AUTOCONFIG 40 | os.environ["AWS_ACCESS_KEY_ID"] = AUTOCONFIG.aws_access_key() 41 | os.environ["AWS_SECRET_ACCESS_KEY"] = AUTOCONFIG.aws_access_secret() 42 | 43 | full_s3_path = os.path.join(AWS_S3_PATH, path) 44 | bucket_name, bucket_relative_path = split_s3_full_path(full_s3_path) 45 | try: 46 | bucket = boto3.resource('s3').Bucket(bucket_name) 47 | bucket.download_file(bucket_relative_path, local_path) 48 | except Exception as e: 49 | local_path = None 50 | print("Failed to sync! path: ", path) 51 | print("Exception: ", e) 52 | return local_path 53 | 54 | 55 | def sync_down_folder(path): 56 | is_docker = os.path.isfile("/.dockerenv") 57 | if is_docker: 58 | local_path = "/tmp/%s" % (path) 59 | else: 60 | local_path = "%s/%s" % (LOCAL_LOG_DIR, path) 61 | 62 | local_dir = os.path.dirname(local_path) 63 | os.makedirs(local_dir, exist_ok=True) 64 | 65 | if is_docker: 66 | from doodad.ec2.autoconfig import AUTOCONFIG 67 | os.environ["AWS_ACCESS_KEY_ID"] = AUTOCONFIG.aws_access_key() 68 | os.environ["AWS_SECRET_ACCESS_KEY"] = AUTOCONFIG.aws_access_secret() 69 | 70 | full_s3_path = os.path.join(AWS_S3_PATH, path) 71 | bucket_name, bucket_relative_path = split_s3_full_path(full_s3_path) 72 | command = "aws s3 sync s3://%s/%s %s" % (bucket_name, bucket_relative_path, local_path) 73 | print(command) 74 | stream = os.popen(command) 75 | output = stream.read() 76 | print(output) 77 | return local_path 78 | 79 | 80 | def split_s3_full_path(s3_path): 81 | """ 82 | Split "s3://foo/bar/baz" into "foo" and "bar/baz" 83 | """ 84 | bucket_name_and_directories = s3_path.split('//')[1] 85 | bucket_name, *directories = bucket_name_and_directories.split('/') 86 | directory_path = '/'.join(directories) 87 | return bucket_name, directory_path 88 | 89 | 90 | class CPU_Unpickler(pickle.Unpickler): 91 | """Utility for loading a pickled model on CPU machine saved from a GPU""" 92 | def find_class(self, module, name): 93 | if module == 'torch.storage' and name == '_load_from_bytes': 94 | return lambda b: torch.load(io.BytesIO(b), map_location='cpu') 95 | else: return super().find_class(module, name) 96 | 97 | 98 | def load_local_or_remote_file(filepath, file_type=None): 99 | local_path = local_path_from_s3_or_local_path(filepath) 100 | if file_type is None: 101 | extension = local_path.split('.')[-1] 102 | if extension == 'npy': 103 | file_type = NUMPY 104 | else: 105 | file_type = PICKLE 106 | else: 107 | file_type = PICKLE 108 | if file_type == NUMPY: 109 | object = np.load(open(local_path, "rb"), allow_pickle=True) 110 | elif file_type == JOBLIB: 111 | object = joblib.load(local_path) 112 | else: 113 | #f = open(local_path, 'rb') 114 | #object = CPU_Unpickler(f).load() 115 | object = pickle.load(open(local_path, "rb")) 116 | print("loaded", local_path) 117 | return object 118 | 119 | 120 | def get_absolute_path(path): 121 | if path[0] == "/": 122 | return path 123 | else: 124 | is_docker = os.path.isfile("/.dockerenv") 125 | if is_docker: 126 | local_path = "/tmp/%s" % (path) 127 | else: 128 | local_path = "%s/%s" % (LOCAL_LOG_DIR, path) 129 | return local_path 130 | 131 | 132 | if __name__ == "__main__": 133 | p = sync_down("ashvin/vae/new-point2d/run0/id1/params.pkl") 134 | print("got", p) 135 | -------------------------------------------------------------------------------- /maple/util/ml_util.py: -------------------------------------------------------------------------------- 1 | """ 2 | General utility functions for machine learning. 3 | """ 4 | import abc 5 | import math 6 | import numpy as np 7 | 8 | 9 | class ScalarSchedule(object, metaclass=abc.ABCMeta): 10 | @abc.abstractmethod 11 | def get_value(self, t): 12 | pass 13 | 14 | 15 | class ConstantSchedule(ScalarSchedule): 16 | def __init__(self, value): 17 | self._value = value 18 | 19 | def get_value(self, t): 20 | return self._value 21 | 22 | 23 | class LinearSchedule(ScalarSchedule): 24 | """ 25 | Linearly interpolate and then stop at a final value. 26 | """ 27 | def __init__( 28 | self, 29 | init_value, 30 | final_value, 31 | ramp_duration, 32 | ): 33 | self._init_value = init_value 34 | self._final_value = final_value 35 | self._ramp_duration = ramp_duration 36 | 37 | def get_value(self, t): 38 | return ( 39 | self._init_value 40 | + (self._final_value - self._init_value) 41 | * min(1.0, t * 1.0 / self._ramp_duration) 42 | ) 43 | 44 | 45 | class IntLinearSchedule(LinearSchedule): 46 | """ 47 | Same as RampUpSchedule but round output to an int 48 | """ 49 | def get_value(self, t): 50 | return int(super().get_value(t)) 51 | 52 | 53 | class PiecewiseLinearSchedule(ScalarSchedule): 54 | """ 55 | Given a list of (x, t) value-time pairs, return value x at time t, 56 | and linearly interpolate between the two 57 | """ 58 | def __init__( 59 | self, 60 | x_values, 61 | y_values, 62 | ): 63 | self._x_values = x_values 64 | self._y_values = y_values 65 | 66 | def get_value(self, t): 67 | return np.interp(t, self._x_values, self._y_values) 68 | 69 | 70 | class IntPiecewiseLinearSchedule(PiecewiseLinearSchedule): 71 | def get_value(self, t): 72 | return int(super().get_value(t)) 73 | 74 | 75 | def none_to_infty(bounds): 76 | if bounds is None: 77 | bounds = -math.inf, math.inf 78 | lb, ub = bounds 79 | if lb is None: 80 | lb = -math.inf 81 | if ub is None: 82 | ub = math.inf 83 | return lb, ub 84 | -------------------------------------------------------------------------------- /maple/util/slurm_util.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import pathlib 4 | 5 | def create_sbatch_script(args, use_variants=True): 6 | # Create a new directory path if it doesn't exist and create a new filename that we will write to 7 | exp_dir = args.exp_dir 8 | sbatch_dir = os.path.join(exp_dir, "sbatch") 9 | new_sbatch_fpath = os.path.join(sbatch_dir, "{}.sbatch".format(args.job_name)) 10 | if not os.path.isdir(sbatch_dir): 11 | os.mkdir(sbatch_dir) 12 | 13 | if use_variants: 14 | base_variant = os.path.join(exp_dir, "variants", args.env, "base.json") 15 | variant_update = os.path.join(exp_dir, "variants", args.env, "{}.json".format(args.config)) 16 | assert os.path.exists(base_variant) and os.path.exists(variant_update) 17 | 18 | command = "" 19 | for i in range(args.num_seeds): 20 | # Compose main command to be executed in script 21 | python_script = args.python_script 22 | line_command = "{{\nsleep {}\n".format(30*i) 23 | line_command += "python {python_script} --env {env} --label {label}".format( 24 | python_script=python_script, 25 | env=args.env, 26 | label=args.label, 27 | ) 28 | if use_variants: 29 | line_command = "{line_command} --base_variant {base_variant} --variant_update {variant_update}".format( 30 | line_command=line_command, 31 | base_variant=base_variant, 32 | variant_update=variant_update, 33 | ) 34 | if args.no_gpu: 35 | line_command += " --no_gpu" 36 | 37 | command += "{line_command}\n}} & \n".format(line_command=line_command) 38 | command += "wait" 39 | 40 | if args.partition in ["titans", "dgx"]: 41 | log_dir = "/scratch/cluster/soroush/logs" 42 | elif args.partition in ["svl", "tibet", "napoli-gpu"]: 43 | log_dir = "/cvgl2/u/soroush/logs" 44 | else: 45 | raise ValueError 46 | 47 | if args.exclude is None: 48 | if args.partition == "titans": 49 | args.exclude = "titan-5,titan-12" 50 | else: 51 | args.exclude = "" 52 | 53 | # Define a dict to map expected fill-ins with replacement values 54 | fill_ins = { 55 | "{{PARTITION}}": args.partition, 56 | "{{EXCLUDE}}": args.exclude, 57 | "{{NUM_GPU}}": 0 if args.no_gpu else 1, 58 | "{{NUM_CPU}}": args.num_seeds, 59 | "{{MEM}}": args.mem * args.num_seeds, 60 | "{{JOB_NAME}}": args.job_name, 61 | "{{LOG_DIR}}": log_dir, 62 | "{{HOURS}}": args.max_hours, 63 | "{{CMD}}": command, 64 | "{{CONDA_ENV}}": args.conda_env, 65 | } 66 | 67 | # Open the template file 68 | with open(args.slurm_template) as template: 69 | # Open the new sbatch file 70 | print(new_sbatch_fpath) 71 | with open(new_sbatch_fpath, 'w+') as new_file: 72 | # Loop through template and write to this new file 73 | for line in template: 74 | wrote = False 75 | # Check for various cases 76 | for k, v in fill_ins.items(): 77 | # If the key is found in the line, replace it with its value and pop it from the dict 78 | if k in line: 79 | new_file.write(line.replace(k, str(v))) 80 | wrote = True 81 | break 82 | # Otherwise, we just write the line from the template directly 83 | if not wrote: 84 | new_file.write(line) 85 | 86 | # Execute this file! 87 | # TODO: Fix! (Permission denied error) 88 | #os.system(new_sbatch_fpath) 89 | 90 | if __name__ == "__main__": 91 | # noinspection PyTypeChecker 92 | parser = argparse.ArgumentParser() 93 | 94 | parser.add_argument('--env', type=str, default='can') 95 | parser.add_argument('--config', type=str, default='base') 96 | parser.add_argument('--label', type=str, default=None) 97 | parser.add_argument('--job_name', type=str, default=None) 98 | 99 | parser.add_argument('--num_seeds', type=int, default=4) 100 | parser.add_argument('--no_video', action='store_true') 101 | parser.add_argument('--no_gpu', action='store_true') 102 | 103 | parser.add_argument('--mem', type=int, default=9) 104 | parser.add_argument('--max_hours', type=int, default=504) #168 105 | parser.add_argument('--partition', type=str, default="titans") 106 | parser.add_argument('--exclude', type=str, default=None) 107 | 108 | args = parser.parse_args() 109 | 110 | if args.label is None: 111 | args.label = args.config 112 | 113 | if args.job_name is None: 114 | args.job_name = "rl_{}_{}".format(args.env, args.label) 115 | 116 | if args.exclude is None: 117 | if args.partition == "titans": 118 | args.exclude = "titan-5,titan-12" 119 | else: 120 | args.exclude = "" 121 | 122 | create_sbatch_script(args) 123 | -------------------------------------------------------------------------------- /scripts/eval.py: -------------------------------------------------------------------------------- 1 | from maple.launchers.launcher_util import run_experiment 2 | from maple.launchers.robosuite_launcher import experiment 3 | import maple.util.hyperparameter as hyp 4 | import os.path as osp 5 | import argparse 6 | import json 7 | import collections 8 | import copy 9 | 10 | from maple.launchers.conf import LOCAL_LOG_DIR 11 | 12 | base_variant = dict( 13 | algorithm_kwargs=dict( 14 | eval_only=True, 15 | num_epochs=5000, 16 | eval_epoch_freq=100, 17 | ), 18 | replay_buffer_size=int(1E2), 19 | vis_expl=False, 20 | dump_video_kwargs=dict( 21 | rows=1, 22 | columns=6, 23 | pad_length=5, 24 | pad_color=0, 25 | ), 26 | num_eval_rollouts=50, 27 | 28 | # ckpt_epoch=100, #### uncomment if you want to evaluate a specific epoch ckeckpoint only ### 29 | ) 30 | 31 | env_params = dict( 32 | lift={ 33 | 'ckpt_path': [ 34 | ### Add paths here ### 35 | ], 36 | }, 37 | door={ 38 | 'ckpt_path': [ 39 | ### Add paths here ### 40 | ], 41 | }, 42 | pnp={ 43 | 'ckpt_path': [ 44 | ### Add paths here ### 45 | ], 46 | }, 47 | wipe={ 48 | 'ckpt_path': [ 49 | ### Add paths here ### 50 | ], 51 | }, 52 | stack={ 53 | 'ckpt_path': [ 54 | ### Add paths here ### 55 | ], 56 | }, 57 | nut_round={ 58 | 'ckpt_path': [ 59 | ### Add paths here ### 60 | ], 61 | }, 62 | cleanup={ 63 | 'ckpt_path': [ 64 | ### Add paths here ### 65 | ], 66 | }, 67 | peg_ins={ 68 | 'ckpt_path': [ 69 | ### Add paths here ### 70 | ], 71 | }, 72 | ) 73 | 74 | def process_variant(eval_variant): 75 | ckpt_path = eval_variant['ckpt_path'] 76 | json_path = osp.join(LOCAL_LOG_DIR, ckpt_path, 'variant.json') 77 | with open(json_path) as f: 78 | ckpt_variant = json.load(f) 79 | deep_update(ckpt_variant, eval_variant) 80 | variant = copy.deepcopy(ckpt_variant) 81 | 82 | if args.debug: 83 | mpl = variant['algorithm_kwargs']['max_path_length'] 84 | variant['algorithm_kwargs']['num_eval_steps_per_epoch'] = mpl * 3 85 | variant['dump_video_kwargs']['rows'] = 1 86 | variant['dump_video_kwargs']['columns'] = 2 87 | else: 88 | mpl = variant['algorithm_kwargs']['max_path_length'] 89 | variant['algorithm_kwargs']['num_eval_steps_per_epoch'] = mpl * variant['num_eval_rollouts'] 90 | 91 | variant['save_video_period'] = variant['algorithm_kwargs']['eval_epoch_freq'] 92 | 93 | if args.no_video: 94 | variant['save_video'] = False 95 | 96 | variant['exp_label'] = args.label 97 | return variant 98 | 99 | def deep_update(source, overrides): 100 | ''' 101 | Update a nested dictionary or similar mapping. 102 | Modify ``source`` in place. 103 | Copied from: https://stackoverflow.com/questions/3232943/update-value-of-a-nested-dictionary-of-varying-depth 104 | ''' 105 | for key, value in overrides.items(): 106 | if isinstance(value, collections.Mapping) and value: 107 | returned = deep_update(source.get(key, {}), value) 108 | source[key] = returned 109 | else: 110 | source[key] = overrides[key] 111 | return source 112 | 113 | if __name__ == "__main__": 114 | # noinspection PyTypeChecker 115 | parser = argparse.ArgumentParser() 116 | parser.add_argument('--env', type=str) 117 | parser.add_argument('--label', type=str, default='test') 118 | parser.add_argument('--num_seeds', type=int, default=1) 119 | parser.add_argument('--no_video', action='store_true') 120 | parser.add_argument('--no_gpu', action='store_true') 121 | parser.add_argument('--gpu_id', type=int, default=0) 122 | parser.add_argument('--debug', action='store_true') 123 | parser.add_argument('--first_variant', action='store_true') 124 | args = parser.parse_args() 125 | 126 | search_space = env_params[args.env] 127 | sweeper = hyp.DeterministicHyperparameterSweeper( 128 | search_space, default_parameters=base_variant, 129 | ) 130 | for exp_id, eval_variant in enumerate(sweeper.iterate_hyperparameters()): 131 | variant = process_variant(eval_variant) 132 | 133 | run_experiment( 134 | experiment, 135 | exp_folder=args.env, 136 | exp_prefix=args.label, 137 | variant=variant, 138 | snapshot_mode='gap_and_last', 139 | snapshot_gap=200, 140 | exp_id=exp_id, 141 | use_gpu=(not args.no_gpu), 142 | gpu_id=args.gpu_id, 143 | ) 144 | 145 | if args.first_variant: 146 | exit() -------------------------------------------------------------------------------- /scripts/run_experiment_from_doodad.py: -------------------------------------------------------------------------------- 1 | import doodad as dd 2 | # import torch.multiprocessing as mp 3 | 4 | from maple.launchers.launcher_util import run_experiment_here 5 | 6 | if __name__ == "__main__": 7 | import matplotlib 8 | matplotlib.use('agg') 9 | 10 | # mp.set_start_method('forkserver') 11 | args_dict = dd.get_args() 12 | method_call = args_dict['method_call'] 13 | run_experiment_kwargs = args_dict['run_experiment_kwargs'] 14 | output_dir = args_dict['output_dir'] 15 | run_mode = args_dict.get('mode', None) 16 | if run_mode and run_mode in ['slurm_singularity', 'sss']: 17 | import os 18 | run_experiment_kwargs['variant']['slurm-job-id'] = os.environ.get( 19 | 'SLURM_JOB_ID', None 20 | ) 21 | if run_mode and (run_mode == 'ec2' or run_mode == 'gcp'): 22 | if run_mode == 'ec2': 23 | try: 24 | import urllib.request 25 | instance_id = urllib.request.urlopen( 26 | 'http://169.254.169.254/latest/meta-data/instance-id' 27 | ).read().decode() 28 | run_experiment_kwargs['variant']['EC2_instance_id'] = instance_id 29 | except Exception as e: 30 | print("Could not get AWS instance ID. Error was...") 31 | print(e) 32 | if run_mode == 'gcp': 33 | try: 34 | import urllib.request 35 | request = urllib.request.Request( 36 | "http://metadata/computeMetadata/v1/instance/name", 37 | ) 38 | # See this URL for why we need this header: 39 | # https://cloud.google.com/compute/docs/storing-retrieving-metadata 40 | request.add_header("Metadata-Flavor", "Google") 41 | instance_name = urllib.request.urlopen(request).read().decode() 42 | run_experiment_kwargs['variant']['GCP_instance_name'] = ( 43 | instance_name 44 | ) 45 | except Exception as e: 46 | print("Could not get GCP instance name. Error was...") 47 | print(e) 48 | # Do this in case base_log_dir was already set 49 | run_experiment_kwargs['base_log_dir'] = output_dir 50 | run_experiment_here( 51 | method_call, 52 | include_exp_prefix_sub_dir=False, 53 | **run_experiment_kwargs 54 | ) 55 | else: 56 | run_experiment_here( 57 | method_call, 58 | log_dir=output_dir, 59 | **run_experiment_kwargs 60 | ) 61 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from distutils.core import setup 2 | from setuptools import find_packages 3 | 4 | setup( 5 | name='maple', 6 | version='0.2.1dev', 7 | packages=find_packages(), 8 | license='MIT License', 9 | long_description=open('README.md').read(), 10 | ) 11 | --------------------------------------------------------------------------------