├── .gitignore ├── LICENSE ├── README.md ├── docs ├── HER.md ├── ReNN.md ├── goal_based_envs.md └── images │ ├── FetchReach-v1_HER-TD3.png │ ├── SawyerReachXYZEnv-v0_HER-TD3.png │ └── stack6_29.gif ├── environment ├── docker │ ├── Dockerfile │ └── vendor │ │ ├── 10_nvidia.json │ │ ├── Xdummy │ │ └── Xdummy-entrypoint ├── linux-cpu-env.yml ├── linux-gpu-env.yml └── mac-env.yml ├── examples ├── ddpg.py ├── doodad │ ├── ec2_example.py │ └── gcp_example.py ├── dqn_and_double_dqn.py ├── her │ ├── her_td3_gym_fetch_reach.py │ └── her_td3_multiworld_sawyer_reach.py ├── relationalrl │ ├── trace.html │ ├── train_pickandplace1.py │ ├── train_sequentialtransfer.py │ ├── train_uniform.py │ └── train_uniformtransfer.py ├── sac.py ├── td3.py └── tsac.py ├── requirements.txt ├── rlkit ├── __init__.py ├── core │ ├── __init__.py │ ├── eval_util.py │ ├── logging.py │ ├── rl_algorithm.py │ ├── serializable.py │ └── tabulate.py ├── data_management │ ├── __init__.py │ ├── env_replay_buffer.py │ ├── normalizer.py │ ├── obs_dict_replay_buffer.py │ ├── path_builder.py │ ├── replay_buffer.py │ └── simple_replay_buffer.py ├── envs │ ├── ant.py │ ├── assets │ │ ├── low_gear_ratio_ant.xml │ │ └── reacher_7dof.xml │ ├── mujoco_env.py │ ├── multi_env_wrapper.py │ ├── vae_wrapper.py │ └── wrappers.py ├── exploration_strategies │ ├── __init__.py │ ├── base.py │ ├── epsilon_greedy.py │ ├── gaussian_and_epsilon_strategy.py │ ├── gaussian_strategy.py │ └── ou_strategy.py ├── launchers │ ├── __init__.py │ ├── config_template.py │ ├── launcher_util.py │ ├── rig_experiments.py │ └── state_based_goal_experiments.py ├── policies │ ├── __init__.py │ ├── argmax.py │ ├── base.py │ └── simple.py ├── pythonplusplus.py ├── samplers │ ├── __init__.py │ ├── in_place.py │ ├── rollout_functions.py │ └── util.py ├── torch │ ├── __init__.py │ ├── conv_networks.py │ ├── core.py │ ├── data_management │ │ ├── __init__.py │ │ └── normalizer.py │ ├── ddpg │ │ ├── __init__.py │ │ └── ddpg.py │ ├── distributions.py │ ├── dqn │ │ ├── __init__.py │ │ ├── double_dqn.py │ │ └── dqn.py │ ├── her │ │ ├── __init__.py │ │ ├── her.py │ │ └── her_replay_buffer.py │ ├── modules.py │ ├── networks.py │ ├── optim │ │ ├── mpi_adam.py │ │ └── util.py │ ├── pytorch_util.py │ ├── relational │ │ ├── modules.py │ │ ├── networks.py │ │ └── relational_util.py │ ├── sac │ │ ├── __init__.py │ │ ├── policies.py │ │ ├── sac.py │ │ └── twin_sac.py │ ├── td3 │ │ ├── __init__.py │ │ └── td3.py │ └── torch_rl_algorithm.py └── util │ ├── hyperparameter.py │ ├── io.py │ └── video.py └── scripts ├── download_s3.py ├── resume_training_with_new_env.py ├── run_experiment_from_doodad.py ├── sim_goal_conditioned_policy.py ├── sim_policy.py └── sim_tdm_policy.py /.gitignore: -------------------------------------------------------------------------------- 1 | .idea/ 2 | data/ 3 | !.gitignore 4 | videos/ 5 | relationalrl_venv/ 6 | rlkit/launchers/config.py 7 | examples/relationalrl/pkls/ 8 | *.pyc 9 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Vitchyr Pong 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # rlkit-relational 2 | Framework for relational reinforcement learning implemented in PyTorch. 3 | 4 | We include additional features, beyond RLKit, aimed at supporting high step complexity tasks and relational RL: 5 | - Training with multiple parallel workers via MPI 6 | - Concise implementations of inductive graph neural net architectures 7 | - Padded and masked minibatching for simultaneous training over variable-sized observation graphs 8 | - Observation graph input module for block construction task 9 | - Tuned example scripts for block construction task 10 | 11 | Implemented algorithms: 12 | 13 | 14 | 15 | - ReNN (*Towards Practical Robotic Manipulation using Relational Reinforcement Learning*) 16 | - [example script](examples/relationalrl/train_pickandplace1.py) 17 | - [ReNN paper](https://arxiv.org/pdf/1912.11032.pdf) 18 | - [Documentation](docs/ReNN.md) 19 | 20 | To get started, checkout the example scripts, linked above. 21 | 22 | If you find this code useful, please cite: 23 | 24 | @inproceedings{li19relationalrl, 25 | Author = {Li, Richard and 26 | Jabri, Allan and Darrell, Trevor and Agrawal, Pulkit}, 27 | Title = {Towards Practical Multi-object Manipulation using Relational Reinforcement Learning}, 28 | Booktitle = {ICRA}, 29 | Year = {2020} 30 | } 31 | 32 | ## Installation 33 | 34 | Note: These settings have only been tested on Ubuntu 18. It is recommended to use Ubuntu 18. 35 | 36 | 1. Install and activate a new python3.6+ virtualenv. (3.6+ is only needed because f-strings are used liberally in the code.. you can change the f-strings to support lower versions of python). 37 | ``` 38 | virtualenv -p python3 relationalrl_venv 39 | ``` 40 | 41 | ``` 42 | source relationalrl_venv/bin/activate 43 | ``` 44 | 45 | For the following steps, make sure you are sourced inside the `relationalrl_venv` virtualenv. 46 | 47 | 2. Install numpy. 48 | ``` 49 | pip install numpy 50 | ``` 51 | 3. Prepare for [mujoco-py](https://github.com/openai/mujoco-py) installation. 52 | 1. Download [mjpro150](https://www.roboti.us/index.html) 53 | 2. `cd ~` 54 | 3. `mkdir .mujoco` 55 | 4. Move mjpro150 folder to `.mujoco` 56 | 5. Move mujoco license key `mjkey.txt` to `~/.mujoco/mjkey.txt` 57 | 6. Set LD_LIBRARY_PATH: 58 | 59 | `export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:~/.mujoco/mjpro150/bin` 60 | 61 | 7. For Ubuntu, run: 62 | 63 | `sudo apt install libosmesa6-dev libgl1-mesa-glx libglfw3` 64 | 65 | `sudo apt install -y patchelf` 66 | 67 | 4. Install supporting packages 68 | ``` 69 | pip install -r requirements.txt 70 | ``` 71 | 72 | Make sure pip is with python3!! 73 | 74 | 5. Install Fetch Block Construction environment: 75 | ``` 76 | git clone https://github.com/richardrl/fetch-block-construction 77 | ``` 78 | 79 | ``` 80 | cd fetch-block-construction 81 | ``` 82 | 83 | ``` 84 | pip install -e . 85 | ``` 86 | 87 | 6. Copy `config_template.py` to `config.py` and fill out `config.py` with desired config settings: 88 | ``` 89 | cp rlkit/launchers/config_template.py rlkit/launchers/config.py 90 | ``` 91 | 92 | 7. Set PYTHONPATH: 93 | 94 | `export PYTHONPATH=$PYTHONPATH:` 95 | 96 | 8. Add the export statements above to `.bashrc` to avoid needing to run them everytime you login. 97 | 98 | 9. Optional: to save videos with the policy visualization script, install ffmpeg: 99 | 100 | `sudo apt-get install ffmpeg` 101 | 102 | ## Running scripts 103 | Make sure to set `mode` in the scripts: 104 | - `here_no_doodad`: run locally, without Docker 105 | - `local_docker`: locally with Docker 106 | - `ec2`: Amazon EC2 107 | 108 | To run multiple workers under the `here_no_doodad` setting, run the following command in the command line: 109 | ``` 110 | mpirun -np python examples/relationalrl/train_pickandplace1.py 111 | ``` 112 | 113 | ## Using a GPU 114 | You can use a GPU by setting 115 | `mode="gpu_opt"` in the example scripts. 116 | 117 | ## Visualizing a policy and seeing results 118 | During training, the results will be saved to a file called under 119 | ``` 120 | LOCAL_LOG_DIR// 121 | ``` 122 | - `LOCAL_LOG_DIR` is the directory set by `rlkit.launchers.config.LOCAL_LOG_DIR`. Default name is 'output'. 123 | - `` is given either to `setup_logger`. 124 | - `` is auto-generated and based off of `exp_prefix`. 125 | - inside this folder, you should see a file called `params.pkl`. To visualize a policy, run 126 | 127 | ``` 128 | (rlkit) $ python scripts/sim_policy.py LOCAL_LOG_DIR///params.pkl 129 | ``` 130 | 131 | To visualize results, download [viskit](https://github.com/vitchyr/viskit). You can visualize results with: 132 | ```bash 133 | python viskit/viskit/frontend.py LOCAL_LOG_DIR// 134 | ``` 135 | This `viskit` repo also has a few extra nice features, like plotting multiple Y-axis values at once, figure-splitting on multiple keys, and being able to filter hyperparametrs out. 136 | 137 | ## Launching jobs with `doodad` 138 | The `run_experiment` function makes it easy to run Python code on Amazon Web 139 | Services (AWS) or Google Cloud Platform (GCP) by using 140 | [doodad](https://github.com/justinjfu/doodad/). 141 | 142 | It's as easy as: 143 | ``` 144 | from rlkit.launchers.launcher_util import run_experiment 145 | 146 | def function_to_run(variant): 147 | learning_rate = variant['learning_rate'] 148 | ... 149 | 150 | run_experiment( 151 | function_to_run, 152 | exp_prefix="my-experiment-name", 153 | mode='ec2', # or 'gcp' 154 | variant={'learning_rate': 1e-3}, 155 | ) 156 | ``` 157 | You will need to set up parameters in config.py (see step one of Installation). 158 | This requires some knowledge of AWS and/or GCP, which is beyond the scope of 159 | this README. 160 | To learn more, more about `doodad`, [go to the repository](https://github.com/justinjfu/doodad/). 161 | 162 | ## Credits 163 | Much of the coding infrastructure and base algorithm implementations are courtesy of [RLKit](https://github.com/vitchyr/rlkit). 164 | 165 | The Dockerfile is based on the [OpenAI mujoco-py Dockerfile](https://github.com/openai/mujoco-py/blob/master/Dockerfile). 166 | -------------------------------------------------------------------------------- /docs/HER.md: -------------------------------------------------------------------------------- 1 | # Hindsight Experience Replay 2 | Some notes on the implementation of 3 | [Hindsight Experience Replay](https://arxiv.org/abs/1707.01495). 4 | ## Expected Results 5 | If you run the [Fetch example](examples/her/her_td3_gym_fetch_reach.py), then 6 | you should get results like this: 7 | ![Fetch HER results](images/FetchReach-v1_HER-TD3.png) 8 | 9 | If you run the [Sawyer example](examples/her/her_td3_multiworld_sawyer_reach.py) 10 | , then you should get results like this: 11 | ![Sawyer HER results](images/SawyerReachXYZEnv-v0_HER-TD3.png) 12 | 13 | Note that these examples use HER combined with TD3, and not DDPG. 14 | TD3 is a new method that came out after the HER paper, and it seems to work 15 | better than DDPG. 16 | 17 | ## Goal-based environments and `ObsDictRelabelingBuffer` 18 | [See here.](goal_based_envs.md) 19 | 20 | ## Implementation Difference 21 | This HER implemention is slightly different from the one presented in the paper. 22 | Rather than relabeling goals when saving data to the replay buffer, the goals 23 | are relabeled when sampling from the replay buffer. 24 | 25 | 26 | In other words, HER in the paper does this: 27 | 28 | Data collection 29 | 1. Sample $(s, a, r, s', g) ~ \text\{ENV}$. 30 | 2. Save $(s, a, r, s', g)$ into replay buffer $\mathcal B$. 31 | For i = 1, ..., K: 32 | Sample $g_i$ using the future strategy. 33 | Recompute rewards $r_i = f(s', g_i)$. 34 | Save $(s, a, r_i, s', g_)$ into replay buffer $\mathcal B$. 35 | Train time 36 | 1. Sample $(s, a, r, s', g)$ from replay buffer 37 | 2. Train Q function $(s, a, r, s', g)$ 38 | 39 | The implementation here does: 40 | 41 | Data collection 42 | 1. Sample $(s, a, r, s', g) ~ \text\{ENV}$. 43 | 2. Save $(s, a, r, s', g)$ into replay buffer $\mathcal B$. 44 | Train time 45 | 1. Sample $(s, a, r, s', g)$ from replay buffer 46 | 2a. With probability 1/(K+1): 47 | Train Q function $(s, a, r, s', g)$ 48 | 2b. With probability 1 - 1/(K+1): 49 | Sample $g'$ using the future strategy. 50 | Recompute rewards $r' = f(s', g')$. 51 | Train Q function on $(s, a, r', s', g')$ 52 | 53 | Both implementations effective do the same thing: with probability 1/(K+1), 54 | you train the policy on the goal used during rollout. Otherwise, train the 55 | policy on a resampled goal. 56 | 57 | -------------------------------------------------------------------------------- /docs/ReNN.md: -------------------------------------------------------------------------------- 1 | ## ReNN 2 | 3 | #### Towards Practical Robotic Manipulation using Relational Reinforcement Learning 4 | 5 | To replicate the main result of the paper, you should run the following with **35** workers. You should have at least the same number of available threads as workers. Having more threads can speed up wall-clock training time by 30%+. 6 | 7 | 1. Train on Pick and Place task (equivalent to Single Tower task with `stackonly=False`) with 1 block: 8 | 9 | 1. Run with 35 workers. 10 | 11 | `mpirun -np 35 python examples/relationalrl/train_pickandplace1.py` 12 | 13 | Your results should appear in `data` folder under `rlkit-relational` project root. 14 | 15 | 2. Transfer to Pick and Place task with 2 blocks: 16 | 17 | 1. Select a .PKL file from the folder in data at the desired checkpoint. Out of 35 worker folders, only one is checkpointing weights. You can find it by selecting the worker folder with the biggest size on disk. You should select a checkpoint where the policy is beginning to converge. 18 | 19 | 2. Assign filename on line 92 of `examples/relationalrl/train_sequential_transfer.py` to be the .PKL selected in step 1. 20 | 21 | 3. Run with 35 workers. 22 | 23 | `mpirun -np 35 python examples/relationarl/train_sequential_transfer.py` 24 | 25 | 3. Transfer to Single Tower task (`stackonly=True`) for 2+ blocks by repeating Step 2 with the most recently trained .PKL and changing `num_blocks` to be the desired number of blocks when prompted. -------------------------------------------------------------------------------- /docs/goal_based_envs.md: -------------------------------------------------------------------------------- 1 | # Goal-based environments and `ObsDictRelabelingBuffer` 2 | Some algorithms, like HER, are for goal-conditioned environments, like 3 | the [OpenAI Gym GoalEnv](https://blog.openai.com/ingredients-for-robotics-research/) 4 | or the [multiworld MultitaskEnv](https://github.com/vitchyr/multiworld/) 5 | environments. 6 | 7 | These environments are different from normal gym environments in that they 8 | return dictionaries for observations, like so: the environments work like this: 9 | 10 | ``` 11 | env = CarEnv() 12 | obs = env.reset() 13 | next_obs, reward, done, info = env.step(action) 14 | print(obs) 15 | 16 | # Output: 17 | # { 18 | # 'observation': ..., 19 | # 'desired_goal': ..., 20 | # 'achieved_goal': ..., 21 | # } 22 | ``` 23 | The `GoalEnv` environments also have a function with signature 24 | ``` 25 | def compute_rewards (achieved_goal, desired_goal): 26 | # achieved_goal and desired_goal are vectors 27 | ``` 28 | while the `MultitaskEnv` has a signature like 29 | ``` 30 | def compute_rewards (observation, action, next_observation): 31 | # observation and next_observations are dictionaries 32 | ``` 33 | To learn more about these environments, check out the URLs above. 34 | This means that normal RL algorithms won't even "type check" with these 35 | environments. 36 | 37 | `ObsDictRelabelingBuffer` perform hindsight experience replay with 38 | either types of environments and works by saving specific values in the 39 | observation dictionary. 40 | 41 | -------------------------------------------------------------------------------- /docs/images/FetchReach-v1_HER-TD3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/richardrl/rlkit-relational/e01973d0a7e393cb31fbd48e8180ab0d3d8d2a2e/docs/images/FetchReach-v1_HER-TD3.png -------------------------------------------------------------------------------- /docs/images/SawyerReachXYZEnv-v0_HER-TD3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/richardrl/rlkit-relational/e01973d0a7e393cb31fbd48e8180ab0d3d8d2a2e/docs/images/SawyerReachXYZEnv-v0_HER-TD3.png -------------------------------------------------------------------------------- /docs/images/stack6_29.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/richardrl/rlkit-relational/e01973d0a7e393cb31fbd48e8180ab0d3d8d2a2e/docs/images/stack6_29.gif -------------------------------------------------------------------------------- /environment/docker/Dockerfile: -------------------------------------------------------------------------------- 1 | # We need the CUDA base dockerfile to enable GPU rendering 2 | # on hosts with GPUs. 3 | # The image below is a pinned version of nvidia/cuda:9.1-cudnn7-devel-ubuntu16.04 (from Jan 2018) 4 | # If updating the base image, be sure to test on GPU since it has broken in the past. 5 | FROM nvidia/cuda:9.0-cudnn7-devel-ubuntu16.04 6 | 7 | 8 | RUN apt-get update -q \ 9 | && DEBIAN_FRONTEND=noninteractive apt-get install -y \ 10 | curl \ 11 | git \ 12 | libgl1-mesa-dev \ 13 | libgl1-mesa-glx \ 14 | libglew-dev \ 15 | libosmesa6-dev \ 16 | software-properties-common \ 17 | net-tools \ 18 | unzip \ 19 | vim \ 20 | virtualenv \ 21 | wget \ 22 | xpra \ 23 | xserver-xorg-dev \ 24 | && apt-get clean \ 25 | && rm -rf /var/lib/apt/lists/* 26 | 27 | RUN DEBIAN_FRONTEND=noninteractive add-apt-repository --yes ppa:deadsnakes/ppa && apt-get update 28 | RUN DEBIAN_FRONTEND=noninteractive apt-get install --yes python3.5-dev python3.5 python3-pip 29 | RUN virtualenv --python=python3.5 env 30 | 31 | RUN rm /usr/bin/python 32 | RUN ln -s /env/bin/python3.5 /usr/bin/python 33 | RUN ln -s /env/bin/pip3.5 /usr/bin/pip 34 | RUN ln -s /env/bin/pytest /usr/bin/pytest 35 | 36 | RUN curl -o /usr/local/bin/patchelf https://s3-us-west-2.amazonaws.com/openai-sci-artifacts/manual-builds/patchelf_0.9_amd64.elf \ 37 | && chmod +x /usr/local/bin/patchelf 38 | 39 | ENV LANG C.UTF-8 40 | 41 | RUN mkdir -p /root/.mujoco \ 42 | && wget https://www.roboti.us/download/mjpro150_linux.zip -O mujoco.zip \ 43 | && unzip mujoco.zip -d /root/.mujoco \ 44 | && rm mujoco.zip 45 | COPY ./mjkey.txt /root/.mujoco/ 46 | ENV LD_LIBRARY_PATH /root/.mujoco/mjpro150/bin:${LD_LIBRARY_PATH} 47 | ENV LD_LIBRARY_PATH /usr/local/nvidia/lib64:${LD_LIBRARY_PATH} 48 | 49 | COPY vendor/Xdummy /usr/local/bin/Xdummy 50 | RUN chmod +x /usr/local/bin/Xdummy 51 | 52 | # Workaround for https://bugs.launchpad.net/ubuntu/+source/nvidia-graphics-drivers-375/+bug/1674677 53 | COPY ./vendor/10_nvidia.json /usr/share/glvnd/egl_vendor.d/10_nvidia.json 54 | 55 | RUN apt-get update && apt-get install -y libav-tools 56 | 57 | # For some reason this works despite an error showing up... 58 | RUN DEBIAN_FRONTEND=noninteractive apt-get -qy install nvidia-384; exit 0 59 | ENV LD_LIBRARY_PATH ${LD_LIBRARY_PATH}:/usr/lib/nvidia-384 60 | 61 | RUN mkdir /root/code 62 | WORKDIR /root/code 63 | 64 | WORKDIR /mujoco_py 65 | 66 | # For atari-py 67 | RUN apt-get install -y zlib1g-dev swig cmake 68 | 69 | # Previous versions installed from a requirements.txt, but direct pip 70 | # install seems cleaner 71 | RUN pip install glfw>=1.4.0 72 | RUN pip install numpy>=1.11 73 | RUN pip install Cython>=0.27.2 74 | RUN pip install imageio>=2.1.2 75 | RUN pip install cffi>=1.10 76 | RUN pip install imagehash>=3.4 77 | RUN pip install ipdb 78 | RUN pip install Pillow>=4.0.0 79 | RUN pip install pycparser>=2.17.0 80 | RUN pip install pytest>=3.0.5 81 | RUN pip install pytest-instafail==0.3.0 82 | RUN pip install scipy>=0.18.0 83 | RUN pip install sphinx 84 | RUN pip install sphinx_rtd_theme 85 | RUN pip install numpydoc 86 | RUN pip install cloudpickle==0.5.2 87 | RUN pip install cached-property==1.3.1 88 | RUN pip install gym[all]==0.10.5 89 | RUN pip install gitpython==2.1.7 90 | RUN pip install gtimer==1.0.0b5 91 | RUN pip install awscli==1.11.179 92 | RUN pip install boto3==1.4.8 93 | RUN pip install ray==0.2.2 94 | RUN pip install path.py==10.3.1 95 | RUN pip install http://download.pytorch.org/whl/cu90/torch-0.4.1-cp35-cp35m-linux_x86_64.whl 96 | RUN pip install joblib==0.9.4 97 | RUN pip install opencv-python==3.4.0.12 98 | RUN pip install torchvision==0.2.0 99 | RUN pip install sk-video==1.1.10 100 | -------------------------------------------------------------------------------- /environment/docker/vendor/10_nvidia.json: -------------------------------------------------------------------------------- 1 | { 2 | "file_format_version" : "1.0.0", 3 | "ICD" : { 4 | "library_path" : "libEGL_nvidia.so.0" 5 | } 6 | } 7 | -------------------------------------------------------------------------------- /environment/docker/vendor/Xdummy-entrypoint: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | import argparse 3 | import os 4 | import sys 5 | import subprocess 6 | 7 | parser = argparse.ArgumentParser() 8 | args, extra_args = parser.parse_known_args() 9 | subprocess.Popen(["nohup", "Xdummy"], stdout=open('/dev/null', 'w'), stderr=open('/dev/null', 'w')) 10 | os.environ['DISPLAY'] = ':0' 11 | if not extra_args: 12 | sys.argv = ['/bin/bash'] 13 | else: 14 | sys.argv = extra_args 15 | # Explicitly flush right before the exec since otherwise things might get 16 | # lost in Python's buffers around stdout/stderr (!). 17 | sys.stdout.flush() 18 | sys.stderr.flush() 19 | os.execvpe(sys.argv[0], sys.argv, os.environ) 20 | 21 | -------------------------------------------------------------------------------- /environment/linux-cpu-env.yml: -------------------------------------------------------------------------------- 1 | name: rlkit 2 | channels: 3 | - kne # for pybox2d 4 | - pytorch 5 | - anaconda # for mkl 6 | dependencies: 7 | - cython 8 | - ipython # technically unnecessary 9 | - joblib=0.9.4 10 | - lockfile 11 | - mako=1.0.6=py35_0 12 | - matplotlib=2.0.2=np111py35_0 13 | - mkl=2018.0.2 # Need to add explicit dependence for pytorch 14 | - numba=0.35.0=np111py35_0 15 | - numpy=1.11.3 16 | - path.py=10.3.1=py35_0 17 | - pybox2d=2.3.1post2=py35_0 18 | - python=3.5.2 19 | - python-dateutil=2.6.1=py35_0 20 | - pytorch-cpu=0.4.1 21 | - scipy=1.0.1 22 | - patchelf 23 | - pip: 24 | - cloudpickle==0.5.2 25 | - gym[all]==0.10.5 26 | - gitpython==2.1.7 27 | - gtimer==1.0.0b5 28 | - pygame==1.9.2 29 | - ipdb # technically unnecessary 30 | -------------------------------------------------------------------------------- /environment/linux-gpu-env.yml: -------------------------------------------------------------------------------- 1 | name: rlkit 2 | channels: 3 | - kne # for pybox2d 4 | - pytorch 5 | - anaconda # for mkl 6 | dependencies: 7 | - cython 8 | - ipython # technically unnecessary 9 | - joblib=0.9.4 10 | - lockfile 11 | - mako=1.0.6=py35_0 12 | - matplotlib=2.0.2=np111py35_0 13 | - mkl=2018.0.2 # Need to add explicit dependence for pytorch 14 | - numba=0.35.0=np111py35_0 15 | - numpy=1.11.3 16 | - path.py=10.3.1=py35_0 17 | - pybox2d=2.3.1post2=py35_0 18 | - python=3.5.2 19 | - python-dateutil=2.6.1=py35_0 20 | - pytorch=0.4.1 21 | - scipy=1.0.1 22 | - patchelf 23 | - pip: 24 | - cloudpickle==0.5.2 25 | - gym[all]==0.10.5 26 | - gitpython==2.1.7 27 | - gtimer==1.0.0b5 28 | - pygame==1.9.2 29 | - ipdb # technically unnecessary 30 | -------------------------------------------------------------------------------- /environment/mac-env.yml: -------------------------------------------------------------------------------- 1 | name: rlkit 2 | channels: 3 | - kne # for pybox2d 4 | - pytorch 5 | - anaconda # for mkl 6 | dependencies: 7 | - cython 8 | - ipython # technically unnecessary 9 | - joblib=0.9.4 10 | - lockfile 11 | - mako=1.0.6=py35_0 12 | - matplotlib=2.0.2=np111py35_0 13 | - mkl=2018.0.2 # Need to add explicit dependence for pytorch 14 | - numba=0.35.0=np111py35_0 15 | - numpy=1.11.3 16 | - path.py=10.3.1=py35_0 17 | - pybox2d=2.3.1post2=py35_0 18 | - python=3.5.2 19 | - python-dateutil=2.6.1=py35_0 20 | - pytorch=0.4.1 21 | - scipy=1.0.1 22 | - pip: 23 | - cloudpickle==0.5.2 24 | - gym[all]==0.10.5 25 | - gitpython==2.1.7 26 | - gtimer==1.0.0b5 27 | - pygame==1.9.2 28 | - ipdb # technically unnecessary 29 | -------------------------------------------------------------------------------- /examples/ddpg.py: -------------------------------------------------------------------------------- 1 | """ 2 | Example of running PyTorch implementation of DDPG on HalfCheetah. 3 | """ 4 | from gym.envs.mujoco import HalfCheetahEnv 5 | 6 | from rlkit.envs.wrappers import NormalizedBoxEnv 7 | from rlkit.exploration_strategies.base import ( 8 | PolicyWrappedWithExplorationStrategy 9 | ) 10 | from rlkit.exploration_strategies.ou_strategy import OUStrategy 11 | from rlkit.launchers.launcher_util import setup_logger 12 | from rlkit.torch.networks import FlattenMlp, TanhMlpPolicy 13 | from rlkit.torch.ddpg.ddpg import DDPG 14 | import rlkit.torch.pytorch_util as ptu 15 | 16 | 17 | def experiment(variant): 18 | env = NormalizedBoxEnv(HalfCheetahEnv()) 19 | # Or for a specific version: 20 | # import gym 21 | # env = NormalizedBoxEnv(gym.make('HalfCheetah-v1')) 22 | es = OUStrategy(action_space=env.action_space) 23 | obs_dim = env.observation_space.low.size 24 | action_dim = env.action_space.low.size 25 | qf = FlattenMlp( 26 | input_size=obs_dim + action_dim, 27 | output_size=1, 28 | hidden_sizes=[400, 300], 29 | ) 30 | policy = TanhMlpPolicy( 31 | input_size=obs_dim, 32 | output_size=action_dim, 33 | hidden_sizes=[400, 300], 34 | ) 35 | exploration_policy = PolicyWrappedWithExplorationStrategy( 36 | exploration_strategy=es, 37 | policy=policy, 38 | ) 39 | algorithm = DDPG( 40 | env, 41 | qf=qf, 42 | policy=policy, 43 | exploration_policy=exploration_policy, 44 | **variant['algo_params'] 45 | ) 46 | algorithm.to(ptu.device) 47 | algorithm.train() 48 | 49 | 50 | if __name__ == "__main__": 51 | # noinspection PyTypeChecker 52 | variant = dict( 53 | algo_params=dict( 54 | num_epochs=1000, 55 | num_steps_per_epoch=1000, 56 | num_steps_per_eval=1000, 57 | use_soft_update=True, 58 | tau=1e-2, 59 | batch_size=128, 60 | max_path_length=1000, 61 | discount=0.99, 62 | qf_learning_rate=1e-3, 63 | policy_learning_rate=1e-4, 64 | ), 65 | ) 66 | # ptu.set_gpu_mode(True) # optionally set the GPU (default=False) 67 | setup_logger('name-of-experiment', variant=variant) 68 | experiment(variant) 69 | -------------------------------------------------------------------------------- /examples/doodad/ec2_example.py: -------------------------------------------------------------------------------- 1 | """ 2 | Example of running stuff on EC2 3 | """ 4 | import time 5 | 6 | from rlkit.core import logger 7 | from rlkit.launchers.launcher_util import run_experiment 8 | from datetime import datetime 9 | from pytz import timezone 10 | import pytz 11 | 12 | 13 | def example(variant): 14 | import torch 15 | logger.log(torch.__version__) 16 | date_format = '%m/%d/%Y %H:%M:%S %Z' 17 | date = datetime.now(tz=pytz.utc) 18 | logger.log("start") 19 | logger.log('Current date & time is: {}'.format(date.strftime(date_format))) 20 | if torch.cuda.is_available(): 21 | x = torch.randn(3) 22 | logger.log(str(x.to(ptu.device))) 23 | 24 | date = date.astimezone(timezone('US/Pacific')) 25 | logger.log('Local date & time is: {}'.format(date.strftime(date_format))) 26 | for i in range(variant['num_seconds']): 27 | logger.log("Tick, {}".format(i)) 28 | time.sleep(1) 29 | logger.log("end") 30 | logger.log('Local date & time is: {}'.format(date.strftime(date_format))) 31 | 32 | logger.log("start mujoco") 33 | from gym.envs.mujoco import HalfCheetahEnv 34 | e = HalfCheetahEnv() 35 | img = e.sim.render(32, 32) 36 | logger.log(str(sum(img))) 37 | logger.log("end mujocoy") 38 | 39 | 40 | if __name__ == "__main__": 41 | # noinspection PyTypeChecker 42 | date_format = '%m/%d/%Y %H:%M:%S %Z' 43 | date = datetime.now(tz=pytz.utc) 44 | logger.log("start") 45 | variant = dict( 46 | num_seconds=10, 47 | launch_time=str(date.strftime(date_format)), 48 | ) 49 | run_experiment( 50 | example, 51 | exp_prefix="ec2-test", 52 | mode='ec2', 53 | variant=variant, 54 | # use_gpu=True, # GPUs are much more expensive! 55 | ) 56 | -------------------------------------------------------------------------------- /examples/doodad/gcp_example.py: -------------------------------------------------------------------------------- 1 | """ 2 | Example of running stuff on GCP 3 | """ 4 | import time 5 | 6 | from rlkit.core import logger 7 | from rlkit.launchers.launcher_util import run_experiment 8 | from datetime import datetime 9 | from pytz import timezone 10 | import pytz 11 | 12 | 13 | def example(variant): 14 | import torch 15 | import rlkit.torch.pytorch_util as ptu 16 | print("Starting") 17 | logger.log(torch.__version__) 18 | date_format = '%m/%d/%Y %H:%M:%S %Z' 19 | date = datetime.now(tz=pytz.utc) 20 | logger.log("start") 21 | logger.log('Current date & time is: {}'.format(date.strftime(date_format))) 22 | logger.log("Cuda available: {}".format(torch.cuda.is_available())) 23 | if torch.cuda.is_available(): 24 | x = torch.randn(3) 25 | logger.log(str(x.to(ptu.device))) 26 | 27 | date = date.astimezone(timezone('US/Pacific')) 28 | logger.log('Local date & time is: {}'.format(date.strftime(date_format))) 29 | for i in range(variant['num_seconds']): 30 | logger.log("Tick, {}".format(i)) 31 | time.sleep(1) 32 | logger.log("end") 33 | logger.log('Local date & time is: {}'.format(date.strftime(date_format))) 34 | 35 | logger.log("start mujoco") 36 | from gym.envs.mujoco import HalfCheetahEnv 37 | e = HalfCheetahEnv() 38 | img = e.sim.render(32, 32) 39 | logger.log(str(sum(img))) 40 | logger.log("end mujoco") 41 | 42 | logger.record_tabular('Epoch', 1) 43 | logger.dump_tabular() 44 | logger.record_tabular('Epoch', 2) 45 | logger.dump_tabular() 46 | logger.record_tabular('Epoch', 3) 47 | logger.dump_tabular() 48 | print("Done") 49 | 50 | 51 | if __name__ == "__main__": 52 | # noinspection PyTypeChecker 53 | date_format = '%m/%d/%Y %H:%M:%S %Z' 54 | date = datetime.now(tz=pytz.utc) 55 | logger.log("start") 56 | variant = dict( 57 | num_seconds=10, 58 | launch_time=str(date.strftime(date_format)), 59 | ) 60 | run_experiment( 61 | example, 62 | exp_prefix="gcp-test", 63 | mode='gcp', 64 | variant=variant, 65 | # use_gpu=True, # GPUs are much more expensive! 66 | ) 67 | -------------------------------------------------------------------------------- /examples/dqn_and_double_dqn.py: -------------------------------------------------------------------------------- 1 | """ 2 | Run DQN on grid world. 3 | """ 4 | 5 | import gym 6 | import numpy as np 7 | from torch import nn as nn 8 | 9 | import rlkit.torch.pytorch_util as ptu 10 | from rlkit.launchers.launcher_util import setup_logger 11 | from rlkit.torch.dqn.dqn import DQN 12 | from rlkit.torch.networks import Mlp 13 | 14 | 15 | def experiment(variant): 16 | env = gym.make('CartPole-v0') 17 | training_env = gym.make('CartPole-v0') 18 | 19 | qf = Mlp( 20 | hidden_sizes=[32, 32], 21 | input_size=int(np.prod(env.observation_space.shape)), 22 | output_size=env.action_space.n, 23 | ) 24 | qf_criterion = nn.MSELoss() 25 | # Use this to switch to DoubleDQN 26 | # algorithm = DoubleDQN( 27 | algorithm = DQN( 28 | env, 29 | training_env=training_env, 30 | qf=qf, 31 | qf_criterion=qf_criterion, 32 | **variant['algo_params'] 33 | ) 34 | algorithm.to(ptu.device) 35 | algorithm.train() 36 | 37 | 38 | if __name__ == "__main__": 39 | # noinspection PyTypeChecker 40 | variant = dict( 41 | algo_params=dict( 42 | num_epochs=500, 43 | num_steps_per_epoch=1000, 44 | num_steps_per_eval=1000, 45 | batch_size=128, 46 | max_path_length=200, 47 | discount=0.99, 48 | epsilon=0.2, 49 | tau=0.001, 50 | hard_update_period=1000, 51 | save_environment=False, # Can't serialize CartPole for some reason 52 | ), 53 | ) 54 | setup_logger('name-of-experiment', variant=variant) 55 | experiment(variant) 56 | -------------------------------------------------------------------------------- /examples/her/her_td3_gym_fetch_reach.py: -------------------------------------------------------------------------------- 1 | """ 2 | This should results in an average return of ~3000 by the end of training. 3 | 4 | Usually hits 3000 around epoch 80-100. Within a see, the performance will be 5 | a bit noisy from one epoch to the next (occasionally dips dow to ~2000). 6 | 7 | Note that one epoch = 5k steps, so 200 epochs = 1 million steps. 8 | """ 9 | import gym 10 | 11 | import rlkit.torch.pytorch_util as ptu 12 | from rlkit.data_management.obs_dict_replay_buffer import ObsDictRelabelingBuffer 13 | from rlkit.exploration_strategies.base import ( 14 | PolicyWrappedWithExplorationStrategy 15 | ) 16 | from rlkit.exploration_strategies.gaussian_and_epsilon_strategy import ( 17 | GaussianAndEpislonStrategy 18 | ) 19 | from rlkit.launchers.launcher_util import setup_logger 20 | from rlkit.torch.her.her import HerTd3 21 | from rlkit.torch.networks import FlattenMlp, TanhMlpPolicy 22 | 23 | 24 | def experiment(variant): 25 | env = gym.make('FetchReach-v1') 26 | es = GaussianAndEpislonStrategy( 27 | action_space=env.action_space, 28 | max_sigma=.2, 29 | min_sigma=.2, # constant sigma 30 | epsilon=.3, 31 | ) 32 | obs_dim = env.observation_space.spaces['observation'].low.size 33 | goal_dim = env.observation_space.spaces['desired_goal'].low.size 34 | action_dim = env.action_space.low.size 35 | qf1 = FlattenMlp( 36 | input_size=obs_dim + goal_dim + action_dim, 37 | output_size=1, 38 | hidden_sizes=[400, 300], 39 | ) 40 | qf2 = FlattenMlp( 41 | input_size=obs_dim + goal_dim + action_dim, 42 | output_size=1, 43 | hidden_sizes=[400, 300], 44 | ) 45 | policy = TanhMlpPolicy( 46 | input_size=obs_dim + goal_dim, 47 | output_size=action_dim, 48 | hidden_sizes=[400, 300], 49 | ) 50 | exploration_policy = PolicyWrappedWithExplorationStrategy( 51 | exploration_strategy=es, 52 | policy=policy, 53 | ) 54 | replay_buffer = ObsDictRelabelingBuffer( 55 | env=env, 56 | **variant['replay_buffer_kwargs'] 57 | ) 58 | algorithm = HerTd3( 59 | env=env, 60 | qf1=qf1, 61 | qf2=qf2, 62 | policy=policy, 63 | exploration_policy=exploration_policy, 64 | replay_buffer=replay_buffer, 65 | **variant['algo_kwargs'] 66 | ) 67 | algorithm.to(ptu.device) 68 | algorithm.train() 69 | 70 | 71 | if __name__ == "__main__": 72 | variant = dict( 73 | algo_kwargs=dict( 74 | num_epochs=100, 75 | num_steps_per_epoch=1000, 76 | num_steps_per_eval=1000, 77 | max_path_length=50, 78 | batch_size=128, 79 | discount=0.99, 80 | ), 81 | replay_buffer_kwargs=dict( 82 | max_size=100000, 83 | fraction_goals_rollout_goals=0.2, # equal to k = 4 in HER paper 84 | fraction_goals_env_goals=0.0, 85 | ), 86 | ) 87 | setup_logger('her-td3-fetch-experiment', variant=variant) 88 | experiment(variant) 89 | -------------------------------------------------------------------------------- /examples/her/her_td3_multiworld_sawyer_reach.py: -------------------------------------------------------------------------------- 1 | """ 2 | This should results in an average return of ~3000 by the end of training. 3 | 4 | Usually hits 3000 around epoch 80-100. Within a see, the performance will be 5 | a bit noisy from one epoch to the next (occasionally dips dow to ~2000). 6 | 7 | Note that one epoch = 5k steps, so 200 epochs = 1 million steps. 8 | """ 9 | import gym 10 | 11 | import rlkit.torch.pytorch_util as ptu 12 | from rlkit.data_management.obs_dict_replay_buffer import ObsDictRelabelingBuffer 13 | from rlkit.exploration_strategies.base import \ 14 | PolicyWrappedWithExplorationStrategy 15 | from rlkit.exploration_strategies.gaussian_and_epsilon_strategy import ( 16 | GaussianAndEpislonStrategy 17 | ) 18 | from rlkit.launchers.launcher_util import setup_logger 19 | from rlkit.torch.her.her import HerTd3 20 | from rlkit.torch.networks import FlattenMlp, TanhMlpPolicy 21 | 22 | 23 | def experiment(variant): 24 | env = gym.make('SawyerReachXYZEnv-v0') 25 | es = GaussianAndEpislonStrategy( 26 | action_space=env.action_space, 27 | max_sigma=.2, 28 | min_sigma=.2, # constant sigma 29 | epsilon=.3, 30 | ) 31 | obs_dim = env.observation_space.spaces['observation'].low.size 32 | goal_dim = env.observation_space.spaces['desired_goal'].low.size 33 | action_dim = env.action_space.low.size 34 | qf1 = FlattenMlp( 35 | input_size=obs_dim + goal_dim + action_dim, 36 | output_size=1, 37 | hidden_sizes=[400, 300], 38 | ) 39 | qf2 = FlattenMlp( 40 | input_size=obs_dim + goal_dim + action_dim, 41 | output_size=1, 42 | hidden_sizes=[400, 300], 43 | ) 44 | policy = TanhMlpPolicy( 45 | input_size=obs_dim + goal_dim, 46 | output_size=action_dim, 47 | hidden_sizes=[400, 300], 48 | ) 49 | exploration_policy = PolicyWrappedWithExplorationStrategy( 50 | exploration_strategy=es, 51 | policy=policy, 52 | ) 53 | replay_buffer = ObsDictRelabelingBuffer( 54 | env=env, 55 | achieved_goal_key='state_achieved_goal', 56 | desired_goal_key='state_desired_goal', 57 | **variant['replay_buffer_kwargs'] 58 | ) 59 | algorithm = HerTd3( 60 | env=env, 61 | qf1=qf1, 62 | qf2=qf2, 63 | policy=policy, 64 | exploration_policy=exploration_policy, 65 | replay_buffer=replay_buffer, 66 | **variant['algo_kwargs'] 67 | ) 68 | algorithm.to(ptu.device) 69 | algorithm.train() 70 | 71 | 72 | if __name__ == "__main__": 73 | variant = dict( 74 | algo_kwargs=dict( 75 | num_epochs=100, 76 | num_steps_per_epoch=1000, 77 | num_steps_per_eval=1000, 78 | max_path_length=50, 79 | batch_size=128, 80 | discount=0.99, 81 | ), 82 | replay_buffer_kwargs=dict( 83 | max_size=100000, 84 | fraction_goals_rollout_goals=0.2, 85 | fraction_goals_env_goals=0.0, 86 | ), 87 | ) 88 | setup_logger('her-td3-sawyer-experiment', variant=variant) 89 | experiment(variant) 90 | -------------------------------------------------------------------------------- /examples/relationalrl/trace.html: -------------------------------------------------------------------------------- 1 |
# ThreadID: 140296524760832
 2 | File: "/usr/lib/python3.6/threading.py", line 884, in _bootstrap
 3 |   self._bootstrap_inner()
 4 | File: "/usr/lib/python3.6/threading.py", line 916, in _bootstrap_inner
 5 |   self.run()
 6 | File: "/home/richard/rlkit_fresh/stacktracer.py", line 64, in run
 7 |   self.stacktraces()
 8 | File: "/home/richard/rlkit_fresh/stacktracer.py", line 78, in stacktraces
 9 |   fout.write(stacktraces())
10 | File: "/home/richard/rlkit_fresh/stacktracer.py", line 26, in stacktraces
11 |   for filename, lineno, name, line in traceback.extract_stack(stack):
12 | 
13 | # ThreadID: 140296657909568
14 | File: "/home/richard/rlkit-relational/examples/relationalrl/train_sequential_transfer.py", line 225, in <module>
15 | File: "/home/richard/rlkit-relational/rlkit/launchers/launcher_util.py", line 590, in run_experiment
16 |   **run_experiment_kwargs
17 | File: "/home/richard/rlkit-relational/rlkit/launchers/launcher_util.py", line 168, in run_experiment_here
18 |   return experiment_function(variant)
19 | File: "/home/richard/rlkit-relational/examples/relationalrl/train_sequential_transfer.py", line 101, in experiment
20 | File: "/home/richard/rlkit-relational/rlkit/core/rl_algorithm.py", line 169, in train
21 |   self.train_batch(start_epoch=start_epoch)
22 | File: "/home/richard/rlkit-relational/rlkit/core/rl_algorithm.py", line 215, in train_batch
23 |   self._try_to_train()
24 | File: "/home/richard/rlkit-relational/rlkit/core/rl_algorithm.py", line 282, in _try_to_train
25 |   self._do_training()
26 | File: "/home/richard/rlkit-relational/rlkit/torch/sac/twin_sac.py", line 281, in _do_training
27 |   self.vf_optimizer.step()
28 | File: "/home/richard/rlkit-relational/rlkit/torch/optim/mpi_adam.py", line 123, in step
29 |   self.set_params_from_flat((self.get_flat_params() + step_update).to(device=torch.device("cpu")))
30 | File: "/home/richard/rlkit-relational/rlkit/torch/optim/util.py", line 58, in __call__
31 |   param.data.copy_(flattened_parameters[start:start+size].view(shape))
32 | 
33 | -------------------------------------------------------------------------------- /examples/sac.py: -------------------------------------------------------------------------------- 1 | """ 2 | Run PyTorch Soft Actor Critic on HalfCheetahEnv. 3 | 4 | NOTE: You need PyTorch 0.3 or more (to have torch.distributions) 5 | """ 6 | import numpy as np 7 | from gym.envs.mujoco import HalfCheetahEnv 8 | 9 | import rlkit.torch.pytorch_util as ptu 10 | from rlkit.envs.wrappers import NormalizedBoxEnv 11 | from rlkit.launchers.launcher_util import setup_logger 12 | from rlkit.torch.sac.policies import TanhGaussianPolicy 13 | from rlkit.torch.sac.sac import SoftActorCritic 14 | from rlkit.torch.networks import FlattenMlp 15 | 16 | 17 | def experiment(variant): 18 | env = NormalizedBoxEnv(HalfCheetahEnv()) 19 | # Or for a specific version: 20 | # import gym 21 | # env = NormalizedBoxEnv(gym.make('HalfCheetah-v1')) 22 | 23 | obs_dim = int(np.prod(env.observation_space.shape)) 24 | action_dim = int(np.prod(env.action_space.shape)) 25 | 26 | net_size = variant['net_size'] 27 | qf = FlattenMlp( 28 | hidden_sizes=[net_size, net_size], 29 | input_size=obs_dim + action_dim, 30 | output_size=1, 31 | ) 32 | vf = FlattenMlp( 33 | hidden_sizes=[net_size, net_size], 34 | input_size=obs_dim, 35 | output_size=1, 36 | ) 37 | policy = TanhGaussianPolicy( 38 | hidden_sizes=[net_size, net_size], 39 | obs_dim=obs_dim, 40 | action_dim=action_dim, 41 | ) 42 | algorithm = SoftActorCritic( 43 | env=env, 44 | policy=policy, 45 | qf=qf, 46 | vf=vf, 47 | **variant['algo_params'] 48 | ) 49 | algorithm.to(ptu.device) 50 | algorithm.train() 51 | 52 | 53 | if __name__ == "__main__": 54 | # noinspection PyTypeChecker 55 | variant = dict( 56 | algo_params=dict( 57 | num_epochs=1000, 58 | num_steps_per_epoch=1000, 59 | num_steps_per_eval=1000, 60 | batch_size=128, 61 | max_path_length=999, 62 | discount=0.99, 63 | reward_scale=1, 64 | 65 | soft_target_tau=0.001, 66 | policy_lr=3E-4, 67 | qf_lr=3E-4, 68 | vf_lr=3E-4, 69 | ), 70 | net_size=300, 71 | ) 72 | setup_logger('name-of-experiment', variant=variant) 73 | # ptu.set_gpu_mode(True) # optionally set the GPU (default=False) 74 | experiment(variant) 75 | -------------------------------------------------------------------------------- /examples/td3.py: -------------------------------------------------------------------------------- 1 | """ 2 | This should results in an average return of ~3000 by the end of training. 3 | 4 | Usually hits 3000 around epoch 80-100. Within a see, the performance will be 5 | a bit noisy from one epoch to the next (occasionally dips dow to ~2000). 6 | 7 | Note that one epoch = 5k steps, so 200 epochs = 1 million steps. 8 | """ 9 | from gym.envs.mujoco import HopperEnv 10 | 11 | import rlkit.torch.pytorch_util as ptu 12 | from rlkit.envs.wrappers import NormalizedBoxEnv 13 | from rlkit.exploration_strategies.base import \ 14 | PolicyWrappedWithExplorationStrategy 15 | from rlkit.exploration_strategies.gaussian_strategy import GaussianStrategy 16 | from rlkit.launchers.launcher_util import setup_logger 17 | from rlkit.torch.networks import FlattenMlp, TanhMlpPolicy 18 | from rlkit.torch.td3.td3 import TD3 19 | 20 | 21 | def experiment(variant): 22 | env = NormalizedBoxEnv(HopperEnv()) 23 | es = GaussianStrategy( 24 | action_space=env.action_space, 25 | max_sigma=0.1, 26 | min_sigma=0.1, # Constant sigma 27 | ) 28 | obs_dim = env.observation_space.low.size 29 | action_dim = env.action_space.low.size 30 | qf1 = FlattenMlp( 31 | input_size=obs_dim + action_dim, 32 | output_size=1, 33 | hidden_sizes=[400, 300], 34 | ) 35 | qf2 = FlattenMlp( 36 | input_size=obs_dim + action_dim, 37 | output_size=1, 38 | hidden_sizes=[400, 300], 39 | ) 40 | policy = TanhMlpPolicy( 41 | input_size=obs_dim, 42 | output_size=action_dim, 43 | hidden_sizes=[400, 300], 44 | ) 45 | exploration_policy = PolicyWrappedWithExplorationStrategy( 46 | exploration_strategy=es, 47 | policy=policy, 48 | ) 49 | algorithm = TD3( 50 | env, 51 | qf1=qf1, 52 | qf2=qf2, 53 | policy=policy, 54 | exploration_policy=exploration_policy, 55 | **variant['algo_kwargs'] 56 | ) 57 | algorithm.to(ptu.device) 58 | algorithm.train() 59 | 60 | 61 | if __name__ == "__main__": 62 | variant = dict( 63 | algo_kwargs=dict( 64 | num_epochs=200, 65 | num_steps_per_epoch=5000, 66 | num_steps_per_eval=10000, 67 | max_path_length=1000, 68 | batch_size=100, 69 | discount=0.99, 70 | replay_buffer_size=int(1E6), 71 | ), 72 | ) 73 | # ptu.set_gpu_mode(True) # optionally set the GPU (default=False) 74 | setup_logger('name-of-td3-experiment', variant=variant) 75 | experiment(variant) 76 | -------------------------------------------------------------------------------- /examples/tsac.py: -------------------------------------------------------------------------------- 1 | """ 2 | Run PyTorch Soft Actor Critic on HalfCheetahEnv with the "Twin" architecture 3 | from TD3: https://arxiv.org/pdf/1802.09477.pdf 4 | """ 5 | import numpy as np 6 | 7 | import rlkit.torch.pytorch_util as ptu 8 | from rlkit.envs.wrappers import NormalizedBoxEnv 9 | from rlkit.launchers.launcher_util import setup_logger 10 | from rlkit.torch.sac.policies import TanhGaussianPolicy 11 | from rlkit.torch.sac.sac import SoftActorCritic 12 | from rlkit.torch.networks import FlattenMlp 13 | from rlkit.torch.sac.twin_sac import TwinSAC 14 | 15 | 16 | def experiment(variant): 17 | import gym 18 | env = NormalizedBoxEnv(gym.make('HalfCheetah-v2')) 19 | 20 | obs_dim = int(np.prod(env.observation_space.shape)) 21 | action_dim = int(np.prod(env.action_space.shape)) 22 | 23 | net_size = variant['net_size'] 24 | qf1 = FlattenMlp( 25 | hidden_sizes=[net_size, net_size], 26 | input_size=obs_dim + action_dim, 27 | output_size=1, 28 | ) 29 | qf2 = FlattenMlp( 30 | hidden_sizes=[net_size, net_size], 31 | input_size=obs_dim + action_dim, 32 | output_size=1, 33 | ) 34 | vf = FlattenMlp( 35 | hidden_sizes=[net_size, net_size], 36 | input_size=obs_dim, 37 | output_size=1, 38 | ) 39 | policy = TanhGaussianPolicy( 40 | hidden_sizes=[net_size, net_size], 41 | obs_dim=obs_dim, 42 | action_dim=action_dim, 43 | ) 44 | algorithm = TwinSAC( 45 | env=env, 46 | policy=policy, 47 | qf1=qf1, 48 | qf2=qf2, 49 | vf=vf, 50 | **variant['algo_params'] 51 | ) 52 | algorithm.to(ptu.device) 53 | algorithm.train() 54 | 55 | 56 | if __name__ == "__main__": 57 | # noinspection PyTypeChecker 58 | variant = dict( 59 | algo_params=dict( 60 | num_epochs=1000, 61 | num_steps_per_epoch=1000, 62 | num_steps_per_eval=1000, 63 | max_path_length=1000, 64 | batch_size=128, 65 | discount=0.99, 66 | 67 | soft_target_tau=0.001, 68 | policy_lr=3E-4, 69 | qf_lr=3E-4, 70 | vf_lr=3E-4, 71 | ), 72 | net_size=300, 73 | ) 74 | # ptu.set_gpu_mode(True) # optionally set the GPU (default=False) 75 | setup_logger('name-of-experiment', variant=variant) 76 | experiment(variant) 77 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | atomicwrites==1.2.1 2 | attrs==18.2.0 3 | awscli==1.16.96 4 | backcall==0.1.0 5 | boto==2.49.0 6 | boto3==1.9.84 7 | botocore==1.12.86 8 | certifi==2018.11.29 9 | cffi==1.11.5 10 | chardet==3.0.4 11 | Click==7.0 12 | cloudpickle==0.5.2 13 | colorama==0.3.9 14 | cycler==0.10.0 15 | Cython==0.29.2 16 | decorator==4.4.0 17 | dill==0.2.9 18 | docutils==0.14 19 | -e git+git@github.com:richardrl/doodad-2019-fresh.git@93bf5ff595d10f36b1a72419b434858b94489302#egg=doodad 20 | ffmpeg==1.4 21 | filelock==3.0.10 22 | flatbuffers==1.10 23 | funcsigs==1.0.2 24 | future==0.17.1 25 | gitdb2==2.0.5 26 | GitPython==2.1.11 27 | glfw==1.7.0 28 | graphviz==0.12 29 | gtimer==1.0.0b5 30 | gym==0.10.9 31 | h5py==2.9.0 32 | idna==2.8 33 | imageio==2.4.1 34 | ipdb==0.12 35 | ipykernel==5.1.0 36 | ipython==7.4.0 37 | ipython-genutils==0.2.0 38 | jedi==0.13.3 39 | jmespath==0.9.3 40 | joblib==0.13.1 41 | jupyter-client==5.2.4 42 | jupyter-core==4.4.0 43 | kiwisolver==1.0.1 44 | lockfile==0.12.2 45 | matplotlib==3.0.2 46 | more-itertools==5.0.0 47 | mpi4py==3.0.1 48 | mujoco-py==1.50.1.68 49 | numpy==1.16.0 50 | numpy-stl==2.10.0 51 | opencv-python==4.0.0.21 52 | pandas==0.24.1 53 | parso==0.4.0 54 | pexpect==4.7.0 55 | pickleshare==0.7.5 56 | Pillow==5.4.1 57 | pkg-resources==0.0.0 58 | pluggy==0.8.1 59 | prompt-toolkit==2.0.9 60 | psutil==5.5.0 61 | ptyprocess==0.6.0 62 | py==1.7.0 63 | pyasn1==0.4.5 64 | pycparser==2.19 65 | pygame==1.9.4 66 | pyglet==1.3.2 67 | Pygments==2.3.1 68 | pyparsing==2.3.1 69 | pyquaternion==0.9.5 70 | pytest==4.2.0 71 | python-dateutil==2.7.5 72 | python-utils==2.3.0 73 | pytz==2018.9 74 | PyYAML==3.13 75 | pyzmq==18.0.1 76 | ray==0.6.2 77 | redis==3.1.0 78 | requests==2.21.0 79 | rsa==3.4.2 80 | s3transfer==0.1.13 81 | scikit-video==1.1.11 82 | scipy==1.2.0 83 | seaborn==0.9.0 84 | six==1.12.0 85 | smmap2==2.0.5 86 | tk==0.1.0 87 | torch==1.1.0 88 | torchtest==0.4 89 | torchvision==0.2.1 90 | -e git+git@github.com:szagoruyko/pytorchviz.git@46add7f2c071b6d29fc3d56e9d2d21e1c0a3af1d#egg=torchviz 91 | tornado==6.0.2 92 | traitlets==4.3.2 93 | urllib3==1.24.1 94 | wcwidth==0.1.7 -------------------------------------------------------------------------------- /rlkit/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/richardrl/rlkit-relational/e01973d0a7e393cb31fbd48e8180ab0d3d8d2a2e/rlkit/__init__.py -------------------------------------------------------------------------------- /rlkit/core/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | General classes, functions, utilities that are used throughout rlkit. 3 | """ 4 | from rlkit.core.logging import logger 5 | 6 | __all__ = ['logger'] 7 | 8 | -------------------------------------------------------------------------------- /rlkit/core/eval_util.py: -------------------------------------------------------------------------------- 1 | """ 2 | Common evaluation utilities. 3 | """ 4 | 5 | from collections import OrderedDict 6 | from numbers import Number 7 | 8 | import numpy as np 9 | 10 | 11 | def get_generic_path_information(paths, stat_prefix='', num_blocks=None): 12 | """ 13 | Get an OrderedDict with a bunch of statistic names and values. 14 | """ 15 | # assert num_blocks is not None 16 | # assert len(paths) == 21, len(paths) 17 | statistics = OrderedDict() 18 | returns = [sum(path["rewards"]) for path in paths] 19 | statistics.update(create_stats_ordered_dict('Returns', returns, 20 | stat_prefix=stat_prefix)) 21 | 22 | rewards = np.vstack([path["rewards"] for path in paths]) 23 | statistics.update(create_stats_ordered_dict('Rewards', rewards, 24 | stat_prefix=stat_prefix)) 25 | 26 | assert np.all([path['mask'][0] == path['mask'][x] for path in paths for x in range(len(path))]) 27 | 28 | final_num_blocks_stacked = [path['mask'][0].sum() - np.clip(np.abs(path["rewards"][-1]), None, path['mask'][0].sum()) for path in paths] 29 | statistics[F'{stat_prefix} Final Num Blocks Stacked'] = np.mean(final_num_blocks_stacked) 30 | 31 | mean_num_blocks_stacked = [(path['mask'][0].sum() - np.clip(np.abs(path['rewards']), None, path['mask'][0].sum())).mean() for path in paths] 32 | assert all(x >= 0 for x in mean_num_blocks_stacked), mean_num_blocks_stacked 33 | statistics[F'{stat_prefix} Mean Num Blocks Stacked'] = np.mean(mean_num_blocks_stacked) 34 | 35 | if isinstance(paths[0], dict) and num_blocks: 36 | # Keys are block IDs, values are final goal distances. 37 | # Each block ID is a list of final goal distances for all paths 38 | final_goal_dist = dict() 39 | seq = [] 40 | for block_id in range(num_blocks): 41 | final_goal_dist[block_id] = [np.linalg.norm(path['observations'][-1]['achieved_goal'][block_id*3:(block_id+1)*3] - path['observations'][-1]['desired_goal'][block_id*3:(block_id+1)*3]) for path in paths] 42 | # statistics.update(create_stats_ordered_dict(F"Fin Goal Dist Blk {block_id}", final_goal_dist[block_id], 43 | # stat_prefix=stat_prefix)) 44 | seq.append(np.array([np.linalg.norm(path['observations'][-1]['achieved_goal'][block_id*3:(block_id+1)*3] - path['observations'][-1]['desired_goal'][block_id*3:(block_id+1)*3]) for path in paths])) 45 | 46 | block_dists = np.vstack(seq) 47 | assert len(block_dists.shape) == 2 48 | sorted = np.sort(block_dists, axis=0) 49 | # sorted = block_dists 50 | 51 | # for block_id in range(num_blocks): 52 | # statistics.update(create_stats_ordered_dict(F"Fin Goal Dist Blk {block_id}", sorted[block_id], stat_prefix=stat_prefix)) 53 | 54 | total_solved = 0 55 | goal_threshold = .05 56 | for path_fd_tuple_across_blocks in zip(*list(final_goal_dist.values())): 57 | total_solved +=all(fd_blocki < goal_threshold for fd_blocki in path_fd_tuple_across_blocks) 58 | 59 | assert len(paths) == len(final_goal_dist[0]) 60 | percent_solved = total_solved/len(paths) 61 | assert 0 <= percent_solved <= 1, (total_solved, len(paths), final_goal_dist) 62 | statistics[F"{stat_prefix} Percent Solved"] = percent_solved 63 | 64 | actions = [path["actions"] for path in paths] 65 | if len(actions[0].shape) == 1: 66 | actions = np.hstack([path["actions"] for path in paths]) 67 | else: 68 | actions = np.vstack([path["actions"] for path in paths]) 69 | statistics.update(create_stats_ordered_dict( 70 | 'Actions', actions, stat_prefix=stat_prefix 71 | )) 72 | statistics['Num Paths'] = len(paths) 73 | 74 | return statistics 75 | 76 | 77 | def get_average_returns(paths): 78 | returns = [sum(path["rewards"]) for path in paths] 79 | return np.mean(returns) 80 | 81 | 82 | def create_stats_ordered_dict( 83 | name, 84 | data, 85 | stat_prefix=None, 86 | always_show_all_stats=True, 87 | exclude_max_min=False, 88 | ): 89 | if stat_prefix is not None: 90 | name = "{} {}".format(stat_prefix, name) 91 | if isinstance(data, Number): 92 | return OrderedDict({name: data}) 93 | 94 | if len(data) == 0: 95 | return OrderedDict() 96 | 97 | if isinstance(data, tuple): 98 | ordered_dict = OrderedDict() 99 | for number, d in enumerate(data): 100 | sub_dict = create_stats_ordered_dict( 101 | "{0}_{1}".format(name, number), 102 | d, 103 | ) 104 | ordered_dict.update(sub_dict) 105 | return ordered_dict 106 | 107 | if isinstance(data, list): 108 | try: 109 | iter(data[0]) 110 | except TypeError: 111 | pass 112 | else: 113 | data = np.concatenate(data) 114 | 115 | if (isinstance(data, np.ndarray) and data.size == 1 116 | and not always_show_all_stats): 117 | return OrderedDict({name: float(data)}) 118 | 119 | stats = OrderedDict([ 120 | (name + ' Mean', np.mean(data)), 121 | (name + ' Std', np.std(data)), 122 | ]) 123 | if not exclude_max_min: 124 | stats[name + ' Max'] = np.max(data) 125 | stats[name + ' Min'] = np.min(data) 126 | return stats 127 | -------------------------------------------------------------------------------- /rlkit/core/serializable.py: -------------------------------------------------------------------------------- 1 | """ 2 | Based on rllab's serializable.py file 3 | 4 | https://github.com/rll/rllab 5 | """ 6 | 7 | import inspect 8 | import sys 9 | 10 | 11 | class Serializable(object): 12 | 13 | def __init__(self, *args, **kwargs): 14 | self.__args = args 15 | self.__kwargs = kwargs 16 | 17 | def quick_init(self, locals_): 18 | if getattr(self, "_serializable_initialized", False): 19 | return 20 | if sys.version_info >= (3, 0): 21 | spec = inspect.getfullargspec(self.__init__) 22 | # Exclude the first "self" parameter 23 | if spec.varkw: 24 | kwargs = locals_[spec.varkw].copy() 25 | else: 26 | kwargs = dict() 27 | if spec.kwonlyargs: 28 | for key in spec.kwonlyargs: 29 | kwargs[key] = locals_[key] 30 | else: 31 | spec = inspect.getargspec(self.__init__) 32 | if spec.keywords: 33 | kwargs = locals_[spec.keywords] 34 | else: 35 | kwargs = dict() 36 | if spec.varargs: 37 | varargs = locals_[spec.varargs] 38 | else: 39 | varargs = tuple() 40 | in_order_args = [locals_[arg] for arg in spec.args][1:] 41 | self.__args = tuple(in_order_args) + varargs 42 | self.__kwargs = kwargs 43 | setattr(self, "_serializable_initialized", True) 44 | 45 | def __getstate__(self): 46 | return {"__args": self.__args, "__kwargs": self.__kwargs} 47 | 48 | def __setstate__(self, d): 49 | # convert all __args to keyword-based arguments 50 | if sys.version_info >= (3, 0): 51 | spec = inspect.getfullargspec(self.__init__) 52 | else: 53 | spec = inspect.getargspec(self.__init__) 54 | in_order_args = spec.args[1:] 55 | out = type(self)(**dict(zip(in_order_args, d["__args"]), **d["__kwargs"])) 56 | self.__dict__.update(out.__dict__) 57 | self.__dict__.update() 58 | 59 | @classmethod 60 | def clone(cls, obj, **kwargs): 61 | assert isinstance(obj, Serializable) 62 | d = obj.__getstate__() 63 | d["__kwargs"] = dict(d["__kwargs"], **kwargs) 64 | out = type(obj).__new__(type(obj)) 65 | out.__setstate__(d) 66 | return out 67 | -------------------------------------------------------------------------------- /rlkit/data_management/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/richardrl/rlkit-relational/e01973d0a7e393cb31fbd48e8180ab0d3d8d2a2e/rlkit/data_management/__init__.py -------------------------------------------------------------------------------- /rlkit/data_management/env_replay_buffer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from rlkit.data_management.simple_replay_buffer import SimpleReplayBuffer 3 | from gym.spaces import Box, Discrete, Tuple 4 | 5 | 6 | class EnvReplayBuffer(SimpleReplayBuffer): 7 | def __init__( 8 | self, 9 | max_replay_buffer_size, 10 | env, 11 | ): 12 | """ 13 | :param max_replay_buffer_size: 14 | :param env: 15 | """ 16 | self.env = env 17 | self._ob_space = env.observation_space 18 | self._action_space = env.action_space 19 | super().__init__( 20 | max_replay_buffer_size=max_replay_buffer_size, 21 | observation_dim=get_dim(self._ob_space), 22 | action_dim=get_dim(self._action_space), 23 | ) 24 | 25 | def add_sample(self, observation, action, reward, terminal, 26 | next_observation, **kwargs): 27 | 28 | if isinstance(self._action_space, Discrete): 29 | action = np.eye(self._action_space.n)[action] 30 | super(EnvReplayBuffer, self).add_sample( 31 | observation, action, reward, terminal, 32 | next_observation, **kwargs) 33 | 34 | 35 | def get_dim(space): 36 | if isinstance(space, Box): 37 | return space.low.size 38 | elif isinstance(space, Discrete): 39 | return space.n 40 | elif isinstance(space, Tuple): 41 | return sum(get_dim(subspace) for subspace in space.spaces) 42 | elif hasattr(space, 'flat_dim'): 43 | return space.flat_dim 44 | else: 45 | raise TypeError("Unknown space: {}".format(space)) 46 | -------------------------------------------------------------------------------- /rlkit/data_management/normalizer.py: -------------------------------------------------------------------------------- 1 | """ 2 | Based on code from Marcin Andrychowicz 3 | """ 4 | import numpy as np 5 | 6 | 7 | class Normalizer(object): 8 | def __init__( 9 | self, 10 | size, 11 | eps=1e-8, 12 | default_clip_range=np.inf, 13 | mean=0, 14 | std=1, 15 | ): 16 | self.size = size 17 | self.eps = eps 18 | self.default_clip_range = default_clip_range 19 | self.sum = np.zeros(self.size, np.float32) 20 | self.sumsq = np.zeros(self.size, np.float32) 21 | self.count = np.ones(1, np.float32) 22 | self.mean = mean + np.zeros(self.size, np.float32) 23 | self.std = std * np.ones(self.size, np.float32) 24 | self.synchronized = True 25 | 26 | def update(self, v): 27 | if v.ndim == 1: 28 | v = np.expand_dims(v, 0) 29 | assert v.ndim == 2 30 | assert v.shape[1] == self.size 31 | self.sum += v.sum(axis=0) 32 | self.sumsq += (np.square(v)).sum(axis=0) 33 | self.count[0] += v.shape[0] 34 | self.synchronized = False 35 | 36 | def normalize(self, v, clip_range=None): 37 | if not self.synchronized: 38 | self.synchronize() 39 | if clip_range is None: 40 | clip_range = self.default_clip_range 41 | mean, std = self.mean, self.std 42 | if v.ndim == 2: 43 | mean = mean.reshape(1, -1) 44 | std = std.reshape(1, -1) 45 | return np.clip((v - mean) / std, -clip_range, clip_range) 46 | 47 | def denormalize(self, v): 48 | if not self.synchronized: 49 | self.synchronize() 50 | mean, std = self.mean, self.std 51 | if v.ndim == 2: 52 | mean = mean.reshape(1, -1) 53 | std = std.reshape(1, -1) 54 | return mean + v * std 55 | 56 | def synchronize(self): 57 | self.mean[...] = self.sum / self.count[0] 58 | self.std[...] = np.sqrt( 59 | np.maximum( 60 | np.square(self.eps), 61 | self.sumsq / self.count[0] - np.square(self.mean) 62 | ) 63 | ) 64 | self.synchronized = True 65 | 66 | 67 | class IdentityNormalizer(object): 68 | def __init__(self, *args, **kwargs): 69 | pass 70 | 71 | def update(self, v): 72 | pass 73 | 74 | def normalize(self, v, clip_range=None): 75 | return v 76 | 77 | def denormalize(self, v): 78 | return v 79 | 80 | 81 | class FixedNormalizer(object): 82 | def __init__( 83 | self, 84 | size, 85 | default_clip_range=np.inf, 86 | mean=0, 87 | std=1, 88 | eps=1e-8, 89 | ): 90 | assert std > 0 91 | std = std + eps 92 | self.size = size 93 | self.default_clip_range = default_clip_range 94 | self.mean = mean + np.zeros(self.size, np.float32) 95 | self.std = std + np.zeros(self.size, np.float32) 96 | self.eps = eps 97 | 98 | def set_mean(self, mean): 99 | self.mean = mean + np.zeros(self.size, np.float32) 100 | 101 | def set_std(self, std): 102 | std = std + self.eps 103 | self.std = std + np.zeros(self.size, np.float32) 104 | 105 | def normalize(self, v, clip_range=None): 106 | if clip_range is None: 107 | clip_range = self.default_clip_range 108 | mean, std = self.mean, self.std 109 | if v.ndim == 2: 110 | mean = mean.reshape(1, -1) 111 | std = std.reshape(1, -1) 112 | return np.clip((v - mean) / std, -clip_range, clip_range) 113 | 114 | def denormalize(self, v): 115 | mean, std = self.mean, self.std 116 | if v.ndim == 2: 117 | mean = mean.reshape(1, -1) 118 | std = std.reshape(1, -1) 119 | return mean + v * std 120 | 121 | def copy_stats(self, other): 122 | self.set_mean(other.mean) 123 | self.set_std(other.std) 124 | -------------------------------------------------------------------------------- /rlkit/data_management/path_builder.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | class PathBuilder(dict): 5 | """ 6 | Usage: 7 | ``` 8 | path_builder = PathBuilder() 9 | path.add_sample( 10 | observations=1, 11 | actions=2, 12 | next_observations=3, 13 | ... 14 | ) 15 | path.add_sample( 16 | observations=4, 17 | actions=5, 18 | next_observations=6, 19 | ... 20 | ) 21 | 22 | path = path_builder.get_all_stacked() 23 | 24 | path['observations'] 25 | # output: [1, 4] 26 | path['actions'] 27 | # output: [2, 5] 28 | ``` 29 | 30 | Note that the key should be "actions" and not "action" since the 31 | resulting dictionary will have those keys. 32 | """ 33 | 34 | def __init__(self): 35 | super().__init__() 36 | self._path_length = 0 37 | 38 | def add_all(self, **key_to_value): 39 | for k, v in key_to_value.items(): 40 | if k not in self: 41 | self[k] = [v] 42 | else: 43 | self[k].append(v) 44 | self._path_length += 1 45 | 46 | def get_all_stacked(self): 47 | output_dict = dict() 48 | for k, v in self.items(): 49 | output_dict[k] = stack_list(v) 50 | return output_dict 51 | 52 | def __len__(self): 53 | return self._path_length 54 | 55 | 56 | def stack_list(lst): 57 | if isinstance(lst[0], dict): 58 | return lst 59 | else: 60 | return np.array(lst) 61 | -------------------------------------------------------------------------------- /rlkit/data_management/replay_buffer.py: -------------------------------------------------------------------------------- 1 | import abc 2 | 3 | 4 | class ReplayBuffer(object, metaclass=abc.ABCMeta): 5 | """ 6 | A class used to save and replay data. 7 | """ 8 | 9 | @abc.abstractmethod 10 | def add_sample(self, observation, action, reward, next_observation, 11 | terminal, **kwargs): 12 | """ 13 | Add a transition tuple. 14 | """ 15 | pass 16 | 17 | @abc.abstractmethod 18 | def terminate_episode(self): 19 | """ 20 | Let the replay buffer know that the episode has terminated in case some 21 | special book-keeping has to happen. 22 | :return: 23 | """ 24 | pass 25 | 26 | @abc.abstractmethod 27 | def num_steps_can_sample(self, **kwargs): 28 | """ 29 | :return: # of unique items that can be sampled. 30 | """ 31 | pass 32 | 33 | def add_path(self, path): 34 | """ 35 | Add a path to the replay buffer. 36 | 37 | This default implementation naively goes through every step, but you 38 | may want to optimize this. 39 | 40 | NOTE: You should NOT call "terminate_episode" after calling add_path. 41 | It's assumed that this function handles the episode termination. 42 | 43 | :param path: Dict like one outputted by rlkit.samplers.util.rollout 44 | """ 45 | for i, ( 46 | obs, 47 | action, 48 | reward, 49 | next_obs, 50 | terminal, 51 | agent_info, 52 | env_info 53 | ) in enumerate(zip( 54 | path["observations"], 55 | path["actions"], 56 | path["rewards"], 57 | path["next_observations"], 58 | path["terminals"], 59 | path["agent_infos"], 60 | path["env_infos"], 61 | )): 62 | self.add_sample( 63 | obs, 64 | action, 65 | reward, 66 | next_obs, 67 | terminal, 68 | agent_info=agent_info, 69 | env_info=env_info, 70 | ) 71 | self.terminate_episode() 72 | 73 | @abc.abstractmethod 74 | def random_batch(self, batch_size): 75 | """ 76 | Return a batch of size `batch_size`. 77 | :param batch_size: 78 | :return: 79 | """ 80 | pass 81 | -------------------------------------------------------------------------------- /rlkit/data_management/simple_replay_buffer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from rlkit.data_management.replay_buffer import ReplayBuffer 4 | 5 | 6 | class SimpleReplayBuffer(ReplayBuffer): 7 | def __init__( 8 | self, max_replay_buffer_size, observation_dim, action_dim, 9 | ): 10 | self._observation_dim = observation_dim 11 | self._action_dim = action_dim 12 | self._max_replay_buffer_size = max_replay_buffer_size 13 | self._observations = np.zeros((max_replay_buffer_size, observation_dim)) 14 | # It's a bit memory inefficient to save the observations twice, 15 | # but it makes the code *much* easier since you no longer have to 16 | # worry about termination conditions. 17 | self._next_obs = np.zeros((max_replay_buffer_size, observation_dim)) 18 | self._actions = np.zeros((max_replay_buffer_size, action_dim)) 19 | # Make everything a 2D np array to make it easier for other code to 20 | # reason about the shape of the data 21 | self._rewards = np.zeros((max_replay_buffer_size, 1)) 22 | # self._terminals[i] = a terminal was received at time i 23 | self._terminals = np.zeros((max_replay_buffer_size, 1), dtype='uint8') 24 | self._top = 0 25 | self._size = 0 26 | 27 | def add_sample(self, observation, action, reward, terminal, 28 | next_observation, **kwargs): 29 | self._observations[self._top] = observation 30 | self._actions[self._top] = action 31 | self._rewards[self._top] = reward 32 | self._terminals[self._top] = terminal 33 | self._next_obs[self._top] = next_observation 34 | self._advance() 35 | 36 | def terminate_episode(self): 37 | pass 38 | 39 | def _advance(self): 40 | self._top = (self._top + 1) % self._max_replay_buffer_size 41 | if self._size < self._max_replay_buffer_size: 42 | self._size += 1 43 | 44 | def random_batch(self, batch_size): 45 | indices = np.random.randint(0, self._size, batch_size) 46 | return dict( 47 | observations=self._observations[indices], 48 | actions=self._actions[indices], 49 | rewards=self._rewards[indices], 50 | terminals=self._terminals[indices], 51 | next_observations=self._next_obs[indices], 52 | ) 53 | 54 | def num_steps_can_sample(self): 55 | return self._size 56 | -------------------------------------------------------------------------------- /rlkit/envs/ant.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from rlkit.envs.mujoco_env import MujocoEnv 4 | 5 | 6 | class AntEnv(MujocoEnv): 7 | def __init__(self, use_low_gear_ratio=True): 8 | self.init_serialization(locals()) 9 | if use_low_gear_ratio: 10 | xml_path = 'low_gear_ratio_ant.xml' 11 | else: 12 | xml_path = 'normal_gear_ratio_ant.xml' 13 | super().__init__( 14 | xml_path, 15 | frame_skip=5, 16 | automatically_set_obs_and_action_space=True, 17 | ) 18 | 19 | def step(self, a): 20 | torso_xyz_before = self.get_body_com("torso") 21 | self.do_simulation(a, self.frame_skip) 22 | torso_xyz_after = self.get_body_com("torso") 23 | torso_velocity = torso_xyz_after - torso_xyz_before 24 | forward_reward = torso_velocity[0]/self.dt 25 | ctrl_cost = .5 * np.square(a).sum() 26 | contact_cost = 0.5 * 1e-3 * np.sum( 27 | np.square(np.clip(self.sim.data.cfrc_ext, -1, 1))) 28 | survive_reward = 1.0 29 | reward = forward_reward - ctrl_cost - contact_cost + survive_reward 30 | state = self.state_vector() 31 | notdone = np.isfinite(state).all() \ 32 | and state[2] >= 0.2 and state[2] <= 1.0 33 | done = not notdone 34 | ob = self._get_obs() 35 | return ob, reward, done, dict( 36 | reward_forward=forward_reward, 37 | reward_ctrl=-ctrl_cost, 38 | reward_contact=-contact_cost, 39 | reward_survive=survive_reward, 40 | torso_velocity=torso_velocity, 41 | ) 42 | 43 | def _get_obs(self): 44 | return np.concatenate([ 45 | self.sim.data.qpos.flat[2:], 46 | self.sim.data.qvel.flat, 47 | ]) 48 | 49 | def reset_model(self): 50 | qpos = self.init_qpos + self.np_random.uniform(size=self.model.nq, low=-.1, high=.1) 51 | qvel = self.init_qvel + self.np_random.randn(self.model.nv) * .1 52 | self.set_state(qpos, qvel) 53 | return self._get_obs() 54 | 55 | def viewer_setup(self): 56 | self.viewer.cam.distance = self.model.stat.extent * 0.5 57 | -------------------------------------------------------------------------------- /rlkit/envs/assets/low_gear_ratio_ant.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 85 | -------------------------------------------------------------------------------- /rlkit/envs/assets/reacher_7dof.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 84 | -------------------------------------------------------------------------------- /rlkit/envs/mujoco_env.py: -------------------------------------------------------------------------------- 1 | import os 2 | from os import path 3 | 4 | import mujoco_py 5 | import numpy as np 6 | from gym.envs.mujoco import mujoco_env 7 | 8 | from rlkit.core.serializable import Serializable 9 | 10 | ENV_ASSET_DIR = os.path.join(os.path.dirname(__file__), 'assets') 11 | 12 | 13 | class MujocoEnv(mujoco_env.MujocoEnv, Serializable): 14 | """ 15 | My own wrapper around MujocoEnv. 16 | 17 | The caller needs to declare 18 | """ 19 | def __init__( 20 | self, 21 | model_path, 22 | frame_skip=1, 23 | model_path_is_local=True, 24 | automatically_set_obs_and_action_space=False, 25 | ): 26 | if model_path_is_local: 27 | model_path = get_asset_xml(model_path) 28 | if automatically_set_obs_and_action_space: 29 | mujoco_env.MujocoEnv.__init__(self, model_path, frame_skip) 30 | else: 31 | """ 32 | Code below is copy/pasted from MujocoEnv's __init__ function. 33 | """ 34 | if model_path.startswith("/"): 35 | fullpath = model_path 36 | else: 37 | fullpath = os.path.join(os.path.dirname(__file__), "assets", model_path) 38 | if not path.exists(fullpath): 39 | raise IOError("File %s does not exist" % fullpath) 40 | self.frame_skip = frame_skip 41 | self.model = mujoco_py.MjModel(fullpath) 42 | self.data = self.model.data 43 | self.viewer = None 44 | 45 | self.metadata = { 46 | 'render.modes': ['human', 'rgb_array'], 47 | 'video.frames_per_second': int(np.round(1.0 / self.dt)) 48 | } 49 | 50 | self.init_qpos = self.model.data.qpos.ravel().copy() 51 | self.init_qvel = self.model.data.qvel.ravel().copy() 52 | self._seed() 53 | 54 | def init_serialization(self, locals): 55 | Serializable.quick_init(self, locals) 56 | 57 | def log_diagnostics(self, paths): 58 | pass 59 | 60 | 61 | def get_asset_xml(xml_name): 62 | return os.path.join(ENV_ASSET_DIR, xml_name) 63 | -------------------------------------------------------------------------------- /rlkit/envs/multi_env_wrapper.py: -------------------------------------------------------------------------------- 1 | """ 2 | Wrap Algorithm such that we can swap between environments in multi-task setting 3 | """ 4 | import numpy as np 5 | import gym 6 | from rlkit.torch.her.her import HerTwinSAC 7 | 8 | 9 | class MultiEnvWrapperHerTwinSAC(HerTwinSAC): 10 | def __init__(self, env_names, her_kwargs, tsac_kwargs, *args, env_probabilities=None, **kwargs): 11 | """ 12 | :param env_names: List of environment names 13 | """ 14 | HerTwinSAC.__init__(self, 15 | *args, 16 | her_kwargs=her_kwargs, 17 | tsac_kwargs=tsac_kwargs, 18 | **kwargs) 19 | self.env_names = env_names 20 | self.env_probabilities = env_probabilities 21 | assert [gym.make(self.env_names[i]).action_space == gym.make(self.env_names[0]).action_space for i in range(len(self.env_names))] 22 | 23 | def get_new_env(self): 24 | env_idx = np.random.choice(np.arange(len(self.env_names)), p=self.env_probabilities if self.env_probabilities else None) 25 | return gym.make(self.env_names[env_idx]), self.env_names[env_idx] 26 | 27 | def _handle_rollout_ending(self): 28 | super()._handle_rollout_ending() 29 | self.training_env, env_name = self.get_new_env() 30 | print(f"Loaded {env_name}") 31 | # self.training_env = pickle.loads(pickle.dumps(self.env)) 32 | self.replay_buffer.env = self.training_env -------------------------------------------------------------------------------- /rlkit/envs/wrappers.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from gym import Env 3 | from gym.spaces import Box 4 | 5 | from rlkit.core.serializable import Serializable 6 | 7 | 8 | class ProxyEnv(Serializable, Env): 9 | def __init__(self, wrapped_env): 10 | Serializable.quick_init(self, locals()) 11 | self._wrapped_env = wrapped_env 12 | self.action_space = self._wrapped_env.action_space 13 | self.observation_space = self._wrapped_env.observation_space 14 | 15 | @property 16 | def wrapped_env(self): 17 | return self._wrapped_env 18 | 19 | def reset(self, **kwargs): 20 | return self._wrapped_env.reset(**kwargs) 21 | 22 | def step(self, action): 23 | return self._wrapped_env.step(action) 24 | 25 | def render(self, *args, **kwargs): 26 | return self._wrapped_env.render(*args, **kwargs) 27 | 28 | def log_diagnostics(self, paths, *args, **kwargs): 29 | if hasattr(self._wrapped_env, 'log_diagnostics'): 30 | self._wrapped_env.log_diagnostics(paths, *args, **kwargs) 31 | 32 | @property 33 | def horizon(self): 34 | return self._wrapped_env.horizon 35 | 36 | def terminate(self): 37 | if hasattr(self.wrapped_env, "terminate"): 38 | self.wrapped_env.terminate() 39 | 40 | 41 | class NormalizedBoxEnv(ProxyEnv, Serializable): 42 | """ 43 | Normalize action to in [-1, 1]. 44 | 45 | Optionally normalize observations and scale reward. 46 | """ 47 | def __init__( 48 | self, 49 | env, 50 | reward_scale=1., 51 | obs_mean=None, 52 | obs_std=None, 53 | ): 54 | # self._wrapped_env needs to be called first because 55 | # Serializable.quick_init calls getattr, on this class. And the 56 | # implementation of getattr (see below) calls self._wrapped_env. 57 | # Without setting this first, the call to self._wrapped_env would call 58 | # getattr again (since it's not set yet) and therefore loop forever. 59 | self._wrapped_env = env 60 | # Or else serialization gets delegated to the wrapped_env. Serialize 61 | # this env separately from the wrapped_env. 62 | self._serializable_initialized = False 63 | Serializable.quick_init(self, locals()) 64 | ProxyEnv.__init__(self, env) 65 | self._should_normalize = not (obs_mean is None and obs_std is None) 66 | if self._should_normalize: 67 | if obs_mean is None: 68 | obs_mean = np.zeros_like(env.observation_space.low) 69 | else: 70 | obs_mean = np.array(obs_mean) 71 | if obs_std is None: 72 | obs_std = np.ones_like(env.observation_space.low) 73 | else: 74 | obs_std = np.array(obs_std) 75 | self._reward_scale = reward_scale 76 | self._obs_mean = obs_mean 77 | self._obs_std = obs_std 78 | ub = np.ones(self._wrapped_env.action_space.shape) 79 | self.action_space = Box(-1 * ub, ub) 80 | 81 | def estimate_obs_stats(self, obs_batch, override_values=False): 82 | if self._obs_mean is not None and not override_values: 83 | raise Exception("Observation mean and std already set. To " 84 | "override, set override_values to True.") 85 | self._obs_mean = np.mean(obs_batch, axis=0) 86 | self._obs_std = np.std(obs_batch, axis=0) 87 | 88 | def _apply_normalize_obs(self, obs): 89 | return (obs - self._obs_mean) / (self._obs_std + 1e-8) 90 | 91 | def __getstate__(self): 92 | d = Serializable.__getstate__(self) 93 | # Add these explicitly in case they were modified 94 | d["_obs_mean"] = self._obs_mean 95 | d["_obs_std"] = self._obs_std 96 | d["_reward_scale"] = self._reward_scale 97 | return d 98 | 99 | def __setstate__(self, d): 100 | Serializable.__setstate__(self, d) 101 | self._obs_mean = d["_obs_mean"] 102 | self._obs_std = d["_obs_std"] 103 | self._reward_scale = d["_reward_scale"] 104 | 105 | def step(self, action): 106 | lb = self._wrapped_env.action_space.low 107 | ub = self._wrapped_env.action_space.high 108 | scaled_action = lb + (action + 1.) * 0.5 * (ub - lb) 109 | scaled_action = np.clip(scaled_action, lb, ub) 110 | 111 | wrapped_step = self._wrapped_env.step(scaled_action) 112 | next_obs, reward, done, info = wrapped_step 113 | if self._should_normalize: 114 | next_obs = self._apply_normalize_obs(next_obs) 115 | return next_obs, reward * self._reward_scale, done, info 116 | 117 | def __str__(self): 118 | return "Normalized: %s" % self._wrapped_env 119 | 120 | def log_diagnostics(self, paths, **kwargs): 121 | if hasattr(self._wrapped_env, "log_diagnostics"): 122 | return self._wrapped_env.log_diagnostics(paths, **kwargs) 123 | else: 124 | return None 125 | 126 | def __getattr__(self, attrname): 127 | return getattr(self._wrapped_env, attrname) 128 | -------------------------------------------------------------------------------- /rlkit/exploration_strategies/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/richardrl/rlkit-relational/e01973d0a7e393cb31fbd48e8180ab0d3d8d2a2e/rlkit/exploration_strategies/__init__.py -------------------------------------------------------------------------------- /rlkit/exploration_strategies/base.py: -------------------------------------------------------------------------------- 1 | import abc 2 | 3 | from rlkit.policies.base import ExplorationPolicy, SerializablePolicy 4 | 5 | 6 | class ExplorationStrategy(object, metaclass=abc.ABCMeta): 7 | @abc.abstractmethod 8 | def get_action(self, t, observation, policy, **kwargs): 9 | pass 10 | 11 | @abc.abstractmethod 12 | def get_actions(self, t, observation, policy, **kwargs): 13 | pass 14 | 15 | def reset(self): 16 | pass 17 | 18 | 19 | class RawExplorationStrategy(ExplorationStrategy, metaclass=abc.ABCMeta): 20 | @abc.abstractmethod 21 | def get_action_from_raw_action(self, action, **kwargs): 22 | pass 23 | 24 | def get_action(self, t, policy, *args, **kwargs): 25 | action, agent_info = policy.get_action(*args, **kwargs) 26 | return self.get_action_from_raw_action(action, t=t), agent_info 27 | 28 | def get_actions(self, t, observation, policy, **kwargs): 29 | actions = policy.get_actions(observation) 30 | return self.get_actions_from_raw_actions(actions, t=t, **kwargs) 31 | 32 | def reset(self): 33 | pass 34 | 35 | 36 | class PolicyWrappedWithExplorationStrategy(ExplorationPolicy, SerializablePolicy): 37 | def __init__( 38 | self, 39 | exploration_strategy: ExplorationStrategy, 40 | policy: SerializablePolicy, 41 | ): 42 | self.es = exploration_strategy 43 | self.policy = policy 44 | self.t = 0 45 | 46 | def set_num_steps_total(self, t): 47 | self.t = t 48 | 49 | def get_action(self, *args, **kwargs): 50 | return self.es.get_action(self.t, self.policy, *args, **kwargs) 51 | 52 | def get_actions(self, *args, **kwargs): 53 | return self.es.get_actions(self.t, self.policy, *args, **kwargs) 54 | 55 | def reset(self): 56 | self.es.reset() 57 | self.policy.reset() 58 | 59 | def get_param_values(self): 60 | return self.policy.get_param_values() 61 | 62 | def set_param_values(self, param_values): 63 | self.policy.set_param_values(param_values) 64 | 65 | def get_param_values_np(self): 66 | return self.policy.get_param_values_np() 67 | 68 | def set_param_values_np(self, param_values): 69 | self.policy.set_param_values_np(param_values) 70 | -------------------------------------------------------------------------------- /rlkit/exploration_strategies/epsilon_greedy.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | from gym.spaces import Discrete 4 | 5 | from rlkit.exploration_strategies.base import RawExplorationStrategy 6 | from rlkit.core.serializable import Serializable 7 | 8 | 9 | class EpsilonGreedy(RawExplorationStrategy, Serializable): 10 | """ 11 | Take a random discrete action with some probability. 12 | """ 13 | def __init__(self, action_space, prob_random_action=0.1): 14 | Serializable.quick_init(self, locals()) 15 | # assert isinstance(action_space, Discrete) 16 | Serializable.quick_init(self, locals()) 17 | self.prob_random_action = prob_random_action 18 | self.action_space = action_space 19 | 20 | def get_action_from_raw_action(self, action, **kwargs): 21 | if random.random() <= self.prob_random_action: 22 | return self.action_space.sample() 23 | return action 24 | -------------------------------------------------------------------------------- /rlkit/exploration_strategies/gaussian_and_epsilon_strategy.py: -------------------------------------------------------------------------------- 1 | import random 2 | from rlkit.exploration_strategies.base import RawExplorationStrategy 3 | from rlkit.core.serializable import Serializable 4 | import numpy as np 5 | 6 | 7 | class GaussianAndEpislonStrategy(RawExplorationStrategy, Serializable): 8 | """ 9 | With probability epsilon, take a completely random action. 10 | with probability 1-epsilon, add Gaussian noise to the action taken by a 11 | deterministic policy. 12 | """ 13 | def __init__(self, action_space, epsilon, max_sigma=1.0, min_sigma=None, 14 | decay_period=1000000): 15 | assert len(action_space.shape) == 1 16 | Serializable.quick_init(self, locals()) 17 | if min_sigma is None: 18 | min_sigma = max_sigma 19 | self._max_sigma = max_sigma 20 | self._epsilon = epsilon 21 | self._min_sigma = min_sigma 22 | self._decay_period = decay_period 23 | self._action_space = action_space 24 | 25 | def get_action_from_raw_action(self, action, t=None, **kwargs): 26 | if random.random() < self._epsilon: 27 | return self._action_space.sample() 28 | else: 29 | sigma = ( 30 | self._max_sigma - (self._max_sigma - self._min_sigma) 31 | * min(1.0, t * 1.0 / self._decay_period) 32 | ) 33 | return np.clip( 34 | action + np.random.normal(size=len(action)) * sigma, 35 | self._action_space.low, 36 | self._action_space.high, 37 | ) -------------------------------------------------------------------------------- /rlkit/exploration_strategies/gaussian_strategy.py: -------------------------------------------------------------------------------- 1 | from rlkit.exploration_strategies.base import RawExplorationStrategy 2 | from rlkit.core.serializable import Serializable 3 | import numpy as np 4 | 5 | 6 | class GaussianStrategy(RawExplorationStrategy, Serializable): 7 | """ 8 | This strategy adds Gaussian noise to the action taken by the deterministic policy. 9 | 10 | Based on the rllab implementation. 11 | """ 12 | def __init__(self, action_space, max_sigma=1.0, min_sigma=None, 13 | decay_period=1000000): 14 | assert len(action_space.shape) == 1 15 | Serializable.quick_init(self, locals()) 16 | self._max_sigma = max_sigma 17 | if min_sigma is None: 18 | min_sigma = max_sigma 19 | self._min_sigma = min_sigma 20 | self._decay_period = decay_period 21 | self._action_space = action_space 22 | 23 | def get_action_from_raw_action(self, action, t=None, **kwargs): 24 | sigma = ( 25 | self._max_sigma - (self._max_sigma - self._min_sigma) * 26 | min(1.0, t * 1.0 / self._decay_period) 27 | ) 28 | return np.clip( 29 | action + np.random.normal(size=len(action)) * sigma, 30 | self._action_space.low, 31 | self._action_space.high, 32 | ) 33 | -------------------------------------------------------------------------------- /rlkit/exploration_strategies/ou_strategy.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import numpy.random as nr 3 | 4 | from rlkit.exploration_strategies.base import RawExplorationStrategy 5 | from rlkit.core.serializable import Serializable 6 | 7 | 8 | class OUStrategy(RawExplorationStrategy, Serializable): 9 | """ 10 | This strategy implements the Ornstein-Uhlenbeck process, which adds 11 | time-correlated noise to the actions taken by the deterministic policy. 12 | The OU process satisfies the following stochastic differential equation: 13 | dxt = theta*(mu - xt)*dt + sigma*dWt 14 | where Wt denotes the Wiener process 15 | 16 | Based on the rllab implementation. 17 | """ 18 | 19 | def __init__( 20 | self, 21 | action_space, 22 | mu=0, 23 | theta=0.15, 24 | max_sigma=0.3, 25 | min_sigma=None, 26 | decay_period=100000, 27 | ): 28 | Serializable.quick_init(self, locals()) 29 | if min_sigma is None: 30 | min_sigma = max_sigma 31 | self.mu = mu 32 | self.theta = theta 33 | self.sigma = max_sigma 34 | self._max_sigma = max_sigma 35 | if min_sigma is None: 36 | min_sigma = max_sigma 37 | self._min_sigma = min_sigma 38 | self._decay_period = decay_period 39 | self.dim = np.prod(action_space.low.shape) 40 | self.low = action_space.low 41 | self.high = action_space.high 42 | self.state = np.ones(self.dim) * self.mu 43 | self.reset() 44 | 45 | def reset(self): 46 | self.state = np.ones(self.dim) * self.mu 47 | 48 | def evolve_state(self): 49 | x = self.state 50 | dx = self.theta * (self.mu - x) + self.sigma * nr.randn(len(x)) 51 | self.state = x + dx 52 | return self.state 53 | 54 | def get_action_from_raw_action(self, action, t=0, **kwargs): 55 | ou_state = self.evolve_state() 56 | self.sigma = ( 57 | self._max_sigma 58 | - (self._max_sigma - self._min_sigma) 59 | * min(1.0, t * 1.0 / self._decay_period) 60 | ) 61 | return np.clip(action + ou_state, self.low, self.high) 62 | -------------------------------------------------------------------------------- /rlkit/launchers/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | This module contains 'launchers', which are self-contained functions that take 3 | one dictionary and run a full experiment. The dictionary configures the 4 | experiment. 5 | 6 | It is important that the functions are completely self-contained (i.e. they 7 | import their own modules) so that they can be serialized. 8 | """ 9 | -------------------------------------------------------------------------------- /rlkit/launchers/config_template.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copy this file to config.py and modify as needed. 3 | """ 4 | import os 5 | from os.path import join 6 | import rlkit 7 | 8 | def get_infra_settings(mode, instance_type): 9 | """ 10 | Fill this out with settings for each EC2 instance type. 11 | :param mode: 12 | :param instance_type: 13 | :return: 14 | """ 15 | if mode == "ec2" and instance_type == "c5.18xlarge": 16 | return dict( 17 | num_gpus=0, 18 | num_parallel_processes=16, 19 | gpu_mode=False 20 | ) 21 | 22 | """ 23 | `doodad.mount.MountLocal` by default ignores directories called "data" 24 | If you're going to rename this directory and use EC2, then change 25 | `doodad.mount.MountLocal.filter_dir` 26 | """ 27 | # The directory of the project, not source 28 | rlkit_project_dir = join(os.path.dirname(rlkit.__file__), os.pardir) 29 | LOCAL_LOG_DIR = join(rlkit_project_dir, 'data') 30 | 31 | """ 32 | ******************************************************************************** 33 | ******************************************************************************** 34 | ******************************************************************************** 35 | 36 | You probably don't need to set all of the configurations below this line, 37 | unless you use AWS, GCP, Slurm, and/or Slurm on a remote server. I recommend 38 | ignoring most of these things and only using them on an as-needed basis. 39 | 40 | ******************************************************************************** 41 | ******************************************************************************** 42 | ******************************************************************************** 43 | """ 44 | 45 | """ 46 | General doodad settings. 47 | """ 48 | CODE_DIRS_TO_MOUNT = [ 49 | rlkit_project_dir, 50 | # '/home/user/python/module/one', Add more paths as needed 51 | ] 52 | DIR_AND_MOUNT_POINT_MAPPINGS = [ 53 | dict( 54 | local_dir=join(os.getenv('HOME'), '.mujoco/'), 55 | mount_point='/root/.mujoco', 56 | ), 57 | ] 58 | RUN_DOODAD_EXPERIMENT_SCRIPT_PATH = ( 59 | join(rlkit_project_dir, 'scripts', 'run_experiment_from_doodad.py') 60 | # '/home/user/path/to/rlkit/scripts/run_experiment_from_doodad.py' 61 | ) 62 | """ 63 | AWS Settings 64 | """ 65 | # If not set, default will be chosen by doodad 66 | # AWS_S3_PATH = 's3://bucket/directory 67 | 68 | # The docker image is looked up on dockerhub.com. 69 | DOODAD_DOCKER_IMAGE = "TODO" 70 | INSTANCE_TYPE = 'c4.large' 71 | SPOT_PRICE = 0.03 72 | 73 | GPU_DOODAD_DOCKER_IMAGE = 'TODO' 74 | GPU_INSTANCE_TYPE = 'g2.2xlarge' 75 | GPU_SPOT_PRICE = 0.5 76 | 77 | # You can use AMI images with the docker images already installed. 78 | REGION_TO_GPU_AWS_IMAGE_ID = { 79 | 'us-west-1': "TODO", 80 | 'us-east-1': "TODO", 81 | } 82 | 83 | REGION_TO_GPU_AWS_AVAIL_ZONE = { 84 | 'us-east-1': "us-east-1b", 85 | } 86 | 87 | # This really shouldn't matter and in theory could be whatever 88 | OUTPUT_DIR_FOR_DOODAD_TARGET = '/tmp/doodad-output/' 89 | 90 | 91 | """ 92 | Slurm Settings 93 | """ 94 | SINGULARITY_IMAGE = '/home/PATH/TO/IMAGE.img' 95 | # This assumes you saved mujoco to $HOME/.mujoco 96 | SINGULARITY_PRE_CMDS = [ 97 | 'export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$HOME/.mujoco/mjpro150/bin' 98 | ] 99 | SLURM_CPU_CONFIG = dict( 100 | account_name='TODO', 101 | partition='savio', 102 | nodes=1, 103 | n_tasks=1, 104 | n_gpus=1, 105 | ) 106 | SLURM_GPU_CONFIG = dict( 107 | account_name='TODO', 108 | partition='savio2_1080ti', 109 | nodes=1, 110 | n_tasks=1, 111 | n_gpus=1, 112 | ) 113 | 114 | 115 | """ 116 | Slurm Script Settings 117 | 118 | These are basically the same settings as above, but for the remote machine 119 | where you will be running the generated script. 120 | """ 121 | SSS_CODE_DIRS_TO_MOUNT = [ 122 | ] 123 | SSS_DIR_AND_MOUNT_POINT_MAPPINGS = [ 124 | dict( 125 | local_dir='/global/home/users/USERNAME/.mujoco', 126 | mount_point='/root/.mujoco', 127 | ), 128 | ] 129 | SSS_LOG_DIR = '/global/scratch/USERNAME/doodad-log' 130 | 131 | SSS_IMAGE = '/global/scratch/USERNAME/TODO.img' 132 | SSS_RUN_DOODAD_EXPERIMENT_SCRIPT_PATH = ( 133 | '/global/home/users/USERNAME/path/to/rlkit/scripts' 134 | '/run_experiment_from_doodad.py' 135 | ) 136 | SSS_PRE_CMDS = [ 137 | 'export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/global/home/users/USERNAME' 138 | '/.mujoco/mjpro150/bin' 139 | ] 140 | 141 | 142 | 143 | """ 144 | GCP Settings 145 | """ 146 | GCP_IMAGE_NAME = 'TODO' 147 | GCP_GPU_IMAGE_NAME = 'TODO' 148 | GCP_BUCKET_NAME = 'TODO' 149 | 150 | GCP_DEFAULT_KWARGS = dict( 151 | zone='us-west2-c', 152 | instance_type='n1-standard-4', 153 | image_project='TODO', 154 | terminate=True, 155 | preemptible=True, 156 | gpu_kwargs=dict( 157 | gpu_model='nvidia-tesla-p4', 158 | num_gpu=1, 159 | ) 160 | ) 161 | -------------------------------------------------------------------------------- /rlkit/launchers/state_based_goal_experiments.py: -------------------------------------------------------------------------------- 1 | import gym 2 | # Trigger environment registrations 3 | # noinspection PyUnresolvedReferences 4 | import multiworld.envs.mujoco 5 | # noinspection PyUnresolvedReferences 6 | import multiworld.envs.pygame 7 | import rlkit.samplers.rollout_functions as rf 8 | import rlkit.torch.pytorch_util as ptu 9 | from rlkit.exploration_strategies.base import ( 10 | PolicyWrappedWithExplorationStrategy 11 | ) 12 | from rlkit.exploration_strategies.epsilon_greedy import EpsilonGreedy 13 | from rlkit.exploration_strategies.gaussian_strategy import GaussianStrategy 14 | from rlkit.exploration_strategies.ou_strategy import OUStrategy 15 | from rlkit.launchers.rig_experiments import get_video_save_func 16 | from rlkit.torch.her.her import HerTd3 17 | from rlkit.torch.networks import FlattenMlp, TanhMlpPolicy 18 | from rlkit.data_management.obs_dict_replay_buffer import ( 19 | ObsDictRelabelingBuffer 20 | ) 21 | 22 | 23 | def her_td3_experiment(variant): 24 | if 'env_id' in variant: 25 | env = gym.make(variant['env_id']) 26 | else: 27 | env = variant['env_class'](**variant['env_kwargs']) 28 | 29 | observation_key = variant['observation_key'] 30 | desired_goal_key = variant['desired_goal_key'] 31 | variant['algo_kwargs']['her_kwargs']['observation_key'] = observation_key 32 | variant['algo_kwargs']['her_kwargs']['desired_goal_key'] = desired_goal_key 33 | if variant.get('normalize', False): 34 | raise NotImplementedError() 35 | 36 | achieved_goal_key = desired_goal_key.replace("desired", "achieved") 37 | replay_buffer = ObsDictRelabelingBuffer( 38 | env=env, 39 | observation_key=observation_key, 40 | desired_goal_key=desired_goal_key, 41 | achieved_goal_key=achieved_goal_key, 42 | **variant['replay_buffer_kwargs'] 43 | ) 44 | obs_dim = env.observation_space.spaces['observation'].low.size 45 | action_dim = env.action_space.low.size 46 | goal_dim = env.observation_space.spaces['desired_goal'].low.size 47 | exploration_type = variant['exploration_type'] 48 | if exploration_type == 'ou': 49 | es = OUStrategy( 50 | action_space=env.action_space, 51 | **variant['es_kwargs'] 52 | ) 53 | elif exploration_type == 'gaussian': 54 | es = GaussianStrategy( 55 | action_space=env.action_space, 56 | **variant['es_kwargs'], 57 | ) 58 | elif exploration_type == 'epsilon': 59 | es = EpsilonGreedy( 60 | action_space=env.action_space, 61 | **variant['es_kwargs'], 62 | ) 63 | else: 64 | raise Exception("Invalid type: " + exploration_type) 65 | qf1 = FlattenMlp( 66 | input_size=obs_dim + action_dim + goal_dim, 67 | output_size=1, 68 | **variant['qf_kwargs'] 69 | ) 70 | qf2 = FlattenMlp( 71 | input_size=obs_dim + action_dim + goal_dim, 72 | output_size=1, 73 | **variant['qf_kwargs'] 74 | ) 75 | policy = TanhMlpPolicy( 76 | input_size=obs_dim + goal_dim, 77 | output_size=action_dim, 78 | **variant['policy_kwargs'] 79 | ) 80 | exploration_policy = PolicyWrappedWithExplorationStrategy( 81 | exploration_strategy=es, 82 | policy=policy, 83 | ) 84 | algorithm = HerTd3( 85 | env, 86 | qf1=qf1, 87 | qf2=qf2, 88 | policy=policy, 89 | exploration_policy=exploration_policy, 90 | replay_buffer=replay_buffer, 91 | **variant['algo_kwargs'] 92 | ) 93 | if variant.get("save_video", False): 94 | rollout_function = rf.create_rollout_function( 95 | rf.multitask_rollout, 96 | max_path_length=algorithm.max_path_length, 97 | observation_key=algorithm.observation_key, 98 | desired_goal_key=algorithm.desired_goal_key, 99 | ) 100 | video_func = get_video_save_func( 101 | rollout_function, 102 | env, 103 | policy, 104 | variant, 105 | ) 106 | algorithm.post_epoch_funcs.append(video_func) 107 | algorithm.to(ptu.device) 108 | algorithm.train() 109 | -------------------------------------------------------------------------------- /rlkit/policies/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/richardrl/rlkit-relational/e01973d0a7e393cb31fbd48e8180ab0d3d8d2a2e/rlkit/policies/__init__.py -------------------------------------------------------------------------------- /rlkit/policies/argmax.py: -------------------------------------------------------------------------------- 1 | """ 2 | Torch argmax policy 3 | """ 4 | import numpy as np 5 | import rlkit.torch.pytorch_util as ptu 6 | from rlkit.policies.base import SerializablePolicy 7 | from rlkit.torch.core import PyTorchModule 8 | 9 | 10 | class ArgmaxDiscretePolicy(PyTorchModule, SerializablePolicy): 11 | def __init__(self, qf): 12 | self.save_init_params(locals()) 13 | super().__init__() 14 | self.qf = qf 15 | 16 | def get_action(self, obs): 17 | obs = np.expand_dims(obs, axis=0) 18 | obs = ptu.from_numpy(obs).float() 19 | q_values = self.qf(obs).squeeze(0) 20 | q_values_np = ptu.get_numpy(q_values) 21 | return q_values_np.argmax(), {} 22 | -------------------------------------------------------------------------------- /rlkit/policies/base.py: -------------------------------------------------------------------------------- 1 | import abc 2 | 3 | 4 | class Policy(object, metaclass=abc.ABCMeta): 5 | """ 6 | General policy interface. 7 | """ 8 | @abc.abstractmethod 9 | def get_action(self, observation): 10 | """ 11 | 12 | :param observation: 13 | :return: action, debug_dictionary 14 | """ 15 | pass 16 | 17 | def reset(self): 18 | pass 19 | 20 | 21 | class ExplorationPolicy(Policy, metaclass=abc.ABCMeta): 22 | def set_num_steps_total(self, t): 23 | pass 24 | 25 | 26 | class SerializablePolicy(Policy, metaclass=abc.ABCMeta): 27 | """ 28 | Policy that can be serialized. 29 | """ 30 | def get_param_values(self): 31 | return None 32 | 33 | def set_param_values(self, values): 34 | pass 35 | 36 | """ 37 | Parameters should be passed as np arrays in the two functions below. 38 | """ 39 | def get_param_values_np(self): 40 | return None 41 | 42 | def set_param_values_np(self, values): 43 | pass 44 | -------------------------------------------------------------------------------- /rlkit/policies/simple.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from rlkit.policies.base import SerializablePolicy 4 | 5 | 6 | class RandomPolicy(SerializablePolicy): 7 | """ 8 | Policy that always outputs zero. 9 | """ 10 | 11 | def __init__(self, action_space): 12 | self.action_space = action_space 13 | 14 | def get_action(self, obs): 15 | return self.action_space.sample(), {} 16 | -------------------------------------------------------------------------------- /rlkit/samplers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/richardrl/rlkit-relational/e01973d0a7e393cb31fbd48e8180ab0d3d8d2a2e/rlkit/samplers/__init__.py -------------------------------------------------------------------------------- /rlkit/samplers/in_place.py: -------------------------------------------------------------------------------- 1 | from rlkit.samplers.util import rollout 2 | from rlkit.samplers.rollout_functions import multitask_rollout 3 | import numpy as np 4 | 5 | 6 | class InPlacePathSampler(object): 7 | """ 8 | A sampler that does not serialization for sampling. Instead, it just uses 9 | the current policy and environment as-is. 10 | 11 | WARNING: This will affect the environment! So 12 | ``` 13 | sampler = InPlacePathSampler(env, ...) 14 | sampler.obtain_samples # this has side-effects: env will change! 15 | ``` 16 | """ 17 | def __init__(self, env, policy, max_samples, max_path_length, randomize_env=False, alg=None): 18 | self.env = env 19 | self.policy = policy 20 | self.max_path_length = max_path_length 21 | self.max_samples = max_samples 22 | assert max_samples >= max_path_length, "Need max_samples >= max_path_length" 23 | self.randomize_env = randomize_env 24 | self.alg = alg 25 | 26 | def start_worker(self): 27 | pass 28 | 29 | def shutdown_worker(self): 30 | pass 31 | 32 | def obtain_samples(self, rollout_type="multitask"): 33 | paths = [] 34 | n_steps_total = 0 35 | while n_steps_total + self.max_path_length <= self.max_samples: 36 | if self.randomize_env: 37 | self.env, env_name = self.alg.get_new_env() 38 | print(f"Evaluating {env_name}") 39 | if rollout_type == "multitask": 40 | path = multitask_rollout( 41 | self.env, 42 | self.policy, 43 | max_path_length=self.max_path_length, 44 | animated=False, 45 | observation_key='observation', 46 | desired_goal_key='desired_goal', 47 | get_action_kwargs=dict( 48 | return_stacked_softmax=False, 49 | mask=np.ones((1, self.env.unwrapped.num_blocks)), 50 | deterministic=True 51 | ) 52 | ) 53 | else: 54 | path = rollout( 55 | self.env, self.policy, max_path_length=self.max_path_length 56 | ) 57 | paths.append(path) 58 | n_steps_total += len(path['observations']) 59 | return paths 60 | -------------------------------------------------------------------------------- /rlkit/samplers/rollout_functions.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from gym.wrappers.monitor import Monitor 3 | 4 | 5 | def create_rollout_function(rollout_function, **initial_kwargs): 6 | """ 7 | initial_kwargs for 8 | rollout_function=tdm_rollout_function may contain: 9 | init_tau, 10 | decrement_tau, 11 | cycle_tau, 12 | get_action_kwargs, 13 | observation_key, 14 | desired_goal_key, 15 | rollout_function=multitask_rollout may contain: 16 | observation_key, 17 | desired_goal_key, 18 | """ 19 | def wrapped_rollout_func(*args, **dynamic_kwargs): 20 | combined_args = { 21 | **initial_kwargs, 22 | **dynamic_kwargs 23 | } 24 | return rollout_function(*args, **combined_args) 25 | return wrapped_rollout_func 26 | 27 | 28 | def multitask_rollout( 29 | env, 30 | agent, 31 | max_path_length=np.inf, 32 | animated=False, 33 | observation_key='observation', 34 | desired_goal_key='desired_goal', 35 | get_action_kwargs=None, 36 | max_num_blocks=None, 37 | cur_num_blocks=None, 38 | goal=None, 39 | reset_observation=None, 40 | pad_obs=False 41 | ): 42 | if get_action_kwargs is None: 43 | get_action_kwargs = {} 44 | full_observations = [] 45 | observations = [] 46 | actions = [] 47 | rewards = [] 48 | terminals = [] 49 | agent_infos = [] 50 | env_infos = [] 51 | next_observations = [] 52 | path_length = 0 53 | agent.reset() 54 | o = env.reset() 55 | 56 | if goal is not None: 57 | assert reset_observation is not None 58 | env.unwrapped.goal = goal 59 | env.unwrapped.reset_sim_from_obs(reset_observation) 60 | o = env.unwrapped._get_obs() 61 | 62 | if animated: 63 | env.render() 64 | goal = o[desired_goal_key] 65 | key_sizes = dict(observation=15, 66 | goal=3) 67 | 68 | if max_num_blocks is not None: 69 | lop_state = goal[-3:].copy() 70 | assert (0 == lop_state).all() 71 | goal = goal[:-3].copy() 72 | goal = np.pad(goal, 73 | ((0, int(max_num_blocks - cur_num_blocks) * key_sizes['goal'])), 74 | "constant", constant_values=((0, -999))) 75 | goal = np.concatenate((goal, lop_state)).copy() 76 | 77 | while path_length < max_path_length: 78 | full_observations.append(o) 79 | observations.append(o) 80 | 81 | o = o[observation_key] 82 | if max_num_blocks is not None: 83 | o = np.pad(o, 84 | ((0, int(max_num_blocks - cur_num_blocks) * key_sizes['observation'])), 85 | "constant", constant_values=((0, -999))) 86 | 87 | obs_goal = np.hstack((o, goal)) 88 | 89 | a, agent_info = agent.get_action(obs_goal, **get_action_kwargs) 90 | 91 | next_o_dic, r, d, env_info = env.step(a) 92 | if animated: 93 | env.render() 94 | rewards.append(r) 95 | 96 | terminals.append(d) 97 | actions.append(a) 98 | next_observations.append(next_o_dic) 99 | agent_infos.append(agent_info) 100 | env_infos.append(env_info) 101 | path_length += 1 102 | o = next_o_dic 103 | if d: 104 | break 105 | full_observations.append(o) 106 | actions = np.array(actions) 107 | if len(actions.shape) == 1: 108 | actions = np.expand_dims(actions, 1) 109 | # observations = np.array(observations) 110 | # next_observations = np.array(next_observations) 111 | 112 | assert len(get_action_kwargs['mask'].shape) == 2 113 | 114 | assert isinstance(observations[0], dict) 115 | return dict( 116 | observations=observations, 117 | actions=actions, 118 | rewards=np.array(rewards).reshape(-1, 1), 119 | next_observations=next_observations, 120 | terminals=np.array(terminals).reshape(-1, 1), 121 | agent_infos=agent_infos, 122 | env_infos=env_infos, 123 | goals=np.repeat(goal[None], path_length, 0), 124 | mask=np.broadcast_to(get_action_kwargs['mask'], (actions.shape[0], get_action_kwargs['mask'].shape[1])), 125 | full_observations=full_observations, 126 | ) 127 | 128 | 129 | def rollout(env, agent, max_path_length=np.inf, animated=False): 130 | """ 131 | The following value for the following keys will be a 2D array, with the 132 | first dimension corresponding to the time dimension. 133 | - observations 134 | - actions 135 | - rewards 136 | - next_observations 137 | - terminals 138 | 139 | The next two elements will be lists of dictionaries, with the index into 140 | the list being the index into the time 141 | - agent_infos 142 | - env_infos 143 | 144 | :param env: 145 | :param agent: 146 | :param max_path_length: 147 | :param animated: 148 | :return: 149 | """ 150 | observations = [] 151 | actions = [] 152 | rewards = [] 153 | terminals = [] 154 | agent_infos = [] 155 | env_infos = [] 156 | o = env.reset() 157 | agent.reset() 158 | next_o = None 159 | path_length = 0 160 | if animated: 161 | env.render() 162 | while path_length < max_path_length: 163 | a, agent_info = agent.get_action(o) 164 | next_o, r, d, env_info = env.step(a) 165 | observations.append(o) 166 | rewards.append(r) 167 | terminals.append(d) 168 | actions.append(a) 169 | agent_infos.append(agent_info) 170 | env_infos.append(env_info) 171 | path_length += 1 172 | if d: 173 | break 174 | o = next_o 175 | if animated: 176 | env.render() 177 | 178 | actions = np.array(actions) 179 | if len(actions.shape) == 1: 180 | actions = np.expand_dims(actions, 1) 181 | observations = np.array(observations) 182 | if len(observations.shape) == 1: 183 | observations = np.expand_dims(observations, 1) 184 | next_o = np.array([next_o]) 185 | next_observations = np.vstack( 186 | ( 187 | observations[1:, :], 188 | np.expand_dims(next_o, 0) 189 | ) 190 | ) 191 | return dict( 192 | observations=observations, 193 | actions=actions, 194 | rewards=np.array(rewards).reshape(-1, 1), 195 | next_observations=next_observations, 196 | terminals=np.array(terminals).reshape(-1, 1), 197 | agent_infos=agent_infos, 198 | env_infos=env_infos, 199 | ) -------------------------------------------------------------------------------- /rlkit/samplers/util.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def rollout(env, agent, max_path_length=np.inf, animated=False): 5 | """ 6 | The following value for the following keys will be a 2D array, with the 7 | first dimension corresponding to the time dimension. 8 | - observations 9 | - actions 10 | - rewards 11 | - next_observations 12 | - terminals 13 | 14 | The next two elements will be lists of dictionaries, with the index into 15 | the list being the index into the time 16 | - agent_infos 17 | - env_infos 18 | 19 | :param env: 20 | :param agent: 21 | :param max_path_length: 22 | :param animated: 23 | :return: 24 | """ 25 | observations = [] 26 | actions = [] 27 | rewards = [] 28 | terminals = [] 29 | agent_infos = [] 30 | env_infos = [] 31 | o = env.reset() 32 | next_o = None 33 | path_length = 0 34 | if animated: 35 | env.render() 36 | while path_length < max_path_length: 37 | a, agent_info = agent.get_action(o) 38 | next_o, r, d, env_info = env.step(a) 39 | observations.append(o) 40 | rewards.append(r) 41 | terminals.append(d) 42 | actions.append(a) 43 | agent_infos.append(agent_info) 44 | env_infos.append(env_info) 45 | path_length += 1 46 | if d: 47 | break 48 | o = next_o 49 | if animated: 50 | env.render() 51 | 52 | actions = np.array(actions) 53 | if len(actions.shape) == 1: 54 | actions = np.expand_dims(actions, 1) 55 | observations = np.array(observations) 56 | if len(observations.shape) == 1: 57 | observations = np.expand_dims(observations, 1) 58 | next_o = np.array([next_o]) 59 | next_observations = np.vstack( 60 | ( 61 | observations[1:, :], 62 | np.expand_dims(next_o, 0) 63 | ) 64 | ) 65 | return dict( 66 | observations=observations, 67 | actions=actions, 68 | rewards=np.array(rewards).reshape(-1, 1), 69 | next_observations=next_observations, 70 | terminals=np.array(terminals).reshape(-1, 1), 71 | agent_infos=agent_infos, 72 | env_infos=env_infos, 73 | ) 74 | 75 | 76 | def split_paths(paths): 77 | """ 78 | Stack multiples obs/actions/etc. from different paths 79 | :param paths: List of paths, where one path is something returned from 80 | the rollout functino above. 81 | :return: Tuple. Every element will have shape batch_size X DIM, including 82 | the rewards and terminal flags. 83 | """ 84 | rewards = [path["rewards"].reshape(-1, 1) for path in paths] 85 | terminals = [path["terminals"].reshape(-1, 1) for path in paths] 86 | actions = [path["actions"] for path in paths] 87 | obs = [path["observations"] for path in paths] 88 | next_obs = [path["next_observations"] for path in paths] 89 | rewards = np.vstack(rewards) 90 | terminals = np.vstack(terminals) 91 | obs = np.vstack(obs) 92 | actions = np.vstack(actions) 93 | next_obs = np.vstack(next_obs) 94 | assert len(rewards.shape) == 2 95 | assert len(terminals.shape) == 2 96 | assert len(obs.shape) == 2 97 | assert len(actions.shape) == 2 98 | assert len(next_obs.shape) == 2 99 | return rewards, terminals, obs, actions, next_obs 100 | 101 | 102 | def split_paths_to_dict(paths): 103 | rewards, terminals, obs, actions, next_obs = split_paths(paths) 104 | return dict( 105 | rewards=rewards, 106 | terminals=terminals, 107 | observations=obs, 108 | actions=actions, 109 | next_observations=next_obs, 110 | ) 111 | 112 | 113 | def get_stat_in_paths(paths, dict_name, scalar_name): 114 | if len(paths) == 0: 115 | return np.array([[]]) 116 | 117 | if type(paths[0][dict_name]) == dict: 118 | # Support rllab interface 119 | return [path[dict_name][scalar_name] for path in paths] 120 | 121 | return [ 122 | [info[scalar_name] for info in path[dict_name]] 123 | for path in paths 124 | ] -------------------------------------------------------------------------------- /rlkit/torch/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/richardrl/rlkit-relational/e01973d0a7e393cb31fbd48e8180ab0d3d8d2a2e/rlkit/torch/__init__.py -------------------------------------------------------------------------------- /rlkit/torch/core.py: -------------------------------------------------------------------------------- 1 | import abc 2 | import numpy as np 3 | from collections import OrderedDict 4 | 5 | from torch import nn as nn 6 | from torch.autograd import Variable 7 | 8 | from rlkit.torch import pytorch_util as ptu 9 | from rlkit.core.serializable import Serializable 10 | from functools import reduce 11 | import torch 12 | 13 | 14 | class PyTorchModule(nn.Module, Serializable, metaclass=abc.ABCMeta): 15 | 16 | def get_param_values(self): 17 | return self.state_dict() 18 | 19 | def set_param_values(self, param_values): 20 | # new_param_values = param_values.copy() 21 | # try: 22 | self.load_state_dict(param_values) 23 | # except RuntimeError as e: 24 | # import re 25 | # e_str = re.search("(?<=state_dict: ).*(?=\.)", str(e)).group(0) 26 | # e_str.replace(" ", "") 27 | # x_stripped = [x.strip() for x in e_str.split(",")] 28 | # for x in x_stripped: 29 | # new_param_values[x] = nn.Parameter(torch.tensor(1.0)) 30 | 31 | def get_param_values_np(self): 32 | state_dict = self.state_dict() 33 | np_dict = OrderedDict() 34 | for key, tensor in state_dict.items(): 35 | np_dict[key] = ptu.get_numpy(tensor) 36 | return np_dict 37 | 38 | def set_param_values_np(self, param_values): 39 | torch_dict = OrderedDict() 40 | for key, tensor in param_values.items(): 41 | torch_dict[key] = ptu.from_numpy(tensor) 42 | self.load_state_dict(torch_dict) 43 | 44 | def copy(self): 45 | copy = Serializable.clone(self) 46 | ptu.copy_model_params_from_to(self, copy) 47 | return copy 48 | 49 | def save_init_params(self, locals): 50 | """ 51 | Should call this FIRST THING in the __init__ method if you ever want 52 | to serialize or clone this network. 53 | 54 | Usage: 55 | ``` 56 | def __init__(self, ...): 57 | self.init_serialization(locals()) 58 | ... 59 | ``` 60 | :param locals: 61 | :return: 62 | """ 63 | Serializable.quick_init(self, locals) 64 | 65 | def __getstate__(self): 66 | d = Serializable.__getstate__(self) 67 | d["params"] = self.get_param_values() 68 | return d 69 | 70 | def __setstate__(self, d): 71 | Serializable.__setstate__(self, d) 72 | self.set_param_values(d["params"]) 73 | 74 | def regularizable_parameters(self): 75 | """ 76 | Return generator of regularizable parameters. Right now, all non-flat 77 | vectors are assumed to be regularizabled, presumably because only 78 | biases are flat. 79 | 80 | :return: 81 | """ 82 | for param in self.parameters(): 83 | if len(param.size()) > 1: 84 | yield param 85 | 86 | def eval_np(self, *args, **kwargs): 87 | """ 88 | Eval this module with a numpy interface 89 | 90 | Same as a call to __call__ except all Variable input/outputs are 91 | replaced with numpy equivalents. 92 | 93 | Assumes the output is either a single object or a tuple of objects. 94 | """ 95 | torch_args = tuple(torch_ify(x) for x in args) 96 | torch_kwargs = {k: torch_ify(v) for k, v in kwargs.items()} 97 | outputs = self.__call__(*torch_args, **torch_kwargs) 98 | if isinstance(outputs, tuple): 99 | return tuple(np_ify(x) for x in outputs) # return tuple(np_ify(x) for x in outputs) 100 | return recursive_np_ify(outputs) 101 | else: 102 | return np_ify(outputs) 103 | 104 | 105 | def torch_ify(np_array_or_other): 106 | if isinstance(np_array_or_other, np.ndarray): 107 | return ptu.from_numpy(np_array_or_other) 108 | else: 109 | return np_array_or_other 110 | 111 | 112 | def np_ify(tensor_or_other): 113 | if isinstance(tensor_or_other, Variable): 114 | return ptu.get_numpy(tensor_or_other) 115 | else: 116 | return tensor_or_other 117 | 118 | 119 | def recursive_np_ify(object_holding_tensor): 120 | if isinstance(object_holding_tensor, torch.Tensor): 121 | return np_ify(object_holding_tensor) 122 | elif isinstance(object_holding_tensor, dict): 123 | return {k: np_ify(v) for k, v in object_holding_tensor.items()} 124 | elif isinstance(object_holding_tensor, tuple): 125 | return tuple([recursive_np_ify(el) for el in object_holding_tensor]) 126 | 127 | 128 | def rgetattr(obj, attr, *args): 129 | """See https://stackoverflow.com/questions/31174295/getattr-and-setattr-on-nested-objects""" 130 | def _getattr(obj, attr): 131 | if obj is None: 132 | return None 133 | return getattr(obj, attr, *args) 134 | return reduce(_getattr, [obj] + attr.split('.')) 135 | 136 | 137 | def rsetattr(obj, attr, val): 138 | pre, _, post = attr.rpartition('.') 139 | return setattr(rgetattr(obj, pre) if pre else obj, post, val) -------------------------------------------------------------------------------- /rlkit/torch/data_management/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/richardrl/rlkit-relational/e01973d0a7e393cb31fbd48e8180ab0d3d8d2a2e/rlkit/torch/data_management/__init__.py -------------------------------------------------------------------------------- /rlkit/torch/data_management/normalizer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import rlkit.torch.pytorch_util as ptu 3 | import numpy as np 4 | 5 | from rlkit.data_management.normalizer import Normalizer, FixedNormalizer 6 | from rlkit.torch.relational.relational_util import fetch_preprocessing, invert_fetch_preprocessing 7 | from rlkit.torch.core import PyTorchModule 8 | 9 | 10 | class TorchNormalizer(Normalizer): 11 | """ 12 | Update with np array, but de/normalize pytorch Tensors. 13 | """ 14 | def normalize(self, v, clip_range=None): 15 | if not self.synchronized: 16 | self.synchronize() 17 | if clip_range is None: 18 | clip_range = self.default_clip_range 19 | mean = ptu.from_numpy(self.mean) 20 | std = ptu.from_numpy(self.std) 21 | if v.dim() == 2: 22 | # Unsqueeze along the batch use automatic broadcasting 23 | mean = mean.unsqueeze(0) 24 | std = std.unsqueeze(0) 25 | return torch.clamp((v - mean) / std, -clip_range, clip_range) 26 | 27 | def denormalize(self, v): 28 | if not self.synchronized: 29 | self.synchronize() 30 | mean = ptu.from_numpy(self.mean) 31 | std = ptu.from_numpy(self.std) 32 | if v.dim() == 2: 33 | mean = mean.unsqueeze(0) 34 | std = std.unsqueeze(0) 35 | return mean + v * std 36 | 37 | 38 | class TorchFixedNormalizer(FixedNormalizer): 39 | def normalize(self, v, clip_range=None): 40 | if clip_range is None: 41 | clip_range = self.default_clip_range 42 | mean = ptu.from_numpy(self.mean) 43 | std = ptu.from_numpy(self.std) 44 | if v.dim() == 2: 45 | # Unsqueeze along the batch use automatic broadcasting 46 | mean = mean.unsqueeze(0) 47 | std = std.unsqueeze(0) 48 | return torch.clamp((v - mean) / std, -clip_range, clip_range) 49 | 50 | def normalize_scale(self, v): 51 | """ 52 | Only normalize the scale. Do not subtract the mean. 53 | """ 54 | std = ptu.from_numpy(self.std) 55 | if v.dim() == 2: 56 | std = std.unsqueeze(0) 57 | return v / std 58 | 59 | def denormalize(self, v): 60 | mean = ptu.from_numpy(self.mean) 61 | std = ptu.from_numpy(self.std) 62 | if v.dim() == 2: 63 | mean = mean.unsqueeze(0) 64 | std = std.unsqueeze(0) 65 | return mean + v * std 66 | 67 | def denormalize_scale(self, v): 68 | """ 69 | Only denormalize the scale. Do not add the mean. 70 | """ 71 | std = ptu.from_numpy(self.std) 72 | if v.dim() == 2: 73 | std = std.unsqueeze(0) 74 | return v * std 75 | 76 | 77 | class CompositeNormalizer: 78 | """ 79 | Useful for normalizing different data types e.g. when using the same normalizer for the Q function and the policy function 80 | """ 81 | def __init__(self, 82 | obs_dim, 83 | action_dim, 84 | reshape_blocks=False, 85 | fetch_kwargs=dict(), 86 | **kwargs): 87 | # self.save_init_params(locals()) 88 | self.observation_dim = obs_dim 89 | self.action_dim = action_dim 90 | self.obs_normalizer = TorchNormalizer(self.observation_dim, **kwargs) 91 | self.action_normalizer = TorchNormalizer(self.action_dim) 92 | self.reshape_blocks = reshape_blocks 93 | self.kwargs = kwargs 94 | self.fetch_kwargs = fetch_kwargs 95 | 96 | def normalize_all( 97 | self, 98 | flat_obs, 99 | actions): 100 | """ 101 | 102 | :param flat_obs: 103 | :param actions: 104 | :return: 105 | """ 106 | if flat_obs is not None: 107 | flat_obs = self.obs_normalizer.normalize(flat_obs) 108 | if actions is not None: 109 | actions = self.action_normalizer.normalize(actions) 110 | return flat_obs, actions 111 | 112 | def update(self, data_type, v, mask=None): 113 | """ 114 | Takes in tensor and updates numpy array 115 | :param data_type: 116 | :param v: 117 | :return: 118 | """ 119 | if data_type == "obs": 120 | # Reshape_blocks: takes flat, turns batch, normalizes batch, updates the obs_normalizer... 121 | if self.reshape_blocks: 122 | batched_robot_state, batched_objects_and_goals = fetch_preprocessing(v, mask=mask, return_combined_state=False, **self.fetch_kwargs) 123 | N, nB, nR = batched_robot_state.size() 124 | v = torch.cat((batched_robot_state, batched_objects_and_goals), dim=-1).view(N * nB, -1) 125 | if mask is not None: 126 | v = v[mask.view(N * nB).to(dtype=torch.bool)] 127 | 128 | # if self.lop_state_dim: 129 | # v = v.narrow(-1, -3, 3) 130 | self.obs_normalizer.update(ptu.get_numpy(v)) 131 | elif data_type == "actions": 132 | self.action_normalizer.update(ptu.get_numpy(v)) 133 | else: 134 | raise("data_type not set") -------------------------------------------------------------------------------- /rlkit/torch/ddpg/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/richardrl/rlkit-relational/e01973d0a7e393cb31fbd48e8180ab0d3d8d2a2e/rlkit/torch/ddpg/__init__.py -------------------------------------------------------------------------------- /rlkit/torch/distributions.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.distributions import Distribution, Normal 3 | import rlkit.torch.pytorch_util as ptu 4 | 5 | 6 | class TanhNormal(Distribution): 7 | """ 8 | Represent distribution of X where 9 | X ~ tanh(Z) 10 | Z ~ N(mean, std) 11 | 12 | Note: this is not very numerically stable. 13 | """ 14 | def __init__(self, normal_mean, normal_std, epsilon=1e-6): 15 | """ 16 | :param normal_mean: Mean of the normal distribution 17 | :param normal_std: Std of the normal distribution 18 | :param epsilon: Numerical stability epsilon when computing log-prob. 19 | """ 20 | self.normal_mean = normal_mean 21 | self.normal_std = normal_std 22 | self.normal = Normal(normal_mean, normal_std) 23 | self.epsilon = epsilon 24 | 25 | def sample_n(self, n, return_pre_tanh_value=False): 26 | z = self.normal.sample_n(n) 27 | if return_pre_tanh_value: 28 | return torch.tanh(z), z 29 | else: 30 | return torch.tanh(z) 31 | 32 | def log_prob(self, value, pre_tanh_value=None): 33 | """ 34 | 35 | :param value: some value, x 36 | :param pre_tanh_value: arctanh(x) 37 | :return: 38 | """ 39 | if pre_tanh_value is None: 40 | pre_tanh_value = torch.log( 41 | (1+value) / (1-value) 42 | ) / 2 43 | return self.normal.log_prob(pre_tanh_value) - torch.log( 44 | 1 - value * value + self.epsilon 45 | ) 46 | 47 | def sample(self, return_pretanh_value=False): 48 | """ 49 | Gradients will and should *not* pass through this operation. 50 | 51 | See https://github.com/pytorch/pytorch/issues/4620 for discussion. 52 | """ 53 | z = self.normal.sample().detach() 54 | 55 | if return_pretanh_value: 56 | return torch.tanh(z), z 57 | else: 58 | return torch.tanh(z) 59 | 60 | def rsample(self, return_pretanh_value=False): 61 | """ 62 | Sampling in the reparameterization case. 63 | """ 64 | z = ( 65 | self.normal_mean + 66 | self.normal_std * 67 | Normal( 68 | ptu.zeros(self.normal_mean.size(), torch_device=ptu.get_device()), 69 | ptu.ones(self.normal_std.size(), torch_device=ptu.get_device()) 70 | ).sample() 71 | ) 72 | z.requires_grad_() 73 | 74 | if return_pretanh_value: 75 | return torch.tanh(z), z 76 | else: 77 | return torch.tanh(z) 78 | 79 | -------------------------------------------------------------------------------- /rlkit/torch/dqn/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/richardrl/rlkit-relational/e01973d0a7e393cb31fbd48e8180ab0d3d8d2a2e/rlkit/torch/dqn/__init__.py -------------------------------------------------------------------------------- /rlkit/torch/dqn/double_dqn.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | 3 | import numpy as np 4 | import torch 5 | 6 | import rlkit.torch.pytorch_util as ptu 7 | from rlkit.core.eval_util import create_stats_ordered_dict 8 | from rlkit.torch.dqn.dqn import DQN 9 | 10 | 11 | class DoubleDQN(DQN): 12 | def _do_training(self): 13 | batch = self.get_batch(training=True) 14 | rewards = batch['rewards'] 15 | terminals = batch['terminals'] 16 | obs = batch['observations'] 17 | actions = batch['actions'] 18 | next_obs = batch['next_observations'] 19 | 20 | """ 21 | Compute loss 22 | """ 23 | 24 | best_action_idxs = self.qf(next_obs).max( 25 | 1, keepdim=True 26 | )[1] 27 | target_q_values = self.target_qf(next_obs).gather( 28 | 1, best_action_idxs 29 | ).detach() 30 | y_target = rewards + (1. - terminals) * self.discount * target_q_values 31 | y_target = y_target.detach() 32 | # actions is a one-hot vector 33 | y_pred = torch.sum(self.qf(obs) * actions, dim=1, keepdim=True) 34 | qf_loss = self.qf_criterion(y_pred, y_target) 35 | 36 | """ 37 | Update networks 38 | """ 39 | self.qf_optimizer.zero_grad() 40 | qf_loss.backward() 41 | self.qf_optimizer.step() 42 | self._update_target_network() 43 | 44 | """ 45 | Save some statistics for eval using just one batch. 46 | """ 47 | if self.need_to_update_eval_statistics: 48 | self.need_to_update_eval_statistics = False 49 | self.eval_statistics['QF Loss'] = np.mean(ptu.get_numpy(qf_loss)) 50 | self.eval_statistics.update(create_stats_ordered_dict( 51 | 'Y Predictions', 52 | ptu.get_numpy(y_pred), 53 | )) 54 | -------------------------------------------------------------------------------- /rlkit/torch/dqn/dqn.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | 3 | import numpy as np 4 | import torch 5 | import torch.optim as optim 6 | from torch import nn as nn 7 | 8 | import rlkit.torch.pytorch_util as ptu 9 | from rlkit.exploration_strategies.base import ( 10 | PolicyWrappedWithExplorationStrategy 11 | ) 12 | from rlkit.exploration_strategies.epsilon_greedy import EpsilonGreedy 13 | from rlkit.core.eval_util import create_stats_ordered_dict 14 | from rlkit.policies.argmax import ArgmaxDiscretePolicy 15 | from rlkit.torch.torch_rl_algorithm import TorchRLAlgorithm 16 | 17 | 18 | class DQN(TorchRLAlgorithm): 19 | def __init__( 20 | self, 21 | env, 22 | qf, 23 | policy=None, 24 | learning_rate=1e-3, 25 | use_hard_updates=False, 26 | hard_update_period=1000, 27 | tau=0.001, 28 | epsilon=0.1, 29 | qf_criterion=None, 30 | **kwargs 31 | ): 32 | """ 33 | 34 | :param env: Env. 35 | :param qf: QFunction. Maps from state to action Q-values. 36 | :param learning_rate: Learning rate for qf. Adam is used. 37 | :param use_hard_updates: Use a hard rather than soft update. 38 | :param hard_update_period: How many gradient steps before copying the 39 | parameters over. Used if `use_hard_updates` is True. 40 | :param tau: Soft target tau to update target QF. Used if 41 | `use_hard_updates` is False. 42 | :param epsilon: Probability of taking a random action. 43 | :param kwargs: kwargs to pass onto TorchRLAlgorithm 44 | """ 45 | exploration_strategy = EpsilonGreedy( 46 | action_space=env.action_space, 47 | prob_random_action=epsilon, 48 | ) 49 | self.policy = policy or ArgmaxDiscretePolicy(qf) 50 | exploration_policy = PolicyWrappedWithExplorationStrategy( 51 | exploration_strategy=exploration_strategy, 52 | policy=self.policy, 53 | ) 54 | super().__init__( 55 | env, exploration_policy, eval_policy=self.policy, **kwargs 56 | ) 57 | self.qf = qf 58 | self.target_qf = self.qf.copy() 59 | self.learning_rate = learning_rate 60 | self.use_hard_updates = use_hard_updates 61 | self.hard_update_period = hard_update_period 62 | self.tau = tau 63 | self.qf_optimizer = optim.Adam( 64 | self.qf.parameters(), 65 | lr=self.learning_rate, 66 | ) 67 | self.qf_criterion = qf_criterion or nn.MSELoss() 68 | 69 | 70 | def _do_training(self): 71 | batch = self.get_batch() 72 | rewards = batch['rewards'] 73 | terminals = batch['terminals'] 74 | obs = batch['observations'] 75 | actions = batch['actions'] 76 | next_obs = batch['next_observations'] 77 | 78 | """ 79 | Compute loss 80 | """ 81 | 82 | target_q_values = self.target_qf(next_obs).detach().max( 83 | 1, keepdim=True 84 | )[0] 85 | y_target = rewards + (1. - terminals) * self.discount * target_q_values 86 | y_target = y_target.detach() 87 | # actions is a one-hot vector 88 | y_pred = torch.sum(self.qf(obs) * actions, dim=1, keepdim=True) 89 | qf_loss = self.qf_criterion(y_pred, y_target) 90 | 91 | """ 92 | Update networks 93 | """ 94 | self.qf_optimizer.zero_grad() 95 | qf_loss.backward() 96 | self.qf_optimizer.step() 97 | self._update_target_network() 98 | 99 | """ 100 | Save some statistics for eval using just one batch. 101 | """ 102 | if self.need_to_update_eval_statistics: 103 | self.need_to_update_eval_statistics = False 104 | self.eval_statistics['QF Loss'] = np.mean(ptu.get_numpy(qf_loss)) 105 | self.eval_statistics.update(create_stats_ordered_dict( 106 | 'Y Predictions', 107 | ptu.get_numpy(y_pred), 108 | )) 109 | 110 | def _update_target_network(self): 111 | if self.use_hard_updates: 112 | if self._n_train_steps_total % self.hard_update_period == 0: 113 | ptu.copy_model_params_from_to(self.qf, self.target_qf) 114 | else: 115 | ptu.soft_update_from_to(self.qf, self.target_qf, self.tau) 116 | 117 | def offline_evaluate(self, epoch): 118 | raise NotImplementedError() 119 | 120 | def get_epoch_snapshot(self, epoch): 121 | snapshot = super().get_epoch_snapshot(epoch) 122 | snapshot.update( 123 | exploration_policy=self.exploration_policy, 124 | policy=self.policy, 125 | ) 126 | return snapshot 127 | 128 | @property 129 | def networks(self): 130 | return [ 131 | self.qf, 132 | self.target_qf, 133 | ] 134 | -------------------------------------------------------------------------------- /rlkit/torch/her/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/richardrl/rlkit-relational/e01973d0a7e393cb31fbd48e8180ab0d3d8d2a2e/rlkit/torch/her/__init__.py -------------------------------------------------------------------------------- /rlkit/torch/her/her_replay_buffer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from rlkit.data_management.env_replay_buffer import EnvReplayBuffer 4 | 5 | 6 | class RelabelingReplayBuffer(EnvReplayBuffer): 7 | """ 8 | Save goals from the same trajectory into the replay buffer. 9 | Only add_path is implemented. 10 | Implementation details: 11 | - Every sample from [0, self._size] will be valid. 12 | """ 13 | def __init__( 14 | self, 15 | max_size, 16 | env, 17 | fraction_goals_are_rollout_goals=1.0, # default, no HER 18 | fraction_resampled_goals_are_env_goals=0.0, # this many goals are just sampled from environment directly 19 | ): 20 | """ 21 | :param resampling_strategy: How to resample states from the rest of 22 | the trajectory? 23 | - 'uniform': Sample them uniformly 24 | - 'truncated_geometric': Used a truncated geometric distribution 25 | """ 26 | super().__init__(max_size, env) 27 | self._goals = np.zeros((max_size, self.env.goal_dim)) 28 | self._num_steps_left = np.zeros((max_size, 1)) 29 | self.fraction_goals_are_rollout_goals = fraction_goals_are_rollout_goals 30 | self.fraction_resampled_goals_are_env_goals = fraction_resampled_goals_are_env_goals 31 | 32 | # Let j be any index in self._idx_to_future_obs_idx[i] 33 | # Then self._next_obs[j] is a valid next observation for observation i 34 | self._idx_to_future_obs_idx = [None] * max_size 35 | 36 | def add_sample(self, observation, action, reward, terminal, 37 | next_observation, **kwargs): 38 | raise NotImplementedError("Only use add_path") 39 | 40 | def add_path(self, path): 41 | obs = path["observations"] 42 | actions = path["actions"] 43 | rewards = path["rewards"] 44 | next_obs = path["next_observations"] 45 | terminals = path["terminals"] 46 | goals = path["goals"] 47 | num_steps_left = path["rewards"].copy() # path["num_steps_left"] # irrelevant for non-TDM 48 | path_len = len(rewards) 49 | 50 | actions = flatten_n(actions) 51 | obs = flatten_n(obs) 52 | next_obs = flatten_n(next_obs) 53 | 54 | if self._top + path_len >= self._max_replay_buffer_size: 55 | num_pre_wrap_steps = self._max_replay_buffer_size - self._top 56 | # numpy slice 57 | pre_wrap_buffer_slice = np.s_[ 58 | self._top:self._top + num_pre_wrap_steps, : 59 | ] 60 | pre_wrap_path_slice = np.s_[0:num_pre_wrap_steps, :] 61 | 62 | num_post_wrap_steps = path_len - num_pre_wrap_steps 63 | post_wrap_buffer_slice = slice(0, num_post_wrap_steps) 64 | post_wrap_path_slice = slice(num_pre_wrap_steps, path_len) 65 | for buffer_slice, path_slice in [ 66 | (pre_wrap_buffer_slice, pre_wrap_path_slice), 67 | (post_wrap_buffer_slice, post_wrap_path_slice), 68 | ]: 69 | self._observations[buffer_slice] = obs[path_slice] 70 | self._actions[buffer_slice] = actions[path_slice] 71 | self._rewards[buffer_slice] = rewards[path_slice] 72 | self._next_obs[buffer_slice] = next_obs[path_slice] 73 | self._terminals[buffer_slice] = terminals[path_slice] 74 | self._goals[buffer_slice] = goals[path_slice] 75 | self._num_steps_left[buffer_slice] = num_steps_left[path_slice] 76 | # Pointers from before the wrap 77 | for i in range(self._top, self._max_replay_buffer_size): 78 | self._idx_to_future_obs_idx[i] = np.hstack(( 79 | # Pre-wrap indices 80 | np.arange(i, self._max_replay_buffer_size), 81 | # Post-wrap indices 82 | np.arange(0, num_post_wrap_steps) 83 | )) 84 | # Pointers after the wrap 85 | for i in range(0, num_post_wrap_steps): 86 | self._idx_to_future_obs_idx[i] = np.arange( 87 | i, 88 | num_post_wrap_steps, 89 | ) 90 | else: 91 | slc = np.s_[self._top:self._top + path_len, :] 92 | self._observations[slc] = obs 93 | self._actions[slc] = actions 94 | self._rewards[slc] = rewards 95 | self._next_obs[slc] = next_obs 96 | self._terminals[slc] = terminals 97 | self._goals[slc] = goals 98 | self._num_steps_left[slc] = num_steps_left 99 | for i in range(self._top, self._top + path_len): 100 | self._idx_to_future_obs_idx[i] = np.arange( 101 | i, self._top + path_len 102 | ) 103 | self._top = (self._top + path_len) % self._max_replay_buffer_size 104 | self._size = min(self._size + path_len, self._max_replay_buffer_size) 105 | 106 | def _sample_indices(self, batch_size): 107 | return np.random.randint(0, self._size, batch_size) 108 | 109 | def random_batch(self, batch_size): 110 | indices = self._sample_indices(batch_size) 111 | next_obs_idxs = [] 112 | for i in indices: 113 | possible_next_obs_idxs = self._idx_to_future_obs_idx[i] 114 | # This is generally faster than random.choice. Makes you wonder what 115 | # random.choice is doing 116 | num_options = len(possible_next_obs_idxs) 117 | if num_options == 1: 118 | next_obs_i = 0 119 | else: 120 | next_obs_i = int(np.random.randint(0, num_options)) 121 | next_obs_idxs.append(possible_next_obs_idxs[next_obs_i]) 122 | next_obs_idxs = np.array(next_obs_idxs) 123 | resampled_goals = self.env.convert_obs_to_goals( 124 | self._next_obs[next_obs_idxs] 125 | ) 126 | num_goals_are_from_rollout = int( 127 | batch_size * self.fraction_goals_are_rollout_goals 128 | ) 129 | if num_goals_are_from_rollout > 0: 130 | resampled_goals[:num_goals_are_from_rollout] = self._goals[ 131 | indices[:num_goals_are_from_rollout] 132 | ] 133 | # recompute rewards 134 | new_obs = self._observations[indices] 135 | new_next_obs = self._next_obs[indices] 136 | new_actions = self._actions[indices] 137 | new_rewards = self._rewards[indices].copy() # needs to be recomputed 138 | random_numbers = np.random.rand(batch_size) 139 | for i in range(batch_size): 140 | if random_numbers[i] < self.fraction_resampled_goals_are_env_goals: 141 | resampled_goals[i, :] = self.env.sample_goal_for_rollout() 142 | 143 | new_reward = self.env.compute_her_reward_np( 144 | new_obs[i, :], 145 | new_actions[i, :], 146 | new_next_obs[i, :], 147 | resampled_goals[i, :], 148 | ) 149 | new_rewards[i] = new_reward 150 | 151 | batch = dict( 152 | observations=new_obs, 153 | actions=new_actions, 154 | rewards=new_rewards, 155 | terminals=self._terminals[indices], 156 | next_observations=new_next_obs, 157 | goals_used_for_rollout=self._goals[indices], 158 | resampled_goals=resampled_goals, 159 | num_steps_left=self._num_steps_left[indices], 160 | indices=np.array(indices).reshape(-1, 1), 161 | goals=resampled_goals, 162 | ) 163 | return batch 164 | 165 | def flatten_n(xs): 166 | xs = np.asarray(xs) 167 | return xs.reshape((xs.shape[0], -1)) 168 | 169 | 170 | def flatten_env_info(env_infos, env_info_keys): 171 | # Turns list of env_info dicts into env_info dict of 2D np arrays 172 | return { 173 | key: flatten_n( 174 | [env_infos[i][key] for i in range(len(env_infos))] 175 | ) 176 | for key in env_info_keys 177 | } 178 | -------------------------------------------------------------------------------- /rlkit/torch/modules.py: -------------------------------------------------------------------------------- 1 | """ 2 | Contain some self-contained modules. 3 | """ 4 | import torch 5 | import torch.nn as nn 6 | 7 | 8 | class HuberLoss(nn.Module): 9 | def __init__(self, delta=1): 10 | super().__init__() 11 | self.huber_loss_delta1 = nn.SmoothL1Loss() 12 | self.delta = delta 13 | 14 | def forward(self, x, x_hat): 15 | loss = self.huber_loss_delta1(x / self.delta, x_hat / self.delta) 16 | return loss * self.delta * self.delta 17 | 18 | 19 | class LayerNorm(nn.Module): 20 | """ 21 | Simple 1D LayerNorm. 22 | """ 23 | 24 | def __init__(self, features, center=True, scale=False, eps=1e-6): 25 | super().__init__() 26 | self.center = center 27 | self.scale = scale 28 | self.eps = eps 29 | if self.scale: 30 | self.scale_param = nn.Parameter(torch.ones(features)) 31 | else: 32 | self.scale_param = None 33 | if self.center: 34 | self.center_param = nn.Parameter(torch.zeros(features)) 35 | else: 36 | self.center_param = None 37 | 38 | def forward(self, x): 39 | mean = x.mean(-1, keepdim=True) 40 | std = x.std(-1, keepdim=True) 41 | output = (x - mean) / (std + self.eps) 42 | if self.scale: 43 | output = output * self.scale_param 44 | if self.center: 45 | output = output + self.center_param 46 | return output 47 | -------------------------------------------------------------------------------- /rlkit/torch/optim/mpi_adam.py: -------------------------------------------------------------------------------- 1 | # Adapted from https://github.com/openai/baselines/blob/master/baselines/common/mpi_adam.py 2 | 3 | import rlkit.torch.optim.util as U 4 | import torch 5 | from torch.optim.optimizer import Optimizer 6 | import math 7 | import numpy as np 8 | import rlkit.torch.pytorch_util as ptu 9 | from rlkit.core.serializable import Serializable 10 | try: 11 | from mpi4py import MPI 12 | except ImportError: 13 | MPI = None 14 | 15 | 16 | class MpiAdam(Optimizer): 17 | def __init__(self, 18 | params, 19 | lr=1e-3, 20 | beta1=0.9, 21 | beta2=0.999, 22 | epsilon=1e-08, 23 | scale_grad_by_procs=True, 24 | comm=None, 25 | gpu_id=0): 26 | # Serializable.quick_init(self, 27 | # locals()) 28 | super().__init__(params, dict()) 29 | self.lr = lr 30 | self.beta1 = beta1 31 | self.beta2 = beta2 32 | self.epsilon = epsilon 33 | self.scale_grad_by_procs = scale_grad_by_procs 34 | total_params = sum([U.num_elements(param) for param in U.get_flat_params(self.param_groups)]) 35 | if ptu.get_mode() == "gpu_opt": 36 | assert gpu_id is not None 37 | self.m = torch.zeros(total_params, dtype=torch.float32).to(device=F"cuda:{gpu_id}") 38 | self.v = torch.zeros(total_params, dtype=torch.float32).to(device=F"cuda:{gpu_id}") 39 | elif not ptu.get_mode(): #CPU is false 40 | self.m = torch.zeros(total_params, dtype=torch.float32) 41 | self.v = torch.zeros(total_params, dtype=torch.float32) 42 | else: 43 | print(ptu.get_mode()) 44 | raise NotImplementedError 45 | self.t = 0 46 | self.set_params_from_flat = U.SetFromFlat(self.param_groups) 47 | self.get_params_as_flat = U.GetParamsFlat(self.param_groups) 48 | self.comm = MPI.COMM_WORLD if comm is None and MPI is not None else comm 49 | 50 | def __getstate__(self): 51 | # d = Serializable.__getstate__(self) 52 | # d = dict() 53 | d = super().__getstate__() 54 | d['lr'] = self.lr 55 | d['beta1'] = self.beta1 56 | d['beta2'] = self.beta2 57 | d['epsilon'] = self.epsilon 58 | d['scale_grad_by_procs'] = self.scale_grad_by_procs 59 | d["m"] = self.m.clone() 60 | d["v"] = self.v.clone() 61 | d["t"] = self.t 62 | return d 63 | 64 | def __setstate__(self, d): 65 | # Serializable.__setstate__(self, d) 66 | super().__setstate__(d) 67 | if "lr" in d.keys(): 68 | self.lr = d['lr'] 69 | else: 70 | self.lr = 3E-4 71 | self.beta1 = d['beta1'] 72 | self.beta2 = d['beta2'] 73 | self.epsilon = d['epsilon'] 74 | self.scale_grad_by_procs = d['scale_grad_by_procs'] 75 | self.m = d["m"] 76 | self.v = d["v"] 77 | self.t = d["t"] 78 | 79 | def reset_state(self, gpu_id=0): 80 | self.m = torch.zeros_like(self.m, dtype=torch.float32).to(device=F"cuda:{gpu_id}") 81 | self.v = torch.zeros_like(self.v, dtype=torch.float32).to(device=F"cuda:{gpu_id}") 82 | 83 | def reconnect_params(self, params): 84 | super().__init__(params, dict()) # This does not alter the optimizer state m or v 85 | self.reinit_flat_operators() 86 | 87 | def reinit_flat_operators(self): 88 | self.set_params_from_flat = U.SetFromFlat(self.param_groups) 89 | self.get_params_as_flat = U.GetParamsFlat(self.param_groups) 90 | 91 | def step(self, closure=None): 92 | """ 93 | Aggregate and reduce gradients across all threads 94 | :param closure: 95 | :return: 96 | """ 97 | # self.param_groups updated on the GPU, stepped, then moved back to its own thread 98 | localg = U.get_flattened_grads(self.param_groups) 99 | if self.t % 100 == 0: 100 | self.check_synced() 101 | if localg.device.type == "cpu": 102 | localg = localg.detach().numpy() 103 | else: 104 | localg = localg.cpu().detach().numpy() 105 | if self.comm is not None: 106 | globalg = np.zeros_like(localg) 107 | self.comm.Allreduce(localg, globalg, op=MPI.SUM) 108 | if self.scale_grad_by_procs: 109 | globalg /= self.comm.Get_size() 110 | if localg.shape[0] > 1 and self.comm.Get_size() > 1: 111 | assert not (localg == globalg).all() 112 | globalg = ptu.from_numpy(globalg, device=torch.device(ptu.get_device())) 113 | else: 114 | globalg = ptu.from_numpy(localg, device=torch.device(ptu.get_device())) 115 | 116 | self.t += 1 117 | a = self.lr * math.sqrt(1 - self.beta2**self.t)/(1 - self.beta1**self.t) 118 | self.m = self.beta1 * self.m + (1 - self.beta1) * globalg 119 | self.v = self.beta2 * self.v + (1 - self.beta2) * (globalg * globalg) 120 | step_update = (- a) * self.m / (torch.sqrt(self.v) + self.epsilon) 121 | # print("before: ") 122 | # print(self.get_params_as_flat()) 123 | self.set_params_from_flat((self.get_flat_params() + step_update).to(device=torch.device("cpu"))) 124 | # print("after, in mpi adam: ") 125 | # print(self.get_params_as_flat()) 126 | 127 | def sync(self): 128 | if self.comm is None: 129 | return 130 | theta = ptu.get_numpy(self.get_params_as_flat()) 131 | self.comm.Bcast(theta, root=0) 132 | self.set_params_from_flat(ptu.from_numpy(theta)) 133 | 134 | def check_synced(self): 135 | # If this fails on iteration 0, remember to call SYNC for each optimizer!!! 136 | if self.comm is None: 137 | return 138 | if self.comm.Get_rank() == 0: # this is root 139 | theta = ptu.get_numpy(self.get_params_as_flat()) 140 | self.comm.Bcast(theta, root=0) 141 | else: 142 | thetalocal = ptu.get_numpy(self.get_params_as_flat()) 143 | thetaroot = np.empty_like(thetalocal) 144 | self.comm.Bcast(thetaroot, root=0) 145 | assert (thetaroot == thetalocal).all(), (thetaroot, thetalocal) 146 | 147 | def to(self, device=None): 148 | if device is None: 149 | device = ptu.device 150 | self.m = self.m.to(device=device) 151 | self.v = self.v.to(device=device) 152 | 153 | def get_flat_params(self): 154 | """ 155 | Get params from a CPU thread 156 | :return: 157 | """ 158 | return torch.cat([param.view([U.num_elements(param)]) for param in U.get_flat_params(self.param_groups)], dim=0) -------------------------------------------------------------------------------- /rlkit/torch/optim/util.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import rlkit.torch.pytorch_util as ptu 4 | 5 | 6 | def get_flattened_grads(parameter_groups): 7 | """Flattens a variables and their gradients. 8 | """ 9 | parameter_list = get_flat_params(parameter_groups) 10 | grads = [param.grad for param in parameter_list] 11 | # return torch.cat([grad.view([num_elements(grad)]) if grad is not None else 0 for grad in grads], dim=0) 12 | return torch.cat([grads[i].view([num_elements(grads[i])]) if grads[i] is not None else ptu.zeros([num_elements(parameter_list[i])]) for i in range(len(grads))], dim=0) 13 | 14 | 15 | def var_shape(x): 16 | return x.shape 17 | 18 | 19 | def num_elements(x): 20 | return intprod(var_shape(x)) 21 | 22 | 23 | def intprod(x): 24 | return int(np.prod(x)) 25 | 26 | 27 | def get_flat_params(parameter_groups): 28 | """ 29 | 30 | :param parameter_groups: 31 | :return: List of parameters pulled out of parameter groups 32 | """ 33 | parameter_list = [] 34 | for parameter_group in parameter_groups: 35 | parameter_list += parameter_group['params'] 36 | 37 | return parameter_list 38 | 39 | 40 | class SetFromFlat(object): 41 | def __init__(self, 42 | parameter_groups): 43 | self.parameter_list = get_flat_params(parameter_groups) 44 | self.shapes = list(map(var_shape, self.parameter_list)) 45 | self.total_size = np.sum([intprod(shape) for shape in self.shapes]) 46 | 47 | def __call__(self, flattened_parameters): 48 | """ 49 | 50 | :param flattened_parameters: type -> Torch Tensor 51 | :return: 52 | """ 53 | # before = flattened_parameters.detach().clone() 54 | # Update worker parameters with flattened_parameters weights broadcasted from root process 55 | start = 0 56 | for (shape, param) in zip(self.shapes, self.parameter_list): 57 | size = intprod(shape) 58 | param.data.copy_(flattened_parameters[start:start+size].view(shape)) 59 | start += size 60 | 61 | # assert not (before == self.parameter_list).all() 62 | assert start == self.total_size, (start, self.total_size) 63 | 64 | 65 | class GetParamsFlat(object): 66 | def __init__(self, 67 | parameter_groups): 68 | self.parameter_list = get_flat_params(parameter_groups) 69 | 70 | def __call__(self): 71 | return torch.cat([param.view([num_elements(param)]) for param in self.parameter_list], dim=0) -------------------------------------------------------------------------------- /rlkit/torch/relational/networks.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | from rlkit.policies.base import ExplorationPolicy 4 | import torch 5 | from rlkit.torch.networks import Mlp 6 | from rlkit.torch.core import PyTorchModule 7 | import rlkit.torch.pytorch_util as ptu 8 | from rlkit.torch.sac.policies import FlattenTanhGaussianPolicy, CompositeNormalizedTanhGaussianPolicy 9 | from rlkit.torch.relational.relational_util import fetch_preprocessing 10 | import numpy as np 11 | import numpy 12 | from torch.nn import Parameter 13 | from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence 14 | from rlkit.torch.pytorch_util import shuffle_and_mask 15 | from rlkit.torch.networks import Mlp 16 | from rlkit.torch.relational.modules import * 17 | 18 | 19 | class GraphPropagation(PyTorchModule): 20 | """ 21 | Input: state 22 | Output: context vector 23 | """ 24 | 25 | def __init__(self, 26 | num_relational_blocks=1, 27 | num_query_heads=1, 28 | graph_module_kwargs=None, 29 | layer_norm=False, 30 | activation_fnx=F.leaky_relu, 31 | graph_module=AttentiveGraphToGraph, 32 | post_residual_activation=True, 33 | recurrent_graph=False, 34 | **kwargs 35 | ): 36 | """ 37 | 38 | :param embedding_dim: 39 | :param lstm_cell_class: 40 | :param lstm_num_layers: 41 | :param graph_module_kwargs: 42 | :param style: OSIL or relational inductive bias. 43 | """ 44 | self.save_init_params(locals()) 45 | super().__init__() 46 | 47 | # Instance settings 48 | 49 | self.num_query_heads = num_query_heads 50 | self.num_relational_blocks = num_relational_blocks 51 | assert graph_module_kwargs, graph_module_kwargs 52 | self.embedding_dim = graph_module_kwargs['embedding_dim'] 53 | 54 | if recurrent_graph: 55 | rg = graph_module(**graph_module_kwargs) 56 | self.graph_module_list = nn.ModuleList( 57 | [rg for i in range(num_relational_blocks)]) 58 | else: 59 | self.graph_module_list = nn.ModuleList( 60 | [graph_module(**graph_module_kwargs) for i in range(num_relational_blocks)]) 61 | 62 | # Layer norm takes in N x nB x nE and normalizes 63 | if layer_norm: 64 | self.layer_norms = nn.ModuleList([nn.LayerNorm(self.embedding_dim) for i in range(num_relational_blocks)]) 65 | 66 | # What's key here is we never use the num_objects in the init, 67 | # which means we can change it as we like for later. 68 | 69 | """ 70 | ReNN Arguments 71 | """ 72 | self.layer_norm = layer_norm 73 | self.activation_fnx = activation_fnx 74 | 75 | def forward(self, vertices, mask=None, *kwargs): 76 | """ 77 | 78 | :param shared_state: state that should be broadcasted along nB dimension. N * (nR + nB * nF) 79 | :param object_and_goal_state: individual objects 80 | :return: 81 | """ 82 | output = vertices 83 | 84 | for i in range(self.num_relational_blocks): 85 | new_output = self.graph_module_list[i](output, mask) 86 | new_output = output + new_output 87 | 88 | output = self.activation_fnx(new_output) # Diff from 7/22 89 | # Apply layer normalization 90 | if self.layer_norm: 91 | output = self.layer_norms[i](output) 92 | return output 93 | 94 | 95 | class ValueReNN(PyTorchModule): 96 | def __init__(self, 97 | graph_propagation, 98 | readout, 99 | input_module=FetchInputPreprocessing, 100 | input_module_kwargs=None, 101 | state_preprocessing_fnx=fetch_preprocessing, 102 | *args, 103 | value_mlp_kwargs=None, 104 | composite_normalizer=None, 105 | **kwargs): 106 | self.save_init_params(locals()) 107 | super().__init__() 108 | self.input_module = input_module(**input_module_kwargs) 109 | self.graph_propagation = graph_propagation 110 | self.readout = readout 111 | self.composite_normalizer = composite_normalizer 112 | 113 | def forward(self, 114 | obs, 115 | mask=None, 116 | return_stacked_softmax=False): 117 | vertices = self.input_module(obs, mask=mask) 118 | new_vertices = self.graph_propagation.forward(vertices, mask=mask) 119 | pooled_output = self.readout(new_vertices, mask=mask) 120 | return pooled_output 121 | 122 | 123 | class QValueReNN(PyTorchModule): 124 | """ 125 | Used for q-value network 126 | """ 127 | 128 | def __init__(self, 129 | graph_propagation, 130 | readout, 131 | input_module=FetchInputPreprocessing, 132 | input_module_kwargs=None, 133 | state_preprocessing_fnx=fetch_preprocessing, 134 | *args, 135 | composite_normalizer=None, 136 | **kwargs): 137 | self.save_init_params(locals()) 138 | super().__init__() 139 | self.graph_propagation = graph_propagation 140 | self.state_preprocessing_fnx = state_preprocessing_fnx 141 | self.readout = readout 142 | self.composite_normalizer = composite_normalizer 143 | self.input_module = input_module(**input_module_kwargs) 144 | 145 | def forward(self, obs, actions, mask=None, return_stacked_softmax=False): 146 | assert mask is not None 147 | vertices = self.input_module(obs, actions=actions, mask=mask) 148 | relational_block_embeddings = self.graph_propagation.forward(vertices, mask=mask) 149 | pooled_output = self.readout(relational_block_embeddings, mask=mask) 150 | assert pooled_output.size(-1) == 1 151 | return pooled_output 152 | 153 | 154 | class PolicyReNN(PyTorchModule, ExplorationPolicy): 155 | """ 156 | Used for policy network 157 | """ 158 | 159 | def __init__(self, 160 | graph_propagation, 161 | readout, 162 | *args, 163 | input_module=FetchInputPreprocessing, 164 | input_module_kwargs=None, 165 | mlp_class=FlattenTanhGaussianPolicy, 166 | composite_normalizer=None, 167 | batch_size=None, 168 | **kwargs): 169 | self.save_init_params(locals()) 170 | super().__init__() 171 | self.composite_normalizer = composite_normalizer 172 | 173 | # Internal modules 174 | self.graph_propagation = graph_propagation 175 | self.selection_attention = readout 176 | 177 | self.mlp = mlp_class(**kwargs['mlp_kwargs']) 178 | self.input_module = input_module(**input_module_kwargs) 179 | 180 | def forward(self, 181 | obs, 182 | mask=None, 183 | demo_normalizer=False, 184 | **mlp_kwargs): 185 | assert mask is not None 186 | vertices = self.input_module(obs, mask=mask) 187 | response_embeddings = self.graph_propagation.forward(vertices, mask=mask) 188 | 189 | selected_objects = self.selection_attention( 190 | vertices=response_embeddings, 191 | mask=mask 192 | ) 193 | selected_objects = selected_objects.squeeze(1) 194 | return self.mlp(selected_objects, **mlp_kwargs) 195 | 196 | def get_action(self, 197 | obs_np, 198 | **kwargs): 199 | assert len(obs_np.shape) == 1 200 | actions, agent_info = self.get_actions(obs_np[None], **kwargs) 201 | assert isinstance(actions, np.ndarray) 202 | return actions[0, :], agent_info 203 | 204 | def get_actions(self, 205 | obs_np, 206 | **kwargs): 207 | mlp_outputs = self.eval_np(obs_np, **kwargs) 208 | assert len(mlp_outputs) == 8 209 | actions = mlp_outputs[0] 210 | 211 | agent_info = dict() 212 | return actions, agent_info -------------------------------------------------------------------------------- /rlkit/torch/sac/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/richardrl/rlkit-relational/e01973d0a7e393cb31fbd48e8180ab0d3d8d2a2e/rlkit/torch/sac/__init__.py -------------------------------------------------------------------------------- /rlkit/torch/sac/sac.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | 3 | import numpy as np 4 | import torch.optim as optim 5 | from torch import nn as nn 6 | 7 | import rlkit.torch.pytorch_util as ptu 8 | from rlkit.core.eval_util import create_stats_ordered_dict 9 | from rlkit.torch.torch_rl_algorithm import TorchRLAlgorithm 10 | from rlkit.torch.sac.policies import MakeDeterministic 11 | 12 | 13 | class SoftActorCritic(TorchRLAlgorithm): 14 | def __init__( 15 | self, 16 | env, 17 | policy, 18 | qf, 19 | vf, 20 | 21 | policy_lr=1e-3, 22 | qf_lr=1e-3, 23 | vf_lr=1e-3, 24 | policy_mean_reg_weight=1e-3, 25 | policy_std_reg_weight=1e-3, 26 | policy_pre_activation_weight=0., 27 | optimizer_class=optim.Adam, 28 | 29 | train_policy_with_reparameterization=True, 30 | soft_target_tau=1e-2, 31 | plotter=None, 32 | render_eval_paths=False, 33 | eval_deterministic=True, 34 | 35 | use_automatic_entropy_tuning=True, 36 | target_entropy=None, 37 | **kwargs 38 | ): 39 | if eval_deterministic: 40 | eval_policy = MakeDeterministic(policy) 41 | else: 42 | eval_policy = policy 43 | super().__init__( 44 | env=env, 45 | exploration_policy=policy, 46 | eval_policy=eval_policy, 47 | **kwargs 48 | ) 49 | self.policy = policy 50 | self.qf = qf 51 | self.vf = vf 52 | self.train_policy_with_reparameterization = ( 53 | train_policy_with_reparameterization 54 | ) 55 | self.soft_target_tau = soft_target_tau 56 | self.policy_mean_reg_weight = policy_mean_reg_weight 57 | self.policy_std_reg_weight = policy_std_reg_weight 58 | self.policy_pre_activation_weight = policy_pre_activation_weight 59 | self.plotter = plotter 60 | self.render_eval_paths = render_eval_paths 61 | self.use_automatic_entropy_tuning = use_automatic_entropy_tuning 62 | if self.use_automatic_entropy_tuning: 63 | if target_entropy: 64 | self.target_entropy = target_entropy 65 | else: 66 | self.target_entropy = -np.prod(self.env.action_space.shape).item() # heuristic value from Tuomas 67 | self.log_alpha = ptu.zeros(1, requires_grad=True) 68 | self.alpha_optimizer = optimizer_class( 69 | [self.log_alpha], 70 | lr=policy_lr, 71 | ) 72 | 73 | self.target_vf = vf.copy() 74 | self.qf_criterion = nn.MSELoss() 75 | self.vf_criterion = nn.MSELoss() 76 | 77 | self.policy_optimizer = optimizer_class( 78 | self.policy.parameters(), 79 | lr=policy_lr, 80 | ) 81 | self.qf_optimizer = optimizer_class( 82 | self.qf.parameters(), 83 | lr=qf_lr, 84 | ) 85 | self.vf_optimizer = optimizer_class( 86 | self.vf.parameters(), 87 | lr=vf_lr, 88 | ) 89 | 90 | def _do_training(self): 91 | batch = self.get_batch() 92 | rewards = batch['rewards'] 93 | terminals = batch['terminals'] 94 | obs = batch['observations'] 95 | actions = batch['actions'] 96 | next_obs = batch['next_observations'] 97 | 98 | q_pred = self.qf(obs, actions) 99 | v_pred = self.vf(obs) 100 | # Make sure policy accounts for squashing functions like tanh correctly! 101 | policy_outputs = self.policy( 102 | obs, 103 | reparameterize=self.train_policy_with_reparameterization, 104 | return_log_prob=True, 105 | ) 106 | new_actions, policy_mean, policy_log_std, log_pi = policy_outputs[:4] 107 | if self.use_automatic_entropy_tuning: 108 | """ 109 | Alpha Loss 110 | """ 111 | alpha_loss = -(self.log_alpha * (log_pi + self.target_entropy).detach()).mean() 112 | self.alpha_optimizer.zero_grad() 113 | alpha_loss.backward() 114 | self.alpha_optimizer.step() 115 | alpha = self.log_alpha.exp() 116 | else: 117 | alpha = 1 118 | alpha_loss = 0 119 | 120 | """ 121 | QF Loss 122 | """ 123 | target_v_values = self.target_vf(next_obs) 124 | q_target = rewards + (1. - terminals) * self.discount * target_v_values 125 | qf_loss = self.qf_criterion(q_pred, q_target.detach()) 126 | 127 | """ 128 | VF Loss 129 | """ 130 | q_new_actions = self.qf(obs, new_actions) 131 | v_target = q_new_actions - alpha*log_pi 132 | vf_loss = self.vf_criterion(v_pred, v_target.detach()) 133 | 134 | """ 135 | Policy Loss 136 | """ 137 | if self.train_policy_with_reparameterization: 138 | policy_loss = (alpha*log_pi - q_new_actions).mean() 139 | else: 140 | log_policy_target = q_new_actions - v_pred 141 | policy_loss = ( 142 | log_pi * (alpha*log_pi - log_policy_target).detach() 143 | ).mean() 144 | mean_reg_loss = self.policy_mean_reg_weight * (policy_mean**2).mean() 145 | std_reg_loss = self.policy_std_reg_weight * (policy_log_std**2).mean() 146 | pre_tanh_value = policy_outputs[-1] 147 | pre_activation_reg_loss = self.policy_pre_activation_weight * ( 148 | (pre_tanh_value**2).sum(dim=1).mean() 149 | ) 150 | policy_reg_loss = mean_reg_loss + std_reg_loss + pre_activation_reg_loss 151 | policy_loss = policy_loss + policy_reg_loss 152 | 153 | """ 154 | Update networks 155 | """ 156 | self.qf_optimizer.zero_grad() 157 | qf_loss.backward() 158 | self.qf_optimizer.step() 159 | 160 | self.vf_optimizer.zero_grad() 161 | vf_loss.backward() 162 | self.vf_optimizer.step() 163 | 164 | self.policy_optimizer.zero_grad() 165 | policy_loss.backward() 166 | self.policy_optimizer.step() 167 | 168 | self._update_target_network() 169 | 170 | """ 171 | Save some statistics for eval using just one batch. 172 | """ 173 | if self.need_to_update_eval_statistics: 174 | self.need_to_update_eval_statistics = False 175 | self.eval_statistics['QF Loss'] = np.mean(ptu.get_numpy(qf_loss)) 176 | self.eval_statistics['VF Loss'] = np.mean(ptu.get_numpy(vf_loss)) 177 | self.eval_statistics['Policy Loss'] = np.mean(ptu.get_numpy( 178 | policy_loss 179 | )) 180 | self.eval_statistics.update(create_stats_ordered_dict( 181 | 'Q Predictions', 182 | ptu.get_numpy(q_pred), 183 | )) 184 | self.eval_statistics.update(create_stats_ordered_dict( 185 | 'V Predictions', 186 | ptu.get_numpy(v_pred), 187 | )) 188 | self.eval_statistics.update(create_stats_ordered_dict( 189 | 'Log Pis', 190 | ptu.get_numpy(log_pi), 191 | )) 192 | self.eval_statistics.update(create_stats_ordered_dict( 193 | 'Policy mu', 194 | ptu.get_numpy(policy_mean), 195 | )) 196 | self.eval_statistics.update(create_stats_ordered_dict( 197 | 'Policy log std', 198 | ptu.get_numpy(policy_log_std), 199 | )) 200 | if self.use_automatic_entropy_tuning: 201 | self.eval_statistics['Alpha'] = alpha.item() 202 | self.eval_statistics['Alpha Loss'] = alpha_loss.item() 203 | 204 | @property 205 | def networks(self): 206 | return [ 207 | self.policy, 208 | self.qf, 209 | self.vf, 210 | self.target_vf, 211 | ] 212 | 213 | def _update_target_network(self): 214 | ptu.soft_update_from_to(self.vf, self.target_vf, self.soft_target_tau) 215 | 216 | def get_epoch_snapshot(self, epoch): 217 | snapshot = super().get_epoch_snapshot(epoch) 218 | snapshot.update( 219 | qf=self.qf, 220 | policy=self.policy, 221 | vf=self.vf, 222 | target_vf=self.target_vf, 223 | ) 224 | return snapshot 225 | -------------------------------------------------------------------------------- /rlkit/torch/td3/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/richardrl/rlkit-relational/e01973d0a7e393cb31fbd48e8180ab0d3d8d2a2e/rlkit/torch/td3/__init__.py -------------------------------------------------------------------------------- /rlkit/torch/td3/td3.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | 3 | import numpy as np 4 | import torch 5 | import torch.optim as optim 6 | from torch import nn as nn 7 | 8 | import rlkit.torch.pytorch_util as ptu 9 | from rlkit.core.eval_util import create_stats_ordered_dict 10 | from rlkit.torch.torch_rl_algorithm import TorchRLAlgorithm 11 | 12 | 13 | class TD3(TorchRLAlgorithm): 14 | """ 15 | Twin Delayed Deep Deterministic policy gradients 16 | 17 | https://arxiv.org/abs/1802.09477 18 | """ 19 | 20 | def __init__( 21 | self, 22 | env, 23 | qf1, 24 | qf2, 25 | policy, 26 | exploration_policy, 27 | 28 | target_policy_noise=0.2, 29 | target_policy_noise_clip=0.5, 30 | 31 | policy_learning_rate=1e-3, 32 | qf_learning_rate=1e-3, 33 | policy_and_target_update_period=2, 34 | tau=0.005, 35 | qf_criterion=None, 36 | optimizer_class=optim.Adam, 37 | 38 | **kwargs 39 | ): 40 | super().__init__( 41 | env, 42 | exploration_policy, 43 | eval_policy=policy, 44 | **kwargs 45 | ) 46 | if qf_criterion is None: 47 | qf_criterion = nn.MSELoss() 48 | self.qf1 = qf1 49 | self.qf2 = qf2 50 | self.policy = policy 51 | 52 | self.target_policy_noise = target_policy_noise 53 | self.target_policy_noise_clip = target_policy_noise_clip 54 | 55 | self.policy_and_target_update_period = policy_and_target_update_period 56 | self.tau = tau 57 | self.qf_criterion = qf_criterion 58 | 59 | self.target_policy = policy.copy() 60 | self.target_qf1 = self.qf1.copy() 61 | self.target_qf2 = self.qf2.copy() 62 | self.qf1_optimizer = optimizer_class( 63 | self.qf1.parameters(), 64 | lr=qf_learning_rate, 65 | ) 66 | self.qf2_optimizer = optimizer_class( 67 | self.qf2.parameters(), 68 | lr=qf_learning_rate, 69 | ) 70 | self.policy_optimizer = optimizer_class( 71 | self.policy.parameters(), 72 | lr=policy_learning_rate, 73 | ) 74 | 75 | def _do_training(self): 76 | batch = self.get_batch() 77 | rewards = batch['rewards'] 78 | terminals = batch['terminals'] 79 | obs = batch['observations'] 80 | actions = batch['actions'] 81 | next_obs = batch['next_observations'] 82 | 83 | """ 84 | Critic operations. 85 | """ 86 | 87 | next_actions = self.target_policy(next_obs) 88 | noise = torch.normal( 89 | torch.zeros_like(next_actions), 90 | self.target_policy_noise, 91 | ) 92 | noise = torch.clamp( 93 | noise, 94 | -self.target_policy_noise_clip, 95 | self.target_policy_noise_clip 96 | ) 97 | noisy_next_actions = next_actions + noise 98 | 99 | target_q1_values = self.target_qf1(next_obs, noisy_next_actions) 100 | target_q2_values = self.target_qf2(next_obs, noisy_next_actions) 101 | target_q_values = torch.min(target_q1_values, target_q2_values) 102 | q_target = rewards + (1. - terminals) * self.discount * target_q_values 103 | q_target = q_target.detach() 104 | 105 | q1_pred = self.qf1(obs, actions) 106 | bellman_errors_1 = (q1_pred - q_target) ** 2 107 | qf1_loss = bellman_errors_1.mean() 108 | 109 | q2_pred = self.qf2(obs, actions) 110 | bellman_errors_2 = (q2_pred - q_target) ** 2 111 | qf2_loss = bellman_errors_2.mean() 112 | 113 | """ 114 | Update Networks 115 | """ 116 | self.qf1_optimizer.zero_grad() 117 | qf1_loss.backward() 118 | self.qf1_optimizer.step() 119 | 120 | self.qf2_optimizer.zero_grad() 121 | qf2_loss.backward() 122 | self.qf2_optimizer.step() 123 | 124 | policy_actions = policy_loss = None 125 | if self._n_train_steps_total % self.policy_and_target_update_period == 0: 126 | policy_actions = self.policy(obs) 127 | q_output = self.qf1(obs, policy_actions) 128 | policy_loss = - q_output.mean() 129 | 130 | self.policy_optimizer.zero_grad() 131 | policy_loss.backward() 132 | self.policy_optimizer.step() 133 | 134 | ptu.soft_update_from_to(self.policy, self.target_policy, self.tau) 135 | ptu.soft_update_from_to(self.qf1, self.target_qf1, self.tau) 136 | ptu.soft_update_from_to(self.qf2, self.target_qf2, self.tau) 137 | 138 | """ 139 | Save some statistics for eval using just one batch. 140 | """ 141 | if self.need_to_update_eval_statistics: 142 | self.need_to_update_eval_statistics = False 143 | if policy_loss is None: 144 | policy_actions = self.policy(obs) 145 | q_output = self.qf1(obs, policy_actions) 146 | policy_loss = - q_output.mean() 147 | 148 | self.eval_statistics['QF1 Loss'] = np.mean(ptu.get_numpy(qf1_loss)) 149 | self.eval_statistics['QF2 Loss'] = np.mean(ptu.get_numpy(qf2_loss)) 150 | self.eval_statistics['Policy Loss'] = np.mean(ptu.get_numpy( 151 | policy_loss 152 | )) 153 | self.eval_statistics.update(create_stats_ordered_dict( 154 | 'Q1 Predictions', 155 | ptu.get_numpy(q1_pred), 156 | )) 157 | self.eval_statistics.update(create_stats_ordered_dict( 158 | 'Q2 Predictions', 159 | ptu.get_numpy(q2_pred), 160 | )) 161 | self.eval_statistics.update(create_stats_ordered_dict( 162 | 'Q Targets', 163 | ptu.get_numpy(q_target), 164 | )) 165 | self.eval_statistics.update(create_stats_ordered_dict( 166 | 'Bellman Errors 1', 167 | ptu.get_numpy(bellman_errors_1), 168 | )) 169 | self.eval_statistics.update(create_stats_ordered_dict( 170 | 'Bellman Errors 2', 171 | ptu.get_numpy(bellman_errors_2), 172 | )) 173 | self.eval_statistics.update(create_stats_ordered_dict( 174 | 'Policy Action', 175 | ptu.get_numpy(policy_actions), 176 | )) 177 | 178 | def get_epoch_snapshot(self, epoch): 179 | snapshot = super().get_epoch_snapshot(epoch) 180 | snapshot.update( 181 | qf1=self.qf1, 182 | qf2=self.qf2, 183 | policy=self.eval_policy, 184 | trained_policy=self.policy, 185 | target_policy=self.target_policy, 186 | exploration_policy=self.exploration_policy, 187 | ) 188 | return snapshot 189 | 190 | @property 191 | def networks(self): 192 | return [ 193 | self.policy, 194 | self.qf1, 195 | self.qf2, 196 | self.target_policy, 197 | self.target_qf1, 198 | self.target_qf2, 199 | ] 200 | -------------------------------------------------------------------------------- /rlkit/torch/torch_rl_algorithm.py: -------------------------------------------------------------------------------- 1 | import abc 2 | from typing import Iterable 3 | 4 | import numpy as np 5 | 6 | from rlkit.core.rl_algorithm import RLAlgorithm 7 | from rlkit.torch import pytorch_util as ptu 8 | from rlkit.torch.core import PyTorchModule 9 | 10 | 11 | class TorchRLAlgorithm(RLAlgorithm, metaclass=abc.ABCMeta): 12 | def get_batch(self): 13 | batch = self.replay_buffer.random_batch(self.batch_size) 14 | return np_to_pytorch_batch(batch) 15 | 16 | @property 17 | @abc.abstractmethod 18 | def networks(self) -> Iterable[PyTorchModule]: 19 | pass 20 | 21 | def training_mode(self, mode): 22 | for net in self.networks: 23 | net.train(mode) 24 | 25 | def to(self, device=None): 26 | if device is None: 27 | device = ptu.device 28 | for net in self.networks: 29 | net.to(device) 30 | 31 | 32 | def _elem_or_tuple_to_variable(elem_or_tuple): 33 | if isinstance(elem_or_tuple, tuple): 34 | return tuple( 35 | _elem_or_tuple_to_variable(e) for e in elem_or_tuple 36 | ) 37 | return ptu.from_numpy(elem_or_tuple).float() 38 | 39 | 40 | def _filter_batch(np_batch): 41 | for k, v in np_batch.items(): 42 | assert isinstance(v, np.ndarray), "Dict values must be of type ndarray" 43 | 44 | if v.dtype == np.bool: 45 | yield k, v.astype(int) 46 | else: 47 | yield k, v 48 | 49 | 50 | def np_to_pytorch_batch(np_batch): 51 | return { 52 | k: _elem_or_tuple_to_variable(x) 53 | for k, x in _filter_batch(np_batch) 54 | if x.dtype != np.dtype('O') # ignore object (e.g. dictionaries) 55 | } 56 | -------------------------------------------------------------------------------- /rlkit/util/hyperparameter.py: -------------------------------------------------------------------------------- 1 | """ 2 | Custom hyperparameter functions. 3 | """ 4 | import abc 5 | import copy 6 | import math 7 | import random 8 | import itertools 9 | from typing import List 10 | 11 | import railrl.pythonplusplus as ppp 12 | 13 | 14 | class Hyperparameter(metaclass=abc.ABCMeta): 15 | def __init__(self, name): 16 | self._name = name 17 | 18 | @property 19 | def name(self): 20 | return self._name 21 | 22 | 23 | class RandomHyperparameter(Hyperparameter): 24 | def __init__(self, name): 25 | super().__init__(name) 26 | self._last_value = None 27 | 28 | @abc.abstractmethod 29 | def generate_next_value(self): 30 | """Return a value for the hyperparameter""" 31 | return 32 | 33 | def generate(self): 34 | self._last_value = self.generate_next_value() 35 | return self._last_value 36 | 37 | 38 | class EnumParam(RandomHyperparameter): 39 | def __init__(self, name, possible_values): 40 | super().__init__(name) 41 | self.possible_values = possible_values 42 | 43 | def generate_next_value(self): 44 | return random.choice(self.possible_values) 45 | 46 | 47 | class LogFloatParam(RandomHyperparameter): 48 | """ 49 | Return something ranging from [min_value + offset, max_value + offset], 50 | distributed with a log. 51 | """ 52 | def __init__(self, name, min_value, max_value, *, offset=0): 53 | super(LogFloatParam, self).__init__(name) 54 | self._linear_float_param = LinearFloatParam("log_" + name, 55 | math.log(min_value), 56 | math.log(max_value)) 57 | self.offset = offset 58 | 59 | def generate_next_value(self): 60 | return math.e ** (self._linear_float_param.generate()) + self.offset 61 | 62 | 63 | class LinearFloatParam(RandomHyperparameter): 64 | def __init__(self, name, min_value, max_value): 65 | super(LinearFloatParam, self).__init__(name) 66 | self._min = min_value 67 | self._delta = max_value - min_value 68 | 69 | def generate_next_value(self): 70 | return random.random() * self._delta + self._min 71 | 72 | 73 | class LogIntParam(RandomHyperparameter): 74 | def __init__(self, name, min_value, max_value, *, offset=0): 75 | super().__init__(name) 76 | self._linear_float_param = LinearFloatParam("log_" + name, 77 | math.log(min_value), 78 | math.log(max_value)) 79 | self.offset = offset 80 | 81 | def generate_next_value(self): 82 | return int( 83 | math.e ** (self._linear_float_param.generate()) + self.offset 84 | ) 85 | 86 | 87 | class LinearIntParam(RandomHyperparameter): 88 | def __init__(self, name, min_value, max_value): 89 | super(LinearIntParam, self).__init__(name) 90 | self._min = min_value 91 | self._max = max_value 92 | 93 | def generate_next_value(self): 94 | return random.randint(self._min, self._max) 95 | 96 | 97 | class FixedParam(RandomHyperparameter): 98 | def __init__(self, name, value): 99 | super().__init__(name) 100 | self._value = value 101 | 102 | def generate_next_value(self): 103 | return self._value 104 | 105 | 106 | class Sweeper(object): 107 | pass 108 | 109 | 110 | class RandomHyperparameterSweeper(Sweeper): 111 | def __init__(self, hyperparameters=None, default_kwargs=None): 112 | if default_kwargs is None: 113 | default_kwargs = {} 114 | self._hyperparameters = hyperparameters or [] 115 | self._validate_hyperparameters() 116 | self._default_kwargs = default_kwargs 117 | 118 | def _validate_hyperparameters(self): 119 | names = set() 120 | for hp in self._hyperparameters: 121 | name = hp.name 122 | if name in names: 123 | raise Exception("Hyperparameter '{0}' already added.".format( 124 | name)) 125 | names.add(name) 126 | 127 | def set_default_parameters(self, default_kwargs): 128 | self._default_kwargs = default_kwargs 129 | 130 | def generate_random_hyperparameters(self): 131 | hyperparameters = {} 132 | for hp in self._hyperparameters: 133 | hyperparameters[hp.name] = hp.generate() 134 | hyperparameters = ppp.dot_map_dict_to_nested_dict(hyperparameters) 135 | return ppp.merge_recursive_dicts( 136 | hyperparameters, 137 | copy.deepcopy(self._default_kwargs), 138 | ignore_duplicate_keys_in_second_dict=True, 139 | ) 140 | 141 | def sweep_hyperparameters(self, function, num_configs): 142 | returned_value_and_params = [] 143 | for _ in range(num_configs): 144 | kwargs = self.generate_random_hyperparameters() 145 | score = function(**kwargs) 146 | returned_value_and_params.append((score, kwargs)) 147 | 148 | return returned_value_and_params 149 | 150 | 151 | class DeterministicHyperparameterSweeper(Sweeper): 152 | """ 153 | Do a grid search over hyperparameters based on a predefined set of 154 | hyperparameters. 155 | """ 156 | def __init__(self, hyperparameters, default_parameters=None): 157 | """ 158 | 159 | :param hyperparameters: A dictionary of the form 160 | ``` 161 | { 162 | 'hp_1': [value1, value2, value3], 163 | 'hp_2': [value1, value2, value3], 164 | ... 165 | } 166 | ``` 167 | This format is like the param_grid in SciKit-Learn: 168 | http://scikit-learn.org/stable/modules/grid_search.html#exhaustive-grid-search 169 | :param default_parameters: Default key-value pairs to add to the 170 | dictionary. 171 | """ 172 | self._hyperparameters = hyperparameters 173 | self._default_kwargs = default_parameters or {} 174 | named_hyperparameters = [] 175 | for name, values in self._hyperparameters.items(): 176 | named_hyperparameters.append( 177 | [(name, v) for v in values] 178 | ) 179 | self._hyperparameters_dicts = [ 180 | ppp.dot_map_dict_to_nested_dict(dict(tuple_list)) 181 | for tuple_list in itertools.product(*named_hyperparameters) 182 | ] 183 | 184 | def iterate_hyperparameters(self): 185 | """ 186 | Iterate over the hyperparameters in a grid-manner. 187 | 188 | :return: List of dictionaries. Each dictionary is a map from name to 189 | hyperpameter. 190 | """ 191 | return [ 192 | ppp.merge_recursive_dicts( 193 | hyperparameters, 194 | copy.deepcopy(self._default_kwargs), 195 | ignore_duplicate_keys_in_second_dict=True, 196 | ) 197 | for hyperparameters in self._hyperparameters_dicts 198 | ] 199 | 200 | 201 | # TODO(vpong): Test this 202 | class DeterministicSweeperCombiner(object): 203 | """ 204 | A simple wrapper to combiner multiple DeterministicHyperParameterSweeper's 205 | """ 206 | def __init__(self, sweepers: List[DeterministicHyperparameterSweeper]): 207 | self._sweepers = sweepers 208 | 209 | def iterate_list_of_hyperparameters(self): 210 | """ 211 | Usage: 212 | 213 | ``` 214 | sweeper1 = DeterministicHyperparameterSweeper(...) 215 | sweeper2 = DeterministicHyperparameterSweeper(...) 216 | combiner = DeterministicSweeperCombiner([sweeper1, sweeper2]) 217 | 218 | for params_1, params_2 in combiner.iterate_list_of_hyperparameters(): 219 | # param_1 = {...} 220 | # param_2 = {...} 221 | ``` 222 | :return: Generator of hyperparameters, in the same order as provided 223 | sweepers. 224 | """ 225 | return itertools.product( 226 | sweeper.iterate_hyperparameters() 227 | for sweeper in self._sweepers 228 | ) -------------------------------------------------------------------------------- /rlkit/util/io.py: -------------------------------------------------------------------------------- 1 | import joblib 2 | import numpy as np 3 | import pickle 4 | 5 | import boto3 6 | 7 | from rlkit.launchers.config import LOCAL_LOG_DIR, AWS_S3_PATH 8 | import os 9 | 10 | PICKLE = 'pickle' 11 | NUMPY = 'numpy' 12 | JOBLIB = 'joblib' 13 | 14 | 15 | def local_path_from_s3_or_local_path(filename): 16 | relative_filename = os.path.join(LOCAL_LOG_DIR, filename) 17 | if os.path.isfile(filename): 18 | return filename 19 | elif os.path.isfile(relative_filename): 20 | return relative_filename 21 | else: 22 | return sync_down(filename) 23 | 24 | 25 | def sync_down(path, check_exists=True): 26 | is_docker = os.path.isfile("/.dockerenv") 27 | if is_docker: 28 | local_path = "/tmp/%s" % (path) 29 | else: 30 | local_path = "%s/%s" % (LOCAL_LOG_DIR, path) 31 | 32 | if check_exists and os.path.isfile(local_path): 33 | return local_path 34 | 35 | local_dir = os.path.dirname(local_path) 36 | os.makedirs(local_dir, exist_ok=True) 37 | 38 | if is_docker: 39 | from doodad.ec2.autoconfig import AUTOCONFIG 40 | os.environ["AWS_ACCESS_KEY_ID"] = AUTOCONFIG.aws_access_key() 41 | os.environ["AWS_SECRET_ACCESS_KEY"] = AUTOCONFIG.aws_access_secret() 42 | 43 | full_s3_path = os.path.join(AWS_S3_PATH, path) 44 | bucket_name, bucket_relative_path = split_s3_full_path(full_s3_path) 45 | try: 46 | bucket = boto3.resource('s3').Bucket(bucket_name) 47 | bucket.download_file(bucket_relative_path, local_path) 48 | except Exception as e: 49 | local_path = None 50 | print("Failed to sync! path: ", path) 51 | print("Exception: ", e) 52 | return local_path 53 | 54 | 55 | def split_s3_full_path(s3_path): 56 | """ 57 | Split "s3://foo/bar/baz" into "foo" and "bar/baz" 58 | """ 59 | bucket_name_and_directories = s3_path.split('//')[1] 60 | bucket_name, *directories = bucket_name_and_directories.split('/') 61 | directory_path = '/'.join(directories) 62 | return bucket_name, directory_path 63 | 64 | 65 | def load_local_or_remote_file(filepath, file_type=None): 66 | local_path = local_path_from_s3_or_local_path(filepath) 67 | if file_type is None: 68 | extension = local_path.split('.')[-1] 69 | if extension == 'npy': 70 | file_type = NUMPY 71 | else: 72 | file_type = PICKLE 73 | else: 74 | file_type = PICKLE 75 | if file_type == NUMPY: 76 | object = np.load(open(local_path, "rb")) 77 | elif file_type == JOBLIB: 78 | object = joblib.load(local_path) 79 | else: 80 | object = pickle.load(open(local_path, "rb")) 81 | print("loaded", local_path) 82 | return object 83 | 84 | 85 | if __name__ == "__main__": 86 | p = sync_down("ashvin/vae/new-point2d/run0/id1/params.pkl") 87 | print("got", p) -------------------------------------------------------------------------------- /rlkit/util/video.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as osp 3 | import time 4 | 5 | import numpy as np 6 | import scipy.misc 7 | import skvideo.io 8 | 9 | from rlkit.envs.vae_wrapper import VAEWrappedEnv 10 | 11 | 12 | def dump_video( 13 | env, 14 | policy, 15 | filename, 16 | rollout_function, 17 | rows=3, 18 | columns=6, 19 | pad_length=0, 20 | pad_color=255, 21 | do_timer=True, 22 | horizon=100, 23 | dirname_to_save_images=None, 24 | subdirname="rollouts", 25 | imsize=84, 26 | num_channels=3, 27 | ): 28 | frames = [] 29 | H = 3 * imsize 30 | W = imsize 31 | N = rows * columns 32 | for i in range(N): 33 | start = time.time() 34 | path = rollout_function( 35 | env, 36 | policy, 37 | max_path_length=horizon, 38 | animated=False, 39 | ) 40 | is_vae_env = isinstance(env, VAEWrappedEnv) 41 | l = [] 42 | for d in path['full_observations']: 43 | if is_vae_env: 44 | recon = np.clip( 45 | env._reconstruct_img(d['image_observation']), 0, 1 46 | ) 47 | else: 48 | recon = d['image_observation'] 49 | l.append( 50 | get_image( 51 | d['image_desired_goal'], 52 | d['image_observation'], 53 | recon, 54 | pad_length=pad_length, 55 | pad_color=pad_color, 56 | imsize=imsize, 57 | ) 58 | ) 59 | frames += l 60 | 61 | if dirname_to_save_images: 62 | rollout_dir = osp.join(dirname_to_save_images, subdirname, str(i)) 63 | os.makedirs(rollout_dir, exist_ok=True) 64 | rollout_frames = frames[-101:] 65 | goal_img = np.flip(rollout_frames[0][:imsize, :imsize, :], 0) 66 | scipy.misc.imsave(rollout_dir + "/goal.png", goal_img) 67 | goal_img = np.flip(rollout_frames[1][:imsize, :imsize, :], 0) 68 | scipy.misc.imsave(rollout_dir + "/z_goal.png", goal_img) 69 | for j in range(0, 101, 1): 70 | img = np.flip(rollout_frames[j][imsize:, :imsize, :], 0) 71 | scipy.misc.imsave(rollout_dir + "/" + str(j) + ".png", img) 72 | if do_timer: 73 | print(i, time.time() - start) 74 | 75 | frames = np.array(frames, dtype=np.uint8).reshape( 76 | (N, horizon + 1, H + 2 * pad_length, W + 2 * pad_length, num_channels)) 77 | f1 = [] 78 | for k1 in range(columns): 79 | f2 = [] 80 | for k2 in range(rows): 81 | k = k1 * rows + k2 82 | f2.append(frames[k:k + 1, :, :, :, :]. 83 | reshape((horizon + 1, H + 2 * pad_length, 84 | W + 2 * pad_length, num_channels))) 85 | f1.append(np.concatenate(f2, axis=1)) 86 | outputdata = np.concatenate(f1, axis=2) 87 | skvideo.io.vwrite(filename, outputdata) 88 | print("Saved video to ", filename) 89 | 90 | 91 | def get_image(goal, obs, recon_obs, imsize=84, pad_length=1, pad_color=255): 92 | if len(goal.shape) == 1: 93 | goal = goal.reshape(-1, imsize, imsize).transpose() 94 | obs = obs.reshape(-1, imsize, imsize).transpose() 95 | recon_obs = recon_obs.reshape(-1, imsize, imsize).transpose() 96 | img = np.concatenate((goal, obs, recon_obs)) 97 | img = np.uint8(255 * img) 98 | if pad_length > 0: 99 | img = add_border(img, pad_length, pad_color) 100 | return img 101 | 102 | 103 | def add_border(img, pad_length, pad_color, imsize=84): 104 | H = 3 * imsize 105 | W = imsize 106 | img = img.reshape((3 * imsize, imsize, -1)) 107 | img2 = np.ones((H + 2 * pad_length, W + 2 * pad_length, img.shape[2]), 108 | dtype=np.uint8) * pad_color 109 | img2[pad_length:-pad_length, pad_length:-pad_length, :] = img 110 | return img2 111 | -------------------------------------------------------------------------------- /scripts/download_s3.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import subprocess 3 | import rlkit.launchers.config as config 4 | 5 | cmd = F"aws s3 sync --exact-timestamp --exclude '*' --include '12-02*' {config.AWS_S3_PATH}/ ../../s3_files/" 6 | 7 | process = subprocess.Popen(cmd, stdout=subprocess.PIPE, shell=True) 8 | print(cmd) 9 | for line in iter(process.stdout.readline, b''): 10 | sys.stdout.buffer.write(line) -------------------------------------------------------------------------------- /scripts/resume_training_with_new_env.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import pickle 3 | import joblib 4 | import gym 5 | 6 | from rlkit.core import logger 7 | from rlkit.samplers.rollout_functions import multitask_rollout 8 | from rlkit.torch import pytorch_util as ptu 9 | from rlkit.envs.vae_wrapper import VAEWrappedEnv 10 | from rlkit.launchers.launcher_util import run_experiment 11 | import robotics_recorder 12 | from rlkit.data_management.path_builder import PathBuilder 13 | from rlkit.data_management.obs_dict_replay_buffer import ObsDictRelabelingBuffer 14 | import torch 15 | import numpy as np 16 | from rlkit.launchers.config import get_infra_settings 17 | 18 | try: 19 | from mpi4py import MPI 20 | except ImportError: 21 | MPI = None 22 | 23 | 24 | def resume_training(variant): 25 | data = pickle.load(open(variant['trained_file'], "rb")) 26 | algorithm = data['algorithm'] 27 | 28 | env = gym.make(variant['env_id']) 29 | algorithm.training_env = pickle.loads(pickle.dumps(env)) 30 | algorithm.env = env 31 | 32 | observation_key = 'observation' 33 | desired_goal_key = 'desired_goal' 34 | achieved_goal_key = desired_goal_key.replace("desired", "achieved") 35 | for key in algorithm.__dict__: 36 | if "optimizer" in key: 37 | getattr(algorithm, key).comm = MPI.COMM_WORLD 38 | network = key.split("_")[0] 39 | if network == "alpha": 40 | # getattr(algorithm, key).reconnect_params(getattr(algorithm, network).parameters()) 41 | algorithm.alpha_optimizer.reconnect_params([algorithm.log_alpha]) 42 | else: 43 | getattr(algorithm, key).reconnect_params(getattr(algorithm, network).parameters()) 44 | 45 | if not hasattr(algorithm.replay_buffer, "_masks"): 46 | algorithm.replay_buffer._masks = np.zeros((algorithm.replay_buffer.max_size, num_blocks)) 47 | algorithm.replay_buffer.max_num_blocks = num_blocks 48 | algorithm.replay_buffer.key_sizes = dict(observation=15, 49 | desired_goal=3, 50 | achieved_goal=3) 51 | algorithm.demonstration_policy = None 52 | algorithm.replay_buffer.demonstration_buffer = None 53 | algorithm._old_table_keys = None 54 | 55 | algorithm.to(ptu.device) 56 | algorithm.train() 57 | 58 | 59 | if __name__ == "__main__": 60 | docker_img = "latest" 61 | 62 | filename = "/home/richard/s3_files/11-27-sequentialtransfer-recurrentFalse-stack5-stack6-numrelblocks3-nqh1-dockimglatest-rewardIncremental-stackonlyTrue/11-27-sequentialtransfer_recurrentFalse_stack5_stack6_numrelblocks3_nqh1_dockimglatest_rewardIncremental_stackonlyTrue-1574894953750/11-27-sequentialtransfer_recurrentFalse_stack5_stack6_numrelblocks3_nqh1_dockimglatest_rewardIncremental_stackonlyTrue_2019_11_27_22_56_32_0000--s-23990/itr_250.pkl" 63 | 64 | num_blocks = int(input("Num blocks: ")) 65 | # assert "relational_preloadstack1" in filename 66 | import re 67 | if "nqh" in filename: 68 | nqh = int(re.search("(?<=nqh)(\d+)(?=_)", filename).group(0)) 69 | else: 70 | nqh = 0 71 | if "numrelblocks" in filename: 72 | num_relational_blocks = int(re.search("(?<=numrelblocks)(\d+)", filename).group(0)) 73 | else: 74 | num_relational_blocks = 0 75 | 76 | prev_stackonly = bool(input("Prev stack only: ")) 77 | 78 | variant = dict( 79 | trained_file=filename, 80 | env_id=F"FetchBlockConstruction_{num_blocks}Blocks_IncrementalReward_DictstateObs_42Rendersize_{prev_stackonly}Stackonly_SingletowerCase-v1", 81 | replay_buffer_kwargs=dict( 82 | max_size=100000, 83 | fraction_goals_rollout_goals=0.2, # equal to k = 4 in HER paper 84 | fraction_goals_env_goals=0.0, 85 | ), 86 | new_replay_buffer=False, 87 | doodad_docker_image = F"richardrl/fbc:{docker_img}", 88 | gpu_doodad_docker_image = F"richardrl/fbc:{docker_img}", 89 | ) 90 | 91 | mode="here_no_doodad" 92 | instance_type = "c5.18xlarge" 93 | num_parallel_processes = get_infra_settings(mode, instance_type)['num_parallel_processes'] 94 | 95 | prefix = input("Prefix: ") 96 | 97 | run_experiment( 98 | resume_training, 99 | exp_prefix=f"resume-{prefix}-numrelblocks{num_relational_blocks}-nqh{nqh}-numblocks{num_blocks}-stackonly{prev_stackonly}_dockimg{docker_img}", # Make sure no spaces.. 100 | region="us-west-2", 101 | mode=mode, 102 | variant=variant, 103 | gpu_mode=False, 104 | spot_price=5, 105 | snapshot_mode='gap_and_last', 106 | snapshot_gap=100, 107 | num_exps_per_instance=1, 108 | instance_type=instance_type, 109 | python_cmd=F"mpirun --allow-run-as-root -np {num_parallel_processes} python" 110 | ) 111 | -------------------------------------------------------------------------------- /scripts/run_experiment_from_doodad.py: -------------------------------------------------------------------------------- 1 | import doodad as dd 2 | from rlkit.launchers.launcher_util import run_experiment_here 3 | import torch.multiprocessing as mp 4 | 5 | if __name__ == "__main__": 6 | import matplotlib 7 | matplotlib.use('agg') 8 | 9 | mp.set_start_method('forkserver') 10 | args_dict = dd.get_args() 11 | method_call = args_dict['method_call'] 12 | run_experiment_kwargs = args_dict['run_experiment_kwargs'] 13 | output_dir = args_dict['output_dir'] 14 | run_mode = args_dict.get('mode', None) 15 | if run_mode and run_mode in ['slurm_singularity', 'sss']: 16 | import os 17 | run_experiment_kwargs['variant']['slurm-job-id'] = os.environ.get( 18 | 'SLURM_JOB_ID', None 19 | ) 20 | if run_mode and run_mode == 'ec2': 21 | try: 22 | import urllib.request 23 | instance_id = urllib.request.urlopen( 24 | 'http://169.254.169.254/latest/meta-data/instance-id' 25 | ).read().decode() 26 | run_experiment_kwargs['variant']['EC2_instance_id'] = instance_id 27 | except Exception as e: 28 | print("Could not get instance ID. Error was...") 29 | print(e) 30 | if run_mode and (run_mode == 'ec2' or run_mode == 'gcp'): 31 | # Do this in case base_log_dir was already set 32 | run_experiment_kwargs['base_log_dir'] = output_dir 33 | run_experiment_here( 34 | method_call, 35 | include_exp_prefix_sub_dir=False, 36 | **run_experiment_kwargs 37 | ) 38 | else: 39 | run_experiment_here( 40 | method_call, 41 | log_dir=output_dir, 42 | **run_experiment_kwargs 43 | ) -------------------------------------------------------------------------------- /scripts/sim_goal_conditioned_policy.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import pickle 3 | 4 | from rlkit.core import logger 5 | from rlkit.samplers.rollout_functions import multitask_rollout 6 | from rlkit.torch import pytorch_util as ptu 7 | import numpy as np 8 | from gym.wrappers.monitor import Monitor 9 | import gym 10 | 11 | 12 | def is_solved(path, num_blocks): 13 | num_succeeded = 0 14 | goal_threshold = .05 15 | for block_id in range(num_blocks): 16 | if np.linalg.norm(path['full_observations'][-1]['achieved_goal'][block_id * 3:(block_id + 1) * 3] - path['full_observations'][-1]['desired_goal'][block_id * 3:(block_id + 1) * 3]) < goal_threshold: 17 | num_succeeded += 1 18 | return num_succeeded == num_blocks 19 | 20 | 21 | def get_final_subgoaldist(env, path): 22 | if isinstance(env, Monitor): 23 | return sum(env.env.unwrapped.subgoal_distances(path['full_observations'][-1]['achieved_goal'], path['full_observations'][-1]['desired_goal'])) 24 | else: 25 | return sum(env.unwrapped.subgoal_distances(path['full_observations'][-1]['achieved_goal'], path['full_observations'][-1]['desired_goal'])) 26 | 27 | 28 | def simulate_policy(args): 29 | # import torch 30 | # torch.manual_seed(6199) 31 | if args.pause: 32 | import ipdb; ipdb.set_trace() 33 | data = pickle.load(open(args.file, "rb")) 34 | policy = data['algorithm'].policy 35 | 36 | num_blocks = 6 37 | stack_only = True 38 | 39 | 40 | # env = data['env'] 41 | env = gym.make(F"FetchBlockConstruction_{num_blocks}Blocks_IncrementalReward_DictstateObs_42Rendersize_{stack_only}Stackonly_AllCase-v1") 42 | 43 | env = Monitor(env, force=True, directory="videos/", video_callable=lambda x:x) 44 | 45 | print("Policy and environment loaded") 46 | if args.gpu: 47 | ptu.set_gpu_mode(True) 48 | policy.to(ptu.device) 49 | if args.enable_render or hasattr(env, 'enable_render'): 50 | # some environments need to be reconfigured for visualization 51 | env.enable_render() 52 | policy.train(False) 53 | failures = [] 54 | successes = [] 55 | for path_idx in range(100): 56 | path = multitask_rollout( 57 | env, 58 | policy, 59 | max_path_length=num_blocks*50, 60 | animated=not args.hide, 61 | observation_key='observation', 62 | desired_goal_key='desired_goal', 63 | get_action_kwargs=dict( 64 | mask=np.ones((1, num_blocks)), 65 | deterministic=True 66 | ), 67 | ) 68 | 69 | if not is_solved(path, num_blocks): 70 | failures.append(path) 71 | print(F"Failed {path_idx}") 72 | else: 73 | print(F"Succeeded {path_idx}") 74 | successes.append(path) 75 | # if hasattr(env, "log_diagnostics"): 76 | # env.log_diagnostics(paths) 77 | # if hasattr(env, "get_diagnostics"): 78 | # for k, v in env.get_diagnostics(paths).items(): 79 | # logger.record_tabular(k, v) 80 | # logger.dump_tabular() 81 | print(f"Success rate {len(successes)/(len(successes) + len(failures))}") 82 | from rlkit.core.eval_util import get_generic_path_information 83 | path_info = get_generic_path_information(successes + failures, num_blocks=num_blocks) 84 | print(path_info) 85 | 86 | if __name__ == "__main__": 87 | 88 | parser = argparse.ArgumentParser() 89 | parser.add_argument('file', type=str, 90 | help='path to the snapshot file') 91 | parser.add_argument('--H', type=int, default=np.inf, 92 | help='Max length of rollout') 93 | parser.add_argument('--speedup', type=float, default=10, 94 | help='Speedup') 95 | parser.add_argument('--mode', default='video_env', type=str, 96 | help='env mode') 97 | parser.add_argument('--gpu', action='store_true') 98 | parser.add_argument('--pause', action='store_true') 99 | parser.add_argument('--enable_render', action='store_true') 100 | parser.add_argument('--multitaskpause', action='store_true') 101 | parser.add_argument('--hide', action='store_true') 102 | args = parser.parse_args() 103 | 104 | simulate_policy(args) 105 | -------------------------------------------------------------------------------- /scripts/sim_policy.py: -------------------------------------------------------------------------------- 1 | from rlkit.samplers.util import rollout 2 | from rlkit.torch.core import PyTorchModule 3 | from rlkit.torch.pytorch_util import set_gpu_mode 4 | import argparse 5 | import joblib 6 | import uuid 7 | from rlkit.core import logger 8 | 9 | filename = str(uuid.uuid4()) 10 | 11 | 12 | def simulate_policy(args): 13 | data = joblib.load(args.file) 14 | policy = data['policy'] 15 | env = data['env'] 16 | print("Policy loaded") 17 | if args.gpu: 18 | set_gpu_mode(True) 19 | policy.cuda() 20 | if isinstance(policy, PyTorchModule): 21 | policy.train(False) 22 | while True: 23 | path = rollout( 24 | env, 25 | policy, 26 | max_path_length=args.H, 27 | animated=True, 28 | ) 29 | if hasattr(env, "log_diagnostics"): 30 | env.log_diagnostics([path]) 31 | logger.dump_tabular() 32 | 33 | 34 | if __name__ == "__main__": 35 | parser = argparse.ArgumentParser() 36 | parser.add_argument('file', type=str, 37 | help='path to the snapshot file') 38 | parser.add_argument('--H', type=int, default=300, 39 | help='Max length of rollout') 40 | parser.add_argument('--gpu', action='store_true') 41 | args = parser.parse_args() 42 | 43 | simulate_policy(args) 44 | -------------------------------------------------------------------------------- /scripts/sim_tdm_policy.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | 4 | import joblib 5 | from pathlib import Path 6 | 7 | import rlkit.torch.pytorch_util as ptu 8 | from rlkit.core.eval_util import get_generic_path_information 9 | from rlkit.torch.tdm.sampling import multitask_rollout 10 | from rlkit.core import logger 11 | if __name__ == "__main__": 12 | 13 | parser = argparse.ArgumentParser() 14 | parser.add_argument('file', type=str, help='path to the snapshot file') 15 | parser.add_argument('--H', type=int, default=300, 16 | help='Max length of rollout') 17 | parser.add_argument('--nrolls', type=int, default=1, 18 | help='Number of rollout per eval') 19 | parser.add_argument('--mtau', type=float, help='Max tau value') 20 | parser.add_argument('--gpu', action='store_true') 21 | parser.add_argument('--hide', action='store_true') 22 | args = parser.parse_args() 23 | 24 | data = joblib.load(args.file) 25 | if args.mtau is None: 26 | # Load max tau from variant.json file 27 | variant_path = Path(args.file).parents[0] / 'variant.json' 28 | variant = json.load(variant_path.open()) 29 | try: 30 | max_tau = variant['tdm_kwargs']['max_tau'] 31 | print("Max tau read from variant: {}".format(max_tau)) 32 | except KeyError: 33 | print("Defaulting max tau to 0.") 34 | max_tau = 0 35 | else: 36 | max_tau = args.mtau 37 | 38 | env = data['env'] 39 | policy = data['policy'] 40 | policy.train(False) 41 | 42 | if args.gpu: 43 | ptu.set_gpu_mode(True) 44 | policy.cuda() 45 | 46 | while True: 47 | paths = [] 48 | for _ in range(args.nrolls): 49 | goal = env.sample_goal_for_rollout() 50 | path = multitask_rollout( 51 | env, 52 | policy, 53 | init_tau=max_tau, 54 | goal=goal, 55 | max_path_length=args.H, 56 | animated=not args.hide, 57 | cycle_tau=True, 58 | decrement_tau=True, 59 | ) 60 | paths.append(path) 61 | env.log_diagnostics(paths) 62 | for key, value in get_generic_path_information(paths).items(): 63 | logger.record_tabular(key, value) 64 | logger.dump_tabular() 65 | --------------------------------------------------------------------------------