├── .gitignore ├── LICENSE ├── README.md ├── docs ├── HER.md ├── RIG.md ├── SMAC.md ├── SkewFit.md ├── TDMs.md ├── goal_based_envs.md └── images │ ├── FetchReach-v1_HER-TD3.png │ ├── SawyerReachXYZEnv-v0_HER-TD3.png │ ├── her_dqn.png │ ├── her_td3_sawyer_reacher.png │ ├── skewfit_door.png │ ├── skewfit_pickup.png │ └── skewfit_pusher.png ├── environment ├── docker │ ├── Dockerfile │ └── vendor │ │ ├── 10_nvidia.json │ │ ├── Xdummy │ │ └── Xdummy-entrypoint ├── linux-cpu-env.yml ├── linux-gpu-env.yml ├── mac-env.yml └── val_v1 │ ├── environment.yml │ └── requirements.txt ├── examples ├── __init__.py ├── awac │ ├── README.md │ ├── hand │ │ └── awac1.py │ └── mujoco │ │ └── awac1.py ├── ddpg.py ├── doodad │ ├── ec2_example.py │ └── gcp_example.py ├── dqn_and_double_dqn.py ├── her │ ├── her_dqn_gridworld.py │ ├── her_sac_gym_fetch_reach.py │ └── her_td3_multiworld_sawyer_reach.py ├── iql │ ├── README.md │ ├── antmaze_finetune.py │ └── mujoco_finetune.py ├── sac.py ├── simplegym │ ├── sac.py │ └── sac_hc.py ├── skewfit │ ├── sawyer_door.py │ ├── sawyer_pickup.py │ └── sawyer_push.py ├── smac │ ├── ant.py │ ├── ant_tasks.joblib │ ├── cheetah.py │ ├── cheetah_tasks.joblib │ ├── generate_ant_data.py │ └── generate_cheetah_data.py └── td3.py ├── rlkit ├── __init__.py ├── core │ ├── __init__.py │ ├── batch_rl_algorithm.py │ ├── eval_util.py │ ├── logging.py │ ├── loss.py │ ├── meta_rl_algorithm.py │ ├── online_rl_algorithm.py │ ├── rl_algorithm.py │ ├── serializable.py │ ├── simple_offline_rl_algorithm.py │ ├── tabulate.py │ ├── timer.py │ └── trainer.py ├── data_management │ ├── __init__.py │ ├── env_replay_buffer.py │ ├── meta_learning_replay_buffer.py │ ├── multitask_replay_buffer.py │ ├── normalizer.py │ ├── obs_dict_replay_buffer.py │ ├── online_vae_replay_buffer.py │ ├── path_builder.py │ ├── replay_buffer.py │ ├── shared_obs_dict_replay_buffer.py │ ├── simple_replay_buffer.py │ └── split_buffer.py ├── demos │ ├── __init__.py │ ├── collect_demo.py │ ├── play_demo.py │ ├── source │ │ ├── __init__.py │ │ ├── demo_source.py │ │ ├── dict_to_mdp_path_loader.py │ │ ├── hand_demo_source.py │ │ ├── hdf5_path_loader.py │ │ ├── mdp_path_loader.py │ │ └── path_loader.py │ └── spacemouse │ │ ├── README.md │ │ ├── __init__.py │ │ ├── input_client.py │ │ └── input_server.py ├── envs │ ├── __init__.py │ ├── ant.py │ ├── assets │ │ ├── low_gear_ratio_ant.xml │ │ └── reacher_7dof.xml │ ├── env_utils.py │ ├── goal_generation │ │ └── pickup_goal_dataset.py │ ├── make_env.py │ ├── mujoco_env.py │ ├── mujoco_image_env.py │ ├── pearl_envs │ │ ├── __init__.py │ │ ├── ant.py │ │ ├── ant_dir.py │ │ ├── ant_goal.py │ │ ├── ant_multitask_base.py │ │ ├── ant_normal.py │ │ ├── assets │ │ │ ├── ant.xml │ │ │ └── low_gear_ratio_ant.xml │ │ ├── half_cheetah.py │ │ ├── half_cheetah_dir.py │ │ ├── half_cheetah_vel.py │ │ ├── hopper_rand_params_wrapper.py │ │ ├── humanoid_dir.py │ │ ├── mujoco_env.py │ │ ├── point_robot.py │ │ ├── rand_param_envs │ │ │ ├── __init__.py │ │ │ ├── base.py │ │ │ ├── hopper_rand_params.py │ │ │ ├── pr2_env_reach.py │ │ │ └── walker2d_rand_params.py │ │ ├── walker_rand_params_wrapper.py │ │ └── wrappers.py │ ├── proxy_env.py │ ├── vae_wrapper.py │ ├── wrappers.py │ └── wrappers │ │ ├── __init__.py │ │ ├── discretize_env.py │ │ ├── history_env.py │ │ ├── image_mujoco_env.py │ │ ├── image_mujoco_env_with_obs.py │ │ ├── normalized_box_env.py │ │ ├── reward_wrapper_env.py │ │ └── stack_observation_env.py ├── exploration_strategies │ ├── __init__.py │ ├── base.py │ ├── epsilon_greedy.py │ ├── gaussian_and_epsilon_strategy.py │ ├── gaussian_strategy.py │ └── ou_strategy.py ├── launchers │ ├── __init__.py │ ├── conf.py │ ├── experiments │ │ └── awac │ │ │ ├── __init__.py │ │ │ ├── awac_encoder_rl.py │ │ │ ├── awac_gcrl.py │ │ │ ├── awac_rl.py │ │ │ └── finetune_rl.py │ ├── launcher_util.py │ └── skewfit_experiments.py ├── policies │ ├── __init__.py │ ├── argmax.py │ ├── base.py │ └── simple.py ├── pythonplusplus.py ├── samplers │ ├── __init__.py │ ├── data_collector │ │ ├── __init__.py │ │ ├── base.py │ │ ├── contextual_path_collector.py │ │ ├── joint_path_collector.py │ │ ├── path_collector.py │ │ ├── step_collector.py │ │ └── vae_env.py │ ├── in_place.py │ ├── rollout_functions.py │ └── util.py ├── testing │ ├── __init__.py │ ├── csv_util.py │ ├── debug_util.py │ ├── np_test_case.py │ ├── stub_classes.py │ ├── testing_utils.py │ └── tf_test_case.py ├── torch │ ├── __init__.py │ ├── conv_networks.py │ ├── core.py │ ├── data.py │ ├── data_management │ │ ├── __init__.py │ │ └── normalizer.py │ ├── ddpg │ │ ├── __init__.py │ │ └── ddpg.py │ ├── distributions.py │ ├── dqn │ │ ├── __init__.py │ │ ├── double_dqn.py │ │ └── dqn.py │ ├── her │ │ ├── __init__.py │ │ └── her.py │ ├── lvm │ │ ├── __init__.py │ │ ├── bear_vae.py │ │ └── latent_variable_model.py │ ├── modules.py │ ├── networks │ │ ├── __init__.py │ │ ├── basic.py │ │ ├── cnn.py │ │ ├── custom.py │ │ ├── dcnn.py │ │ ├── feat_point_mlp.py │ │ ├── image_state.py │ │ ├── linear_transform.py │ │ ├── mlp.py │ │ ├── normalization.py │ │ ├── pretrained_cnn.py │ │ ├── stochastic │ │ │ ├── __init__.py │ │ │ └── distribution_generator.py │ │ └── two_headed_mlp.py │ ├── pytorch_util.py │ ├── sac │ │ ├── __init__.py │ │ ├── awac_trainer.py │ │ ├── iql_trainer.py │ │ ├── policies │ │ │ ├── __init__.py │ │ │ ├── base.py │ │ │ ├── gaussian_policy.py │ │ │ ├── lvm_policy.py │ │ │ └── policy_from_q.py │ │ └── sac.py │ ├── skewfit │ │ ├── online_vae_algorithm.py │ │ └── video_gen.py │ ├── smac │ │ ├── agent.py │ │ ├── base_config.py │ │ ├── diagnostics.py │ │ ├── launcher.py │ │ ├── launcher_util.py │ │ ├── networks.py │ │ ├── pearl.py │ │ ├── pearl_launcher.py │ │ ├── sampler.py │ │ └── smac.py │ ├── td3 │ │ ├── __init__.py │ │ └── td3.py │ ├── torch_rl_algorithm.py │ └── vae │ │ ├── conv_vae.py │ │ ├── vae_base.py │ │ ├── vae_schedules.py │ │ └── vae_trainer.py ├── util │ ├── hyperparameter.py │ ├── io.py │ ├── ml_util.py │ ├── video.py │ └── wrapper.py └── visualization │ └── plot_util.py ├── scripts ├── run_experiment_from_doodad.py ├── run_goal_conditioned_policy.py └── run_policy.py ├── setup.py └── tests └── regression ├── iql ├── halfcheetah_offline_progress.csv ├── halfcheetah_online_progress.csv ├── test_iql_offline.py └── test_iql_online.py └── simplegym ├── test_sac.py └── test_sac_progress.csv /.gitignore: -------------------------------------------------------------------------------- 1 | */*/mjkey.txt 2 | **/.DS_STORE 3 | **/*.pyc 4 | **/*.swp 5 | rlkit/launchers/config.py 6 | rlkit/launchers/conf_private.py 7 | MANIFEST 8 | *.egg-info 9 | \.idea/ 10 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Vitchyr Pong 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /docs/HER.md: -------------------------------------------------------------------------------- 1 | # Hindsight Experience Replay 2 | Some notes on the implementation of 3 | [Hindsight Experience Replay](https://arxiv.org/abs/1707.01495). 4 | ## Expected Results 5 | If you run the [Fetch example](examples/her/her_td3_gym_fetch_reach.py), then 6 | you should get results like this: 7 | ![Fetch HER results](images/FetchReach-v1_HER-TD3.png) 8 | 9 | If you run the [GridWorld example](examples/her/her_dqn_gridworld.py) 10 | , then you should get results like this: 11 | ![HER Gridworld results](images/her_dqn.png) 12 | 13 | Note that these examples use HER combined with DQN and SAC, and not DDPG. 14 | 15 | These plots are generated using [viskit](https://github.com/vitchyr/viskit). 16 | 17 | ## Goal-based environments and `ObsDictRelabelingBuffer` 18 | [See here.](goal_based_envs.md) 19 | 20 | ## Implementation Difference 21 | This HER implemention is slightly different from the one presented in the paper. 22 | Rather than relabeling goals when saving data to the replay buffer, the goals 23 | are relabeled when sampling from the replay buffer. 24 | 25 | 26 | In other words, HER in the paper does this: 27 | 28 | Data collection 29 | 1. Sample $(s, a, r, s', g) ~ \text\{ENV}$. 30 | 2. Save $(s, a, r, s', g)$ into replay buffer $\mathcal B$. 31 | For i = 1, ..., K: 32 | Sample $g_i$ using the future strategy. 33 | Recompute rewards $r_i = f(s', g_i)$. 34 | Save $(s, a, r_i, s', g_)$ into replay buffer $\mathcal B$. 35 | Train time 36 | 1. Sample $(s, a, r, s', g)$ from replay buffer 37 | 2. Train Q function $(s, a, r, s', g)$ 38 | 39 | The implementation here does: 40 | 41 | Data collection 42 | 1. Sample $(s, a, r, s', g) ~ \text\{ENV}$. 43 | 2. Save $(s, a, r, s', g)$ into replay buffer $\mathcal B$. 44 | Train time 45 | 1. Sample $(s, a, r, s', g)$ from replay buffer 46 | 2a. With probability 1/(K+1): 47 | Train Q function $(s, a, r, s', g)$ 48 | 2b. With probability 1 - 1/(K+1): 49 | Sample $g'$ using the future strategy. 50 | Recompute rewards $r' = f(s', g')$. 51 | Train Q function on $(s, a, r', s', g')$ 52 | 53 | Both implementations effective do the same thing: with probability 1/(K+1), 54 | you train the policy on the goal used during rollout. Otherwise, train the 55 | policy on a resampled goal. 56 | 57 | -------------------------------------------------------------------------------- /docs/RIG.md: -------------------------------------------------------------------------------- 1 | # Reinforcement Learning with Imagined Goals 2 | Implementation of reinforcement learning with imagined goals (RIG) 3 | To find out 4 | more, see any of the following links: 5 | * arXiv: https://arxiv.org/abs/1807.04742 6 | * Website: https://sites.google.com/site/visualrlwithimaginedgoals/ 7 | * Blog Post: https://bair.berkeley.edu/blog/2018/09/06/rig/ 8 | 9 | To see the original implementation, checkout version `v0.1.2` of this repo. 10 | 11 | In versions `0.2+`, RIG is a special case of [Skew-Fit](SkewFit.md) with the 12 | power set to `0`. 13 | 14 | ## Goal-based environments and `ObsDictRelabelingBuffer` 15 | [See here.](goal_based_envs.md) 16 | -------------------------------------------------------------------------------- /docs/SMAC.md: -------------------------------------------------------------------------------- 1 | Requirements that differ from base requirements: 2 | - python 3.6.5 3 | - joblib==0.9.4 4 | - numpy==1.18.5 5 | 6 | Running these examples requires first generating the data, updating the main launch script to point to that generated data, and then launching the SMAC experiments. 7 | 8 | This can be done by first running 9 | ```bash 10 | python examples/smac/generate_{ant|cheetah}_data.py 11 | ``` 12 | which runs [PEARL](https://github.com/katerakelly/oyster) to generate multi-task data. 13 | This script will generate a directory and file of the form 14 | ``` 15 | LOCAL_LOG_DIR///extra_snapshot_itrXYZ.cpkl 16 | ``` 17 | 18 | You can then update the `examples/smac/{ant|cheetah}.py` file, where it says `TODO: update to point to correct file` to point to this file. 19 | Finally, run the SMAC script 20 | ```bash 21 | python examples/smac/{ant|cheetah}.py 22 | ``` 23 | -------------------------------------------------------------------------------- /docs/SkewFit.md: -------------------------------------------------------------------------------- 1 | # Skew-Fit 2 | Requires [multiworld](https://github.com/vitchyr/multiworld) to be installed: 3 | ``` 4 | pip install git+https://github.com/vitchyr/multiworld.git@f711cdb 5 | ``` 6 | 7 | Implementation of Skew-Fit. For more information: 8 | - [Videos](https://sites.google.com/view/skew-fit) 9 | - [arXiv](https://arxiv.org/abs/1903.03698) 10 | 11 | To reproduce the results, use these library versions as the performance seems to depend on the library version: 12 | - multiworld: f711cdb (git hash) 13 | - python: 3.5.2 14 | - torch: 0.4.1.post2 15 | - mujoco_py: 1.50.1.59 16 | - gym: 0.10.5 17 | 18 | Here are the results you should expect from each script. 19 | These plots are generated with [viskit](https://github.com/vitchyr/viskit) 20 | with smoothing on. 21 | 22 | Note that [RIG](RIG.md) is a special-case of Skew-Fit with `power=0`. 23 | 24 | 25 | [examples/skewfit/sawyer_door.py](../examples/skewfit/sawyer_door.py). 1 Seed: 26 | ![Skew-Fit Sawyer Door results](images/skewfit_door.png) 27 | 28 | [examples/skewfit/sawyer_pickup.py](../examples/skewfit/sawyer_pickup.py). 3 Seeds: 29 | ![Skew-Fit Sawyer Pickup results](images/skewfit_pickup.png) 30 | 31 | [examples/skewfit/sawyer_pusher.py](../examples/skewfit/sawyer_pusher.py). 9 Seeds: 32 | ![Skew-Fit Sawyer Pusher results](images/skewfit_pusher.png) 33 | -------------------------------------------------------------------------------- /docs/TDMs.md: -------------------------------------------------------------------------------- 1 | # Temporal Difference Models (TDMs) 2 | The TDM implementation is a bit different from the other algorithms. One reason for this is that the goals and rewards are retroactively relabelled. Some notable implementation details: 3 | - The networks (policy and QF) take in the goal and tau (the number of time steps left). 4 | - The algorithm relabels the terminal and rewards, so the terminal/reward from the environment are ignored completely. 5 | - TdmNormalizer is used to normalize the observations/states. If you want, you can totally ignore it and set `num_pretrain_path=0`. 6 | - The environments need to be [MultitaskEnv](rlkit/torch/tdm/envs/multitask_env), meaning standard gym environments won't work out of the box. See below for details. 7 | 8 | The example scripts have tuned hyperparameters. Specifically, the following hyperparameters are tuned, as they seem to be the most important ones to tune: 9 | - `num_updates_per_env_step` 10 | - `reward_scale` 11 | - `max_tau` 12 | 13 | 14 | ## Creating your own environment for TDM 15 | A [MultitaskEnv](envs/multitask_env.py) instances needs to implement 3 functions: 16 | 17 | ```python 18 | def goal_dim(self) -> int: 19 | """ 20 | :return: int, dimension of goal vector 21 | """ 22 | pass 23 | 24 | @abc.abstractmethod 25 | def sample_goals(self, batch_size): 26 | pass 27 | 28 | @abc.abstractmethod 29 | def convert_obs_to_goals(self, obs): 30 | pass 31 | ``` 32 | 33 | If you want to see how to make an existing environment multitask, see [GoalXVelHalfCheetah](envs/half_cheetah_env.py), which builds off of Gym's HalfCheetah environments. 34 | 35 | Another useful example might be [GoalXYPosAnt](envs/ant_env.py), which builds off a custom environment. 36 | 37 | One important thing is that the environment should *not* include the goal as part of the state, since the goal will be separately given to the networks. 38 | -------------------------------------------------------------------------------- /docs/goal_based_envs.md: -------------------------------------------------------------------------------- 1 | # Goal-based environments and `ObsDictRelabelingBuffer` 2 | Some algorithms, like HER, are for goal-conditioned environments, like 3 | the [OpenAI Gym GoalEnv](https://blog.openai.com/ingredients-for-robotics-research/) 4 | or the [multiworld MultitaskEnv](https://github.com/vitchyr/multiworld/) 5 | environments. 6 | 7 | These environments are different from normal gym environments in that they 8 | return dictionaries for observations, like so: the environments work like this: 9 | 10 | ``` 11 | env = CarEnv() 12 | obs = env.reset() 13 | next_obs, reward, done, info = env.step(action) 14 | print(obs) 15 | 16 | # Output: 17 | # { 18 | # 'observation': ..., 19 | # 'desired_goal': ..., 20 | # 'achieved_goal': ..., 21 | # } 22 | ``` 23 | The `GoalEnv` environments also have a function with signature 24 | ``` 25 | def compute_rewards (achieved_goal, desired_goal): 26 | # achieved_goal and desired_goal are vectors 27 | ``` 28 | while the `MultitaskEnv` has a signature like 29 | ``` 30 | def compute_rewards (observation, action, next_observation): 31 | # observation and next_observations are dictionaries 32 | ``` 33 | To learn more about these environments, check out the URLs above. 34 | This means that normal RL algorithms won't even "type check" with these 35 | environments. 36 | 37 | `ObsDictRelabelingBuffer` perform hindsight experience replay with 38 | either types of environments and works by saving specific values in the 39 | observation dictionary. 40 | 41 | -------------------------------------------------------------------------------- /docs/images/FetchReach-v1_HER-TD3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rail-berkeley/rlkit/ac45a9db24b89d97369bef302487273bcc3e3d84/docs/images/FetchReach-v1_HER-TD3.png -------------------------------------------------------------------------------- /docs/images/SawyerReachXYZEnv-v0_HER-TD3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rail-berkeley/rlkit/ac45a9db24b89d97369bef302487273bcc3e3d84/docs/images/SawyerReachXYZEnv-v0_HER-TD3.png -------------------------------------------------------------------------------- /docs/images/her_dqn.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rail-berkeley/rlkit/ac45a9db24b89d97369bef302487273bcc3e3d84/docs/images/her_dqn.png -------------------------------------------------------------------------------- /docs/images/her_td3_sawyer_reacher.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rail-berkeley/rlkit/ac45a9db24b89d97369bef302487273bcc3e3d84/docs/images/her_td3_sawyer_reacher.png -------------------------------------------------------------------------------- /docs/images/skewfit_door.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rail-berkeley/rlkit/ac45a9db24b89d97369bef302487273bcc3e3d84/docs/images/skewfit_door.png -------------------------------------------------------------------------------- /docs/images/skewfit_pickup.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rail-berkeley/rlkit/ac45a9db24b89d97369bef302487273bcc3e3d84/docs/images/skewfit_pickup.png -------------------------------------------------------------------------------- /docs/images/skewfit_pusher.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rail-berkeley/rlkit/ac45a9db24b89d97369bef302487273bcc3e3d84/docs/images/skewfit_pusher.png -------------------------------------------------------------------------------- /environment/docker/vendor/10_nvidia.json: -------------------------------------------------------------------------------- 1 | { 2 | "file_format_version" : "1.0.0", 3 | "ICD" : { 4 | "library_path" : "libEGL_nvidia.so.0" 5 | } 6 | } 7 | -------------------------------------------------------------------------------- /environment/docker/vendor/Xdummy-entrypoint: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | import argparse 3 | import os 4 | import sys 5 | import subprocess 6 | 7 | parser = argparse.ArgumentParser() 8 | args, extra_args = parser.parse_known_args() 9 | subprocess.Popen(["nohup", "Xdummy"], stdout=open('/dev/null', 'w'), stderr=open('/dev/null', 'w')) 10 | os.environ['DISPLAY'] = ':0' 11 | if not extra_args: 12 | sys.argv = ['/bin/bash'] 13 | else: 14 | sys.argv = extra_args 15 | # Explicitly flush right before the exec since otherwise things might get 16 | # lost in Python's buffers around stdout/stderr (!). 17 | sys.stdout.flush() 18 | sys.stderr.flush() 19 | os.execvpe(sys.argv[0], sys.argv, os.environ) 20 | 21 | -------------------------------------------------------------------------------- /environment/linux-cpu-env.yml: -------------------------------------------------------------------------------- 1 | name: rlkit 2 | channels: 3 | - kne # for pybox2d 4 | - pytorch 5 | - anaconda # for mkl 6 | dependencies: 7 | - cython 8 | - ipython # technically unnecessary 9 | - joblib=0.9.4 10 | - lockfile 11 | - mako=1.0.6=py35_0 12 | - matplotlib=2.0.2=np111py35_0 13 | - mkl=2018.0.2 # Need to add explicit dependence for pytorch 14 | - numba=0.35.0=np111py35_0 15 | - numpy=1.11.3 16 | - path.py=10.3.1=py35_0 17 | - pybox2d=2.3.1post2=py35_0 18 | - python=3.5.2 19 | - python-dateutil=2.6.1=py35_0 20 | - pytorch-cpu=0.4.1 21 | - scipy=1.0.1 22 | - patchelf 23 | - pip: 24 | - cloudpickle==0.5.2 25 | - gym[all]==0.10.5 26 | - gitpython==2.1.7 27 | - gtimer==1.0.0b5 28 | - pygame==1.9.2 29 | - ipdb # technically unnecessary 30 | -------------------------------------------------------------------------------- /environment/linux-gpu-env.yml: -------------------------------------------------------------------------------- 1 | name: rlkit 2 | channels: 3 | - kne # for pybox2d 4 | - pytorch 5 | - anaconda # for mkl 6 | dependencies: 7 | - cython 8 | - ipython # technically unnecessary 9 | - joblib=0.9.4 10 | - lockfile 11 | - mako=1.0.6=py35_0 12 | - matplotlib=2.0.2=np111py35_0 13 | - mkl=2018.0.2 # Need to add explicit dependence for pytorch 14 | - numba=0.35.0=np111py35_0 15 | - numpy=1.11.3 16 | - path.py=10.3.1=py35_0 17 | - pybox2d=2.3.1post2=py35_0 18 | - python=3.5.2 19 | - python-dateutil=2.6.1=py35_0 20 | - pytorch=0.4.1 21 | - scipy=1.0.1 22 | - patchelf 23 | - pip: 24 | - cloudpickle==0.5.2 25 | - gym[all]==0.10.5 26 | - gitpython==2.1.7 27 | - gtimer==1.0.0b5 28 | - pygame==1.9.2 29 | - ipdb # technically unnecessary 30 | -------------------------------------------------------------------------------- /environment/mac-env.yml: -------------------------------------------------------------------------------- 1 | name: rlkit 2 | channels: 3 | - kne # for pybox2d 4 | - pytorch 5 | - anaconda # for mkl 6 | dependencies: 7 | - cython 8 | - ipython # technically unnecessary 9 | - joblib=0.9.4 10 | - lockfile 11 | - mako=1.0.6=py35_0 12 | - matplotlib=2.0.2=np111py35_0 13 | - mkl=2018.0.2 # Need to add explicit dependence for pytorch 14 | - numba=0.35.0=np111py35_0 15 | - numpy=1.11.3 16 | - path.py=10.3.1=py35_0 17 | - pybox2d=2.3.1post2=py35_0 18 | - python=3.5.2 19 | - python-dateutil=2.6.1=py35_0 20 | - pytorch=0.4.1 21 | - scipy=1.0.1 22 | - pip: 23 | - cloudpickle==0.5.2 24 | - gym[all]==0.10.5 25 | - gitpython==2.1.7 26 | - gtimer==1.0.0b5 27 | - pygame==1.9.2 28 | - ipdb # technically unnecessary 29 | -------------------------------------------------------------------------------- /examples/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rail-berkeley/rlkit/ac45a9db24b89d97369bef302487273bcc3e3d84/examples/__init__.py -------------------------------------------------------------------------------- /examples/awac/README.md: -------------------------------------------------------------------------------- 1 | # AWAC 2 | 3 | This directory contains examples to run the implementation of advantage 4 | weighted actor critic (AWAC, pronounced "awake"). The paper with more details 5 | is available [here](https://arxiv.org/abs/2006.09359). 6 | 7 | ## Usage Instructions 8 | 9 | Running the dexterous manipulation experiments requires setting up the 10 | environments in this repository: 11 | [https://github.com/aravindr93/hand_dapg](https://github.com/aravindr93/hand_dapg). 12 | You can also use the follwing docker image, which has the required dependencies 13 | set up: anair17/railrl-hand-v3 14 | 15 | For the mj_envs repository, please use: 16 | https://github.com/anair13/mj_envs 17 | 18 | ## Data 19 | 20 | Data can be downloaded from the following links: 21 | 22 | MuJoCo benchmark tasks - https://drive.google.com/file/d/1IMUqUShv7tVqBvowvPkORe50KJ9gqPgR/view?usp=sharing 23 | 24 | Dexterous manipulation - https://drive.google.com/file/d/1yUdJnGgYit94X_AvV6JJP5Y3Lx2JF30Y/view?usp=sharing 25 | 26 | You will then have to update the paths in rlkit/launchers/experiments/awac/awac_rl.py 27 | -------------------------------------------------------------------------------- /examples/ddpg.py: -------------------------------------------------------------------------------- 1 | """ 2 | Example of running PyTorch implementation of DDPG on HalfCheetah. 3 | """ 4 | import copy 5 | 6 | from gym.envs.mujoco import HalfCheetahEnv 7 | 8 | from rlkit.data_management.env_replay_buffer import EnvReplayBuffer 9 | from rlkit.envs.wrappers import NormalizedBoxEnv 10 | from rlkit.exploration_strategies.base import ( 11 | PolicyWrappedWithExplorationStrategy 12 | ) 13 | from rlkit.exploration_strategies.ou_strategy import OUStrategy 14 | from rlkit.launchers.launcher_util import setup_logger 15 | from rlkit.samplers.data_collector import MdpPathCollector 16 | from rlkit.torch.networks import ConcatMlp, TanhMlpPolicy 17 | from rlkit.torch.ddpg.ddpg import DDPGTrainer 18 | import rlkit.torch.pytorch_util as ptu 19 | from rlkit.torch.torch_rl_algorithm import TorchBatchRLAlgorithm 20 | 21 | 22 | def experiment(variant): 23 | eval_env = NormalizedBoxEnv(HalfCheetahEnv()) 24 | expl_env = NormalizedBoxEnv(HalfCheetahEnv()) 25 | # Or for a specific version: 26 | # import gym 27 | # env = NormalizedBoxEnv(gym.make('HalfCheetah-v1')) 28 | obs_dim = eval_env.observation_space.low.size 29 | action_dim = eval_env.action_space.low.size 30 | qf = ConcatMlp( 31 | input_size=obs_dim + action_dim, 32 | output_size=1, 33 | **variant['qf_kwargs'] 34 | ) 35 | policy = TanhMlpPolicy( 36 | input_size=obs_dim, 37 | output_size=action_dim, 38 | **variant['policy_kwargs'] 39 | ) 40 | target_qf = copy.deepcopy(qf) 41 | target_policy = copy.deepcopy(policy) 42 | eval_path_collector = MdpPathCollector(eval_env, policy) 43 | exploration_policy = PolicyWrappedWithExplorationStrategy( 44 | exploration_strategy=OUStrategy(action_space=expl_env.action_space), 45 | policy=policy, 46 | ) 47 | expl_path_collector = MdpPathCollector(expl_env, exploration_policy) 48 | replay_buffer = EnvReplayBuffer(variant['replay_buffer_size'], expl_env) 49 | trainer = DDPGTrainer( 50 | qf=qf, 51 | target_qf=target_qf, 52 | policy=policy, 53 | target_policy=target_policy, 54 | **variant['trainer_kwargs'] 55 | ) 56 | algorithm = TorchBatchRLAlgorithm( 57 | trainer=trainer, 58 | exploration_env=expl_env, 59 | evaluation_env=eval_env, 60 | exploration_data_collector=expl_path_collector, 61 | evaluation_data_collector=eval_path_collector, 62 | replay_buffer=replay_buffer, 63 | **variant['algorithm_kwargs'] 64 | ) 65 | algorithm.to(ptu.device) 66 | algorithm.train() 67 | 68 | 69 | if __name__ == "__main__": 70 | # noinspection PyTypeChecker 71 | variant = dict( 72 | algorithm_kwargs=dict( 73 | num_epochs=1000, 74 | num_eval_steps_per_epoch=1000, 75 | num_trains_per_train_loop=1000, 76 | num_expl_steps_per_train_loop=1000, 77 | min_num_steps_before_training=10000, 78 | max_path_length=1000, 79 | batch_size=128, 80 | ), 81 | trainer_kwargs=dict( 82 | use_soft_update=True, 83 | tau=1e-2, 84 | discount=0.99, 85 | qf_learning_rate=1e-3, 86 | policy_learning_rate=1e-4, 87 | ), 88 | qf_kwargs=dict( 89 | hidden_sizes=[400, 300], 90 | ), 91 | policy_kwargs=dict( 92 | hidden_sizes=[400, 300], 93 | ), 94 | replay_buffer_size=int(1E6), 95 | ) 96 | # ptu.set_gpu_mode(True) # optionally set the GPU (default=False) 97 | setup_logger('name-of-experiment', variant=variant) 98 | experiment(variant) 99 | -------------------------------------------------------------------------------- /examples/doodad/ec2_example.py: -------------------------------------------------------------------------------- 1 | """ 2 | Example of running stuff on EC2 3 | """ 4 | import time 5 | 6 | from rlkit.core import logger 7 | from rlkit.launchers.launcher_util import run_experiment 8 | from datetime import datetime 9 | from pytz import timezone 10 | import pytz 11 | 12 | 13 | def example(variant): 14 | import torch 15 | logger.log(torch.__version__) 16 | date_format = '%m/%d/%Y %H:%M:%S %Z' 17 | date = datetime.now(tz=pytz.utc) 18 | logger.log("start") 19 | logger.log('Current date & time is: {}'.format(date.strftime(date_format))) 20 | if torch.cuda.is_available(): 21 | x = torch.randn(3) 22 | logger.log(str(x.to(ptu.device))) 23 | 24 | date = date.astimezone(timezone('US/Pacific')) 25 | logger.log('Local date & time is: {}'.format(date.strftime(date_format))) 26 | for i in range(variant['num_seconds']): 27 | logger.log("Tick, {}".format(i)) 28 | time.sleep(1) 29 | logger.log("end") 30 | logger.log('Local date & time is: {}'.format(date.strftime(date_format))) 31 | 32 | logger.log("start mujoco") 33 | from gym.envs.mujoco import HalfCheetahEnv 34 | e = HalfCheetahEnv() 35 | img = e.sim.render(32, 32) 36 | logger.log(str(sum(img))) 37 | logger.log("end mujocoy") 38 | 39 | 40 | if __name__ == "__main__": 41 | # noinspection PyTypeChecker 42 | date_format = '%m/%d/%Y %H:%M:%S %Z' 43 | date = datetime.now(tz=pytz.utc) 44 | logger.log("start") 45 | variant = dict( 46 | num_seconds=10, 47 | launch_time=str(date.strftime(date_format)), 48 | ) 49 | run_experiment( 50 | example, 51 | exp_prefix="ec2-test", 52 | mode='ec2', 53 | variant=variant, 54 | # use_gpu=True, # GPUs are much more expensive! 55 | ) 56 | -------------------------------------------------------------------------------- /examples/doodad/gcp_example.py: -------------------------------------------------------------------------------- 1 | """ 2 | Example of running stuff on GCP 3 | """ 4 | import time 5 | 6 | from rlkit.core import logger 7 | from rlkit.launchers.launcher_util import run_experiment 8 | from datetime import datetime 9 | from pytz import timezone 10 | import pytz 11 | 12 | 13 | def example(variant): 14 | import torch 15 | import rlkit.torch.pytorch_util as ptu 16 | print("Starting") 17 | logger.log(torch.__version__) 18 | date_format = '%m/%d/%Y %H:%M:%S %Z' 19 | date = datetime.now(tz=pytz.utc) 20 | logger.log("start") 21 | logger.log('Current date & time is: {}'.format(date.strftime(date_format))) 22 | logger.log("Cuda available: {}".format(torch.cuda.is_available())) 23 | if torch.cuda.is_available(): 24 | x = torch.randn(3) 25 | logger.log(str(x.to(ptu.device))) 26 | 27 | date = date.astimezone(timezone('US/Pacific')) 28 | logger.log('Local date & time is: {}'.format(date.strftime(date_format))) 29 | for i in range(variant['num_seconds']): 30 | logger.log("Tick, {}".format(i)) 31 | time.sleep(1) 32 | logger.log("end") 33 | logger.log('Local date & time is: {}'.format(date.strftime(date_format))) 34 | 35 | logger.log("start mujoco") 36 | from gym.envs.mujoco import HalfCheetahEnv 37 | e = HalfCheetahEnv() 38 | img = e.sim.render(32, 32) 39 | logger.log(str(sum(img))) 40 | logger.log("end mujoco") 41 | 42 | logger.record_tabular('Epoch', 1) 43 | logger.dump_tabular() 44 | logger.record_tabular('Epoch', 2) 45 | logger.dump_tabular() 46 | logger.record_tabular('Epoch', 3) 47 | logger.dump_tabular() 48 | print("Done") 49 | 50 | 51 | if __name__ == "__main__": 52 | # noinspection PyTypeChecker 53 | date_format = '%m/%d/%Y %H:%M:%S %Z' 54 | date = datetime.now(tz=pytz.utc) 55 | logger.log("start") 56 | variant = dict( 57 | num_seconds=10, 58 | launch_time=str(date.strftime(date_format)), 59 | ) 60 | run_experiment( 61 | example, 62 | exp_prefix="gcp-test", 63 | mode='gcp', 64 | variant=variant, 65 | # use_gpu=True, # GPUs are much more expensive! 66 | ) 67 | -------------------------------------------------------------------------------- /examples/dqn_and_double_dqn.py: -------------------------------------------------------------------------------- 1 | """ 2 | Run DQN on CartPole-v0. 3 | """ 4 | 5 | import gym 6 | from torch import nn as nn 7 | 8 | from rlkit.exploration_strategies.base import \ 9 | PolicyWrappedWithExplorationStrategy 10 | from rlkit.exploration_strategies.epsilon_greedy import EpsilonGreedy 11 | from rlkit.policies.argmax import ArgmaxDiscretePolicy 12 | from rlkit.torch.dqn.dqn import DQNTrainer 13 | from rlkit.torch.networks import Mlp 14 | import rlkit.torch.pytorch_util as ptu 15 | from rlkit.data_management.env_replay_buffer import EnvReplayBuffer 16 | from rlkit.launchers.launcher_util import setup_logger 17 | from rlkit.samplers.data_collector import MdpPathCollector 18 | from rlkit.torch.torch_rl_algorithm import TorchBatchRLAlgorithm 19 | 20 | 21 | def experiment(variant): 22 | expl_env = gym.make('CartPole-v0').env 23 | eval_env = gym.make('CartPole-v0').env 24 | obs_dim = expl_env.observation_space.low.size 25 | action_dim = eval_env.action_space.n 26 | 27 | qf = Mlp( 28 | hidden_sizes=[32, 32], 29 | input_size=obs_dim, 30 | output_size=action_dim, 31 | ) 32 | target_qf = Mlp( 33 | hidden_sizes=[32, 32], 34 | input_size=obs_dim, 35 | output_size=action_dim, 36 | ) 37 | qf_criterion = nn.MSELoss() 38 | eval_policy = ArgmaxDiscretePolicy(qf) 39 | expl_policy = PolicyWrappedWithExplorationStrategy( 40 | EpsilonGreedy(expl_env.action_space), 41 | eval_policy, 42 | ) 43 | eval_path_collector = MdpPathCollector( 44 | eval_env, 45 | eval_policy, 46 | ) 47 | expl_path_collector = MdpPathCollector( 48 | expl_env, 49 | expl_policy, 50 | ) 51 | trainer = DQNTrainer( 52 | qf=qf, 53 | target_qf=target_qf, 54 | qf_criterion=qf_criterion, 55 | **variant['trainer_kwargs'] 56 | ) 57 | replay_buffer = EnvReplayBuffer( 58 | variant['replay_buffer_size'], 59 | expl_env, 60 | ) 61 | algorithm = TorchBatchRLAlgorithm( 62 | trainer=trainer, 63 | exploration_env=expl_env, 64 | evaluation_env=eval_env, 65 | exploration_data_collector=expl_path_collector, 66 | evaluation_data_collector=eval_path_collector, 67 | replay_buffer=replay_buffer, 68 | **variant['algorithm_kwargs'] 69 | ) 70 | algorithm.to(ptu.device) 71 | algorithm.train() 72 | 73 | 74 | if __name__ == "__main__": 75 | # noinspection PyTypeChecker 76 | variant = dict( 77 | algorithm="DQN", 78 | version="normal", 79 | layer_size=256, 80 | replay_buffer_size=int(1E6), 81 | algorithm_kwargs=dict( 82 | num_epochs=3000, 83 | num_eval_steps_per_epoch=5000, 84 | num_trains_per_train_loop=1000, 85 | num_expl_steps_per_train_loop=1000, 86 | min_num_steps_before_training=1000, 87 | max_path_length=1000, 88 | batch_size=256, 89 | ), 90 | trainer_kwargs=dict( 91 | discount=0.99, 92 | learning_rate=3E-4, 93 | ), 94 | ) 95 | setup_logger('dqn-CartPole', variant=variant) 96 | # ptu.set_gpu_mode(True) # optionally set the GPU (default=False) 97 | experiment(variant) 98 | -------------------------------------------------------------------------------- /examples/iql/README.md: -------------------------------------------------------------------------------- 1 | # Implicit Q-Learning 2 | 3 | This repository contains a PyTorch re-implementation of [Offline Reinforcement Learning with Implicit Q-Learning](https://arxiv.org/abs/2110.06169) by [Ilya Kostrikov](https://kostrikov.xyz), [Ashvin Nair](https://ashvin.me/), and [Sergey Levine](https://people.eecs.berkeley.edu/~svlevine/). 4 | 5 | For the official repository, please use: https://github.com/ikostrikov/implicit_q_learning 6 | 7 | This code can be used for offline RL, or for offline RL followed by online finetuning. Negative epochs are offline RL, positive epochs are online (the agent is actively collecting data and adding it to the replay buffer). 8 | 9 | If you use this code for your research, please consider citing the paper: 10 | ``` 11 | @article{kostrikov2021iql, 12 | title={Offline Reinforcement Learning with Implicit Q-Learning}, 13 | author={Ilya Kostrikov and Ashvin Nair and Sergey Levine}, 14 | year={2021}, 15 | archivePrefix={arXiv}, 16 | primaryClass={cs.LG} 17 | } 18 | ``` 19 | 20 | ## Tests 21 | 22 | To run quick versions of these experiments to test if the code matches exactly as the results below, you can run the tests in `tests/regression/iql` 23 | 24 | ## Mujoco results 25 | ![Mujoco results](https://i.ibb.co/6Pd8KT7/download-79.png) 26 | 27 | ## Antmaze results 28 | ![Ant-maze results](https://i.ibb.co/HrTMY2P/download-77.png) 29 | -------------------------------------------------------------------------------- /examples/iql/antmaze_finetune.py: -------------------------------------------------------------------------------- 1 | """ 2 | AWR + SAC from demo experiment 3 | """ 4 | 5 | from rlkit.demos.source.hdf5_path_loader import HDF5PathLoader 6 | from rlkit.launchers.experiments.awac.finetune_rl import experiment, process_args 7 | 8 | from rlkit.launchers.launcher_util import run_experiment 9 | 10 | from rlkit.torch.sac.policies import GaussianPolicy 11 | from rlkit.torch.sac.iql_trainer import IQLTrainer 12 | 13 | import random 14 | 15 | import d4rl 16 | 17 | variant = dict( 18 | algo_kwargs=dict( 19 | start_epoch=-1000, # offline epochs 20 | num_epochs=1001, # online epochs 21 | batch_size=256, 22 | num_eval_steps_per_epoch=1000, 23 | num_trains_per_train_loop=1000, 24 | num_expl_steps_per_train_loop=1000, 25 | min_num_steps_before_training=1000, 26 | ), 27 | max_path_length=1000, 28 | replay_buffer_size=int(2E6), 29 | layer_size=256, 30 | policy_class=GaussianPolicy, 31 | policy_kwargs=dict( 32 | hidden_sizes=[256, 256, ], 33 | max_log_std=0, 34 | min_log_std=-6, 35 | std_architecture="values", 36 | ), 37 | qf_kwargs=dict( 38 | hidden_sizes=[256, 256, ], 39 | ), 40 | 41 | algorithm="SAC", 42 | version="normal", 43 | collection_mode='batch', 44 | trainer_class=IQLTrainer, 45 | trainer_kwargs=dict( 46 | discount=0.99, 47 | policy_lr=3E-4, 48 | qf_lr=3E-4, 49 | reward_scale=1, 50 | 51 | policy_weight_decay=0, 52 | q_weight_decay=0, 53 | 54 | reward_transform_kwargs=dict(m=1, b=-1), 55 | terminal_transform_kwargs=None, 56 | 57 | beta=0.1, 58 | quantile=0.9, 59 | clip_score=100, 60 | ), 61 | launcher_config=dict( 62 | num_exps_per_instance=1, 63 | region='us-west-2', 64 | ), 65 | 66 | path_loader_class=HDF5PathLoader, 67 | path_loader_kwargs=dict(), 68 | add_env_demos=False, 69 | add_env_offpolicy_data=False, 70 | 71 | load_demos=False, 72 | load_env_dataset_demos=True, 73 | 74 | normalize_env=False, 75 | env_id='antmaze-umaze-v0', 76 | 77 | seed=random.randint(0, 100000), 78 | ) 79 | 80 | def main(): 81 | run_experiment(experiment, 82 | variant=variant, 83 | exp_prefix='iql-antmaze-umaze-v0', 84 | mode="here_no_doodad", 85 | unpack_variant=False 86 | ) 87 | 88 | if __name__ == "__main__": 89 | main() 90 | -------------------------------------------------------------------------------- /examples/iql/mujoco_finetune.py: -------------------------------------------------------------------------------- 1 | """ 2 | AWR + SAC from demo experiment 3 | """ 4 | 5 | from rlkit.demos.source.hdf5_path_loader import HDF5PathLoader 6 | from rlkit.launchers.experiments.awac.finetune_rl import experiment, process_args 7 | 8 | from rlkit.launchers.launcher_util import run_experiment 9 | 10 | from rlkit.torch.sac.policies import GaussianPolicy 11 | from rlkit.torch.sac.iql_trainer import IQLTrainer 12 | 13 | import random 14 | 15 | import d4rl 16 | 17 | variant = dict( 18 | algo_kwargs=dict( 19 | start_epoch=-1000, # offline epochs 20 | num_epochs=1001, # online epochs 21 | batch_size=256, 22 | num_eval_steps_per_epoch=1000, 23 | num_trains_per_train_loop=1000, 24 | num_expl_steps_per_train_loop=1000, 25 | min_num_steps_before_training=1000, 26 | ), 27 | max_path_length=1000, 28 | replay_buffer_size=int(2E6), 29 | layer_size=256, 30 | policy_class=GaussianPolicy, 31 | policy_kwargs=dict( 32 | hidden_sizes=[256, 256, ], 33 | max_log_std=0, 34 | min_log_std=-6, 35 | std_architecture="values", 36 | ), 37 | qf_kwargs=dict( 38 | hidden_sizes=[256, 256, ], 39 | ), 40 | 41 | algorithm="SAC", 42 | version="normal", 43 | collection_mode='batch', 44 | trainer_class=IQLTrainer, 45 | trainer_kwargs=dict( 46 | discount=0.99, 47 | policy_lr=3E-4, 48 | qf_lr=3E-4, 49 | reward_scale=1, 50 | soft_target_tau=0.005, 51 | 52 | policy_weight_decay=0, 53 | q_weight_decay=0, 54 | 55 | reward_transform_kwargs=None, 56 | terminal_transform_kwargs=None, 57 | 58 | beta=1.0 / 3, 59 | quantile=0.7, 60 | clip_score=100, 61 | ), 62 | launcher_config=dict( 63 | num_exps_per_instance=1, 64 | region='us-west-2', 65 | ), 66 | 67 | path_loader_class=HDF5PathLoader, 68 | path_loader_kwargs=dict(), 69 | add_env_demos=False, 70 | add_env_offpolicy_data=False, 71 | 72 | load_demos=False, 73 | load_env_dataset_demos=True, 74 | 75 | normalize_env=False, 76 | env_id='halfcheetah-medium-v2', 77 | normalize_rewards_by_return_range=True, 78 | 79 | seed=random.randint(0, 100000), 80 | ) 81 | 82 | def main(): 83 | run_experiment(experiment, 84 | variant=variant, 85 | exp_prefix='iql-halfcheetah-medium-v2', 86 | mode="here_no_doodad", 87 | unpack_variant=False, 88 | use_gpu=False, 89 | ) 90 | 91 | if __name__ == "__main__": 92 | main() 93 | -------------------------------------------------------------------------------- /examples/sac.py: -------------------------------------------------------------------------------- 1 | from gym.envs.mujoco import HalfCheetahEnv 2 | 3 | import rlkit.torch.pytorch_util as ptu 4 | from rlkit.data_management.env_replay_buffer import EnvReplayBuffer 5 | from rlkit.envs.wrappers import NormalizedBoxEnv 6 | from rlkit.launchers.launcher_util import setup_logger 7 | from rlkit.samplers.data_collector import MdpPathCollector 8 | from rlkit.torch.sac.policies import TanhGaussianPolicy, MakeDeterministic 9 | from rlkit.torch.sac.sac import SACTrainer 10 | from rlkit.torch.networks import ConcatMlp 11 | from rlkit.torch.torch_rl_algorithm import TorchBatchRLAlgorithm 12 | 13 | 14 | def experiment(variant): 15 | expl_env = NormalizedBoxEnv(HalfCheetahEnv()) 16 | eval_env = NormalizedBoxEnv(HalfCheetahEnv()) 17 | obs_dim = expl_env.observation_space.low.size 18 | action_dim = eval_env.action_space.low.size 19 | 20 | M = variant['layer_size'] 21 | qf1 = ConcatMlp( 22 | input_size=obs_dim + action_dim, 23 | output_size=1, 24 | hidden_sizes=[M, M], 25 | ) 26 | qf2 = ConcatMlp( 27 | input_size=obs_dim + action_dim, 28 | output_size=1, 29 | hidden_sizes=[M, M], 30 | ) 31 | target_qf1 = ConcatMlp( 32 | input_size=obs_dim + action_dim, 33 | output_size=1, 34 | hidden_sizes=[M, M], 35 | ) 36 | target_qf2 = ConcatMlp( 37 | input_size=obs_dim + action_dim, 38 | output_size=1, 39 | hidden_sizes=[M, M], 40 | ) 41 | policy = TanhGaussianPolicy( 42 | obs_dim=obs_dim, 43 | action_dim=action_dim, 44 | hidden_sizes=[M, M], 45 | ) 46 | eval_policy = MakeDeterministic(policy) 47 | eval_path_collector = MdpPathCollector( 48 | eval_env, 49 | eval_policy, 50 | ) 51 | expl_path_collector = MdpPathCollector( 52 | expl_env, 53 | policy, 54 | ) 55 | replay_buffer = EnvReplayBuffer( 56 | variant['replay_buffer_size'], 57 | expl_env, 58 | ) 59 | trainer = SACTrainer( 60 | env=eval_env, 61 | policy=policy, 62 | qf1=qf1, 63 | qf2=qf2, 64 | target_qf1=target_qf1, 65 | target_qf2=target_qf2, 66 | **variant['trainer_kwargs'] 67 | ) 68 | algorithm = TorchBatchRLAlgorithm( 69 | trainer=trainer, 70 | exploration_env=expl_env, 71 | evaluation_env=eval_env, 72 | exploration_data_collector=expl_path_collector, 73 | evaluation_data_collector=eval_path_collector, 74 | replay_buffer=replay_buffer, 75 | **variant['algorithm_kwargs'] 76 | ) 77 | algorithm.to(ptu.device) 78 | algorithm.train() 79 | 80 | 81 | 82 | 83 | if __name__ == "__main__": 84 | # noinspection PyTypeChecker 85 | variant = dict( 86 | algorithm="SAC", 87 | version="normal", 88 | layer_size=256, 89 | replay_buffer_size=int(1E6), 90 | algorithm_kwargs=dict( 91 | num_epochs=3000, 92 | num_eval_steps_per_epoch=5000, 93 | num_trains_per_train_loop=1000, 94 | num_expl_steps_per_train_loop=1000, 95 | min_num_steps_before_training=1000, 96 | max_path_length=1000, 97 | batch_size=256, 98 | ), 99 | trainer_kwargs=dict( 100 | discount=0.99, 101 | soft_target_tau=5e-3, 102 | target_update_period=1, 103 | policy_lr=3E-4, 104 | qf_lr=3E-4, 105 | reward_scale=1, 106 | use_automatic_entropy_tuning=True, 107 | ), 108 | ) 109 | setup_logger('name-of-experiment', variant=variant) 110 | # ptu.set_gpu_mode(True) # optionally set the GPU (default=False) 111 | experiment(variant) 112 | -------------------------------------------------------------------------------- /examples/smac/ant.py: -------------------------------------------------------------------------------- 1 | from rlkit.torch.smac.base_config import DEFAULT_CONFIG 2 | from rlkit.torch.smac.launcher import smac_experiment 3 | import rlkit.util.hyperparameter as hyp 4 | 5 | 6 | # @click.command() 7 | # @click.option('--debug', is_flag=True, default=False) 8 | # @click.option('--dry', is_flag=True, default=False) 9 | # @click.option('--suffix', default=None) 10 | # @click.option('--nseeds', default=1) 11 | # @click.option('--mode', default='here_no_doodad') 12 | # def main(debug, dry, suffix, nseeds, mode): 13 | def main(): 14 | debug = True 15 | dry = False 16 | mode = 'here_no_doodad' 17 | suffix = '' 18 | nseeds = 1 19 | gpu = True 20 | 21 | path_parts = __file__.split('/') 22 | suffix = '' if suffix is None else '--{}'.format(suffix) 23 | exp_name = 'pearl-awac-{}--{}{}'.format( 24 | path_parts[-2].replace('_', '-'), 25 | path_parts[-1].split('.')[0].replace('_', '-'), 26 | suffix, 27 | ) 28 | 29 | if debug or dry: 30 | exp_name = 'dev--' + exp_name 31 | mode = 'here_no_doodad' 32 | nseeds = 1 33 | 34 | variant = DEFAULT_CONFIG.copy() 35 | variant["env_name"] = "ant-dir" 36 | variant["env_params"]["direction_in_degrees"] = True 37 | search_space = { 38 | 'load_buffer_kwargs.pretrain_buffer_path': [ 39 | "results/.../extra_snapshot_itr100.cpkl" # TODO: update to point to correct file 40 | ], 41 | 'saved_tasks_path': [ 42 | "examples/smac/ant_tasks.joblib", # TODO: update to point to correct file 43 | ], 44 | 'load_buffer_kwargs.start_idx': [ 45 | -1200, 46 | ], 47 | 'seed': list(range(nseeds)), 48 | } 49 | from rlkit.launchers.launcher_util import run_experiment 50 | sweeper = hyp.DeterministicHyperparameterSweeper( 51 | search_space, default_parameters=variant, 52 | ) 53 | for exp_id, variant in enumerate(sweeper.iterate_hyperparameters()): 54 | variant['exp_id'] = exp_id 55 | run_experiment( 56 | smac_experiment, 57 | unpack_variant=True, 58 | exp_prefix=exp_name, 59 | mode=mode, 60 | variant=variant, 61 | use_gpu=gpu, 62 | ) 63 | 64 | print(exp_name) 65 | 66 | 67 | 68 | 69 | if __name__ == "__main__": 70 | main() 71 | 72 | -------------------------------------------------------------------------------- /examples/smac/ant_tasks.joblib: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rail-berkeley/rlkit/ac45a9db24b89d97369bef302487273bcc3e3d84/examples/smac/ant_tasks.joblib -------------------------------------------------------------------------------- /examples/smac/cheetah.py: -------------------------------------------------------------------------------- 1 | from rlkit.launchers.launcher_util import run_experiment 2 | from rlkit.torch.smac.launcher import smac_experiment 3 | from rlkit.torch.smac.base_config import DEFAULT_CONFIG 4 | import rlkit.util.hyperparameter as hyp 5 | 6 | 7 | # @click.command() 8 | # @click.option('--debug', is_flag=True, default=False) 9 | # @click.option('--dry', is_flag=True, default=False) 10 | # @click.option('--suffix', default=None) 11 | # @click.option('--nseeds', default=1) 12 | # @click.option('--mode', default='here_no_doodad') 13 | # def main(debug, dry, suffix, nseeds, mode): 14 | def main(): 15 | debug = True 16 | dry = False 17 | mode = 'here_no_doodad' 18 | suffix = '' 19 | nseeds = 1 20 | gpu=True 21 | 22 | path_parts = __file__.split('/') 23 | suffix = '' if suffix is None else '--{}'.format(suffix) 24 | exp_name = 'pearl-awac-{}--{}{}'.format( 25 | path_parts[-2].replace('_', '-'), 26 | path_parts[-1].split('.')[0].replace('_', '-'), 27 | suffix, 28 | ) 29 | 30 | if debug or dry: 31 | exp_name = 'dev--' + exp_name 32 | mode = 'here_no_doodad' 33 | nseeds = 1 34 | 35 | print(exp_name) 36 | 37 | variant = DEFAULT_CONFIG.copy() 38 | variant["env_name"] = "cheetah-vel" 39 | search_space = { 40 | 'load_buffer_kwargs.pretrain_buffer_path': [ 41 | "results/.../extra_snapshot_itr100.cpkl" # TODO: update to point to correct file 42 | ], 43 | 'saved_tasks_path': [ 44 | "examples/smac/cheetah_tasks.joblib", # TODO: update to point to correct file 45 | ], 46 | 'load_macaw_buffer_kwargs.rl_buffer_start_end_idxs': [ 47 | [(0, 1200)], 48 | ], 49 | 'load_macaw_buffer_kwargs.encoder_buffer_start_end_idxs': [ 50 | [(-400, None)], 51 | ], 52 | 'load_macaw_buffer_kwargs.encoder_buffer_matches_rl_buffer': [ 53 | False, 54 | ], 55 | 'algo_kwargs.use_rl_buffer_for_enc_buffer': [ 56 | False, 57 | ], 58 | 'algo_kwargs.train_encoder_decoder_in_unsupervised_phase': [ 59 | False, 60 | ], 61 | 'algo_kwargs.freeze_encoder_buffer_in_unsupervised_phase': [ 62 | False, 63 | ], 64 | 'algo_kwargs.use_encoder_snapshot_for_reward_pred_in_unsupervised_phase': [ 65 | True, 66 | ], 67 | 'pretrain_offline_algo_kwargs.logging_period': [ 68 | 25000, 69 | ], 70 | 'algo_kwargs.num_iterations': [ 71 | 51, 72 | ], 73 | 'seed': list(range(nseeds)), 74 | } 75 | sweeper = hyp.DeterministicHyperparameterSweeper( 76 | search_space, default_parameters=variant, 77 | ) 78 | for exp_id, variant in enumerate(sweeper.iterate_hyperparameters()): 79 | variant['exp_id'] = exp_id 80 | run_experiment( 81 | smac_experiment, 82 | unpack_variant=True, 83 | exp_prefix=exp_name, 84 | mode=mode, 85 | variant=variant, 86 | use_gpu=gpu, 87 | ) 88 | 89 | print(exp_name) 90 | 91 | 92 | 93 | 94 | if __name__ == "__main__": 95 | main() 96 | 97 | -------------------------------------------------------------------------------- /examples/smac/cheetah_tasks.joblib: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rail-berkeley/rlkit/ac45a9db24b89d97369bef302487273bcc3e3d84/examples/smac/cheetah_tasks.joblib -------------------------------------------------------------------------------- /examples/smac/generate_ant_data.py: -------------------------------------------------------------------------------- 1 | import rlkit.util.hyperparameter as hyp 2 | from rlkit.launchers.launcher_util import run_experiment 3 | from rlkit.torch.smac.base_config import DEFAULT_PEARL_CONFIG 4 | from rlkit.torch.smac.pearl_launcher import pearl_experiment 5 | from rlkit.util.io import load_local_or_remote_file 6 | 7 | 8 | # @click.command() 9 | # @click.option('--debug', is_flag=True, default=False) 10 | # @click.option('--dry', is_flag=True, default=False) 11 | # @click.option('--suffix', default=None) 12 | # @click.option('--nseeds', default=1) 13 | # @click.option('--mode', default='here_no_doodad') 14 | # def main(debug, dry, suffix, nseeds, mode): 15 | def main(): 16 | debug = True 17 | dry = False 18 | mode = 'here_no_doodad' 19 | suffix = '' 20 | nseeds = 1 21 | gpu = True 22 | 23 | path_parts = __file__.split('/') 24 | suffix = '' if suffix is None else '--{}'.format(suffix) 25 | exp_name = 'pearl-awac-{}--{}{}'.format( 26 | path_parts[-2].replace('_', '-'), 27 | path_parts[-1].split('.')[0].replace('_', '-'), 28 | suffix, 29 | ) 30 | 31 | if debug or dry: 32 | exp_name = 'dev--' + exp_name 33 | mode = 'here_no_doodad' 34 | nseeds = 1 35 | 36 | if dry: 37 | mode = 'here_no_doodad' 38 | 39 | print(exp_name) 40 | 41 | task_data = load_local_or_remote_file( 42 | "examples/smac/ant_tasks.joblib", # TODO: update to point to correct file 43 | file_type='joblib') 44 | tasks = task_data['tasks'] 45 | search_space = { 46 | 'seed': list(range(nseeds)), 47 | } 48 | variant = DEFAULT_PEARL_CONFIG.copy() 49 | variant["env_name"] = "ant-dir" 50 | # variant["train_task_idxs"] = list(range(100)) 51 | # variant["eval_task_idxs"] = list(range(100, 120)) 52 | variant["env_params"]["fixed_tasks"] = [t['goal'] for t in tasks] 53 | variant["env_params"]["direction_in_degrees"] = True 54 | variant["trainer_kwargs"]["train_context_decoder"] = True 55 | variant["trainer_kwargs"]["backprop_q_loss_into_encoder"] = True 56 | variant["saved_tasks_path"] = "examples/smac/ant_tasks.joblib" # TODO: update to point to correct file 57 | 58 | sweeper = hyp.DeterministicHyperparameterSweeper( 59 | search_space, default_parameters=variant, 60 | ) 61 | for exp_id, variant in enumerate(sweeper.iterate_hyperparameters()): 62 | variant['exp_id'] = exp_id 63 | run_experiment( 64 | pearl_experiment, 65 | unpack_variant=True, 66 | exp_prefix=exp_name, 67 | mode=mode, 68 | variant=variant, 69 | time_in_mins=3 * 24 * 60 - 1, 70 | use_gpu=gpu, 71 | ) 72 | 73 | if __name__ == "__main__": 74 | main() 75 | 76 | -------------------------------------------------------------------------------- /examples/smac/generate_cheetah_data.py: -------------------------------------------------------------------------------- 1 | """ 2 | PEARL Experiment 3 | """ 4 | 5 | import rlkit.util.hyperparameter as hyp 6 | from rlkit.launchers.launcher_util import run_experiment 7 | from rlkit.torch.smac.base_config import DEFAULT_PEARL_CONFIG 8 | 9 | from rlkit.torch.smac.pearl_launcher import pearl_experiment 10 | from rlkit.util.io import load_local_or_remote_file 11 | 12 | 13 | # @click.command() 14 | # @click.option('--debug', is_flag=True, default=False) 15 | # @click.option('--dry', is_flag=True, default=False) 16 | # @click.option('--suffix', default=None) 17 | # @click.option('--nseeds', default=1) 18 | # @click.option('--mode', default='here_no_doodad') 19 | # def main(debug, dry, suffix, nseeds, mode): 20 | def main(): 21 | debug = True 22 | dry = False 23 | mode = 'here_no_doodad' 24 | suffix = '' 25 | nseeds = 1 26 | gpu = True 27 | 28 | path_parts = __file__.split('/') 29 | suffix = '' if suffix is None else '--{}'.format(suffix) 30 | exp_name = 'pearl-awac-{}--{}{}'.format( 31 | path_parts[-2].replace('_', '-'), 32 | path_parts[-1].split('.')[0].replace('_', '-'), 33 | suffix, 34 | ) 35 | 36 | if debug or dry: 37 | exp_name = 'dev--' + exp_name 38 | mode = 'here_no_doodad' 39 | nseeds = 1 40 | 41 | if dry: 42 | mode = 'here_no_doodad' 43 | 44 | print(exp_name) 45 | 46 | search_space = { 47 | 'seed': list(range(nseeds)), 48 | } 49 | variant = DEFAULT_PEARL_CONFIG.copy() 50 | variant["env_name"] = "cheetah-vel" 51 | variant['trainer_kwargs']["train_context_decoder"] = True 52 | variant["saved_tasks_path"] = "examples/smac/cheetah_tasks.joblib" # TODO: update to point to correct file 53 | 54 | sweeper = hyp.DeterministicHyperparameterSweeper( 55 | search_space, default_parameters=variant, 56 | ) 57 | for exp_id, variant in enumerate(sweeper.iterate_hyperparameters()): 58 | variant['exp_id'] = exp_id 59 | run_experiment( 60 | pearl_experiment, 61 | unpack_variant=True, 62 | exp_prefix=exp_name, 63 | mode=mode, 64 | variant=variant, 65 | time_in_mins=3 * 24 * 60 - 1, 66 | use_gpu=gpu, 67 | ) 68 | 69 | 70 | if __name__ == "__main__": 71 | main() 72 | 73 | -------------------------------------------------------------------------------- /rlkit/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rail-berkeley/rlkit/ac45a9db24b89d97369bef302487273bcc3e3d84/rlkit/__init__.py -------------------------------------------------------------------------------- /rlkit/core/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | General classes, functions, utilities that are used throughout rlkit. 3 | """ 4 | from rlkit.core.logging import logger 5 | 6 | __all__ = ['logger'] 7 | 8 | -------------------------------------------------------------------------------- /rlkit/core/loss.py: -------------------------------------------------------------------------------- 1 | import abc 2 | from collections import OrderedDict 3 | 4 | 5 | LossStatistics = OrderedDict 6 | 7 | 8 | class LossFunction(object, metaclass=abc.ABCMeta): 9 | @abc.abstractmethod 10 | def compute_loss(self, batch, skip_statistics=False, **kwargs): 11 | """Returns loss and statistics given a batch of data. 12 | batch : Data to compute loss of 13 | skip_statistics: Whether statistics should be calculated. If True, then 14 | an empty dict is returned for the statistics. 15 | 16 | Returns: (loss, stats) tuple. 17 | """ 18 | pass 19 | -------------------------------------------------------------------------------- /rlkit/core/serializable.py: -------------------------------------------------------------------------------- 1 | """ 2 | Based on rllab's serializable.py file 3 | 4 | https://github.com/rll/rllab 5 | """ 6 | 7 | import inspect 8 | import sys 9 | 10 | 11 | class Serializable(object): 12 | 13 | def __init__(self, *args, **kwargs): 14 | self.__args = args 15 | self.__kwargs = kwargs 16 | 17 | def quick_init(self, locals_): 18 | if getattr(self, "_serializable_initialized", False): 19 | return 20 | if sys.version_info >= (3, 0): 21 | spec = inspect.getfullargspec(self.__init__) 22 | # Exclude the first "self" parameter 23 | if spec.varkw: 24 | kwargs = locals_[spec.varkw].copy() 25 | else: 26 | kwargs = dict() 27 | if spec.kwonlyargs: 28 | for key in spec.kwonlyargs: 29 | kwargs[key] = locals_[key] 30 | else: 31 | spec = inspect.getargspec(self.__init__) 32 | if spec.keywords: 33 | kwargs = locals_[spec.keywords] 34 | else: 35 | kwargs = dict() 36 | if spec.varargs: 37 | varargs = locals_[spec.varargs] 38 | else: 39 | varargs = tuple() 40 | in_order_args = [locals_[arg] for arg in spec.args][1:] 41 | self.__args = tuple(in_order_args) + varargs 42 | self.__kwargs = kwargs 43 | setattr(self, "_serializable_initialized", True) 44 | 45 | def __getstate__(self): 46 | return {"__args": self.__args, "__kwargs": self.__kwargs} 47 | 48 | def __setstate__(self, d): 49 | # convert all __args to keyword-based arguments 50 | if sys.version_info >= (3, 0): 51 | spec = inspect.getfullargspec(self.__init__) 52 | else: 53 | spec = inspect.getargspec(self.__init__) 54 | in_order_args = spec.args[1:] 55 | out = type(self)(**dict(zip(in_order_args, d["__args"]), **d["__kwargs"])) 56 | self.__dict__.update(out.__dict__) 57 | 58 | @classmethod 59 | def clone(cls, obj, **kwargs): 60 | assert isinstance(obj, Serializable) 61 | d = obj.__getstate__() 62 | d["__kwargs"] = dict(d["__kwargs"], **kwargs) 63 | out = type(obj).__new__(type(obj)) 64 | out.__setstate__(d) 65 | return out 66 | -------------------------------------------------------------------------------- /rlkit/core/timer.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | from collections import defaultdict 4 | 5 | 6 | class Timer: 7 | def __init__(self, return_global_times=False): 8 | self.stamps = None 9 | self.epoch_start_time = None 10 | self.global_start_time = time.time() 11 | self._return_global_times = return_global_times 12 | 13 | self.reset() 14 | 15 | def reset(self): 16 | self.stamps = defaultdict(lambda: 0) 17 | self.start_times = {} 18 | self.epoch_start_time = time.time() 19 | 20 | def start_timer(self, name, unique=True): 21 | if unique: 22 | assert name not in self.start_times.keys() 23 | self.start_times[name] = time.time() 24 | 25 | def stop_timer(self, name): 26 | assert name in self.start_times.keys() 27 | start_time = self.start_times[name] 28 | end_time = time.time() 29 | self.stamps[name] += (end_time - start_time) 30 | 31 | def get_times(self): 32 | global_times = {} 33 | cur_time = time.time() 34 | global_times['epoch_time'] = (cur_time - self.epoch_start_time) 35 | if self._return_global_times: 36 | global_times['global_time'] = (cur_time - self.global_start_time) 37 | return { 38 | **self.stamps.copy(), 39 | **global_times, 40 | } 41 | 42 | @property 43 | def return_global_times(self): 44 | return self._return_global_times 45 | 46 | @return_global_times.setter 47 | def return_global_times(self, value): 48 | self._return_global_times = value 49 | 50 | 51 | timer = Timer() 52 | -------------------------------------------------------------------------------- /rlkit/core/trainer.py: -------------------------------------------------------------------------------- 1 | import abc 2 | 3 | 4 | class Trainer(object, metaclass=abc.ABCMeta): 5 | @abc.abstractmethod 6 | def train(self, data): 7 | pass 8 | 9 | def end_epoch(self, epoch): 10 | pass 11 | 12 | def get_snapshot(self): 13 | return {} 14 | 15 | def get_diagnostics(self): 16 | return {} 17 | -------------------------------------------------------------------------------- /rlkit/data_management/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rail-berkeley/rlkit/ac45a9db24b89d97369bef302487273bcc3e3d84/rlkit/data_management/__init__.py -------------------------------------------------------------------------------- /rlkit/data_management/env_replay_buffer.py: -------------------------------------------------------------------------------- 1 | from gym.spaces import Discrete 2 | 3 | from rlkit.data_management.simple_replay_buffer import SimpleReplayBuffer 4 | from rlkit.envs.env_utils import get_dim 5 | import numpy as np 6 | 7 | 8 | class EnvReplayBuffer(SimpleReplayBuffer): 9 | def __init__( 10 | self, 11 | max_replay_buffer_size, 12 | env, 13 | env_info_sizes=None 14 | ): 15 | """ 16 | :param max_replay_buffer_size: 17 | :param env: 18 | """ 19 | self.env = env 20 | self._ob_space = env.observation_space 21 | self._action_space = env.action_space 22 | 23 | if env_info_sizes is None: 24 | if hasattr(env, 'info_sizes'): 25 | env_info_sizes = env.info_sizes 26 | else: 27 | env_info_sizes = dict() 28 | 29 | super().__init__( 30 | max_replay_buffer_size=max_replay_buffer_size, 31 | observation_dim=get_dim(self._ob_space), 32 | action_dim=get_dim(self._action_space), 33 | env_info_sizes=env_info_sizes 34 | ) 35 | 36 | def add_sample(self, observation, action, reward, terminal, 37 | next_observation, **kwargs): 38 | if isinstance(self._action_space, Discrete): 39 | new_action = np.zeros(self._action_dim) 40 | new_action[action] = 1 41 | else: 42 | new_action = action 43 | return super().add_sample( 44 | observation=observation, 45 | action=new_action, 46 | reward=reward, 47 | next_observation=next_observation, 48 | terminal=terminal, 49 | **kwargs 50 | ) 51 | -------------------------------------------------------------------------------- /rlkit/data_management/path_builder.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | class PathBuilder(dict): 5 | """ 6 | Usage: 7 | ``` 8 | path_builder = PathBuilder() 9 | path.add_sample( 10 | observations=1, 11 | actions=2, 12 | next_observations=3, 13 | ... 14 | ) 15 | path.add_sample( 16 | observations=4, 17 | actions=5, 18 | next_observations=6, 19 | ... 20 | ) 21 | 22 | path = path_builder.get_all_stacked() 23 | 24 | path['observations'] 25 | # output: [1, 4] 26 | path['actions'] 27 | # output: [2, 5] 28 | ``` 29 | 30 | Note that the key should be "actions" and not "action" since the 31 | resulting dictionary will have those keys. 32 | """ 33 | 34 | def __init__(self): 35 | super().__init__() 36 | self._path_length = 0 37 | 38 | def add_all(self, **key_to_value): 39 | for k, v in key_to_value.items(): 40 | if k not in self: 41 | self[k] = [v] 42 | else: 43 | self[k].append(v) 44 | self._path_length += 1 45 | 46 | def get_all_stacked(self): 47 | output_dict = dict() 48 | for k, v in self.items(): 49 | output_dict[k] = stack_list(v) 50 | return output_dict 51 | 52 | def __len__(self): 53 | return self._path_length 54 | 55 | 56 | def stack_list(lst): 57 | if isinstance(lst[0], dict): 58 | return lst 59 | else: 60 | return np.array(lst) 61 | -------------------------------------------------------------------------------- /rlkit/data_management/replay_buffer.py: -------------------------------------------------------------------------------- 1 | import abc 2 | 3 | 4 | class ReplayBuffer(object, metaclass=abc.ABCMeta): 5 | """ 6 | A class used to save and replay data. 7 | """ 8 | 9 | @abc.abstractmethod 10 | def add_sample(self, observation, action, reward, next_observation, 11 | terminal, **kwargs): 12 | """ 13 | Add a transition tuple. 14 | """ 15 | pass 16 | 17 | @abc.abstractmethod 18 | def terminate_episode(self): 19 | """ 20 | Let the replay buffer know that the episode has terminated in case some 21 | special book-keeping has to happen. 22 | :return: 23 | """ 24 | pass 25 | 26 | @abc.abstractmethod 27 | def num_steps_can_sample(self, **kwargs): 28 | """ 29 | :return: # of unique items that can be sampled. 30 | """ 31 | pass 32 | 33 | def add_path(self, path): 34 | """ 35 | Add a path to the replay buffer. 36 | 37 | This default implementation naively goes through every step, but you 38 | may want to optimize this. 39 | 40 | NOTE: You should NOT call "terminate_episode" after calling add_path. 41 | It's assumed that this function handles the episode termination. 42 | 43 | :param path: Dict like one outputted by rlkit.samplers.util.rollout 44 | """ 45 | for i, ( 46 | obs, 47 | action, 48 | reward, 49 | next_obs, 50 | terminal, 51 | agent_info, 52 | env_info 53 | ) in enumerate(zip( 54 | path["observations"], 55 | path["actions"], 56 | path["rewards"], 57 | path["next_observations"], 58 | path["terminals"], 59 | path["agent_infos"], 60 | path["env_infos"], 61 | )): 62 | self.add_sample( 63 | observation=obs, 64 | action=action, 65 | reward=reward, 66 | next_observation=next_obs, 67 | terminal=terminal, 68 | agent_info=agent_info, 69 | env_info=env_info, 70 | ) 71 | self.terminate_episode() 72 | 73 | def add_paths(self, paths): 74 | for path in paths: 75 | self.add_path(path) 76 | 77 | @abc.abstractmethod 78 | def random_batch(self, batch_size): 79 | """ 80 | Return a batch of size `batch_size`. 81 | :param batch_size: 82 | :return: 83 | """ 84 | pass 85 | 86 | def get_diagnostics(self): 87 | return {} 88 | 89 | def get_snapshot(self): 90 | return {} 91 | 92 | def end_epoch(self, epoch): 93 | return 94 | 95 | -------------------------------------------------------------------------------- /rlkit/data_management/split_buffer.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | from rlkit.data_management.replay_buffer import ReplayBuffer 4 | 5 | 6 | class SplitReplayBuffer(ReplayBuffer): 7 | """ 8 | Split the data into a training and validation set. 9 | """ 10 | def __init__( 11 | self, 12 | train_replay_buffer: ReplayBuffer, 13 | validation_replay_buffer: ReplayBuffer, 14 | fraction_paths_in_train, 15 | ): 16 | self.train_replay_buffer = train_replay_buffer 17 | self.validation_replay_buffer = validation_replay_buffer 18 | self.fraction_paths_in_train = fraction_paths_in_train 19 | self.replay_buffer = self.train_replay_buffer 20 | 21 | def add_sample(self, *args, **kwargs): 22 | self.replay_buffer.add_sample(*args, **kwargs) 23 | 24 | def add_path(self, path): 25 | self.replay_buffer.add_path(path) 26 | self._randomly_set_replay_buffer() 27 | 28 | def num_steps_can_sample(self): 29 | return min( 30 | self.train_replay_buffer.num_steps_can_sample(), 31 | self.validation_replay_buffer.num_steps_can_sample(), 32 | ) 33 | 34 | def terminate_episode(self, *args, **kwargs): 35 | self.replay_buffer.terminate_episode(*args, **kwargs) 36 | self._randomly_set_replay_buffer() 37 | 38 | def _randomly_set_replay_buffer(self): 39 | if random.random() <= self.fraction_paths_in_train: 40 | self.replay_buffer = self.train_replay_buffer 41 | else: 42 | self.replay_buffer = self.validation_replay_buffer 43 | 44 | def get_replay_buffer(self, training=True): 45 | if training: 46 | return self.train_replay_buffer 47 | else: 48 | return self.validation_replay_buffer 49 | 50 | def random_batch(self, batch_size): 51 | return self.train_replay_buffer.random_batch(batch_size) 52 | 53 | def __getattr__(self, attrname): 54 | return getattr(self.replay_buffer, attrname) 55 | 56 | def __getstate__(self): 57 | # Do not save self.replay_buffer since it's a duplicate and seems to 58 | # cause joblib recursion issues. 59 | return dict( 60 | train_replay_buffer=self.train_replay_buffer, 61 | validation_replay_buffer=self.validation_replay_buffer, 62 | fraction_paths_in_train=self.fraction_paths_in_train, 63 | ) 64 | 65 | def __setstate__(self, d): 66 | self.train_replay_buffer = d['train_replay_buffer'] 67 | self.validation_replay_buffer = d['validation_replay_buffer'] 68 | self.fraction_paths_in_train = d['fraction_paths_in_train'] 69 | self.replay_buffer = self.train_replay_buffer 70 | -------------------------------------------------------------------------------- /rlkit/demos/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rail-berkeley/rlkit/ac45a9db24b89d97369bef302487273bcc3e3d84/rlkit/demos/__init__.py -------------------------------------------------------------------------------- /rlkit/demos/play_demo.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import sys 4 | # print(sys.path) 5 | sys.path.remove("/opt/ros/kinetic/lib/python2.7/dist-packages") 6 | 7 | import cv2 8 | import sys 9 | import pickle 10 | 11 | def play_demos(path): 12 | data = pickle.load(open(path, "rb")) 13 | # data = np.load(path, allow_pickle=True) 14 | 15 | for traj in data: 16 | obs = traj["observations"] 17 | 18 | for o in obs: 19 | img = o["image_observation"].reshape(3, 500, 300)[:, 60:, :240].transpose() 20 | img = img[:, :, ::-1] 21 | cv2.imshow('window', img) 22 | cv2.waitKey(100) 23 | 24 | if __name__ == '__main__': 25 | demo_path = sys.argv[1] 26 | play_demos(demo_path) 27 | -------------------------------------------------------------------------------- /rlkit/demos/source/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rail-berkeley/rlkit/ac45a9db24b89d97369bef302487273bcc3e3d84/rlkit/demos/source/__init__.py -------------------------------------------------------------------------------- /rlkit/demos/source/demo_source.py: -------------------------------------------------------------------------------- 1 | class DemoSource: 2 | def load_paths(self): 3 | """Should return a list of paths in PathBuilder format""" 4 | return [{}, ] 5 | -------------------------------------------------------------------------------- /rlkit/demos/source/hand_demo_source.py: -------------------------------------------------------------------------------- 1 | from rlkit.demos.source.demo_source import DemoSource 2 | import pickle 3 | 4 | from rlkit.data_management.path_builder import PathBuilder 5 | 6 | from rlkit.util.io import load_local_or_remote_file 7 | 8 | class HandDemoSource(DemoSource): 9 | def __init__(self, filename): 10 | self.data = load_local_or_remote_file(filename) 11 | 12 | def load_paths(self): 13 | paths = [] 14 | for i in range(len(self.data)): 15 | p = self.data[i] 16 | H = len(p["observations"]) - 1 17 | 18 | path_builder = PathBuilder() 19 | 20 | for t in range(H): 21 | p["observations"][t] 22 | 23 | ob = path["observations"][t, :] 24 | action = path["actions"][t, :] 25 | reward = path["rewards"][t] 26 | next_ob = path["observations"][t+1, :] 27 | terminal = 0 28 | agent_info = {} # todo (need to unwrap each key) 29 | env_info = {} # todo (need to unwrap each key) 30 | 31 | path_builder.add_all( 32 | observations=ob, 33 | actions=action, 34 | rewards=reward, 35 | next_observations=next_ob, 36 | terminals=terminal, 37 | agent_infos=agent_info, 38 | env_infos=env_info, 39 | ) 40 | 41 | path = path_builder.get_all_stacked() 42 | paths.append(path) 43 | return paths 44 | -------------------------------------------------------------------------------- /rlkit/demos/source/hdf5_path_loader.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | 3 | import numpy as np 4 | import torch 5 | import torch.optim as optim 6 | from torch import nn as nn 7 | import torch.nn.functional as F 8 | 9 | import rlkit.torch.pytorch_util as ptu 10 | from rlkit.core.eval_util import create_stats_ordered_dict 11 | from rlkit.torch.torch_rl_algorithm import TorchTrainer 12 | 13 | from rlkit.util.io import load_local_or_remote_file 14 | 15 | import random 16 | from rlkit.torch.core import np_to_pytorch_batch 17 | from rlkit.data_management.path_builder import PathBuilder 18 | 19 | # import matplotlib 20 | # matplotlib.use('TkAgg') 21 | # import matplotlib.pyplot as plt 22 | 23 | from rlkit.core import logger 24 | 25 | import glob 26 | 27 | def load_hdf5(dataset, replay_buffer): 28 | _obs = dataset['observations'] 29 | N = _obs.shape[0] 30 | assert replay_buffer._max_replay_buffer_size >= N, "dataset does not fit in replay buffer" 31 | 32 | _actions = dataset['actions'] 33 | _next_obs = dataset['next_observations'] 34 | _rew = dataset['rewards'][:N] 35 | _done = dataset['terminals'][:N] 36 | 37 | replay_buffer._observations[:N] = _obs[:N] 38 | replay_buffer._next_obs[:N] = _next_obs[:N] 39 | replay_buffer._actions[:N] = _actions[:N] 40 | replay_buffer._rewards[:N] = np.expand_dims(_rew, 1)[:N] 41 | replay_buffer._terminals[:N] = np.expand_dims(_done, 1)[:N] 42 | replay_buffer._size = N-1 43 | replay_buffer._top = replay_buffer._size 44 | 45 | class HDF5PathLoader: 46 | """ 47 | Path loader for that loads obs-dict demonstrations 48 | into a Trainer with EnvReplayBuffer 49 | """ 50 | 51 | def __init__( 52 | self, 53 | trainer, 54 | replay_buffer, 55 | demo_train_buffer, 56 | demo_test_buffer, 57 | demo_paths=[], # list of dicts 58 | demo_train_split=0.9, 59 | demo_data_split=1, 60 | add_demos_to_replay_buffer=True, 61 | bc_num_pretrain_steps=0, 62 | bc_batch_size=64, 63 | bc_weight=1.0, 64 | rl_weight=1.0, 65 | q_num_pretrain_steps=0, 66 | weight_decay=0, 67 | eval_policy=None, 68 | recompute_reward=False, 69 | env_info_key=None, 70 | obs_key=None, 71 | load_terminals=True, 72 | 73 | **kwargs 74 | ): 75 | self.trainer = trainer 76 | 77 | self.add_demos_to_replay_buffer = add_demos_to_replay_buffer 78 | self.demo_train_split = demo_train_split 79 | self.demo_data_split = demo_data_split 80 | self.replay_buffer = replay_buffer 81 | self.demo_train_buffer = demo_train_buffer 82 | self.demo_test_buffer = demo_test_buffer 83 | 84 | self.demo_paths = demo_paths 85 | 86 | self.bc_num_pretrain_steps = bc_num_pretrain_steps 87 | self.q_num_pretrain_steps = q_num_pretrain_steps 88 | self.demo_trajectory_rewards = [] 89 | 90 | self.env_info_key = env_info_key 91 | self.obs_key = obs_key 92 | self.recompute_reward = recompute_reward 93 | self.load_terminals = load_terminals 94 | 95 | self.trainer.replay_buffer = self.replay_buffer 96 | self.trainer.demo_train_buffer = self.demo_train_buffer 97 | self.trainer.demo_test_buffer = self.demo_test_buffer 98 | 99 | def load_demos(self, dataset): 100 | # Off policy 101 | load_hdf5(dataset, self.replay_buffer) 102 | 103 | def get_batch_from_buffer(self, replay_buffer): 104 | batch = replay_buffer.random_batch(self.bc_batch_size) 105 | batch = np_to_pytorch_batch(batch) 106 | return batch 107 | -------------------------------------------------------------------------------- /rlkit/demos/source/path_loader.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | 3 | import numpy as np 4 | import torch 5 | import torch.optim as optim 6 | from torch import nn as nn 7 | import torch.nn.functional as F 8 | 9 | import rlkit.torch.pytorch_util as ptu 10 | from rlkit.core.eval_util import create_stats_ordered_dict 11 | from rlkit.torch.torch_rl_algorithm import TorchTrainer 12 | 13 | from rlkit.util.io import load_local_or_remote_file 14 | 15 | import random 16 | from rlkit.torch.core import np_to_pytorch_batch 17 | from rlkit.data_management.path_builder import PathBuilder 18 | 19 | # import matplotlib 20 | # matplotlib.use('TkAgg') 21 | # import matplotlib.pyplot as plt 22 | 23 | from rlkit.core import logger 24 | 25 | import glob 26 | 27 | class PathLoader: 28 | """ 29 | Loads demonstrations and/or off-policy data into a Trainer 30 | """ 31 | 32 | def load_demos(self, ): 33 | pass 34 | -------------------------------------------------------------------------------- /rlkit/demos/spacemouse/README.md: -------------------------------------------------------------------------------- 1 | # Spacemouse Demonstrations 2 | 3 | This code allows an agent (including physical robots) to be controlled by a 7 degree-of-freedom 3Dconnexion Spacemouse device. 4 | 5 | ## Setup Instructions 6 | 7 | ### Spacemouse (on a Mac) 8 | 9 | 1. Clone robosuite ([https://github.com/anair13/robosuite](https://github.com/anair13/robosuite)) and add it to the python path 10 | 2. Run `pip install hidapi` and install Spacemouse drivers 11 | 2. Ensure you can run the following file (and 12 | see input values from the spacemouse): `robosuite/devices/spacemouse.py` 13 | 4. Follow the example in `railrl/demos/collect_demo.py` to collect demonstrations 14 | 15 | ### Server 16 | We haven't been able to install Spacemouse drivers for Linux but instead we use a Spacemouse on a Mac ("client") and send messages over a network to a Linux machine ("server"). 17 | 18 | #### Setup 19 | On the client, run the setup above. On the server, run: 20 | 1. Run `pip install Pyro4` 21 | 2. Make sure the hostname in `railrl/demos/spacemouse/config.py` is correct (in the example I use gauss1.banatao.berkeley.edu). This hostname needs to be visible (eg. you can ping it) from both the client and server 22 | 23 | #### Run 24 | 25 | 1. On the server, start the nameserver: 26 | ```export PYRO_SERIALIZERS_ACCEPTED=serpent,json,marshal,pickle 27 | python -m Pyro4.naming -n euler1.dyn.berkeley.edu 28 | ``` 29 | 2. On the server, run a script that uses the `SpaceMouseExpert` imported from `railrl/demos/spacemouse/input_server.py` such as ```python experiments/ashvin/iros2019/collect_demos_spacemouse.py``` 30 | 2. On the client, run ```python railrl/demos/spacemouse/input_client.py``` 31 | -------------------------------------------------------------------------------- /rlkit/demos/spacemouse/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rail-berkeley/rlkit/ac45a9db24b89d97369bef302487273bcc3e3d84/rlkit/demos/spacemouse/__init__.py -------------------------------------------------------------------------------- /rlkit/demos/spacemouse/input_client.py: -------------------------------------------------------------------------------- 1 | """ 2 | Should be run on a machine connected to a spacemouse 3 | """ 4 | 5 | from robosuite.devices import SpaceMouse 6 | import time 7 | import Pyro4 8 | from rlkit.launchers import config 9 | 10 | Pyro4.config.SERIALIZERS_ACCEPTED = set(['pickle','json', 'marshal', 'serpent']) 11 | Pyro4.config.SERIALIZER='pickle' 12 | 13 | nameserver = Pyro4.locateNS(host=config.SPACEMOUSE_HOSTNAME) 14 | uri = nameserver.lookup("example.greeting") 15 | device_state = Pyro4.Proxy(uri) 16 | device = SpaceMouse() 17 | while True: 18 | state = device.get_controller_state() 19 | print(state) 20 | time.sleep(0.1) 21 | device_state.set_state(state) 22 | -------------------------------------------------------------------------------- /rlkit/envs/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rail-berkeley/rlkit/ac45a9db24b89d97369bef302487273bcc3e3d84/rlkit/envs/__init__.py -------------------------------------------------------------------------------- /rlkit/envs/ant.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from rlkit.envs.mujoco_env import MujocoEnv 4 | 5 | 6 | class AntEnv(MujocoEnv): 7 | def __init__(self, use_low_gear_ratio=True): 8 | self.init_serialization(locals()) 9 | if use_low_gear_ratio: 10 | xml_path = 'low_gear_ratio_ant.xml' 11 | else: 12 | xml_path = 'normal_gear_ratio_ant.xml' 13 | super().__init__( 14 | xml_path, 15 | frame_skip=5, 16 | automatically_set_obs_and_action_space=True, 17 | ) 18 | 19 | def step(self, a): 20 | torso_xyz_before = self.get_body_com("torso") 21 | self.do_simulation(a, self.frame_skip) 22 | torso_xyz_after = self.get_body_com("torso") 23 | torso_velocity = torso_xyz_after - torso_xyz_before 24 | forward_reward = torso_velocity[0]/self.dt 25 | ctrl_cost = .5 * np.square(a).sum() 26 | contact_cost = 0.5 * 1e-3 * np.sum( 27 | np.square(np.clip(self.sim.data.cfrc_ext, -1, 1))) 28 | survive_reward = 1.0 29 | reward = forward_reward - ctrl_cost - contact_cost + survive_reward 30 | state = self.state_vector() 31 | notdone = np.isfinite(state).all() \ 32 | and state[2] >= 0.2 and state[2] <= 1.0 33 | done = not notdone 34 | ob = self._get_obs() 35 | return ob, reward, done, dict( 36 | reward_forward=forward_reward, 37 | reward_ctrl=-ctrl_cost, 38 | reward_contact=-contact_cost, 39 | reward_survive=survive_reward, 40 | torso_velocity=torso_velocity, 41 | ) 42 | 43 | def _get_obs(self): 44 | return np.concatenate([ 45 | self.sim.data.qpos.flat[2:], 46 | self.sim.data.qvel.flat, 47 | ]) 48 | 49 | def reset_model(self): 50 | qpos = self.init_qpos + self.np_random.uniform(size=self.model.nq, low=-.1, high=.1) 51 | qvel = self.init_qvel + self.np_random.randn(self.model.nv) * .1 52 | self.set_state(qpos, qvel) 53 | return self._get_obs() 54 | 55 | def viewer_setup(self): 56 | self.viewer.cam.distance = self.model.stat.extent * 0.5 57 | -------------------------------------------------------------------------------- /rlkit/envs/env_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from gym.spaces import Box, Discrete, Tuple 4 | 5 | ENV_ASSET_DIR = os.path.join(os.path.dirname(__file__), 'assets') 6 | 7 | 8 | def get_asset_full_path(file_name): 9 | return os.path.join(ENV_ASSET_DIR, file_name) 10 | 11 | 12 | def get_dim(space): 13 | if isinstance(space, Box): 14 | return space.low.size 15 | elif isinstance(space, Discrete): 16 | return space.n 17 | elif isinstance(space, Tuple): 18 | return sum(get_dim(subspace) for subspace in space.spaces) 19 | elif hasattr(space, 'flat_dim'): 20 | return space.flat_dim 21 | else: 22 | raise TypeError("Unknown space: {}".format(space)) 23 | 24 | 25 | def mode(env, mode_type): 26 | try: 27 | getattr(env, mode_type)() 28 | except AttributeError: 29 | pass 30 | -------------------------------------------------------------------------------- /rlkit/envs/make_env.py: -------------------------------------------------------------------------------- 1 | """ 2 | This file provides a more uniform interface to gym.make(env_id) that handles 3 | imports and normalization 4 | """ 5 | 6 | import gym 7 | 8 | from rlkit.envs.wrappers import NormalizedBoxEnv 9 | 10 | DAPG_ENVS = [ 11 | 'pen-v0', 'pen-sparse-v0', 'pen-notermination-v0', 'pen-binary-v0', 'pen-binary-old-v0', 12 | 'door-v0', 'door-sparse-v0', 'door-binary-v0', 'door-binary-old-v0', 13 | 'relocate-v0', 'relocate-sparse-v0', 'relocate-binary-v0', 'relocate-binary-old-v0', 14 | 'hammer-v0', 'hammer-sparse-v0', 'hammer-binary-v0', 15 | ] 16 | 17 | D4RL_ENVS = [ 18 | "maze2d-open-v0", "maze2d-umaze-v0", "maze2d-medium-v0", "maze2d-large-v0", 19 | "maze2d-open-dense-v0", "maze2d-umaze-dense-v0", "maze2d-medium-dense-v0", "maze2d-large-dense-v0", 20 | "antmaze-umaze-v0", "antmaze-umaze-diverse-v0", "antmaze-medium-diverse-v0", 21 | "antmaze-medium-play-v0", "antmaze-large-diverse-v0", "antmaze-large-play-v0", 22 | "pen-human-v0", "pen-cloned-v0", "pen-expert-v0", "hammer-human-v0", "hammer-cloned-v0", "hammer-expert-v0", 23 | "door-human-v0", "door-cloned-v0", "door-expert-v0", "relocate-human-v0", "relocate-cloned-v0", "relocate-expert-v0", 24 | "halfcheetah-random-v0", "halfcheetah-medium-v0", "halfcheetah-expert-v0", "halfcheetah-mixed-v0", "halfcheetah-medium-expert-v0", 25 | "walker2d-random-v0", "walker2d-medium-v0", "walker2d-expert-v0", "walker2d-mixed-v0", "walker2d-medium-expert-v0", 26 | "hopper-random-v0", "hopper-medium-v0", "hopper-expert-v0", "hopper-mixed-v0", "hopper-medium-expert-v0" 27 | ] 28 | 29 | def make(env_id=None, env_class=None, env_kwargs=None, normalize_env=True): 30 | assert env_id or env_class 31 | if env_class: 32 | env = env_class(**env_kwargs) 33 | elif env_id in DAPG_ENVS: 34 | import mj_envs 35 | assert normalize_env == False 36 | env = gym.make(env_id) 37 | elif env_id in D4RL_ENVS: 38 | import d4rl 39 | assert normalize_env == False 40 | env = gym.make(env_id) 41 | elif env_id: 42 | env = gym.make(env_id) 43 | env = env.env # unwrap TimeLimit 44 | 45 | if normalize_env: 46 | env = NormalizedBoxEnv(env) 47 | 48 | return env 49 | -------------------------------------------------------------------------------- /rlkit/envs/mujoco_env.py: -------------------------------------------------------------------------------- 1 | import os 2 | from os import path 3 | 4 | import mujoco_py 5 | import numpy as np 6 | from gym.envs.mujoco import mujoco_env 7 | 8 | from rlkit.core.serializable import Serializable 9 | 10 | ENV_ASSET_DIR = os.path.join(os.path.dirname(__file__), 'assets') 11 | 12 | 13 | class MujocoEnv(mujoco_env.MujocoEnv, Serializable): 14 | """ 15 | My own wrapper around MujocoEnv. 16 | 17 | The caller needs to declare 18 | """ 19 | def __init__( 20 | self, 21 | model_path, 22 | frame_skip=1, 23 | model_path_is_local=True, 24 | automatically_set_obs_and_action_space=False, 25 | ): 26 | if model_path_is_local: 27 | model_path = get_asset_xml(model_path) 28 | if automatically_set_obs_and_action_space: 29 | mujoco_env.MujocoEnv.__init__(self, model_path, frame_skip) 30 | else: 31 | """ 32 | Code below is copy/pasted from MujocoEnv's __init__ function. 33 | """ 34 | if model_path.startswith("/"): 35 | fullpath = model_path 36 | else: 37 | fullpath = os.path.join(os.path.dirname(__file__), "assets", model_path) 38 | if not path.exists(fullpath): 39 | raise IOError("File %s does not exist" % fullpath) 40 | self.frame_skip = frame_skip 41 | self.model = mujoco_py.MjModel(fullpath) 42 | self.data = self.model.data 43 | self.viewer = None 44 | 45 | self.metadata = { 46 | 'render.modes': ['human', 'rgb_array'], 47 | 'video.frames_per_second': int(np.round(1.0 / self.dt)) 48 | } 49 | 50 | self.init_qpos = self.model.data.qpos.ravel().copy() 51 | self.init_qvel = self.model.data.qvel.ravel().copy() 52 | self._seed() 53 | 54 | def init_serialization(self, locals): 55 | Serializable.quick_init(self, locals) 56 | 57 | def log_diagnostics(self, paths): 58 | pass 59 | 60 | 61 | def get_asset_xml(xml_name): 62 | return os.path.join(ENV_ASSET_DIR, xml_name) 63 | -------------------------------------------------------------------------------- /rlkit/envs/pearl_envs/__init__.py: -------------------------------------------------------------------------------- 1 | from rlkit.envs.pearl_envs.ant_normal import AntNormal 2 | from rlkit.envs.pearl_envs.ant_dir import AntDirEnv 3 | from rlkit.envs.pearl_envs.ant_goal import AntGoalEnv 4 | from rlkit.envs.pearl_envs.half_cheetah_dir import HalfCheetahDirEnv 5 | from rlkit.envs.pearl_envs.half_cheetah_vel import HalfCheetahVelEnv 6 | from rlkit.envs.pearl_envs.hopper_rand_params_wrapper import \ 7 | HopperRandParamsWrappedEnv 8 | from rlkit.envs.pearl_envs.humanoid_dir import HumanoidDirEnv 9 | from rlkit.envs.pearl_envs.point_robot import PointEnv, SparsePointEnv 10 | from rlkit.envs.pearl_envs.rand_param_envs.walker2d_rand_params import \ 11 | Walker2DRandParamsEnv 12 | from rlkit.envs.pearl_envs.walker_rand_params_wrapper import \ 13 | WalkerRandParamsWrappedEnv 14 | 15 | ENVS = {} 16 | 17 | 18 | def register_env(name): 19 | """Registers a env by name for instantiation in rlkit.""" 20 | 21 | def register_env_fn(fn): 22 | if name in ENVS: 23 | raise ValueError("Cannot register duplicate env {}".format(name)) 24 | if not callable(fn): 25 | raise TypeError("env {} must be callable".format(name)) 26 | ENVS[name] = fn 27 | return fn 28 | 29 | return register_env_fn 30 | 31 | 32 | def _register_env(name, fn): 33 | """Registers a env by name for instantiation in rlkit.""" 34 | if name in ENVS: 35 | raise ValueError("Cannot register duplicate env {}".format(name)) 36 | if not callable(fn): 37 | raise TypeError("env {} must be callable".format(name)) 38 | ENVS[name] = fn 39 | 40 | 41 | def register_pearl_envs(): 42 | _register_env('sparse-point-robot', SparsePointEnv) 43 | _register_env('ant-normal', AntNormal) 44 | _register_env('ant-dir', AntDirEnv) 45 | _register_env('ant-goal', AntGoalEnv) 46 | _register_env('cheetah-dir', HalfCheetahDirEnv) 47 | _register_env('cheetah-vel', HalfCheetahVelEnv) 48 | _register_env('humanoid-dir', HumanoidDirEnv) 49 | _register_env('point-robot', PointEnv) 50 | _register_env('walker-rand-params', WalkerRandParamsWrappedEnv) 51 | _register_env('hopper-rand-params', HopperRandParamsWrappedEnv) 52 | 53 | # automatically import any envs in the envs/ directory 54 | # for file in os.listdir(os.path.dirname(__file__)): 55 | # if file.endswith('.py') and not file.startswith('_'): 56 | # module = file[:file.find('.py')] 57 | # importlib.import_module('rlkit.envs.pearl_envs.' + module) 58 | -------------------------------------------------------------------------------- /rlkit/envs/pearl_envs/ant.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from .mujoco_env import MujocoEnv 4 | 5 | 6 | class AntEnv(MujocoEnv): 7 | def __init__(self, use_low_gear_ratio=False): 8 | # self.init_serialization(locals()) 9 | if use_low_gear_ratio: 10 | xml_path = 'low_gear_ratio_ant.xml' 11 | else: 12 | xml_path = 'ant.xml' 13 | super().__init__( 14 | xml_path, 15 | frame_skip=5, 16 | automatically_set_obs_and_action_space=True, 17 | ) 18 | 19 | def step(self, a): 20 | torso_xyz_before = self.get_body_com("torso") 21 | self.do_simulation(a, self.frame_skip) 22 | torso_xyz_after = self.get_body_com("torso") 23 | torso_velocity = torso_xyz_after - torso_xyz_before 24 | forward_reward = torso_velocity[0]/self.dt 25 | ctrl_cost = 0. # .5 * np.square(a).sum() 26 | contact_cost = 0.5 * 1e-3 * np.sum( 27 | np.square(np.clip(self.sim.data.cfrc_ext, -1, 1))) 28 | survive_reward = 0. # 1.0 29 | reward = forward_reward - ctrl_cost - contact_cost + survive_reward 30 | state = self.state_vector() 31 | notdone = np.isfinite(state).all() \ 32 | and state[2] >= 0.2 and state[2] <= 1.0 33 | done = not notdone 34 | ob = self._get_obs() 35 | return ob, reward, done, dict( 36 | reward_forward=forward_reward, 37 | reward_ctrl=-ctrl_cost, 38 | reward_contact=-contact_cost, 39 | reward_survive=survive_reward, 40 | torso_velocity=torso_velocity, 41 | ) 42 | 43 | def _get_obs(self): 44 | # this is gym ant obs, should use rllab? 45 | # if position is needed, override this in subclasses 46 | return np.concatenate([ 47 | self.sim.data.qpos.flat[2:], 48 | self.sim.data.qvel.flat, 49 | ]) 50 | 51 | def reset_model(self): 52 | qpos = self.init_qpos + self.np_random.uniform(size=self.model.nq, low=-.1, high=.1) 53 | qvel = self.init_qvel + self.np_random.randn(self.model.nv) * .1 54 | self.set_state(qpos, qvel) 55 | return self._get_obs() 56 | 57 | def viewer_setup(self): 58 | try: 59 | from multiworld.envs.mujoco.cameras import create_camera_init 60 | self.camera_init = create_camera_init( 61 | lookat=(0, 0, 0), 62 | distance=10, 63 | elevation=-45, 64 | azimuth=90, 65 | trackbodyid=self.sim.model.body_name2id('torso'), 66 | ) 67 | self.camera_init(self.viewer.cam) 68 | except ImportError as e: 69 | pass 70 | # -------------------------------------------------------------------------------- /rlkit/envs/pearl_envs/ant_dir.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from rlkit.envs.pearl_envs.ant_multitask_base import MultitaskAntEnv 4 | 5 | 6 | class AntDirEnv(MultitaskAntEnv): 7 | 8 | def __init__( 9 | self, 10 | task=None, 11 | n_tasks=2, 12 | fixed_tasks=None, 13 | forward_backward=False, 14 | direction_in_degrees=False, 15 | **kwargs 16 | ): 17 | if task is None: 18 | task = {} 19 | self.fixed_tasks = fixed_tasks 20 | self.direction_in_degrees = direction_in_degrees 21 | self.quick_init(locals()) 22 | self.forward_backward = forward_backward 23 | super(AntDirEnv, self).__init__(task, n_tasks, **kwargs) 24 | 25 | def step(self, action): 26 | torso_xyz_before = np.array(self.get_body_com("torso")) 27 | 28 | if self.direction_in_degrees: 29 | goal = self._goal / 180 * np.pi 30 | else: 31 | goal = self._goal 32 | direct = (np.cos(goal), np.sin(goal)) 33 | 34 | self.do_simulation(action, self.frame_skip) 35 | torso_xyz_after = np.array(self.get_body_com("torso")) 36 | torso_velocity = torso_xyz_after - torso_xyz_before 37 | forward_reward = np.dot((torso_velocity[:2]/self.dt), direct) 38 | 39 | ctrl_cost = .5 * np.square(action).sum() 40 | contact_cost = 0.5 * 1e-3 * np.sum( 41 | np.square(np.clip(self.sim.data.cfrc_ext, -1, 1))) 42 | survive_reward = 1.0 43 | reward = forward_reward - ctrl_cost - contact_cost + survive_reward 44 | state = self.state_vector() 45 | notdone = np.isfinite(state).all() \ 46 | and state[2] >= 0.2 and state[2] <= 1.0 47 | done = not notdone 48 | ob = self._get_obs() 49 | return ob, reward, done, dict( 50 | reward_forward=forward_reward, 51 | reward_ctrl=-ctrl_cost, 52 | reward_contact=-contact_cost, 53 | reward_survive=survive_reward, 54 | torso_velocity=torso_velocity, 55 | torso_xy=self.sim.data.qpos.flat[:2].copy(), 56 | ) 57 | 58 | def sample_tasks(self, num_tasks): 59 | if self.forward_backward: 60 | assert num_tasks == 2 61 | if self.direction_in_degrees: 62 | directions = np.array([0., 180]) 63 | else: 64 | directions = np.array([0., np.pi]) 65 | elif self.fixed_tasks: 66 | directions = np.array(self.fixed_tasks) 67 | else: 68 | if self.direction_in_degrees: 69 | directions = np.random.uniform(0., 360, size=(num_tasks,)) 70 | else: 71 | directions = np.random.uniform(0., 2.0 * np.pi, size=(num_tasks,)) 72 | tasks = [{'goal': desired_dir} for desired_dir in directions] 73 | return tasks 74 | 75 | def task_to_vec(self, task): 76 | direction = task['goal'] 77 | if self.direction_in_degrees: 78 | normalized_direction = direction / 360 79 | else: 80 | normalized_direction = direction / (2*np.pi) 81 | return np.array([normalized_direction]) 82 | -------------------------------------------------------------------------------- /rlkit/envs/pearl_envs/ant_goal.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from rlkit.envs.pearl_envs.ant_multitask_base import MultitaskAntEnv 4 | 5 | 6 | # Copy task structure from https://github.com/jonasrothfuss/ProMP/blob/master/meta_policy_search/envs/mujoco_envs/ant_rand_goal.py 7 | class AntGoalEnv(MultitaskAntEnv): 8 | def __init__(self, task={}, n_tasks=2, randomize_tasks=True, **kwargs): 9 | self.quick_init(locals()) 10 | super(AntGoalEnv, self).__init__(task, n_tasks, **kwargs) 11 | 12 | def step(self, action): 13 | self.do_simulation(action, self.frame_skip) 14 | xposafter = np.array(self.get_body_com("torso")) 15 | 16 | goal_reward = -np.sum(np.abs(xposafter[:2] - self._goal)) # make it happy, not suicidal 17 | 18 | ctrl_cost = .1 * np.square(action).sum() 19 | contact_cost = 0.5 * 1e-3 * np.sum( 20 | np.square(np.clip(self.sim.data.cfrc_ext, -1, 1))) 21 | survive_reward = 0.0 22 | reward = goal_reward - ctrl_cost - contact_cost + survive_reward 23 | state = self.state_vector() 24 | done = False 25 | ob = self._get_obs() 26 | return ob, reward, done, dict( 27 | goal_forward=goal_reward, 28 | reward_ctrl=-ctrl_cost, 29 | reward_contact=-contact_cost, 30 | reward_survive=survive_reward, 31 | ) 32 | 33 | def sample_tasks(self, num_tasks): 34 | a = np.random.random(num_tasks) * 2 * np.pi 35 | r = 3 * np.random.random(num_tasks) ** 0.5 36 | goals = np.stack((r * np.cos(a), r * np.sin(a)), axis=-1) 37 | tasks = [{'goal': goal} for goal in goals] 38 | return tasks 39 | 40 | def _get_obs(self): 41 | return np.concatenate([ 42 | self.sim.data.qpos.flat, 43 | self.sim.data.qvel.flat, 44 | np.clip(self.sim.data.cfrc_ext, -1, 1).flat, 45 | ]) 46 | -------------------------------------------------------------------------------- /rlkit/envs/pearl_envs/ant_multitask_base.py: -------------------------------------------------------------------------------- 1 | from rlkit.envs.pearl_envs.ant import AntEnv 2 | 3 | 4 | class MultitaskAntEnv(AntEnv): 5 | def __init__(self, task=None, n_tasks=2, 6 | randomize_tasks=True, 7 | **kwargs): 8 | if task is None: 9 | task = {} 10 | self._task = task 11 | self.tasks = self.sample_tasks(n_tasks) 12 | self._goal = self.tasks[0]['goal'] 13 | super(MultitaskAntEnv, self).__init__(**kwargs) 14 | 15 | """ 16 | def step(self, action): 17 | xposbefore = self.sim.data.qpos[0] 18 | self.do_simulation(action, self.frame_skip) 19 | xposafter = self.sim.data.qpos[0] 20 | 21 | forward_vel = (xposafter - xposbefore) / self.dt 22 | forward_reward = -1.0 * abs(forward_vel - self._goal_vel) 23 | ctrl_cost = 0.5 * 1e-1 * np.sum(np.square(action)) 24 | 25 | observation = self._get_obs() 26 | reward = forward_reward - ctrl_cost 27 | done = False 28 | infos = dict(reward_forward=forward_reward, 29 | reward_ctrl=-ctrl_cost, task=self._task) 30 | return (observation, reward, done, infos) 31 | """ 32 | 33 | 34 | def get_all_task_idx(self): 35 | return range(len(self.tasks)) 36 | 37 | def reset_task(self, idx): 38 | try: 39 | self._task = self.tasks[idx] 40 | except IndexError as e: 41 | import ipdb; ipdb.set_trace() 42 | self._goal = self._task['goal'] # assume parameterization of task by single vector 43 | self.reset() 44 | -------------------------------------------------------------------------------- /rlkit/envs/pearl_envs/ant_normal.py: -------------------------------------------------------------------------------- 1 | from gym.envs.mujoco import AntEnv 2 | 3 | 4 | class AntNormal(AntEnv): 5 | def __init__( 6 | self, 7 | *args, 8 | n_tasks=2, # number of distinct tasks in this domain, shoudl equal sum of train and eval tasks 9 | randomize_tasks=True, # shuffle the tasks after creating them 10 | **kwargs 11 | ): 12 | self.tasks = [0 for _ in range(n_tasks)] 13 | self._goal = 0 14 | super().__init__(*args, **kwargs) 15 | 16 | def get_all_task_idx(self): 17 | return self.tasks 18 | 19 | def reset_task(self, idx): 20 | # not tasks. just give the same reward every time step. 21 | pass 22 | 23 | def sample_tasks(self, num_tasks): 24 | return [0 for _ in range(num_tasks)] 25 | -------------------------------------------------------------------------------- /rlkit/envs/pearl_envs/half_cheetah.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from gym.envs.mujoco import HalfCheetahEnv as HalfCheetahEnv_ 3 | 4 | class HalfCheetahEnv(HalfCheetahEnv_): 5 | def _get_obs(self): 6 | return np.concatenate([ 7 | self.sim.data.qpos.flat[1:], 8 | self.sim.data.qvel.flat, 9 | self.get_body_com("torso").flat, 10 | ]).astype(np.float32).flatten() 11 | 12 | def viewer_setup(self): 13 | camera_id = self.model.camera_name2id('track') 14 | self.viewer.cam.type = 2 15 | self.viewer.cam.fixedcamid = camera_id 16 | self.viewer.cam.distance = self.model.stat.extent * 0.35 17 | # Hide the overlay 18 | self.viewer._hide_overlay = True 19 | 20 | def render(self, mode='human', width=500, height=500, **kwargs): 21 | if mode == 'rgb_array': 22 | self._get_viewer(mode).render(width=width, height=height) 23 | data = self._get_viewer(mode).read_pixels(width, height, depth=False)[::-1, :, :] 24 | return data 25 | elif mode == 'human': 26 | self._get_viewer(mode).render() 27 | -------------------------------------------------------------------------------- /rlkit/envs/pearl_envs/half_cheetah_dir.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from .half_cheetah import HalfCheetahEnv 4 | 5 | 6 | class HalfCheetahDirEnv(HalfCheetahEnv): 7 | """Half-cheetah environment with target direction, as described in [1]. The 8 | code is adapted from 9 | https://github.com/cbfinn/maml_rl/blob/9c8e2ebd741cb0c7b8bf2d040c4caeeb8e06cc95/rllab/envs/mujoco/half_cheetah_env_rand_direc.py 10 | 11 | The half-cheetah follows the dynamics from MuJoCo [2], and receives at each 12 | time step a reward composed of a control cost and a reward equal to its 13 | velocity in the target direction. The tasks are generated by sampling the 14 | target directions from a Bernoulli distribution on {-1, 1} with parameter 15 | 0.5 (-1: backward, +1: forward). 16 | 17 | [1] Chelsea Finn, Pieter Abbeel, Sergey Levine, "Model-Agnostic 18 | Meta-Learning for Fast Adaptation of Deep Networks", 2017 19 | (https://arxiv.org/abs/1703.03400) 20 | [2] Emanuel Todorov, Tom Erez, Yuval Tassa, "MuJoCo: A physics engine for 21 | model-based control", 2012 22 | (https://homes.cs.washington.edu/~todorov/papers/TodorovIROS12.pdf) 23 | """ 24 | def __init__(self, task={}, n_tasks=2, randomize_tasks=False): 25 | directions = [-1, 1] 26 | self.tasks = [{'direction': direction} for direction in directions] 27 | self._task = task 28 | self._goal_dir = task.get('direction', 1) 29 | self._goal = self._goal_dir 30 | super(HalfCheetahDirEnv, self).__init__() 31 | 32 | def step(self, action): 33 | xposbefore = self.sim.data.qpos[0] 34 | self.do_simulation(action, self.frame_skip) 35 | xposafter = self.sim.data.qpos[0] 36 | 37 | forward_vel = (xposafter - xposbefore) / self.dt 38 | forward_reward = self._goal_dir * forward_vel 39 | ctrl_cost = 0.5 * 1e-1 * np.sum(np.square(action)) 40 | 41 | observation = self._get_obs() 42 | reward = forward_reward - ctrl_cost 43 | done = False 44 | infos = dict(reward_forward=forward_reward, 45 | reward_ctrl=-ctrl_cost, task=self._task) 46 | return (observation, reward, done, infos) 47 | 48 | def sample_tasks(self, num_tasks): 49 | directions = 2 * self.np_random.binomial(1, p=0.5, size=(num_tasks,)) - 1 50 | tasks = [{'direction': direction} for direction in directions] 51 | return tasks 52 | 53 | def get_all_task_idx(self): 54 | return list(range(len(self.tasks))) 55 | 56 | def reset_task(self, idx): 57 | self._task = self.tasks[idx] 58 | self._goal_dir = self._task['direction'] 59 | self._goal = self._goal_dir 60 | self.reset() 61 | -------------------------------------------------------------------------------- /rlkit/envs/pearl_envs/half_cheetah_vel.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from .half_cheetah import HalfCheetahEnv 4 | 5 | 6 | class HalfCheetahVelEnv(HalfCheetahEnv): 7 | """Half-cheetah environment with target velocity, as described in [1]. The 8 | code is adapted from 9 | https://github.com/cbfinn/maml_rl/blob/9c8e2ebd741cb0c7b8bf2d040c4caeeb8e06cc95/rllab/envs/mujoco/half_cheetah_env_rand.py 10 | 11 | The half-cheetah follows the dynamics from MuJoCo [2], and receives at each 12 | time step a reward composed of a control cost and a penalty equal to the 13 | difference between its current velocity and the target velocity. The tasks 14 | are generated by sampling the target velocities from the uniform 15 | distribution on [0, 2]. 16 | 17 | [1] Chelsea Finn, Pieter Abbeel, Sergey Levine, "Model-Agnostic 18 | Meta-Learning for Fast Adaptation of Deep Networks", 2017 19 | (https://arxiv.org/abs/1703.03400) 20 | [2] Emanuel Todorov, Tom Erez, Yuval Tassa, "MuJoCo: A physics engine for 21 | model-based control", 2012 22 | (https://homes.cs.washington.edu/~todorov/papers/TodorovIROS12.pdf) 23 | """ 24 | def __init__(self, task={}, presampled_tasks=None, n_tasks=2, randomize_tasks=True): 25 | self._task = task 26 | self.tasks = presampled_tasks or self.sample_tasks(n_tasks) 27 | self._goal_vel = self.tasks[0].get('velocity', 0.0) 28 | self._goal = self._goal_vel 29 | super(HalfCheetahVelEnv, self).__init__() 30 | 31 | def step(self, action): 32 | xposbefore = self.sim.data.qpos[0] 33 | self.do_simulation(action, self.frame_skip) 34 | xposafter = self.sim.data.qpos[0] 35 | 36 | forward_vel = (xposafter - xposbefore) / self.dt 37 | forward_reward = -1.0 * abs(forward_vel - self._goal_vel) 38 | ctrl_cost = 0.5 * 1e-1 * np.sum(np.square(action)) 39 | 40 | observation = self._get_obs() 41 | reward = forward_reward - ctrl_cost 42 | done = False 43 | infos = dict( 44 | reward_forward=forward_reward, 45 | reward_ctrl=-ctrl_cost, 46 | goal_vel=self._goal_vel, 47 | forward_vel=forward_vel, 48 | xposbefore=xposbefore, 49 | ) 50 | return (observation, reward, done, infos) 51 | 52 | def sample_tasks(self, num_tasks): 53 | np.random.seed(1337) 54 | velocities = np.random.uniform(0.0, 3.0, size=(num_tasks,)) 55 | tasks = [{'velocity': velocity} for velocity in velocities] 56 | return tasks 57 | 58 | def get_all_task_idx(self): 59 | return range(len(self.tasks)) 60 | 61 | def reset_task(self, idx): 62 | self._task = self.tasks[idx] 63 | self._goal_vel = self._task['velocity'] 64 | self._goal = self._goal_vel 65 | self.reset() 66 | -------------------------------------------------------------------------------- /rlkit/envs/pearl_envs/hopper_rand_params_wrapper.py: -------------------------------------------------------------------------------- 1 | from rlkit.envs.pearl_envs.rand_param_envs.hopper_rand_params import HopperRandParamsEnv 2 | 3 | 4 | class HopperRandParamsWrappedEnv(HopperRandParamsEnv): 5 | def __init__(self, n_tasks=2, randomize_tasks=True): 6 | super(HopperRandParamsWrappedEnv, self).__init__() 7 | self.tasks = self.sample_tasks(n_tasks) 8 | self.reset_task(0) 9 | 10 | def get_all_task_idx(self): 11 | return range(len(self.tasks)) 12 | 13 | def reset_task(self, idx): 14 | self._task = self.tasks[idx] 15 | self._goal = idx 16 | self.set_task(self._task) 17 | self.reset() 18 | -------------------------------------------------------------------------------- /rlkit/envs/pearl_envs/humanoid_dir.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from gym.envs.mujoco import HumanoidEnv as HumanoidEnv 3 | 4 | 5 | def mass_center(model, sim): 6 | mass = np.expand_dims(model.body_mass, 1) 7 | xpos = sim.data.xipos 8 | return (np.sum(mass * xpos, 0) / np.sum(mass)) 9 | 10 | 11 | class HumanoidDirEnv(HumanoidEnv): 12 | 13 | def __init__(self, task={}, n_tasks=2, randomize_tasks=True): 14 | self.tasks = self.sample_tasks(n_tasks) 15 | self.reset_task(0) 16 | super(HumanoidDirEnv, self).__init__() 17 | 18 | def step(self, action): 19 | pos_before = np.copy(mass_center(self.model, self.sim)[:2]) 20 | self.do_simulation(action, self.frame_skip) 21 | pos_after = mass_center(self.model, self.sim)[:2] 22 | 23 | alive_bonus = 5.0 24 | data = self.sim.data 25 | goal_direction = (np.cos(self._goal), np.sin(self._goal)) 26 | lin_vel_cost = 0.25 * np.sum(goal_direction * (pos_after - pos_before)) / self.model.opt.timestep 27 | quad_ctrl_cost = 0.1 * np.square(data.ctrl).sum() 28 | quad_impact_cost = .5e-6 * np.square(data.cfrc_ext).sum() 29 | quad_impact_cost = min(quad_impact_cost, 10) 30 | reward = lin_vel_cost - quad_ctrl_cost - quad_impact_cost + alive_bonus 31 | qpos = self.sim.data.qpos 32 | done = bool((qpos[2] < 1.0) or (qpos[2] > 2.0)) 33 | 34 | return self._get_obs(), reward, done, dict(reward_linvel=lin_vel_cost, 35 | reward_quadctrl=-quad_ctrl_cost, 36 | reward_alive=alive_bonus, 37 | reward_impact=-quad_impact_cost) 38 | 39 | def _get_obs(self): 40 | data = self.sim.data 41 | return np.concatenate([data.qpos.flat[2:], 42 | data.qvel.flat, 43 | data.cinert.flat, 44 | data.cvel.flat, 45 | data.qfrc_actuator.flat, 46 | data.cfrc_ext.flat]) 47 | 48 | def get_all_task_idx(self): 49 | return range(len(self.tasks)) 50 | 51 | def reset_task(self, idx): 52 | self._task = self.tasks[idx] 53 | self._goal = self._task['goal'] # assume parameterization of task by single vector 54 | 55 | def sample_tasks(self, num_tasks): 56 | # velocities = np.random.uniform(0., 1.0 * np.pi, size=(num_tasks,)) 57 | directions = np.random.uniform(0., 2.0 * np.pi, size=(num_tasks,)) 58 | tasks = [{'goal': d} for d in directions] 59 | return tasks 60 | -------------------------------------------------------------------------------- /rlkit/envs/pearl_envs/mujoco_env.py: -------------------------------------------------------------------------------- 1 | import os 2 | from os import path 3 | 4 | import mujoco_py 5 | import numpy as np 6 | from gym.envs.mujoco import mujoco_env 7 | 8 | from rlkit.core.serializable import Serializable 9 | 10 | ENV_ASSET_DIR = os.path.join(os.path.dirname(__file__), 'assets') 11 | 12 | 13 | class MujocoEnv(mujoco_env.MujocoEnv, Serializable): 14 | """ 15 | My own wrapper around MujocoEnv. 16 | 17 | The caller needs to declare 18 | """ 19 | def __init__( 20 | self, 21 | model_path, 22 | frame_skip=1, 23 | model_path_is_local=True, 24 | automatically_set_obs_and_action_space=False, 25 | ): 26 | if model_path_is_local: 27 | model_path = get_asset_xml(model_path) 28 | if automatically_set_obs_and_action_space: 29 | mujoco_env.MujocoEnv.__init__(self, model_path, frame_skip) 30 | else: 31 | """ 32 | Code below is copy/pasted from MujocoEnv's __init__ function. 33 | """ 34 | if model_path.startswith("/"): 35 | fullpath = model_path 36 | else: 37 | fullpath = os.path.join(os.path.dirname(__file__), "assets", model_path) 38 | if not path.exists(fullpath): 39 | raise IOError("File %s does not exist" % fullpath) 40 | self.frame_skip = frame_skip 41 | self.model = mujoco_py.MjModel(fullpath) 42 | self.data = self.model.data 43 | self.viewer = None 44 | 45 | self.metadata = { 46 | 'render.modes': ['human', 'rgb_array'], 47 | 'video.frames_per_second': int(np.round(1.0 / self.dt)) 48 | } 49 | 50 | self.init_qpos = self.model.data.qpos.ravel().copy() 51 | self.init_qvel = self.model.data.qvel.ravel().copy() 52 | self._seed() 53 | 54 | def init_serialization(self, locals): 55 | Serializable.quick_init(self, locals) 56 | 57 | def log_diagnostics(self, *args, **kwargs): 58 | pass 59 | 60 | 61 | def get_asset_xml(xml_name): 62 | return os.path.join(ENV_ASSET_DIR, xml_name) 63 | -------------------------------------------------------------------------------- /rlkit/envs/pearl_envs/rand_param_envs/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rail-berkeley/rlkit/ac45a9db24b89d97369bef302487273bcc3e3d84/rlkit/envs/pearl_envs/rand_param_envs/__init__.py -------------------------------------------------------------------------------- /rlkit/envs/pearl_envs/rand_param_envs/hopper_rand_params.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from gym import utils 3 | from rlkit.envs.pearl_envs.rand_param_envs.base import RandomEnv 4 | 5 | 6 | class HopperRandParamsEnv(RandomEnv, utils.EzPickle): 7 | def __init__(self, log_scale_limit=3.0): 8 | RandomEnv.__init__(self, log_scale_limit, 'hopper.xml', 4) 9 | utils.EzPickle.__init__(self) 10 | 11 | def _step(self, a): 12 | posbefore = self.sim.data.qpos[0] 13 | self.do_simulation(a, self.frame_skip) 14 | posafter, height, ang = self.sim.data.qpos[0:3] 15 | alive_bonus = 1.0 16 | reward = (posafter - posbefore) / self.dt 17 | reward += alive_bonus 18 | reward -= 1e-3 * np.square(a).sum() 19 | s = self.state_vector() 20 | done = not (np.isfinite(s).all() and (np.abs(s[2:]) < 100).all() and 21 | (height > .7) and (abs(ang) < .2)) 22 | ob = self._get_obs() 23 | return ob, reward, done, {} 24 | 25 | def _get_obs(self): 26 | return np.concatenate([ 27 | self.sim.data.qpos.flat[1:], 28 | np.clip(self.sim.data.qvel.flat, -10, 10) 29 | ]) 30 | 31 | def reset_model(self): 32 | qpos = self.init_qpos + self.np_random.uniform(low=-.005, high=.005, size=self.model.nq) 33 | qvel = self.init_qvel + self.np_random.uniform(low=-.005, high=.005, size=self.model.nv) 34 | self.set_state(qpos, qvel) 35 | return self._get_obs() 36 | 37 | def viewer_setup(self): 38 | self.viewer.cam.trackbodyid = 2 39 | self.viewer.cam.distance = self.model.stat.extent * 0.75 40 | self.viewer.cam.lookat[2] += .8 41 | self.viewer.cam.elevation = -20 42 | 43 | if __name__ == "__main__": 44 | 45 | env = HopperRandParamsEnv() 46 | tasks = env.sample_tasks(40) 47 | while True: 48 | env.reset() 49 | env.set_task(np.random.choice(tasks)) 50 | print(env.model.body_mass) 51 | for _ in range(100): 52 | env.render() 53 | env.step(env.action_space.sample()) # take a random action 54 | 55 | -------------------------------------------------------------------------------- /rlkit/envs/pearl_envs/rand_param_envs/pr2_env_reach.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from gym import utils 3 | from rlkit.envs.pearl_envs.rand_param_envs.base import RandomEnv 4 | import os 5 | 6 | class PR2Env(RandomEnv, utils.EzPickle): 7 | 8 | FILE = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'assets/pr2.xml') 9 | 10 | def __init__(self, log_scale_limit=1.): 11 | self.viewer = None 12 | RandomEnv.__init__(self, log_scale_limit, 'pr2.xml', 4) 13 | utils.EzPickle.__init__(self) 14 | 15 | def _get_obs(self): 16 | return np.concatenate([ 17 | self.model.data.qpos.flat[:7], 18 | self.model.data.qvel.flat[:7], # Do not include the velocity of the target (should be 0). 19 | self.get_tip_position().flat, 20 | self.get_vec_tip_to_goal().flat, 21 | ]) 22 | 23 | def get_tip_position(self): 24 | return self.model.data.site_xpos[0] 25 | 26 | def get_vec_tip_to_goal(self): 27 | tip_position = self.get_tip_position() 28 | goal_position = self.goal 29 | vec_tip_to_goal = goal_position - tip_position 30 | return vec_tip_to_goal 31 | 32 | @property 33 | def goal(self): 34 | return self.model.data.qpos.flat[-3:] 35 | 36 | def _step(self, action): 37 | 38 | self.do_simulation(action, self.frame_skip) 39 | 40 | vec_tip_to_goal = self.get_vec_tip_to_goal() 41 | distance_tip_to_goal = np.linalg.norm(vec_tip_to_goal) 42 | 43 | reward = - distance_tip_to_goal 44 | 45 | state = self.state_vector() 46 | notdone = np.isfinite(state).all() 47 | done = not notdone 48 | 49 | ob = self._get_obs() 50 | 51 | return ob, reward, done, {} 52 | 53 | def reset_model(self): 54 | qpos = self.init_qpos 55 | qvel = self.init_qvel 56 | goal = np.random.uniform((0.2, -0.4, 0.5), (0.5, 0.4, 1.5)) 57 | qpos[-3:] = goal 58 | qpos[:7] += self.np_random.uniform(low=-.005, high=.005, size=7) 59 | qvel[:7] += self.np_random.uniform(low=-.005, high=.005, size=7) 60 | self.set_state(qpos, qvel) 61 | return self._get_obs() 62 | 63 | def viewer_setup(self): 64 | self.viewer.cam.distance = self.model.stat.extent * 2 65 | # self.viewer.cam.lookat[2] += .8 66 | self.viewer.cam.elevation = -50 67 | # self.viewer.cam.lookat[0] = self.model.stat.center[0] 68 | # self.viewer.cam.lookat[1] = self.model.stat.center[1] 69 | # self.viewer.cam.lookat[2] = self.model.stat.center[2] 70 | 71 | 72 | if __name__ == "__main__": 73 | 74 | env = PR2Env() 75 | tasks = env.sample_tasks(40) 76 | while True: 77 | env.reset() 78 | env.set_task(np.random.choice(tasks)) 79 | print(env.model.body_mass) 80 | for _ in range(100): 81 | env.render() 82 | env.step(env.action_space.sample()) 83 | -------------------------------------------------------------------------------- /rlkit/envs/pearl_envs/rand_param_envs/walker2d_rand_params.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from gym import utils 3 | from rlkit.envs.pearl_envs.rand_param_envs.base import RandomEnv 4 | 5 | 6 | class Walker2DRandParamsEnv(RandomEnv, utils.EzPickle): 7 | def __init__(self, log_scale_limit=3.0): 8 | RandomEnv.__init__(self, log_scale_limit, 'walker2d.xml', 5) 9 | utils.EzPickle.__init__(self) 10 | 11 | def _step(self, a): 12 | # import ipdb; ipdb.set_trace() 13 | # posbefore = self.model.data.qpos[0, 0] 14 | posbefore = self.sim.data.qpos[0] 15 | self.do_simulation(a, self.frame_skip) 16 | # posafter, height, ang = self.model.data.qpos[0:3, 0] 17 | posafter, height, ang = self.sim.data.qpos[0:3] 18 | alive_bonus = 1.0 19 | reward = ((posafter - posbefore) / self.dt) 20 | reward += alive_bonus 21 | reward -= 1e-3 * np.square(a).sum() 22 | done = not (height > 0.8 and height < 2.0 and 23 | ang > -1.0 and ang < 1.0) 24 | ob = self._get_obs() 25 | return ob, reward, done, {} 26 | 27 | def _get_obs(self): 28 | # qpos = self.model.data.qpos 29 | # qvel = self.model.data.qvel 30 | qpos = self.sim.data.qpos 31 | qvel = self.sim.data.qvel 32 | return np.concatenate([qpos[1:], np.clip(qvel, -10, 10)]).ravel() 33 | 34 | def reset_model(self): 35 | self.set_state( 36 | self.init_qpos + self.np_random.uniform(low=-.005, high=.005, size=self.model.nq), 37 | self.init_qvel + self.np_random.uniform(low=-.005, high=.005, size=self.model.nv) 38 | ) 39 | return self._get_obs() 40 | 41 | def viewer_setup(self): 42 | self.viewer.cam.trackbodyid = 2 43 | self.viewer.cam.distance = self.model.stat.extent * 0.5 44 | self.viewer.cam.lookat[2] += .8 45 | self.viewer.cam.elevation = -20 46 | 47 | if __name__ == "__main__": 48 | 49 | env = Walker2DRandParamsEnv() 50 | tasks = env.sample_tasks(40) 51 | while True: 52 | env.reset() 53 | env.set_task(np.random.choice(tasks)) 54 | print(env.model.body_mass) 55 | for _ in range(100): 56 | env.render() 57 | env.step(env.action_space.sample()) # take a random action 58 | 59 | -------------------------------------------------------------------------------- /rlkit/envs/pearl_envs/walker_rand_params_wrapper.py: -------------------------------------------------------------------------------- 1 | from rlkit.envs.pearl_envs.rand_param_envs.walker2d_rand_params import Walker2DRandParamsEnv 2 | 3 | 4 | class WalkerRandParamsWrappedEnv(Walker2DRandParamsEnv): 5 | def __init__(self, n_tasks=2, randomize_tasks=True): 6 | super(WalkerRandParamsWrappedEnv, self).__init__() 7 | self.tasks = self.sample_tasks(n_tasks) 8 | self.reset_task(0) 9 | 10 | def get_all_task_idx(self): 11 | return range(len(self.tasks)) 12 | 13 | def reset_task(self, idx): 14 | self._task = self.tasks[idx] 15 | self._goal = idx 16 | self.set_task(self._task) 17 | self.reset() 18 | -------------------------------------------------------------------------------- /rlkit/envs/proxy_env.py: -------------------------------------------------------------------------------- 1 | from gym import Env 2 | 3 | 4 | class ProxyEnv(Env): 5 | def __init__(self, wrapped_env): 6 | self._wrapped_env = wrapped_env 7 | self.action_space = self._wrapped_env.action_space 8 | self.observation_space = self._wrapped_env.observation_space 9 | 10 | @property 11 | def wrapped_env(self): 12 | return self._wrapped_env 13 | 14 | def reset(self, **kwargs): 15 | return self._wrapped_env.reset(**kwargs) 16 | 17 | def step(self, action): 18 | return self._wrapped_env.step(action) 19 | 20 | def render(self, *args, **kwargs): 21 | return self._wrapped_env.render(*args, **kwargs) 22 | 23 | @property 24 | def horizon(self): 25 | return self._wrapped_env.horizon 26 | 27 | def terminate(self): 28 | if hasattr(self.wrapped_env, "terminate"): 29 | self.wrapped_env.terminate() 30 | 31 | def __getattr__(self, attr): 32 | if attr == '_wrapped_env': 33 | raise AttributeError() 34 | return getattr(self._wrapped_env, attr) 35 | 36 | def __getstate__(self): 37 | """ 38 | This is useful to override in case the wrapped env has some funky 39 | __getstate__ that doesn't play well with overriding __getattr__. 40 | 41 | The main problematic case is/was gym's EzPickle serialization scheme. 42 | :return: 43 | """ 44 | return self.__dict__ 45 | 46 | def __setstate__(self, state): 47 | self.__dict__.update(state) 48 | 49 | def __str__(self): 50 | return '{}({})'.format(type(self).__name__, self.wrapped_env) -------------------------------------------------------------------------------- /rlkit/envs/wrappers/__init__.py: -------------------------------------------------------------------------------- 1 | from rlkit.envs.wrappers.discretize_env import DiscretizeEnv 2 | from rlkit.envs.wrappers.history_env import HistoryEnv 3 | from rlkit.envs.wrappers.image_mujoco_env import ImageMujocoEnv 4 | from rlkit.envs.wrappers.image_mujoco_env_with_obs import ImageMujocoWithObsEnv 5 | from rlkit.envs.wrappers.normalized_box_env import NormalizedBoxEnv 6 | from rlkit.envs.proxy_env import ProxyEnv 7 | from rlkit.envs.wrappers.reward_wrapper_env import RewardWrapperEnv 8 | from rlkit.envs.wrappers.stack_observation_env import StackObservationEnv 9 | 10 | 11 | __all__ = [ 12 | 'DiscretizeEnv', 13 | 'HistoryEnv', 14 | 'ImageMujocoEnv', 15 | 'ImageMujocoWithObsEnv', 16 | 'NormalizedBoxEnv', 17 | 'ProxyEnv', 18 | 'RewardWrapperEnv', 19 | 'StackObservationEnv', 20 | ] -------------------------------------------------------------------------------- /rlkit/envs/wrappers/discretize_env.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | 3 | import numpy as np 4 | from gym import Env 5 | from gym.spaces import Discrete 6 | 7 | from rlkit.envs.proxy_env import ProxyEnv 8 | 9 | 10 | class DiscretizeEnv(ProxyEnv, Env): 11 | def __init__(self, wrapped_env, num_bins): 12 | super().__init__(wrapped_env) 13 | low = self.wrapped_env.action_space.low 14 | high = self.wrapped_env.action_space.high 15 | action_ranges = [ 16 | np.linspace(low[i], high[i], num_bins) 17 | for i in range(len(low)) 18 | ] 19 | self.idx_to_continuous_action = [ 20 | np.array(x) for x in itertools.product(*action_ranges) 21 | ] 22 | self.action_space = Discrete(len(self.idx_to_continuous_action)) 23 | 24 | def step(self, action): 25 | continuous_action = self.idx_to_continuous_action[action] 26 | return super().step(continuous_action) 27 | 28 | 29 | -------------------------------------------------------------------------------- /rlkit/envs/wrappers/history_env.py: -------------------------------------------------------------------------------- 1 | from collections import deque 2 | 3 | import numpy as np 4 | from gym import Env 5 | from gym.spaces import Box 6 | 7 | from rlkit.envs.proxy_env import ProxyEnv 8 | 9 | 10 | class HistoryEnv(ProxyEnv, Env): 11 | def __init__(self, wrapped_env, history_len): 12 | super().__init__(wrapped_env) 13 | self.history_len = history_len 14 | 15 | high = np.inf * np.ones( 16 | self.history_len * self.observation_space.low.size) 17 | low = -high 18 | self.observation_space = Box(low=low, 19 | high=high, 20 | ) 21 | self.history = deque(maxlen=self.history_len) 22 | 23 | def step(self, action): 24 | state, reward, done, info = super().step(action) 25 | self.history.append(state) 26 | flattened_history = self._get_history().flatten() 27 | return flattened_history, reward, done, info 28 | 29 | def reset(self, **kwargs): 30 | state = super().reset() 31 | self.history = deque(maxlen=self.history_len) 32 | self.history.append(state) 33 | flattened_history = self._get_history().flatten() 34 | return flattened_history 35 | 36 | def _get_history(self): 37 | observations = list(self.history) 38 | 39 | obs_count = len(observations) 40 | for _ in range(self.history_len - obs_count): 41 | dummy = np.zeros(self._wrapped_env.observation_space.low.size) 42 | observations.append(dummy) 43 | return np.c_[observations] 44 | 45 | 46 | -------------------------------------------------------------------------------- /rlkit/envs/wrappers/image_mujoco_env_with_obs.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from gym.spaces import Box 3 | 4 | from rlkit.envs.wrappers.image_mujoco_env import ImageMujocoEnv 5 | 6 | 7 | class ImageMujocoWithObsEnv(ImageMujocoEnv): 8 | def __init__(self, env, **kwargs): 9 | super().__init__(env, **kwargs) 10 | self.observation_space = Box( 11 | low=0.0, 12 | high=1.0, 13 | shape=(self.image_length * self.history_length 14 | + self.wrapped_env.obs_dim,)) 15 | 16 | def _get_obs(self, history_flat, true_state): 17 | return np.concatenate([history_flat, true_state]) 18 | -------------------------------------------------------------------------------- /rlkit/envs/wrappers/normalized_box_env.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from gym.spaces import Box 3 | 4 | from rlkit.envs.proxy_env import ProxyEnv 5 | 6 | 7 | class NormalizedBoxEnv(ProxyEnv): 8 | """ 9 | Normalize action to in [-1, 1]. 10 | 11 | Optionally normalize observations and scale reward. 12 | """ 13 | 14 | def __init__( 15 | self, 16 | env, 17 | reward_scale=1., 18 | obs_mean=None, 19 | obs_std=None, 20 | ): 21 | ProxyEnv.__init__(self, env) 22 | self._should_normalize = not (obs_mean is None and obs_std is None) 23 | if self._should_normalize: 24 | if obs_mean is None: 25 | obs_mean = np.zeros_like(env.observation_space.low) 26 | else: 27 | obs_mean = np.array(obs_mean) 28 | if obs_std is None: 29 | obs_std = np.ones_like(env.observation_space.low) 30 | else: 31 | obs_std = np.array(obs_std) 32 | self._reward_scale = reward_scale 33 | self._obs_mean = obs_mean 34 | self._obs_std = obs_std 35 | ub = np.ones(self._wrapped_env.action_space.shape) 36 | self.action_space = Box(-1 * ub, ub) 37 | 38 | def estimate_obs_stats(self, obs_batch, override_values=False): 39 | if self._obs_mean is not None and not override_values: 40 | raise Exception("Observation mean and std already set. To " 41 | "override, set override_values to True.") 42 | self._obs_mean = np.mean(obs_batch, axis=0) 43 | self._obs_std = np.std(obs_batch, axis=0) 44 | 45 | def _apply_normalize_obs(self, obs): 46 | return (obs - self._obs_mean) / (self._obs_std + 1e-8) 47 | 48 | def step(self, action): 49 | lb = self._wrapped_env.action_space.low 50 | ub = self._wrapped_env.action_space.high 51 | scaled_action = lb + (action + 1.) * 0.5 * (ub - lb) 52 | scaled_action = np.clip(scaled_action, lb, ub) 53 | 54 | wrapped_step = self._wrapped_env.step(scaled_action) 55 | next_obs, reward, done, info = wrapped_step 56 | if self._should_normalize: 57 | next_obs = self._apply_normalize_obs(next_obs) 58 | return next_obs, reward * self._reward_scale, done, info 59 | 60 | def __str__(self): 61 | return "Normalized: %s" % self._wrapped_env 62 | 63 | 64 | -------------------------------------------------------------------------------- /rlkit/envs/wrappers/reward_wrapper_env.py: -------------------------------------------------------------------------------- 1 | from rlkit.envs.proxy_env import ProxyEnv 2 | 3 | 4 | class RewardWrapperEnv(ProxyEnv): 5 | """Substitute a different reward function""" 6 | 7 | def __init__( 8 | self, 9 | env, 10 | compute_reward_fn, 11 | ): 12 | ProxyEnv.__init__(self, env) 13 | self.spec = env.spec # hack for hand envs 14 | self.compute_reward_fn = compute_reward_fn 15 | 16 | def step(self, action): 17 | next_obs, reward, done, info = self._wrapped_env.step(action) 18 | info["env_reward"] = reward 19 | reward = self.compute_reward_fn(next_obs, reward, done, info) 20 | return next_obs, reward, done, info 21 | -------------------------------------------------------------------------------- /rlkit/envs/wrappers/stack_observation_env.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from gym.spaces import Box 3 | 4 | from rlkit.envs.proxy_env import ProxyEnv 5 | 6 | 7 | class StackObservationEnv(ProxyEnv): 8 | """ 9 | Env wrapper for passing history of observations as the new observation 10 | """ 11 | 12 | def __init__( 13 | self, 14 | env, 15 | stack_obs=1, 16 | ): 17 | ProxyEnv.__init__(self, env) 18 | self.stack_obs = stack_obs 19 | low = env.observation_space.low 20 | high = env.observation_space.high 21 | self.obs_dim = low.size 22 | self._last_obs = np.zeros((self.stack_obs, self.obs_dim)) 23 | self.observation_space = Box( 24 | low=np.repeat(low, stack_obs), 25 | high=np.repeat(high, stack_obs), 26 | ) 27 | 28 | def reset(self): 29 | self._last_obs = np.zeros((self.stack_obs, self.obs_dim)) 30 | next_obs = self._wrapped_env.reset() 31 | self._last_obs[-1, :] = next_obs 32 | return self._last_obs.copy().flatten() 33 | 34 | def step(self, action): 35 | next_obs, reward, done, info = self._wrapped_env.step(action) 36 | self._last_obs = np.vstack(( 37 | self._last_obs[1:, :], 38 | next_obs 39 | )) 40 | return self._last_obs.copy().flatten(), reward, done, info 41 | 42 | 43 | -------------------------------------------------------------------------------- /rlkit/exploration_strategies/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rail-berkeley/rlkit/ac45a9db24b89d97369bef302487273bcc3e3d84/rlkit/exploration_strategies/__init__.py -------------------------------------------------------------------------------- /rlkit/exploration_strategies/base.py: -------------------------------------------------------------------------------- 1 | import abc 2 | 3 | from rlkit.policies.base import ExplorationPolicy 4 | 5 | 6 | class ExplorationStrategy(object, metaclass=abc.ABCMeta): 7 | @abc.abstractmethod 8 | def get_action(self, t, observation, policy, **kwargs): 9 | pass 10 | 11 | def reset(self): 12 | pass 13 | 14 | 15 | class RawExplorationStrategy(ExplorationStrategy, metaclass=abc.ABCMeta): 16 | @abc.abstractmethod 17 | def get_action_from_raw_action(self, action, **kwargs): 18 | pass 19 | 20 | def get_action(self, t, policy, *args, **kwargs): 21 | action, agent_info = policy.get_action(*args, **kwargs) 22 | return self.get_action_from_raw_action(action, t=t), agent_info 23 | 24 | def reset(self): 25 | pass 26 | 27 | 28 | class PolicyWrappedWithExplorationStrategy(ExplorationPolicy): 29 | def __init__( 30 | self, 31 | exploration_strategy: ExplorationStrategy, 32 | policy, 33 | ): 34 | self.es = exploration_strategy 35 | self.policy = policy 36 | self.t = 0 37 | 38 | def set_num_steps_total(self, t): 39 | self.t = t 40 | 41 | def get_action(self, *args, **kwargs): 42 | return self.es.get_action(self.t, self.policy, *args, **kwargs) 43 | 44 | def reset(self): 45 | self.es.reset() 46 | self.policy.reset() 47 | -------------------------------------------------------------------------------- /rlkit/exploration_strategies/epsilon_greedy.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | from rlkit.exploration_strategies.base import RawExplorationStrategy 4 | 5 | 6 | class EpsilonGreedy(RawExplorationStrategy): 7 | """ 8 | Take a random discrete action with some probability. 9 | """ 10 | def __init__(self, action_space, prob_random_action=0.1): 11 | self.prob_random_action = prob_random_action 12 | self.action_space = action_space 13 | 14 | def get_action_from_raw_action(self, action, **kwargs): 15 | if random.random() <= self.prob_random_action: 16 | return self.action_space.sample() 17 | return action 18 | -------------------------------------------------------------------------------- /rlkit/exploration_strategies/gaussian_and_epsilon_strategy.py: -------------------------------------------------------------------------------- 1 | import random 2 | from rlkit.exploration_strategies.base import RawExplorationStrategy 3 | import numpy as np 4 | 5 | 6 | class GaussianAndEpsilonStrategy(RawExplorationStrategy): 7 | """ 8 | With probability epsilon, take a completely random action. 9 | with probability 1-epsilon, add Gaussian noise to the action taken by a 10 | deterministic policy. 11 | """ 12 | def __init__(self, action_space, epsilon, max_sigma=1.0, min_sigma=None, 13 | decay_period=1000000): 14 | assert len(action_space.shape) == 1 15 | if min_sigma is None: 16 | min_sigma = max_sigma 17 | self._max_sigma = max_sigma 18 | self._epsilon = epsilon 19 | self._min_sigma = min_sigma 20 | self._decay_period = decay_period 21 | self._action_space = action_space 22 | 23 | def get_action_from_raw_action(self, action, t=None, **kwargs): 24 | if random.random() < self._epsilon: 25 | return self._action_space.sample() 26 | else: 27 | sigma = self._max_sigma - (self._max_sigma - self._min_sigma) * min(1.0, t * 1.0 / self._decay_period) 28 | return np.clip( 29 | action + np.random.normal(size=len(action)) * sigma, 30 | self._action_space.low, 31 | self._action_space.high, 32 | ) 33 | -------------------------------------------------------------------------------- /rlkit/exploration_strategies/gaussian_strategy.py: -------------------------------------------------------------------------------- 1 | from rlkit.exploration_strategies.base import RawExplorationStrategy 2 | import numpy as np 3 | 4 | 5 | class GaussianStrategy(RawExplorationStrategy): 6 | """ 7 | This strategy adds Gaussian noise to the action taken by the deterministic policy. 8 | 9 | Based on the rllab implementation. 10 | """ 11 | def __init__(self, action_space, max_sigma=1.0, min_sigma=None, 12 | decay_period=1000000): 13 | assert len(action_space.shape) == 1 14 | self._max_sigma = max_sigma 15 | if min_sigma is None: 16 | min_sigma = max_sigma 17 | self._min_sigma = min_sigma 18 | self._decay_period = decay_period 19 | self._action_space = action_space 20 | 21 | def get_action_from_raw_action(self, action, t=None, **kwargs): 22 | sigma = ( 23 | self._max_sigma - (self._max_sigma - self._min_sigma) * 24 | min(1.0, t * 1.0 / self._decay_period) 25 | ) 26 | return np.clip( 27 | action + np.random.normal(size=len(action)) * sigma, 28 | self._action_space.low, 29 | self._action_space.high, 30 | ) 31 | -------------------------------------------------------------------------------- /rlkit/exploration_strategies/ou_strategy.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import numpy.random as nr 3 | 4 | from rlkit.exploration_strategies.base import RawExplorationStrategy 5 | 6 | 7 | class OUStrategy(RawExplorationStrategy): 8 | """ 9 | This strategy implements the Ornstein-Uhlenbeck process, which adds 10 | time-correlated noise to the actions taken by the deterministic policy. 11 | The OU process satisfies the following stochastic differential equation: 12 | dxt = theta*(mu - xt)*dt + sigma*dWt 13 | where Wt denotes the Wiener process 14 | 15 | Based on the rllab implementation. 16 | """ 17 | 18 | def __init__( 19 | self, 20 | action_space, 21 | mu=0, 22 | theta=0.15, 23 | max_sigma=0.3, 24 | min_sigma=None, 25 | decay_period=100000, 26 | ): 27 | if min_sigma is None: 28 | min_sigma = max_sigma 29 | self.mu = mu 30 | self.theta = theta 31 | self.sigma = max_sigma 32 | self._max_sigma = max_sigma 33 | if min_sigma is None: 34 | min_sigma = max_sigma 35 | self._min_sigma = min_sigma 36 | self._decay_period = decay_period 37 | self.dim = np.prod(action_space.low.shape) 38 | self.low = action_space.low 39 | self.high = action_space.high 40 | self.state = np.ones(self.dim) * self.mu 41 | self.reset() 42 | 43 | def reset(self): 44 | self.state = np.ones(self.dim) * self.mu 45 | 46 | def evolve_state(self): 47 | x = self.state 48 | dx = self.theta * (self.mu - x) + self.sigma * nr.randn(len(x)) 49 | self.state = x + dx 50 | return self.state 51 | 52 | def get_action_from_raw_action(self, action, t=0, **kwargs): 53 | ou_state = self.evolve_state() 54 | self.sigma = ( 55 | self._max_sigma 56 | - (self._max_sigma - self._min_sigma) 57 | * min(1.0, t * 1.0 / self._decay_period) 58 | ) 59 | return np.clip(action + ou_state, self.low, self.high) 60 | -------------------------------------------------------------------------------- /rlkit/launchers/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | This module contains 'launchers', which are self-contained functions that take 3 | one dictionary and run a full experiment. The dictionary configures the 4 | experiment. 5 | 6 | It is important that the functions are completely self-contained (i.e. they 7 | import their own modules) so that they can be serialized. 8 | """ 9 | -------------------------------------------------------------------------------- /rlkit/launchers/experiments/awac/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rail-berkeley/rlkit/ac45a9db24b89d97369bef302487273bcc3e3d84/rlkit/launchers/experiments/awac/__init__.py -------------------------------------------------------------------------------- /rlkit/policies/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rail-berkeley/rlkit/ac45a9db24b89d97369bef302487273bcc3e3d84/rlkit/policies/__init__.py -------------------------------------------------------------------------------- /rlkit/policies/argmax.py: -------------------------------------------------------------------------------- 1 | """ 2 | Torch argmax policy 3 | """ 4 | import numpy as np 5 | from torch import nn 6 | 7 | import rlkit.torch.pytorch_util as ptu 8 | from rlkit.policies.base import Policy 9 | 10 | 11 | class ArgmaxDiscretePolicy(nn.Module, Policy): 12 | def __init__(self, qf): 13 | super().__init__() 14 | self.qf = qf 15 | 16 | def get_action(self, obs): 17 | obs = np.expand_dims(obs, axis=0) 18 | obs = ptu.from_numpy(obs).float() 19 | q_values = self.qf(obs).squeeze(0) 20 | q_values_np = ptu.get_numpy(q_values) 21 | return q_values_np.argmax(), {} 22 | -------------------------------------------------------------------------------- /rlkit/policies/base.py: -------------------------------------------------------------------------------- 1 | import abc 2 | 3 | 4 | class Policy(object, metaclass=abc.ABCMeta): 5 | """ 6 | General policy interface. 7 | """ 8 | @abc.abstractmethod 9 | def get_action(self, observation): 10 | """ 11 | 12 | :param observation: 13 | :return: action, debug_dictionary 14 | """ 15 | pass 16 | 17 | def reset(self): 18 | pass 19 | 20 | 21 | class ExplorationPolicy(Policy, metaclass=abc.ABCMeta): 22 | def set_num_steps_total(self, t): 23 | pass 24 | -------------------------------------------------------------------------------- /rlkit/policies/simple.py: -------------------------------------------------------------------------------- 1 | from rlkit.policies.base import Policy 2 | 3 | 4 | class RandomPolicy(Policy): 5 | """ 6 | Policy that always outputs zero. 7 | """ 8 | 9 | def __init__(self, action_space): 10 | self.action_space = action_space 11 | 12 | def get_action(self, obs): 13 | return self.action_space.sample(), {} 14 | -------------------------------------------------------------------------------- /rlkit/samplers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rail-berkeley/rlkit/ac45a9db24b89d97369bef302487273bcc3e3d84/rlkit/samplers/__init__.py -------------------------------------------------------------------------------- /rlkit/samplers/data_collector/__init__.py: -------------------------------------------------------------------------------- 1 | from rlkit.samplers.data_collector.base import ( 2 | DataCollector, 3 | PathCollector, 4 | StepCollector, 5 | ) 6 | from rlkit.samplers.data_collector.path_collector import ( 7 | MdpPathCollector, 8 | ObsDictPathCollector, 9 | GoalConditionedPathCollector, 10 | VAEWrappedEnvPathCollector, 11 | ) 12 | from rlkit.samplers.data_collector.step_collector import ( 13 | GoalConditionedStepCollector 14 | ) 15 | -------------------------------------------------------------------------------- /rlkit/samplers/data_collector/base.py: -------------------------------------------------------------------------------- 1 | import abc 2 | 3 | 4 | class DataCollector(object, metaclass=abc.ABCMeta): 5 | def end_epoch(self, epoch): 6 | pass 7 | 8 | def get_diagnostics(self): 9 | return {} 10 | 11 | def get_snapshot(self): 12 | return {} 13 | 14 | @abc.abstractmethod 15 | def get_epoch_paths(self): 16 | pass 17 | 18 | 19 | class PathCollector(DataCollector, metaclass=abc.ABCMeta): 20 | @abc.abstractmethod 21 | def collect_new_paths( 22 | self, 23 | max_path_length, 24 | num_steps, 25 | discard_incomplete_paths, 26 | ): 27 | pass 28 | 29 | 30 | class StepCollector(DataCollector, metaclass=abc.ABCMeta): 31 | @abc.abstractmethod 32 | def collect_new_steps( 33 | self, 34 | max_path_length, 35 | num_steps, 36 | discard_incomplete_paths, 37 | ): 38 | pass 39 | -------------------------------------------------------------------------------- /rlkit/samplers/data_collector/contextual_path_collector.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | 3 | from rlkit.envs.contextual import ContextualEnv 4 | from rlkit.policies.base import Policy 5 | from rlkit.samplers.data_collector import MdpPathCollector 6 | from rlkit.samplers.rollout_functions import contextual_rollout 7 | 8 | 9 | class ContextualPathCollector(MdpPathCollector): 10 | def __init__( 11 | self, 12 | env: ContextualEnv, 13 | policy: Policy, 14 | max_num_epoch_paths_saved=None, 15 | observation_key='observation', 16 | context_keys_for_policy='context', 17 | render=False, 18 | render_kwargs=None, 19 | **kwargs 20 | ): 21 | rollout_fn = partial( 22 | contextual_rollout, 23 | context_keys_for_policy=context_keys_for_policy, 24 | observation_key=observation_key, 25 | ) 26 | super().__init__( 27 | env, policy, max_num_epoch_paths_saved, render, render_kwargs, 28 | rollout_fn=rollout_fn, 29 | **kwargs 30 | ) 31 | self._observation_key = observation_key 32 | self._context_keys_for_policy = context_keys_for_policy 33 | 34 | def get_snapshot(self): 35 | snapshot = super().get_snapshot() 36 | snapshot.update( 37 | observation_key=self._observation_key, 38 | context_keys_for_policy=self._context_keys_for_policy, 39 | ) 40 | return snapshot 41 | -------------------------------------------------------------------------------- /rlkit/samplers/data_collector/joint_path_collector.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | from typing import Dict 3 | 4 | from rlkit.core.logging import add_prefix 5 | from rlkit.samplers.data_collector import PathCollector 6 | 7 | 8 | class JointPathCollector(PathCollector): 9 | def __init__(self, path_collectors: Dict[str, PathCollector]): 10 | self.path_collectors = path_collectors 11 | 12 | def collect_new_paths(self, max_path_length, num_steps, 13 | discard_incomplete_paths): 14 | paths = [] 15 | for collector in self.path_collectors.values(): 16 | collector.collect_new_paths( 17 | max_path_length, num_steps, discard_incomplete_paths 18 | ) 19 | return paths 20 | 21 | def end_epoch(self, epoch): 22 | for collector in self.path_collectors.values(): 23 | collector.end_epoch(epoch) 24 | 25 | def get_diagnostics(self): 26 | diagnostics = OrderedDict() 27 | for name, collector in self.path_collectors.items(): 28 | diagnostics.update( 29 | add_prefix(collector.get_diagnostics(), name, divider='/'), 30 | ) 31 | return diagnostics 32 | 33 | def get_snapshot(self): 34 | snapshot = {} 35 | for name, collector in self.path_collectors.items(): 36 | snapshot.update( 37 | add_prefix(collector.get_snapshot(), name, divider='/'), 38 | ) 39 | return snapshot 40 | 41 | def get_epoch_paths(self): 42 | paths = {} 43 | for name, collector in self.path_collectors.items(): 44 | paths[name] = collector.get_epoch_paths() 45 | return paths 46 | 47 | -------------------------------------------------------------------------------- /rlkit/samplers/data_collector/vae_env.py: -------------------------------------------------------------------------------- 1 | from rlkit.envs.vae_wrapper import VAEWrappedEnv 2 | from rlkit.samplers.data_collector import GoalConditionedPathCollector 3 | 4 | 5 | class VAEWrappedEnvPathCollector(GoalConditionedPathCollector): 6 | def __init__( 7 | self, 8 | goal_sampling_mode, 9 | env: VAEWrappedEnv, 10 | policy, 11 | decode_goals=False, 12 | **kwargs 13 | ): 14 | super().__init__(env, policy, **kwargs) 15 | self._goal_sampling_mode = goal_sampling_mode 16 | self._decode_goals = decode_goals 17 | 18 | def collect_new_paths(self, *args, **kwargs): 19 | self._env.goal_sampling_mode = self._goal_sampling_mode 20 | self._env.decode_goals = self._decode_goals 21 | return super().collect_new_paths(*args, **kwargs) -------------------------------------------------------------------------------- /rlkit/samplers/in_place.py: -------------------------------------------------------------------------------- 1 | from rlkit.samplers.util import rollout 2 | 3 | 4 | class InPlacePathSampler(object): 5 | """ 6 | A sampler that does not serialization for sampling. Instead, it just uses 7 | the current policy and environment as-is. 8 | 9 | WARNING: This will affect the environment! So 10 | ``` 11 | sampler = InPlacePathSampler(env, ...) 12 | sampler.obtain_samples # this has side-effects: env will change! 13 | ``` 14 | """ 15 | def __init__(self, env, policy, max_samples, max_path_length, render=False): 16 | self.env = env 17 | self.policy = policy 18 | self.max_path_length = max_path_length 19 | self.max_samples = max_samples 20 | self.render = render 21 | assert max_samples >= max_path_length, "Need max_samples >= max_path_length" 22 | 23 | def start_worker(self): 24 | pass 25 | 26 | def shutdown_worker(self): 27 | pass 28 | 29 | def obtain_samples(self): 30 | paths = [] 31 | n_steps_total = 0 32 | while n_steps_total + self.max_path_length <= self.max_samples: 33 | path = rollout( 34 | self.env, self.policy, max_path_length=self.max_path_length, 35 | animated=self.render 36 | ) 37 | paths.append(path) 38 | n_steps_total += len(path['observations']) 39 | return paths 40 | -------------------------------------------------------------------------------- /rlkit/samplers/util.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | def split_paths(paths): 4 | """ 5 | Stack multiples obs/actions/etc. from different paths 6 | :param paths: List of paths, where one path is something returned from 7 | the rollout functino above. 8 | :return: Tuple. Every element will have shape batch_size X DIM, including 9 | the rewards and terminal flags. 10 | """ 11 | rewards = [path["rewards"].reshape(-1, 1) for path in paths] 12 | terminals = [path["terminals"].reshape(-1, 1) for path in paths] 13 | actions = [path["actions"] for path in paths] 14 | obs = [path["observations"] for path in paths] 15 | next_obs = [path["next_observations"] for path in paths] 16 | rewards = np.vstack(rewards) 17 | terminals = np.vstack(terminals) 18 | obs = np.vstack(obs) 19 | actions = np.vstack(actions) 20 | next_obs = np.vstack(next_obs) 21 | assert len(rewards.shape) == 2 22 | assert len(terminals.shape) == 2 23 | assert len(obs.shape) == 2 24 | assert len(actions.shape) == 2 25 | assert len(next_obs.shape) == 2 26 | return rewards, terminals, obs, actions, next_obs 27 | 28 | 29 | def split_paths_to_dict(paths): 30 | rewards, terminals, obs, actions, next_obs = split_paths(paths) 31 | return dict( 32 | rewards=rewards, 33 | terminals=terminals, 34 | observations=obs, 35 | actions=actions, 36 | next_observations=next_obs, 37 | ) 38 | -------------------------------------------------------------------------------- /rlkit/testing/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rail-berkeley/rlkit/ac45a9db24b89d97369bef302487273bcc3e3d84/rlkit/testing/__init__.py -------------------------------------------------------------------------------- /rlkit/testing/csv_util.py: -------------------------------------------------------------------------------- 1 | import csv 2 | import math 3 | 4 | def get_exp(fname): 5 | with open(fname) as csvfile: 6 | reader = csv.DictReader(csvfile) 7 | output = [] 8 | for row in reader: 9 | output.append(row) 10 | return output 11 | 12 | def check_equal(reference, output, keys, ): 13 | for i in range(len(reference)): 14 | reference_row = reference[i] 15 | output_row = output[i] 16 | for key in keys: 17 | assert key in reference_row, "line %d key %s not in reference" % (i, key) 18 | assert key in output_row, "line %d key %s not in output" % (i, key) 19 | r = float(reference_row[key]) 20 | o = float(output_row[key]) 21 | assert math.isclose(r, o, rel_tol=1e-5), "line %d key %s reference: %s, output: %s" % (i, key, r, o) 22 | 23 | def check_exactly_equal(reference, output, ): 24 | for i in range(len(reference)): 25 | reference_row = reference[i] 26 | output_row = output[i] 27 | for key in reference_row: 28 | assert key in output_row, key 29 | assert reference_row[key] == output_row[key], "%s reference: %s, output: %s" % (key, reference_row[key], output_row[key]) 30 | -------------------------------------------------------------------------------- /rlkit/testing/debug_util.py: -------------------------------------------------------------------------------- 1 | """For tracing programs and comparing outputs""" 2 | 3 | import torch 4 | 5 | i = 0 6 | 7 | def save(x): 8 | torch.save(x, "../tmp.pt") 9 | return x 10 | 11 | def load(): 12 | return torch.load("../tmp.pt") 13 | 14 | def savei(x): 15 | global i 16 | torch.save(x, "../tmp/%d.pt" % i) 17 | i = i + 1 18 | return x 19 | 20 | def loadi(): 21 | global i 22 | x = torch.load("../tmp/%d.pt" % i) 23 | i = i + 1 24 | return x 25 | -------------------------------------------------------------------------------- /rlkit/testing/np_test_case.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import numpy as np 4 | 5 | from rlkit.testing.testing_utils import are_np_arrays_equal, \ 6 | are_np_array_iterables_equal 7 | 8 | 9 | class NPTestCase(unittest.TestCase): 10 | """ 11 | Numpy test case, providing useful assert methods. 12 | """ 13 | def assertNpEqual(self, np_arr1, np_arr2, msg="Numpy arrays not equal."): 14 | self.assertTrue(are_np_arrays_equal(np_arr1, np_arr2), msg) 15 | 16 | def assertNpAlmostEqual( 17 | self, 18 | np_arr1, 19 | np_arr2, 20 | msg="Numpy arrays not equal.", 21 | threshold=1e-5, 22 | ): 23 | self.assertTrue( 24 | are_np_arrays_equal(np_arr1, np_arr2, threshold=threshold), 25 | msg 26 | ) 27 | 28 | def assertNpNotEqual(self, np_arr1, np_arr2, msg="Numpy arrays equal"): 29 | self.assertFalse(are_np_arrays_equal(np_arr1, np_arr2), msg) 30 | 31 | def assertNpArraysEqual( 32 | self, 33 | np_arrays1, 34 | np_arrays2, 35 | msg=None, 36 | ): 37 | msg = msg or "Numpy arrays {} and {} are not equal".format( 38 | np_arrays1, 39 | np_arrays2, 40 | ) 41 | self.assertTrue( 42 | are_np_array_iterables_equal( 43 | np_arrays1, 44 | np_arrays2, 45 | ), 46 | msg 47 | ) 48 | 49 | # TODO(vpong): see why such a high threshold is needed 50 | def assertNpArraysAlmostEqual( 51 | self, 52 | np_arrays1, 53 | np_arrays2, 54 | msg="Numpy array lists are not almost equal.", 55 | threshold=1e-4, 56 | ): 57 | self.assertTrue( 58 | are_np_array_iterables_equal( 59 | np_arrays1, 60 | np_arrays2, 61 | threshold=threshold, 62 | ), 63 | msg 64 | ) 65 | 66 | def assertNpArraysNotEqual( 67 | self, 68 | np_arrays1, 69 | np_arrays2, 70 | msg="Numpy array lists are equal." 71 | ): 72 | self.assertFalse(are_np_array_iterables_equal(np_arrays1, np_arrays2), 73 | msg) 74 | 75 | def assertNpArraysNotAlmostEqual( 76 | self, 77 | np_arrays1, 78 | np_arrays2, 79 | msg="Numpy array lists are equal.", 80 | threshold=1e-4, 81 | ): 82 | self.assertFalse( 83 | are_np_array_iterables_equal( 84 | np_arrays1, 85 | np_arrays2, 86 | threshold=threshold, 87 | ), 88 | msg 89 | ) 90 | 91 | def assertNpArrayConstant( 92 | self, 93 | np_array: np.ndarray, 94 | constant 95 | ): 96 | self.assertTrue( 97 | (np_array == constant).all(), 98 | msg="Not all values equal {0}".format(constant) 99 | ) -------------------------------------------------------------------------------- /rlkit/testing/stub_classes.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from gym.spaces import Box, Dict 3 | 4 | from rlkit.exploration_strategies.base import RawExplorationStrategy 5 | 6 | class StubEnv(object): 7 | def __init__(self, obs_dim=1, action_dim=1, **kwargs): 8 | self.obs_dim = obs_dim 9 | obs_low = np.ones(obs_dim) * -1 10 | obs_high = np.ones(obs_dim) 11 | self._observation_space = Box(obs_low, obs_high) 12 | 13 | self.action_dim = action_dim 14 | action_low = np.ones(action_dim) * -1 15 | action_high = np.ones(action_dim) 16 | self._action_space = Box(action_low, action_high) 17 | 18 | print("stub env unused kwargs", kwargs) 19 | 20 | def reset(self): 21 | return np.zeros(self.obs_dim) 22 | 23 | def step(self, action): 24 | return np.zeros(self.obs_dim), 0, 0, {} 25 | 26 | @property 27 | def action_space(self): 28 | return self._action_space 29 | 30 | @property 31 | def horizon(self): 32 | return 99999 33 | 34 | @property 35 | def observation_space(self): 36 | return self._observation_space 37 | 38 | class StubMultiEnv(object): 39 | def __init__(self, obs_dims=None, action_dim=1, **kwargs): 40 | self.obs_dims = obs_dims 41 | 42 | spaces = [] 43 | for name in self.obs_dims: 44 | obs_dim = self.obs_dims[name] 45 | obs_low = np.ones(obs_dim) * -1 46 | obs_high = np.ones(obs_dim) 47 | spaces.append((name, Box(obs_low, obs_high))) 48 | self._observation_space = Dict(spaces) 49 | 50 | self.action_dim = action_dim 51 | action_low = np.ones(action_dim) * -1 52 | action_high = np.ones(action_dim) 53 | self._action_space = Box(action_low, action_high) 54 | 55 | print("stub env unused kwargs", kwargs) 56 | 57 | def reset(self): 58 | return self.get_obs() 59 | 60 | def step(self, action): 61 | return self.get_obs(), 0, 0, {} 62 | 63 | def get_obs(self): 64 | obs = dict() 65 | for name in self.obs_dims: 66 | obs_dim = self.obs_dims[name] 67 | obs[name] = np.zeros(obs_dim) 68 | return obs 69 | 70 | @property 71 | def action_space(self): 72 | return self._action_space 73 | 74 | @property 75 | def horizon(self): 76 | return 99999 77 | 78 | @property 79 | def observation_space(self): 80 | return self._observation_space 81 | 82 | 83 | class StubPolicy(object): 84 | def __init__(self, action): 85 | self._action = action 86 | 87 | def get_action(self, *arg, **kwargs): 88 | return self._action, {} 89 | 90 | 91 | class AddEs(RawExplorationStrategy): 92 | """ 93 | return action + constant 94 | """ 95 | def __init__(self, number): 96 | self._number = number 97 | 98 | def get_action(self, t, observation, policy, **kwargs): 99 | action, _ = policy.get_action(observation) 100 | return self.get_action_from_raw_action(action) 101 | 102 | def get_action_from_raw_action(self, action, **kwargs): 103 | return self._number + action 104 | -------------------------------------------------------------------------------- /rlkit/testing/testing_utils.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numpy as np 3 | from numbers import Number 4 | 5 | 6 | def is_binomial_trial_likely(n, p, num_success, num_std=3): 7 | """ 8 | Returns whether or not seeing `num_sucesss` successes is likely. 9 | :param n: Number of trials. 10 | :param p: Probability of success. 11 | :param num_success: Number of successes 12 | :param num_std: Number of standard deviations the results must be within. 13 | :return: 14 | """ 15 | mean = n * p 16 | std = math.sqrt(n * p * (1 - p)) 17 | margin = num_std * std 18 | return mean - margin < num_success < mean + margin 19 | 20 | 21 | def are_np_array_iterables_equal(np_list1, np_list2, threshold=1e-5): 22 | # import ipdb; ipdb.set_trace() 23 | # if isinstance(np_list1.shape ==, Number) and isinstance(np_itr2, Number): 24 | # return are_np_arrays_equal(np_itr1, np_itr2) 25 | if np_list1.shape == () and np_list2.shape == (): 26 | return are_np_arrays_equal(np_list1, np_list2) 27 | # in case generators were passed in 28 | # np_list1 = list(np_itr1) 29 | # np_list2 = list(np_itr2) 30 | return ( 31 | len(np_list1) == len(np_list2) and 32 | all(are_np_arrays_equal(arr1, arr2, threshold=threshold) 33 | for arr1, arr2 in zip(np_list1, np_list2)) 34 | ) 35 | 36 | 37 | def are_np_arrays_equal(arr1, arr2, threshold=1e-5): 38 | if arr1.shape != arr2.shape: 39 | return False 40 | return (np.abs(arr1 - arr2) <= threshold).all() 41 | 42 | 43 | def is_list_subset(list1, list2): 44 | for a in list1: 45 | if a not in list2: 46 | return False 47 | return True 48 | 49 | 50 | def are_dict_lists_equal(list1, list2): 51 | return is_list_subset(list1, list2) and is_list_subset(list2, list1) 52 | 53 | -------------------------------------------------------------------------------- /rlkit/testing/tf_test_case.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | 4 | from rlkit.testing.np_test_case import NPTestCase 5 | 6 | 7 | class TFTestCase(NPTestCase): 8 | """ 9 | Tensorflow test case, providing useful assert methods and clean default 10 | session. 11 | """ 12 | def setUp(self): 13 | tf.reset_default_graph() 14 | self.sess = tf.get_default_session() or tf.Session() 15 | self.sess_context = self.sess.as_default() 16 | self.sess_context.__enter__() 17 | 18 | def tearDown(self): 19 | self.sess_context.__exit__(None, None, None) 20 | self.sess.close() 21 | 22 | def assertParamsEqual(self, network1, network2): 23 | self.assertNpArraysEqual( 24 | network1.get_param_values(), 25 | network2.get_param_values(), 26 | msg="Parameters are not equal.", 27 | ) 28 | 29 | def assertParamsNotEqual(self, network1, network2): 30 | self.assertNpArraysNotEqual( 31 | network1.get_param_values(), 32 | network2.get_param_values(), 33 | msg="Parameters are equal.", 34 | ) 35 | 36 | def randomize_param_values(self, network): 37 | for v in network.get_params(): 38 | self.sess.run( 39 | v.assign(np.random.rand(*v.get_shape())) 40 | ) 41 | -------------------------------------------------------------------------------- /rlkit/torch/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rail-berkeley/rlkit/ac45a9db24b89d97369bef302487273bcc3e3d84/rlkit/torch/__init__.py -------------------------------------------------------------------------------- /rlkit/torch/core.py: -------------------------------------------------------------------------------- 1 | import abc 2 | 3 | import numpy as np 4 | import torch 5 | from torch import nn as nn 6 | 7 | from rlkit.torch import pytorch_util as ptu 8 | 9 | 10 | class PyTorchModule(nn.Module, metaclass=abc.ABCMeta): 11 | """ 12 | Keeping wrapper around to be a bit more future-proof. 13 | """ 14 | pass 15 | 16 | 17 | def eval_np(module, *args, **kwargs): 18 | """ 19 | Eval this module with a numpy interface 20 | 21 | Same as a call to __call__ except all Variable input/outputs are 22 | replaced with numpy equivalents. 23 | 24 | Assumes the output is either a single object or a tuple of objects. 25 | """ 26 | torch_args = tuple(torch_ify(x) for x in args) 27 | torch_kwargs = {k: torch_ify(v) for k, v in kwargs.items()} 28 | outputs = module(*torch_args, **torch_kwargs) 29 | return elem_or_tuple_to_numpy(outputs) 30 | 31 | 32 | def torch_ify(np_array_or_other): 33 | if isinstance(np_array_or_other, np.ndarray): 34 | return ptu.from_numpy(np_array_or_other) 35 | else: 36 | return np_array_or_other 37 | 38 | 39 | def np_ify(tensor_or_other): 40 | if isinstance(tensor_or_other, torch.autograd.Variable): 41 | return ptu.get_numpy(tensor_or_other) 42 | else: 43 | return tensor_or_other 44 | 45 | 46 | def _elem_or_tuple_to_variable(elem_or_tuple): 47 | if isinstance(elem_or_tuple, tuple): 48 | return tuple( 49 | _elem_or_tuple_to_variable(e) for e in elem_or_tuple 50 | ) 51 | return ptu.from_numpy(elem_or_tuple).float() 52 | 53 | 54 | def elem_or_tuple_to_numpy(elem_or_tuple): 55 | if isinstance(elem_or_tuple, tuple): 56 | return tuple(np_ify(x) for x in elem_or_tuple) 57 | else: 58 | return np_ify(elem_or_tuple) 59 | 60 | 61 | def _filter_batch(np_batch): 62 | for k, v in np_batch.items(): 63 | if v.dtype == np.bool: 64 | yield k, v.astype(int) 65 | else: 66 | yield k, v 67 | 68 | 69 | def np_to_pytorch_batch(np_batch): 70 | if isinstance(np_batch, dict): 71 | return { 72 | k: _elem_or_tuple_to_variable(x) 73 | for k, x in _filter_batch(np_batch) 74 | if x.dtype != np.dtype('O') # ignore object (e.g. dictionaries) 75 | } 76 | else: 77 | _elem_or_tuple_to_variable(np_batch) 78 | -------------------------------------------------------------------------------- /rlkit/torch/data.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch.utils.data import Dataset, Sampler 4 | 5 | # TODO: move this to more reasonable place 6 | from rlkit.data_management.obs_dict_replay_buffer import normalize_image 7 | 8 | 9 | class ImageDataset(Dataset): 10 | 11 | def __init__(self, images, should_normalize=True): 12 | super().__init__() 13 | self.dataset = images 14 | self.dataset_len = len(self.dataset) 15 | assert should_normalize == (images.dtype == np.uint8) 16 | self.should_normalize = should_normalize 17 | 18 | def __len__(self): 19 | return self.dataset_len 20 | 21 | def __getitem__(self, idxs): 22 | samples = self.dataset[idxs, :] 23 | if self.should_normalize: 24 | samples = normalize_image(samples) 25 | return np.float32(samples) 26 | 27 | 28 | class InfiniteRandomSampler(Sampler): 29 | 30 | def __init__(self, data_source): 31 | self.data_source = data_source 32 | self.iter = iter(torch.randperm(len(self.data_source)).tolist()) 33 | 34 | def __iter__(self): 35 | return self 36 | 37 | def __next__(self): 38 | try: 39 | idx = next(self.iter) 40 | except StopIteration: 41 | self.iter = iter(torch.randperm(len(self.data_source)).tolist()) 42 | idx = next(self.iter) 43 | return idx 44 | 45 | def __len__(self): 46 | return 2 ** 62 47 | 48 | 49 | class InfiniteWeightedRandomSampler(Sampler): 50 | 51 | def __init__(self, data_source, weights): 52 | assert len(data_source) == len(weights) 53 | assert len(weights.shape) == 1 54 | self.data_source = data_source 55 | # Always use CPU 56 | self._weights = torch.from_numpy(weights) 57 | self.iter = self._create_iterator() 58 | 59 | def update_weights(self, weights): 60 | self._weights = weights 61 | self.iter = self._create_iterator() 62 | 63 | def _create_iterator(self): 64 | return iter( 65 | torch.multinomial( 66 | self._weights, len(self._weights), replacement=True 67 | ).tolist() 68 | ) 69 | 70 | def __iter__(self): 71 | return self 72 | 73 | def __next__(self): 74 | try: 75 | idx = next(self.iter) 76 | except StopIteration: 77 | self.iter = self._create_iterator() 78 | idx = next(self.iter) 79 | return idx 80 | 81 | def __len__(self): 82 | return 2 ** 62 83 | -------------------------------------------------------------------------------- /rlkit/torch/data_management/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rail-berkeley/rlkit/ac45a9db24b89d97369bef302487273bcc3e3d84/rlkit/torch/data_management/__init__.py -------------------------------------------------------------------------------- /rlkit/torch/data_management/normalizer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import rlkit.torch.pytorch_util as ptu 3 | import numpy as np 4 | 5 | from rlkit.data_management.normalizer import Normalizer, FixedNormalizer 6 | 7 | 8 | class TorchNormalizer(Normalizer): 9 | """ 10 | Update with np array, but de/normalize pytorch Tensors. 11 | """ 12 | def normalize(self, v, clip_range=None): 13 | if not self.synchronized: 14 | self.synchronize() 15 | if clip_range is None: 16 | clip_range = self.default_clip_range 17 | mean = ptu.from_numpy(self.mean) 18 | std = ptu.from_numpy(self.std) 19 | if v.dim() == 2: 20 | # Unsqueeze along the batch use automatic broadcasting 21 | mean = mean.unsqueeze(0) 22 | std = std.unsqueeze(0) 23 | return torch.clamp((v - mean) / std, -clip_range, clip_range) 24 | 25 | def denormalize(self, v): 26 | if not self.synchronized: 27 | self.synchronize() 28 | mean = ptu.from_numpy(self.mean) 29 | std = ptu.from_numpy(self.std) 30 | if v.dim() == 2: 31 | mean = mean.unsqueeze(0) 32 | std = std.unsqueeze(0) 33 | return mean + v * std 34 | 35 | 36 | class TorchFixedNormalizer(FixedNormalizer): 37 | def normalize(self, v, clip_range=None): 38 | if clip_range is None: 39 | clip_range = self.default_clip_range 40 | mean = ptu.from_numpy(self.mean) 41 | std = ptu.from_numpy(self.std) 42 | if v.dim() == 2: 43 | # Unsqueeze along the batch use automatic broadcasting 44 | mean = mean.unsqueeze(0) 45 | std = std.unsqueeze(0) 46 | return torch.clamp((v - mean) / std, -clip_range, clip_range) 47 | 48 | def normalize_scale(self, v): 49 | """ 50 | Only normalize the scale. Do not subtract the mean. 51 | """ 52 | std = ptu.from_numpy(self.std) 53 | if v.dim() == 2: 54 | std = std.unsqueeze(0) 55 | return v / std 56 | 57 | def denormalize(self, v): 58 | mean = ptu.from_numpy(self.mean) 59 | std = ptu.from_numpy(self.std) 60 | if v.dim() == 2: 61 | mean = mean.unsqueeze(0) 62 | std = std.unsqueeze(0) 63 | return mean + v * std 64 | 65 | def denormalize_scale(self, v): 66 | """ 67 | Only denormalize the scale. Do not add the mean. 68 | """ 69 | std = ptu.from_numpy(self.std) 70 | if v.dim() == 2: 71 | std = std.unsqueeze(0) 72 | return v * std 73 | -------------------------------------------------------------------------------- /rlkit/torch/ddpg/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rail-berkeley/rlkit/ac45a9db24b89d97369bef302487273bcc3e3d84/rlkit/torch/ddpg/__init__.py -------------------------------------------------------------------------------- /rlkit/torch/dqn/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rail-berkeley/rlkit/ac45a9db24b89d97369bef302487273bcc3e3d84/rlkit/torch/dqn/__init__.py -------------------------------------------------------------------------------- /rlkit/torch/dqn/double_dqn.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | import rlkit.torch.pytorch_util as ptu 5 | from rlkit.core.eval_util import create_stats_ordered_dict 6 | from rlkit.torch.dqn.dqn import DQNTrainer 7 | 8 | 9 | class DoubleDQNTrainer(DQNTrainer): 10 | def train_from_torch(self, batch): 11 | rewards = batch['rewards'] 12 | terminals = batch['terminals'] 13 | obs = batch['observations'] 14 | actions = batch['actions'] 15 | next_obs = batch['next_observations'] 16 | 17 | """ 18 | Compute loss 19 | """ 20 | 21 | best_action_idxs = self.qf(next_obs).max( 22 | 1, keepdim=True 23 | )[1] 24 | target_q_values = self.target_qf(next_obs).gather( 25 | 1, best_action_idxs 26 | ).detach() 27 | y_target = rewards + (1. - terminals) * self.discount * target_q_values 28 | y_target = y_target.detach() 29 | # actions is a one-hot vector 30 | y_pred = torch.sum(self.qf(obs) * actions, dim=1, keepdim=True) 31 | qf_loss = self.qf_criterion(y_pred, y_target) 32 | 33 | """ 34 | Update networks 35 | """ 36 | self.qf_optimizer.zero_grad() 37 | qf_loss.backward() 38 | self.qf_optimizer.step() 39 | 40 | """ 41 | Soft target network updates 42 | """ 43 | if self._n_train_steps_total % self.target_update_period == 0: 44 | ptu.soft_update_from_to( 45 | self.qf, self.target_qf, self.soft_target_tau 46 | ) 47 | 48 | """ 49 | Save some statistics for eval using just one batch. 50 | """ 51 | if self._need_to_update_eval_statistics: 52 | self._need_to_update_eval_statistics = False 53 | self.eval_statistics['QF Loss'] = np.mean(ptu.get_numpy(qf_loss)) 54 | self.eval_statistics.update(create_stats_ordered_dict( 55 | 'Y Predictions', 56 | ptu.get_numpy(y_pred), 57 | )) 58 | 59 | self._n_train_steps_total += 1 60 | -------------------------------------------------------------------------------- /rlkit/torch/dqn/dqn.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | 3 | import numpy as np 4 | import torch 5 | import torch.optim as optim 6 | from torch import nn as nn 7 | 8 | import rlkit.torch.pytorch_util as ptu 9 | from rlkit.core.eval_util import create_stats_ordered_dict 10 | from rlkit.torch.torch_rl_algorithm import TorchTrainer 11 | 12 | 13 | class DQNTrainer(TorchTrainer): 14 | def __init__( 15 | self, 16 | qf, 17 | target_qf, 18 | learning_rate=1e-3, 19 | soft_target_tau=1e-3, 20 | target_update_period=1, 21 | qf_criterion=None, 22 | 23 | discount=0.99, 24 | reward_scale=1.0, 25 | ): 26 | super().__init__() 27 | self.qf = qf 28 | self.target_qf = target_qf 29 | self.learning_rate = learning_rate 30 | self.soft_target_tau = soft_target_tau 31 | self.target_update_period = target_update_period 32 | self.qf_optimizer = optim.Adam( 33 | self.qf.parameters(), 34 | lr=self.learning_rate, 35 | ) 36 | self.discount = discount 37 | self.reward_scale = reward_scale 38 | self.qf_criterion = qf_criterion or nn.MSELoss() 39 | self.eval_statistics = OrderedDict() 40 | self._n_train_steps_total = 0 41 | self._need_to_update_eval_statistics = True 42 | 43 | def train_from_torch(self, batch): 44 | rewards = batch['rewards'] * self.reward_scale 45 | terminals = batch['terminals'] 46 | obs = batch['observations'] 47 | actions = batch['actions'] 48 | next_obs = batch['next_observations'] 49 | 50 | """ 51 | Compute loss 52 | """ 53 | 54 | target_q_values = self.target_qf(next_obs).detach().max( 55 | 1, keepdim=True 56 | )[0] 57 | y_target = rewards + (1. - terminals) * self.discount * target_q_values 58 | y_target = y_target.detach() 59 | # actions is a one-hot vector 60 | y_pred = torch.sum(self.qf(obs) * actions, dim=1, keepdim=True) 61 | qf_loss = self.qf_criterion(y_pred, y_target) 62 | 63 | """ 64 | Soft target network updates 65 | """ 66 | self.qf_optimizer.zero_grad() 67 | qf_loss.backward() 68 | self.qf_optimizer.step() 69 | 70 | """ 71 | Soft Updates 72 | """ 73 | if self._n_train_steps_total % self.target_update_period == 0: 74 | ptu.soft_update_from_to( 75 | self.qf, self.target_qf, self.soft_target_tau 76 | ) 77 | 78 | """ 79 | Save some statistics for eval using just one batch. 80 | """ 81 | if self._need_to_update_eval_statistics: 82 | self._need_to_update_eval_statistics = False 83 | self.eval_statistics['QF Loss'] = np.mean(ptu.get_numpy(qf_loss)) 84 | self.eval_statistics.update(create_stats_ordered_dict( 85 | 'Y Predictions', 86 | ptu.get_numpy(y_pred), 87 | )) 88 | self._n_train_steps_total += 1 89 | 90 | def get_diagnostics(self): 91 | return self.eval_statistics 92 | 93 | def end_epoch(self, epoch): 94 | self._need_to_update_eval_statistics = True 95 | 96 | @property 97 | def networks(self): 98 | return [ 99 | self.qf, 100 | self.target_qf, 101 | ] 102 | 103 | def get_snapshot(self): 104 | return dict( 105 | qf=self.qf, 106 | target_qf=self.target_qf, 107 | ) 108 | -------------------------------------------------------------------------------- /rlkit/torch/her/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rail-berkeley/rlkit/ac45a9db24b89d97369bef302487273bcc3e3d84/rlkit/torch/her/__init__.py -------------------------------------------------------------------------------- /rlkit/torch/her/her.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from rlkit.torch.torch_rl_algorithm import TorchTrainer 4 | 5 | 6 | class HERTrainer(TorchTrainer): 7 | def __init__(self, base_trainer: TorchTrainer): 8 | super().__init__() 9 | self._base_trainer = base_trainer 10 | 11 | def train_from_torch(self, data): 12 | obs = data['observations'] 13 | next_obs = data['next_observations'] 14 | goals = data['resampled_goals'] 15 | data['observations'] = torch.cat((obs, goals), dim=1) 16 | data['next_observations'] = torch.cat((next_obs, goals), dim=1) 17 | self._base_trainer.train_from_torch(data) 18 | 19 | def get_diagnostics(self): 20 | return self._base_trainer.get_diagnostics() 21 | 22 | def end_epoch(self, epoch): 23 | self._base_trainer.end_epoch(epoch) 24 | 25 | @property 26 | def networks(self): 27 | return self._base_trainer.networks 28 | 29 | def get_snapshot(self): 30 | return self._base_trainer.get_snapshot() 31 | -------------------------------------------------------------------------------- /rlkit/torch/lvm/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rail-berkeley/rlkit/ac45a9db24b89d97369bef302487273bcc3e3d84/rlkit/torch/lvm/__init__.py -------------------------------------------------------------------------------- /rlkit/torch/lvm/latent_variable_model.py: -------------------------------------------------------------------------------- 1 | import abc 2 | import logging 3 | 4 | import numpy as np 5 | import torch 6 | import torch.nn.functional as F 7 | from torch import nn as nn 8 | 9 | import rlkit.torch.pytorch_util as ptu 10 | from rlkit.policies.base import ExplorationPolicy 11 | from rlkit.torch.core import torch_ify, elem_or_tuple_to_numpy 12 | from rlkit.torch.distributions import ( 13 | Delta, TanhNormal, MultivariateDiagonalNormal, GaussianMixture, GaussianMixtureFull, 14 | ) 15 | from rlkit.torch.networks import Mlp, CNN 16 | from rlkit.torch.networks.basic import MultiInputSequential 17 | from rlkit.torch.networks.stochastic.distribution_generator import ( 18 | DistributionGenerator 19 | ) 20 | from rlkit.torch.sac.policies.base import ( 21 | TorchStochasticPolicy, 22 | PolicyFromDistributionGenerator, 23 | MakeDeterministic, 24 | ) 25 | 26 | 27 | class LatentVariableModel(nn.Module): 28 | def __init__( 29 | self, 30 | encoder, 31 | decoder, 32 | **kwargs 33 | ): 34 | super().__init__() 35 | self.encoder = encoder 36 | self.decoder = decoder 37 | -------------------------------------------------------------------------------- /rlkit/torch/modules.py: -------------------------------------------------------------------------------- 1 | """ 2 | Contain some self-contained modules. 3 | """ 4 | import torch 5 | import torch.nn as nn 6 | 7 | 8 | class HuberLoss(nn.Module): 9 | def __init__(self, delta=1): 10 | super().__init__() 11 | self.huber_loss_delta1 = nn.SmoothL1Loss() 12 | self.delta = delta 13 | 14 | def forward(self, x, x_hat): 15 | loss = self.huber_loss_delta1(x / self.delta, x_hat / self.delta) 16 | return loss * self.delta * self.delta 17 | 18 | 19 | class LayerNorm(nn.Module): 20 | """ 21 | Simple 1D LayerNorm. 22 | """ 23 | 24 | def __init__(self, features, center=True, scale=False, eps=1e-6): 25 | super().__init__() 26 | self.center = center 27 | self.scale = scale 28 | self.eps = eps 29 | if self.scale: 30 | self.scale_param = nn.Parameter(torch.ones(features)) 31 | else: 32 | self.scale_param = None 33 | if self.center: 34 | self.center_param = nn.Parameter(torch.zeros(features)) 35 | else: 36 | self.center_param = None 37 | 38 | def forward(self, x): 39 | mean = x.mean(-1, keepdim=True) 40 | std = x.std(-1, keepdim=True) 41 | output = (x - mean) / (std + self.eps) 42 | if self.scale: 43 | output = output * self.scale_param 44 | if self.center: 45 | output = output + self.center_param 46 | return output 47 | -------------------------------------------------------------------------------- /rlkit/torch/networks/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | General networks for pytorch. 3 | 4 | Algorithm-specific networks should go else-where. 5 | """ 6 | from rlkit.torch.networks.basic import ( 7 | Clamp, ConcatTuple, Detach, Flatten, FlattenEach, Split, Reshape, 8 | ) 9 | from rlkit.torch.networks.cnn import BasicCNN, CNN, MergedCNN, CNNPolicy 10 | from rlkit.torch.networks.dcnn import DCNN, TwoHeadDCNN 11 | from rlkit.torch.networks.feat_point_mlp import FeatPointMlp 12 | from rlkit.torch.networks.image_state import ImageStatePolicy, ImageStateQ 13 | from rlkit.torch.networks.linear_transform import LinearTransform 14 | from rlkit.torch.networks.normalization import LayerNorm 15 | from rlkit.torch.networks.mlp import ( 16 | Mlp, ConcatMlp, MlpPolicy, TanhMlpPolicy, 17 | MlpQf, 18 | MlpQfWithObsProcessor, 19 | ConcatMultiHeadedMlp, 20 | ) 21 | from rlkit.torch.networks.pretrained_cnn import PretrainedCNN 22 | from rlkit.torch.networks.two_headed_mlp import TwoHeadMlp 23 | 24 | __all__ = [ 25 | 'Clamp', 26 | 'ConcatMlp', 27 | 'ConcatMultiHeadedMlp', 28 | 'ConcatTuple', 29 | 'BasicCNN', 30 | 'CNN', 31 | 'CNNPolicy', 32 | 'DCNN', 33 | 'Detach', 34 | 'FeatPointMlp', 35 | 'Flatten', 36 | 'FlattenEach', 37 | 'LayerNorm', 38 | 'LinearTransform', 39 | 'ImageStatePolicy', 40 | 'ImageStateQ', 41 | 'MergedCNN', 42 | 'Mlp', 43 | 'PretrainedCNN', 44 | 'Reshape', 45 | 'Split', 46 | 'TwoHeadDCNN', 47 | 'TwoHeadMlp', 48 | ] 49 | 50 | -------------------------------------------------------------------------------- /rlkit/torch/networks/basic.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class Clamp(nn.Module): 6 | def __init__(self, **kwargs): 7 | super().__init__() 8 | self.kwargs = kwargs 9 | self.__name__ = "Clamp" 10 | 11 | def forward(self, x): 12 | return torch.clamp(x, **self.kwargs) 13 | 14 | 15 | class Split(nn.Module): 16 | """ 17 | Split input and process each chunk with a separate module. 18 | """ 19 | def __init__(self, module1, module2, split_idx): 20 | super().__init__() 21 | self.module1 = module1 22 | self.module2 = module2 23 | self.split_idx = split_idx 24 | 25 | def forward(self, x): 26 | in1 = x[:, :self.split_idx] 27 | out1 = self.module1(in1) 28 | 29 | in2 = x[:, self.split_idx:] 30 | out2 = self.module2(in2) 31 | 32 | return out1, out2 33 | 34 | 35 | class FlattenEach(nn.Module): 36 | def forward(self, inputs): 37 | return tuple(x.view(x.size(0), -1) for x in inputs) 38 | 39 | 40 | class FlattenEachParallel(nn.Module): 41 | def forward(self, *inputs): 42 | return tuple(x.view(x.size(0), -1) for x in inputs) 43 | 44 | 45 | class Flatten(nn.Module): 46 | def forward(self, inputs): 47 | return inputs.view(inputs.size(0), -1) 48 | 49 | 50 | class Map(nn.Module): 51 | """Apply a module to each input.""" 52 | def __init__(self, module): 53 | super().__init__() 54 | self.module = module 55 | 56 | def forward(self, inputs): 57 | return tuple(self.module(x) for x in inputs) 58 | 59 | 60 | class ApplyMany(nn.Module): 61 | """Apply many modules to one input.""" 62 | def __init__(self, *modules): 63 | super().__init__() 64 | self.modules_to_apply = nn.ModuleList(modules) 65 | 66 | def forward(self, inputs): 67 | return tuple(m(inputs) for m in self.modules_to_apply) 68 | 69 | 70 | class LearnedPositiveConstant(nn.Module): 71 | def __init__(self, init_value): 72 | super().__init__() 73 | self._constant = nn.Parameter(init_value) 74 | 75 | def forward(self, _): 76 | return self._constant 77 | 78 | 79 | class Reshape(nn.Module): 80 | def __init__(self, *output_shape): 81 | super().__init__() 82 | self._output_shape_with_batch_size = (-1, *output_shape) 83 | 84 | def forward(self, inputs): 85 | return inputs.view(self._output_shape_with_batch_size) 86 | 87 | 88 | class ConcatTuple(nn.Module): 89 | def __init__(self, dim=1): 90 | super().__init__() 91 | self.dim = dim 92 | 93 | def forward(self, inputs): 94 | return torch.cat(inputs, dim=self.dim) 95 | 96 | 97 | class Concat(nn.Module): 98 | def __init__(self, dim=1): 99 | super().__init__() 100 | self.dim = dim 101 | 102 | def forward(self, *inputs): 103 | return torch.cat(inputs, dim=self.dim) 104 | 105 | 106 | class MultiInputSequential(nn.Sequential): 107 | def forward(self, *input): 108 | for module in self._modules.values(): 109 | if isinstance(input, tuple): 110 | input = module(*input) 111 | else: 112 | input = module(input) 113 | return input 114 | 115 | 116 | class Detach(nn.Module): 117 | def __init__(self, wrapped_mlp): 118 | super().__init__() 119 | self.wrapped_mlp = wrapped_mlp 120 | 121 | def forward(self, inputs): 122 | return self.wrapped_mlp.forward(inputs).detach() 123 | 124 | def __getattr__(self, attr_name): 125 | try: 126 | return super().__getattr__(attr_name) 127 | except AttributeError: 128 | return getattr(self.wrapped_mlp, attr_name) 129 | -------------------------------------------------------------------------------- /rlkit/torch/networks/custom.py: -------------------------------------------------------------------------------- 1 | """ 2 | Random networks 3 | """ 4 | -------------------------------------------------------------------------------- /rlkit/torch/networks/image_state.py: -------------------------------------------------------------------------------- 1 | from rlkit.policies.base import Policy 2 | from rlkit.torch.core import PyTorchModule, eval_np 3 | 4 | 5 | class ImageStatePolicy(PyTorchModule, Policy): 6 | """Switches between image or state inputs""" 7 | 8 | def __init__( 9 | self, 10 | image_conv_net, 11 | state_fc_net, 12 | ): 13 | super().__init__() 14 | 15 | assert image_conv_net is None or state_fc_net is None 16 | self.image_conv_net = image_conv_net 17 | self.state_fc_net = state_fc_net 18 | 19 | def forward(self, input, return_preactivations=False): 20 | if self.image_conv_net is not None: 21 | image = input[:, :21168] 22 | return self.image_conv_net(image) 23 | if self.state_fc_net is not None: 24 | state = input[:, 21168:] 25 | return self.state_fc_net(state) 26 | 27 | def get_action(self, obs_np): 28 | actions = self.get_actions(obs_np[None]) 29 | return actions[0, :], {} 30 | 31 | def get_actions(self, obs): 32 | return eval_np(self, obs) 33 | 34 | 35 | class ImageStateQ(PyTorchModule): 36 | """Switches between image or state inputs""" 37 | 38 | def __init__( 39 | self, 40 | # obs_dim, 41 | # action_dim, 42 | # goal_dim, 43 | image_conv_net, # assumed to be a MergedCNN 44 | state_fc_net, 45 | ): 46 | super().__init__() 47 | 48 | assert image_conv_net is None or state_fc_net is None 49 | # self.obs_dim = obs_dim 50 | # self.action_dim = action_dim 51 | # self.goal_dim = goal_dim 52 | self.image_conv_net = image_conv_net 53 | self.state_fc_net = state_fc_net 54 | 55 | def forward(self, input, action, return_preactivations=False): 56 | if self.image_conv_net is not None: 57 | image = input[:, :21168] 58 | return self.image_conv_net(image, action) 59 | if self.state_fc_net is not None: 60 | state = input[:, 21168:] # action + state 61 | return self.state_fc_net(state, action) 62 | 63 | 64 | -------------------------------------------------------------------------------- /rlkit/torch/networks/linear_transform.py: -------------------------------------------------------------------------------- 1 | from rlkit.torch.core import PyTorchModule 2 | 3 | 4 | class LinearTransform(PyTorchModule): 5 | def __init__(self, m, b): 6 | super().__init__() 7 | self.m = m 8 | self.b = b 9 | 10 | def __call__(self, t): 11 | return self.m * t + self.b 12 | -------------------------------------------------------------------------------- /rlkit/torch/networks/normalization.py: -------------------------------------------------------------------------------- 1 | """ 2 | Contain some self-contained modules. Maybe depend on pytorch_util. 3 | """ 4 | import torch 5 | import torch.nn as nn 6 | from rlkit.torch import pytorch_util as ptu 7 | 8 | 9 | class LayerNorm(nn.Module): 10 | """ 11 | Simple 1D LayerNorm. 12 | """ 13 | def __init__(self, features, center=True, scale=False, eps=1e-6): 14 | super().__init__() 15 | self.center = center 16 | self.scale = scale 17 | self.eps = eps 18 | if self.scale: 19 | self.scale_param = nn.Parameter(torch.ones(features)) 20 | else: 21 | self.scale_param = None 22 | if self.center: 23 | self.center_param = nn.Parameter(torch.zeros(features)) 24 | else: 25 | self.center_param = None 26 | 27 | def forward(self, x): 28 | mean = x.mean(-1, keepdim=True) 29 | std = x.std(-1, keepdim=True) 30 | output = (x - mean) / (std + self.eps) 31 | if self.scale: 32 | output = output * self.scale_param 33 | if self.center: 34 | output = output + self.center_param 35 | return output 36 | -------------------------------------------------------------------------------- /rlkit/torch/networks/stochastic/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rail-berkeley/rlkit/ac45a9db24b89d97369bef302487273bcc3e3d84/rlkit/torch/networks/stochastic/__init__.py -------------------------------------------------------------------------------- /rlkit/torch/networks/stochastic/distribution_generator.py: -------------------------------------------------------------------------------- 1 | import abc 2 | 3 | from torch import nn 4 | 5 | from rlkit.torch.distributions import ( 6 | Bernoulli, 7 | Beta, 8 | Distribution, 9 | Independent, 10 | GaussianMixture as GaussianMixtureDistribution, 11 | GaussianMixtureFull as GaussianMixtureFullDistribution, 12 | MultivariateDiagonalNormal, 13 | TanhNormal, 14 | ) 15 | from rlkit.torch.networks.basic import MultiInputSequential 16 | 17 | 18 | class DistributionGenerator(nn.Module, metaclass=abc.ABCMeta): 19 | def forward(self, *input, **kwarg) -> Distribution: 20 | raise NotImplementedError 21 | 22 | 23 | class ModuleToDistributionGenerator( 24 | MultiInputSequential, 25 | DistributionGenerator, 26 | metaclass=abc.ABCMeta 27 | ): 28 | pass 29 | 30 | 31 | class Beta(ModuleToDistributionGenerator): 32 | def forward(self, *input): 33 | alpha, beta = super().forward(*input) 34 | return Beta(alpha, beta) 35 | 36 | 37 | class Gaussian(ModuleToDistributionGenerator): 38 | def __init__(self, module, std=None, reinterpreted_batch_ndims=1): 39 | super().__init__(module) 40 | self.std = std 41 | self.reinterpreted_batch_ndims = reinterpreted_batch_ndims 42 | 43 | def forward(self, *input): 44 | if self.std: 45 | mean = super().forward(*input) 46 | std = self.std 47 | else: 48 | mean, log_std = super().forward(*input) 49 | std = log_std.exp() 50 | return MultivariateDiagonalNormal( 51 | mean, std, reinterpreted_batch_ndims=self.reinterpreted_batch_ndims) 52 | 53 | 54 | class BernoulliGenerator(ModuleToDistributionGenerator): 55 | def forward(self, *input): 56 | probs = super().forward(*input) 57 | return Bernoulli(probs) 58 | 59 | 60 | class IndependentGenerator(ModuleToDistributionGenerator): 61 | def __init__(self, *args, reinterpreted_batch_ndims=1): 62 | super().__init__(*args) 63 | self.reinterpreted_batch_ndims = reinterpreted_batch_ndims 64 | 65 | def forward(self, *input): 66 | distribution = super().forward(*input) 67 | return Independent( 68 | distribution, 69 | reinterpreted_batch_ndims=self.reinterpreted_batch_ndims, 70 | ) 71 | 72 | 73 | class GaussianMixture(ModuleToDistributionGenerator): 74 | def forward(self, *input): 75 | mixture_means, mixture_stds, weights = super().forward(*input) 76 | return GaussianMixtureDistribution(mixture_means, mixture_stds, weights) 77 | 78 | 79 | class GaussianMixtureFull(ModuleToDistributionGenerator): 80 | def forward(self, *input): 81 | mixture_means, mixture_stds, weights = super().forward(*input) 82 | return GaussianMixtureFullDistribution(mixture_means, mixture_stds, weights) 83 | 84 | 85 | class TanhGaussian(ModuleToDistributionGenerator): 86 | def forward(self, *input): 87 | mean, log_std = super().forward(*input) 88 | std = log_std.exp() 89 | return TanhNormal(mean, std) 90 | -------------------------------------------------------------------------------- /rlkit/torch/networks/two_headed_mlp.py: -------------------------------------------------------------------------------- 1 | from torch import nn as nn 2 | from torch.nn import functional as F 3 | 4 | from rlkit.pythonplusplus import identity 5 | from rlkit.torch import pytorch_util as ptu 6 | from rlkit.torch.core import PyTorchModule 7 | from rlkit.torch.networks import LayerNorm 8 | 9 | 10 | class TwoHeadMlp(PyTorchModule): 11 | def __init__( 12 | self, 13 | hidden_sizes, 14 | first_head_size, 15 | second_head_size, 16 | input_size, 17 | init_w=3e-3, 18 | hidden_activation=F.relu, 19 | output_activation=identity, 20 | hidden_init=ptu.fanin_init, 21 | b_init_value=0., 22 | layer_norm=False, 23 | layer_norm_kwargs=None, 24 | ): 25 | super().__init__() 26 | 27 | if layer_norm_kwargs is None: 28 | layer_norm_kwargs = dict() 29 | 30 | self.input_size = input_size 31 | self.first_head_size = first_head_size 32 | self.second_head_size = second_head_size 33 | self.hidden_activation = hidden_activation 34 | self.output_activation = output_activation 35 | self.layer_norm = layer_norm 36 | self.fcs = [] 37 | self.layer_norms = [] 38 | in_size = input_size 39 | 40 | for i, next_size in enumerate(hidden_sizes): 41 | fc = nn.Linear(in_size, next_size) 42 | in_size = next_size 43 | hidden_init(fc.weight) 44 | fc.bias.data.fill_(b_init_value) 45 | self.__setattr__("fc{}".format(i), fc) 46 | self.fcs.append(fc) 47 | 48 | if self.layer_norm: 49 | ln = LayerNorm(next_size) 50 | self.__setattr__("layer_norm{}".format(i), ln) 51 | self.layer_norms.append(ln) 52 | 53 | self.first_head = nn.Linear(in_size, self.first_head_size) 54 | self.first_head.weight.data.uniform_(-init_w, init_w) 55 | 56 | self.second_head = nn.Linear(in_size, self.second_head_size) 57 | self.second_head.weight.data.uniform_(-init_w, init_w) 58 | 59 | def forward(self, input, return_preactivations=False): 60 | h = input 61 | for i, fc in enumerate(self.fcs): 62 | h = fc(h) 63 | if self.layer_norm and i < len(self.fcs) - 1: 64 | h = self.layer_norms[i](h) 65 | h = self.hidden_activation(h) 66 | preactivation = self.first_head(h) 67 | first_output = self.output_activation(preactivation) 68 | preactivation = self.second_head(h) 69 | second_output = self.output_activation(preactivation) 70 | 71 | return first_output, second_output 72 | -------------------------------------------------------------------------------- /rlkit/torch/sac/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rail-berkeley/rlkit/ac45a9db24b89d97369bef302487273bcc3e3d84/rlkit/torch/sac/__init__.py -------------------------------------------------------------------------------- /rlkit/torch/sac/policies/__init__.py: -------------------------------------------------------------------------------- 1 | from rlkit.torch.sac.policies.base import ( 2 | TorchStochasticPolicy, 3 | PolicyFromDistributionGenerator, 4 | MakeDeterministic, 5 | ) 6 | from rlkit.torch.sac.policies.gaussian_policy import ( 7 | TanhGaussianPolicyAdapter, 8 | TanhGaussianPolicy, 9 | GaussianPolicy, 10 | GaussianCNNPolicy, 11 | GaussianMixturePolicy, 12 | BinnedGMMPolicy, 13 | TanhGaussianObsProcessorPolicy, 14 | TanhCNNGaussianPolicy, 15 | ) 16 | from rlkit.torch.sac.policies.lvm_policy import LVMPolicy 17 | from rlkit.torch.sac.policies.policy_from_q import PolicyFromQ 18 | 19 | 20 | __all__ = [ 21 | 'TorchStochasticPolicy', 22 | 'PolicyFromDistributionGenerator', 23 | 'MakeDeterministic', 24 | 'TanhGaussianPolicyAdapter', 25 | 'TanhGaussianPolicy', 26 | 'GaussianPolicy', 27 | 'GaussianCNNPolicy', 28 | 'GaussianMixturePolicy', 29 | 'BinnedGMMPolicy', 30 | 'TanhGaussianObsProcessorPolicy', 31 | 'TanhCNNGaussianPolicy', 32 | 'LVMPolicy', 33 | 'PolicyFromQ', 34 | ] 35 | -------------------------------------------------------------------------------- /rlkit/torch/sac/policies/base.py: -------------------------------------------------------------------------------- 1 | import abc 2 | import logging 3 | 4 | import numpy as np 5 | import torch 6 | import torch.nn.functional as F 7 | from torch import nn as nn 8 | 9 | import rlkit.torch.pytorch_util as ptu 10 | from rlkit.policies.base import ExplorationPolicy 11 | from rlkit.torch.core import torch_ify, elem_or_tuple_to_numpy 12 | from rlkit.torch.distributions import ( 13 | Delta, TanhNormal, MultivariateDiagonalNormal, GaussianMixture, GaussianMixtureFull, 14 | ) 15 | from rlkit.torch.networks import Mlp, CNN 16 | from rlkit.torch.networks.basic import MultiInputSequential 17 | from rlkit.torch.networks.stochastic.distribution_generator import ( 18 | DistributionGenerator 19 | ) 20 | 21 | 22 | class TorchStochasticPolicy( 23 | DistributionGenerator, 24 | ExplorationPolicy, metaclass=abc.ABCMeta 25 | ): 26 | def get_action(self, obs_np, ): 27 | actions = self.get_actions(obs_np[None]) 28 | return actions[0, :], {} 29 | 30 | def get_actions(self, obs_np, ): 31 | dist = self._get_dist_from_np(obs_np) 32 | actions = dist.sample() 33 | return elem_or_tuple_to_numpy(actions) 34 | 35 | def _get_dist_from_np(self, *args, **kwargs): 36 | torch_args = tuple(torch_ify(x) for x in args) 37 | torch_kwargs = {k: torch_ify(v) for k, v in kwargs.items()} 38 | dist = self(*torch_args, **torch_kwargs) 39 | return dist 40 | 41 | 42 | class PolicyFromDistributionGenerator( 43 | MultiInputSequential, 44 | TorchStochasticPolicy, 45 | ): 46 | """ 47 | Usage: 48 | ``` 49 | distribution_generator = FancyGenerativeModel() 50 | policy = PolicyFromBatchDistributionModule(distribution_generator) 51 | ``` 52 | """ 53 | pass 54 | 55 | 56 | class MakeDeterministic(TorchStochasticPolicy): 57 | def __init__( 58 | self, 59 | action_distribution_generator: DistributionGenerator, 60 | ): 61 | super().__init__() 62 | self._action_distribution_generator = action_distribution_generator 63 | 64 | def forward(self, *args, **kwargs): 65 | dist = self._action_distribution_generator.forward(*args, **kwargs) 66 | return Delta(dist.mle_estimate()) 67 | -------------------------------------------------------------------------------- /rlkit/torch/sac/policies/lvm_policy.py: -------------------------------------------------------------------------------- 1 | from rlkit.torch.networks.stochastic.distribution_generator import ( 2 | DistributionGenerator 3 | ) 4 | from rlkit.torch.sac.policies.base import ( 5 | TorchStochasticPolicy, 6 | PolicyFromDistributionGenerator, 7 | MakeDeterministic, 8 | ) 9 | 10 | from rlkit.torch.lvm.latent_variable_model import LatentVariableModel 11 | 12 | 13 | class LVMPolicy(LatentVariableModel, TorchStochasticPolicy): 14 | """Expects encoder p(z|s) and decoder p(u|s,z)""" 15 | 16 | def forward(self, obs): 17 | z_dist = self.encoder(obs) 18 | z = z_dist.sample() 19 | return self.decoder(obs, z) 20 | -------------------------------------------------------------------------------- /rlkit/torch/sac/policies/policy_from_q.py: -------------------------------------------------------------------------------- 1 | import abc 2 | import logging 3 | 4 | import numpy as np 5 | import torch 6 | import torch.nn.functional as F 7 | from torch import nn as nn 8 | 9 | import rlkit.torch.pytorch_util as ptu 10 | from rlkit.policies.base import ExplorationPolicy 11 | from rlkit.torch.core import torch_ify, elem_or_tuple_to_numpy 12 | from rlkit.torch.distributions import ( 13 | Delta, TanhNormal, MultivariateDiagonalNormal, GaussianMixture, GaussianMixtureFull, 14 | ) 15 | from rlkit.torch.networks import Mlp, CNN 16 | from rlkit.torch.networks.basic import MultiInputSequential 17 | from rlkit.torch.networks.stochastic.distribution_generator import ( 18 | DistributionGenerator 19 | ) 20 | from rlkit.torch.sac.policies.base import ( 21 | TorchStochasticPolicy, 22 | PolicyFromDistributionGenerator, 23 | MakeDeterministic, 24 | ) 25 | 26 | 27 | class PolicyFromQ(TorchStochasticPolicy): 28 | def __init__( 29 | self, 30 | qf, 31 | policy, 32 | num_samples=10, 33 | **kwargs 34 | ): 35 | super().__init__() 36 | self.qf = qf 37 | self.policy = policy 38 | self.num_samples = num_samples 39 | 40 | def forward(self, obs): 41 | with torch.no_grad(): 42 | state = obs.repeat(self.num_samples, 1) 43 | action = self.policy(state).sample() 44 | q_values = self.qf(state, action) 45 | ind = q_values.max(0)[1] 46 | return Delta(action[ind]) 47 | -------------------------------------------------------------------------------- /rlkit/torch/smac/diagnostics.py: -------------------------------------------------------------------------------- 1 | from rlkit.envs.pearl_envs import ( 2 | AntDirEnv, 3 | HalfCheetahVelEnv, 4 | ) 5 | 6 | 7 | def get_env_info_sizes(env): 8 | info_sizes = {} 9 | if isinstance(env.wrapped_env, AntDirEnv): 10 | info_sizes = dict( 11 | reward_forward=1, 12 | reward_ctrl=1, 13 | reward_contact=1, 14 | reward_survive=1, 15 | torso_velocity=3, 16 | torso_xy=2, 17 | ) 18 | if isinstance(env.wrapped_env, HalfCheetahVelEnv): 19 | info_sizes = dict( 20 | reward_forward=1, 21 | reward_ctrl=1, 22 | goal_vel=1, 23 | forward_vel=1, 24 | xposbefore=1, 25 | ) 26 | 27 | return info_sizes 28 | -------------------------------------------------------------------------------- /rlkit/torch/smac/networks.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import rlkit.torch.pytorch_util as ptu 4 | from rlkit.torch.networks import ConcatMlp 5 | 6 | 7 | class MlpEncoder(ConcatMlp): 8 | ''' 9 | encode context via MLP 10 | ''' 11 | def __init__(self, *args, use_ground_truth_context=False, **kwargs): 12 | super().__init__(*args, **kwargs) 13 | self.use_ground_truth_context = use_ground_truth_context 14 | 15 | def forward(self, context): 16 | if self.use_ground_truth_context: 17 | return context 18 | else: 19 | return super().forward(context) 20 | 21 | def reset(self, num_tasks=1): 22 | pass 23 | 24 | 25 | class MlpDecoder(ConcatMlp): 26 | ''' 27 | decoder context via MLP 28 | ''' 29 | pass 30 | 31 | 32 | class DummyMlpEncoder(MlpEncoder): 33 | def forward(self, *args, **kwargs): 34 | output = super().forward(*args, **kwargs) 35 | return 0 * output 36 | # TODO: check if this caused issues 37 | 38 | 39 | class RecurrentEncoder(ConcatMlp): 40 | ''' 41 | encode context via recurrent network 42 | ''' 43 | 44 | def __init__(self, 45 | *args, 46 | **kwargs 47 | ): 48 | self.save_init_params(locals()) 49 | super().__init__(*args, **kwargs) 50 | self.hidden_dim = self.hidden_sizes[-1] 51 | self.register_buffer('hidden', torch.zeros(1, 1, self.hidden_dim)) 52 | 53 | # input should be (task, seq, feat) and hidden should be (task, 1, feat) 54 | 55 | self.lstm = nn.LSTM(self.hidden_dim, self.hidden_dim, num_layers=1, batch_first=True) 56 | 57 | def forward(self, in_, return_preactivations=False): 58 | # expects inputs of dimension (task, seq, feat) 59 | task, seq, feat = in_.size() 60 | out = in_.view(task * seq, feat) 61 | 62 | # embed with MLP 63 | for i, fc in enumerate(self.fcs): 64 | out = fc(out) 65 | out = self.hidden_activation(out) 66 | 67 | out = out.view(task, seq, -1) 68 | out, (hn, cn) = self.lstm(out, (self.hidden, torch.zeros(self.hidden.size()).to(ptu.device))) 69 | self.hidden = hn 70 | # take the last hidden state to predict z 71 | out = out[:, -1, :] 72 | 73 | # output layer 74 | preactivation = self.last_fc(out) 75 | output = self.output_activation(preactivation) 76 | if return_preactivations: 77 | return output, preactivation 78 | else: 79 | return output 80 | 81 | def reset(self, num_tasks=1): 82 | self.hidden = self.hidden.new_full((1, num_tasks, self.hidden_dim), 0) 83 | 84 | -------------------------------------------------------------------------------- /rlkit/torch/td3/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rail-berkeley/rlkit/ac45a9db24b89d97369bef302487273bcc3e3d84/rlkit/torch/td3/__init__.py -------------------------------------------------------------------------------- /rlkit/torch/torch_rl_algorithm.py: -------------------------------------------------------------------------------- 1 | import abc 2 | from collections import OrderedDict 3 | 4 | from typing import Iterable 5 | from torch import nn as nn 6 | 7 | from rlkit.core.batch_rl_algorithm import BatchRLAlgorithm 8 | from rlkit.core.online_rl_algorithm import OnlineRLAlgorithm 9 | from rlkit.core.trainer import Trainer 10 | from rlkit.torch.core import np_to_pytorch_batch 11 | 12 | 13 | class TorchOnlineRLAlgorithm(OnlineRLAlgorithm): 14 | def to(self, device): 15 | for net in self.trainer.networks: 16 | net.to(device) 17 | 18 | def training_mode(self, mode): 19 | for net in self.trainer.networks: 20 | net.train(mode) 21 | 22 | 23 | class TorchBatchRLAlgorithm(BatchRLAlgorithm): 24 | def to(self, device): 25 | for net in self.trainer.networks: 26 | net.to(device) 27 | 28 | def training_mode(self, mode): 29 | for net in self.trainer.networks: 30 | net.train(mode) 31 | 32 | 33 | class TorchTrainer(Trainer, metaclass=abc.ABCMeta): 34 | def __init__(self): 35 | self._num_train_steps = 0 36 | 37 | def train(self, np_batch): 38 | self._num_train_steps += 1 39 | batch = np_to_pytorch_batch(np_batch) 40 | self.train_from_torch(batch) 41 | 42 | def get_diagnostics(self): 43 | return OrderedDict([ 44 | ('num train calls', self._num_train_steps), 45 | ]) 46 | 47 | @abc.abstractmethod 48 | def train_from_torch(self, batch): 49 | pass 50 | 51 | @property 52 | @abc.abstractmethod 53 | def networks(self) -> Iterable[nn.Module]: 54 | pass 55 | -------------------------------------------------------------------------------- /rlkit/torch/vae/vae_schedules.py: -------------------------------------------------------------------------------- 1 | def always_train(epoch): 2 | return True, 300 3 | 4 | 5 | def custom_schedule(epoch): 6 | if epoch < 10: 7 | return True, 1000 8 | elif epoch < 300: 9 | return True, 200 10 | else: 11 | return epoch % 3 == 0, 200 12 | 13 | 14 | def custom_schedule_2(epoch): 15 | if epoch < 10: 16 | return True, 1000 17 | elif epoch < 100: 18 | return True, 200 19 | else: 20 | return epoch % 2 == 0, 200 21 | 22 | 23 | def every_other(epoch): 24 | return epoch % 2 == 0, 400 25 | 26 | 27 | def every_three(epoch): 28 | return epoch % 3 == 0, 600 29 | 30 | 31 | def every_three_a_lot(epoch): 32 | return epoch % 3 == 0, 1200 33 | 34 | 35 | def every_six(epoch): 36 | return epoch % 6 == 0, 1200 37 | 38 | 39 | def every_six_less(epoch): 40 | return epoch % 6 == 0, 600 41 | 42 | 43 | def every_six_much_less(epoch): 44 | return epoch % 6 == 0, 300 45 | 46 | 47 | def every_ten(epoch): 48 | return epoch % 10 == 0 or epoch == 5, 1000 49 | 50 | 51 | def every_twenty(epoch): 52 | return epoch % 10 == 0 or epoch == 5 or epoch == 10, 1000 53 | 54 | 55 | def never_train(epoch): 56 | return False, 0 57 | -------------------------------------------------------------------------------- /rlkit/util/ml_util.py: -------------------------------------------------------------------------------- 1 | """ 2 | General utility functions for machine learning. 3 | """ 4 | import abc 5 | import math 6 | import numpy as np 7 | 8 | 9 | class ScalarSchedule(object, metaclass=abc.ABCMeta): 10 | @abc.abstractmethod 11 | def get_value(self, t): 12 | pass 13 | 14 | 15 | class ConstantSchedule(ScalarSchedule): 16 | def __init__(self, value): 17 | self._value = value 18 | 19 | def get_value(self, t): 20 | return self._value 21 | 22 | 23 | class LinearSchedule(ScalarSchedule): 24 | """ 25 | Linearly interpolate and then stop at a final value. 26 | """ 27 | def __init__( 28 | self, 29 | init_value, 30 | final_value, 31 | ramp_duration, 32 | ): 33 | self._init_value = init_value 34 | self._final_value = final_value 35 | self._ramp_duration = ramp_duration 36 | 37 | def get_value(self, t): 38 | return ( 39 | self._init_value 40 | + (self._final_value - self._init_value) 41 | * min(1.0, t * 1.0 / self._ramp_duration) 42 | ) 43 | 44 | 45 | class IntLinearSchedule(LinearSchedule): 46 | """ 47 | Same as RampUpSchedule but round output to an int 48 | """ 49 | def get_value(self, t): 50 | return int(super().get_value(t)) 51 | 52 | 53 | class PiecewiseLinearSchedule(ScalarSchedule): 54 | """ 55 | Given a list of (x, t) value-time pairs, return value x at time t, 56 | and linearly interpolate between the two 57 | """ 58 | def __init__( 59 | self, 60 | x_values, 61 | y_values, 62 | ): 63 | self._x_values = x_values 64 | self._y_values = y_values 65 | 66 | def get_value(self, t): 67 | return np.interp(t, self._x_values, self._y_values) 68 | 69 | 70 | class IntPiecewiseLinearSchedule(PiecewiseLinearSchedule): 71 | def get_value(self, t): 72 | return int(super().get_value(t)) 73 | 74 | 75 | def none_to_infty(bounds): 76 | if bounds is None: 77 | bounds = -math.inf, math.inf 78 | lb, ub = bounds 79 | if lb is None: 80 | lb = -math.inf 81 | if ub is None: 82 | ub = math.inf 83 | return lb, ub 84 | -------------------------------------------------------------------------------- /rlkit/util/wrapper.py: -------------------------------------------------------------------------------- 1 | import os 2 | LOG_DIR = os.getcwd() 3 | 4 | 5 | class Wrapper(object): 6 | """ 7 | Mixin for deferring attributes to a wrapped, inner object. 8 | """ 9 | 10 | def __init__(self, inner): 11 | self.inner = inner 12 | 13 | def __getattr__(self, attr): 14 | """ 15 | Dispatch attributes by their status as magic, members, or missing. 16 | - magic is handled by the standard getattr 17 | - existing attributes are returned 18 | - missing attributes are deferred to the inner object. 19 | """ 20 | # don't make magic any more magical 21 | is_magic = attr.startswith('__') and attr.endswith('__') 22 | if is_magic: 23 | return super().__getattr__(attr) 24 | try: 25 | # try to return the attribute... 26 | return self.__dict__[attr] 27 | except: 28 | # ...and defer to the inner dataset if it's not here 29 | return getattr(self.inner, attr) 30 | 31 | 32 | class SimpleWrapper(object): 33 | """ 34 | Mixin for deferring attributes to a wrapped, inner object. 35 | """ 36 | 37 | def __init__(self, inner): 38 | self._inner = inner 39 | 40 | def __getattr__(self, attr): 41 | if attr == '_inner': 42 | raise AttributeError() 43 | return getattr(self._inner, attr) 44 | -------------------------------------------------------------------------------- /scripts/run_experiment_from_doodad.py: -------------------------------------------------------------------------------- 1 | import doodad as dd 2 | import torch.multiprocessing as mp 3 | 4 | from rlkit.launchers.launcher_util import run_experiment_here 5 | 6 | if __name__ == "__main__": 7 | import matplotlib 8 | matplotlib.use('agg') 9 | 10 | mp.set_start_method('forkserver') 11 | args_dict = dd.get_args() 12 | method_call = args_dict['method_call'] 13 | run_experiment_kwargs = args_dict['run_experiment_kwargs'] 14 | output_dir = args_dict['output_dir'] 15 | run_mode = args_dict.get('mode', None) 16 | if run_mode and run_mode in ['slurm_singularity', 'sss']: 17 | import os 18 | run_experiment_kwargs['variant']['slurm-job-id'] = os.environ.get( 19 | 'SLURM_JOB_ID', None 20 | ) 21 | if run_mode and (run_mode == 'ec2' or run_mode == 'gcp'): 22 | if run_mode == 'ec2': 23 | try: 24 | import urllib.request 25 | instance_id = urllib.request.urlopen( 26 | 'http://169.254.169.254/latest/meta-data/instance-id' 27 | ).read().decode() 28 | run_experiment_kwargs['variant']['EC2_instance_id'] = instance_id 29 | except Exception as e: 30 | print("Could not get AWS instance ID. Error was...") 31 | print(e) 32 | if run_mode == 'gcp': 33 | try: 34 | import urllib.request 35 | request = urllib.request.Request( 36 | "http://metadata/computeMetadata/v1/instance/name", 37 | ) 38 | # See this URL for why we need this header: 39 | # https://cloud.google.com/compute/docs/storing-retrieving-metadata 40 | request.add_header("Metadata-Flavor", "Google") 41 | instance_name = urllib.request.urlopen(request).read().decode() 42 | run_experiment_kwargs['variant']['GCP_instance_name'] = ( 43 | instance_name 44 | ) 45 | except Exception as e: 46 | print("Could not get GCP instance name. Error was...") 47 | print(e) 48 | # Do this in case base_log_dir was already set 49 | run_experiment_kwargs['base_log_dir'] = output_dir 50 | run_experiment_here( 51 | method_call, 52 | include_exp_prefix_sub_dir=False, 53 | **run_experiment_kwargs 54 | ) 55 | else: 56 | run_experiment_here( 57 | method_call, 58 | log_dir=output_dir, 59 | **run_experiment_kwargs 60 | ) 61 | -------------------------------------------------------------------------------- /scripts/run_goal_conditioned_policy.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | 4 | from rlkit.core import logger 5 | from rlkit.samplers.rollout_functions import multitask_rollout 6 | from rlkit.torch import pytorch_util as ptu 7 | from rlkit.envs.vae_wrapper import VAEWrappedEnv 8 | 9 | 10 | def simulate_policy(args): 11 | data = torch.load(args.file) 12 | policy = data['evaluation/policy'] 13 | env = data['evaluation/env'] 14 | print("Policy and environment loaded") 15 | if args.gpu: 16 | ptu.set_gpu_mode(True) 17 | policy.to(ptu.device) 18 | if isinstance(env, VAEWrappedEnv) and hasattr(env, 'mode'): 19 | env.mode(args.mode) 20 | if args.enable_render or hasattr(env, 'enable_render'): 21 | # some environments need to be reconfigured for visualization 22 | env.enable_render() 23 | paths = [] 24 | while True: 25 | paths.append(multitask_rollout( 26 | env, 27 | policy, 28 | max_path_length=args.H, 29 | render=not args.hide, 30 | observation_key='observation', 31 | desired_goal_key='desired_goal', 32 | )) 33 | if hasattr(env, "log_diagnostics"): 34 | env.log_diagnostics(paths) 35 | if hasattr(env, "get_diagnostics"): 36 | for k, v in env.get_diagnostics(paths).items(): 37 | logger.record_tabular(k, v) 38 | logger.dump_tabular() 39 | 40 | 41 | if __name__ == "__main__": 42 | 43 | parser = argparse.ArgumentParser() 44 | parser.add_argument('file', type=str, 45 | help='path to the snapshot file') 46 | parser.add_argument('--H', type=int, default=300, 47 | help='Max length of rollout') 48 | parser.add_argument('--speedup', type=float, default=10, 49 | help='Speedup') 50 | parser.add_argument('--mode', default='video_env', type=str, 51 | help='env mode') 52 | parser.add_argument('--gpu', action='store_true') 53 | parser.add_argument('--enable_render', action='store_true') 54 | parser.add_argument('--hide', action='store_true') 55 | args = parser.parse_args() 56 | 57 | simulate_policy(args) 58 | -------------------------------------------------------------------------------- /scripts/run_policy.py: -------------------------------------------------------------------------------- 1 | from rlkit.samplers.rollout_functions import rollout 2 | from rlkit.torch.pytorch_util import set_gpu_mode 3 | import argparse 4 | import torch 5 | import uuid 6 | from rlkit.core import logger 7 | 8 | filename = str(uuid.uuid4()) 9 | 10 | 11 | def simulate_policy(args): 12 | data = torch.load(args.file) 13 | policy = data['evaluation/policy'] 14 | env = data['evaluation/env'] 15 | print("Policy loaded") 16 | if args.gpu: 17 | set_gpu_mode(True) 18 | policy.cuda() 19 | while True: 20 | path = rollout( 21 | env, 22 | policy, 23 | max_path_length=args.H, 24 | render=True, 25 | ) 26 | if hasattr(env, "log_diagnostics"): 27 | env.log_diagnostics([path]) 28 | logger.dump_tabular() 29 | 30 | 31 | if __name__ == "__main__": 32 | parser = argparse.ArgumentParser() 33 | parser.add_argument('file', type=str, 34 | help='path to the snapshot file') 35 | parser.add_argument('--H', type=int, default=300, 36 | help='Max length of rollout') 37 | parser.add_argument('--gpu', action='store_true') 38 | args = parser.parse_args() 39 | 40 | simulate_policy(args) 41 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from distutils.core import setup 2 | from setuptools import find_packages 3 | 4 | setup( 5 | name='rlkit', 6 | version='0.2.1dev', 7 | packages=find_packages(), 8 | license='MIT License', 9 | long_description=open('README.md').read(), 10 | ) -------------------------------------------------------------------------------- /tests/regression/iql/halfcheetah_offline_progress.csv: -------------------------------------------------------------------------------- 1 | epoch,eval/num paths total,eval/num steps total,expl/Actions Max,expl/Actions Mean,expl/Actions Min,expl/Actions Std,expl/Average Returns,expl/Num Paths,expl/Returns Max,expl/Returns Mean,expl/Returns Min,expl/Returns Std,expl/Rewards Max,expl/Rewards Mean,expl/Rewards Min,expl/Rewards Std,expl/num paths total,expl/num steps total,expl/path length Max,expl/path length Mean,expl/path length Min,expl/path length Std,replay_buffer/size,time/epoch_time (s),time/evaluation sampling (s),time/exploration sampling (s),time/global_time (s),time/replay buffer data storing (s),time/saving (s),time/training (s),trainer/Advantage Score Max,trainer/Advantage Score Mean,trainer/Advantage Score Min,trainer/Advantage Score Std,trainer/Advantage Weights Max,trainer/Advantage Weights Mean,trainer/Advantage Weights Min,trainer/Advantage Weights Std,trainer/Policy Loss,trainer/Q Targets Max,trainer/Q Targets Mean,trainer/Q Targets Min,trainer/Q Targets Std,trainer/Q1 Predictions Max,trainer/Q1 Predictions Mean,trainer/Q1 Predictions Min,trainer/Q1 Predictions Std,trainer/Q2 Predictions Max,trainer/Q2 Predictions Mean,trainer/Q2 Predictions Min,trainer/Q2 Predictions Std,trainer/QF1 Loss,trainer/QF2 Loss,trainer/V1 Predictions Max,trainer/V1 Predictions Mean,trainer/V1 Predictions Min,trainer/V1 Predictions Std,trainer/VF Loss,trainer/num train calls,trainer/policy/mean Max,trainer/policy/mean Mean,trainer/policy/mean Min,trainer/policy/mean Std,trainer/policy/std Max,trainer/policy/std Mean,trainer/policy/std Min,trainer/policy/std Std,trainer/replay_buffer_len,trainer/rewards Max,trainer/rewards Mean,trainer/rewards Min,trainer/rewards Std,trainer/terminals Max,trainer/terminals Mean,trainer/terminals Min,trainer/terminals Std 2 | -2,0,0,0.13717124,0.03626487,-0.049319427,0.045948837,-0.19655459402248504,1,-0.19655459402248504,-0.19655459402248504,-0.19655459402248504,0.0,-0.09325349516403507,-0.09827729701124252,-0.10330109885844999,0.005023801847207458,1,2,2,2.0,2,0.0,998999,0.7398281097412109,0.022743940353393555,0.0016701221466064453,20.234392166137695,1.9073486328125e-06,0.008835077285766602,0.7058250904083252,-0.0036168185,-0.008904904,-0.01419299,0.005288086,0.9892082,0.9737615,0.9583148,0.0154467225,857.7699,1.0095485,0.91235673,0.8151649,0.09719181,-3.695485e-05,-0.0023042315,-0.004571508,0.0022672766,0.0,0.0,0.0,0.0,0.8464968,0.84184104,0.01419299,0.007483647,0.0007743044,0.006709343,3.2178352e-05,100,0.004211932,0.0006116896,-0.0027325656,0.0019366256,0.049787067,0.049787063,0.049787067,3.7252903e-09,998999,1.0072821,0.90650094,0.80571973,0.1007812,0.0,0.0,0.0,0.0 3 | -1,0,0,0.049955986,-0.051808134,-0.24854904,0.08398724,0.4338117202694962,1,0.4338117202694962,0.4338117202694962,0.4338117202694962,0.0,0.24300404877931653,0.2169058601347481,0.19080767149017966,0.026098188644568435,2,4,2,2.0,2,0.0,998999,0.6178200244903564,0.0024394989013671875,0.002116680145263672,20.855294704437256,1.6689300537109375e-06,0.008217334747314453,0.6044015884399414,-0.004432007,-0.0058829254,-0.0073338435,0.0014509181,0.98679197,0.98251534,0.97823876,0.0042766035,440.564,1.0878772,0.8795528,0.6712284,0.20832437,0.07119025,0.055561554,0.039932854,0.0156287,0.06366714,0.046736993,0.029806845,0.016930148,0.7160932,0.7302138,0.013697105,0.0122417575,0.0107864095,0.001455348,1.1014193e-05,200,-0.17424563,-0.48487774,-0.8831331,0.2023349,0.051678017,0.05148844,0.051258426,0.00012519717,998999,1.0813023,0.8714981,0.661694,0.20980415,0.0,0.0,0.0,0.0 4 | -------------------------------------------------------------------------------- /tests/regression/iql/test_iql_offline.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | from examples.iql import mujoco_finetune as iql 5 | 6 | from rlkit.core import logger 7 | from rlkit.testing import csv_util 8 | 9 | def test_iql(): 10 | logger.reset() 11 | 12 | # make tests small by mutating variant 13 | iql.variant["algo_kwargs"]["start_epoch"] = -2 14 | iql.variant["algo_kwargs"]["num_epochs"] = 0 15 | iql.variant["algo_kwargs"]["batch_size"] = 2 16 | iql.variant["algo_kwargs"]["num_eval_steps_per_epoch"] = 2 17 | iql.variant["algo_kwargs"]["num_expl_steps_per_train_loop"] = 2 18 | iql.variant["algo_kwargs"]["num_trains_per_train_loop"] = 100 19 | iql.variant["algo_kwargs"]["min_num_steps_before_training"] = 2 20 | iql.variant["qf_kwargs"] = dict(hidden_sizes=[2, 2]) 21 | 22 | iql.variant["seed"] = 25580 23 | 24 | iql.main() 25 | 26 | reference_csv = "tests/regression/iql/halfcheetah_offline_progress.csv" 27 | output_csv = os.path.join(logger.get_snapshot_dir(), "progress.csv") 28 | print("comparing reference %s against output %s" % (reference_csv, output_csv)) 29 | output = csv_util.get_exp(output_csv) 30 | reference = csv_util.get_exp(reference_csv) 31 | keys = ["epoch", "trainer/Q1 Predictions Mean", ] 32 | csv_util.check_equal(reference, output, keys) 33 | 34 | if __name__ == "__main__": 35 | test_iql() 36 | -------------------------------------------------------------------------------- /tests/regression/iql/test_iql_online.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | from examples.iql import mujoco_finetune as iql 5 | 6 | from rlkit.core import logger 7 | from rlkit.testing import csv_util 8 | 9 | def test_iql(): 10 | logger.reset() 11 | 12 | # make tests small by mutating variant 13 | iql.variant["algo_kwargs"]["start_epoch"] = -2 14 | iql.variant["algo_kwargs"]["num_epochs"] = 2 15 | iql.variant["algo_kwargs"]["batch_size"] = 2 16 | iql.variant["algo_kwargs"]["num_eval_steps_per_epoch"] = 2 17 | iql.variant["algo_kwargs"]["num_expl_steps_per_train_loop"] = 2 18 | iql.variant["algo_kwargs"]["num_trains_per_train_loop"] = 100 19 | iql.variant["algo_kwargs"]["min_num_steps_before_training"] = 2 20 | iql.variant["qf_kwargs"] = dict(hidden_sizes=[2, 2]) 21 | 22 | iql.variant["seed"] = 25580 23 | 24 | iql.main() 25 | 26 | reference_csv = "tests/regression/iql/halfcheetah_online_progress.csv" 27 | output_csv = os.path.join(logger.get_snapshot_dir(), "progress.csv") 28 | print("comparing reference %s against output %s" % (reference_csv, output_csv)) 29 | output = csv_util.get_exp(output_csv) 30 | reference = csv_util.get_exp(reference_csv) 31 | keys = ["epoch", "expl/num steps total", "expl/Average Returns", "trainer/Q1 Predictions Mean", ] 32 | csv_util.check_equal(reference, output, keys) 33 | 34 | if __name__ == "__main__": 35 | test_iql() 36 | -------------------------------------------------------------------------------- /tests/regression/simplegym/test_sac.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from examples.simplegym import sac 4 | 5 | from rlkit.core import logger 6 | from rlkit.testing import csv_util 7 | 8 | def test_sac_online(): 9 | logger.reset() 10 | 11 | # make tests small by mutating variant 12 | sac.variant["algorithm_kwargs"]["num_epochs"] = 2 13 | sac.variant["algorithm_kwargs"]["batch_size"] = 2 14 | sac.variant["algorithm_kwargs"]["num_trains_per_train_loop"] = 100 15 | sac.variant["qf_kwargs"] = dict(hidden_sizes=[2, 2]) 16 | sac.variant["policy_kwargs"] = dict(hidden_sizes=[2, 2]) 17 | sac.variant["seed"] = 25580 18 | 19 | sac.main() 20 | 21 | reference_csv = "tests/regression/simplegym/test_sac_progress.csv" 22 | output_csv = os.path.join(logger.get_snapshot_dir(), "progress.csv") 23 | print("comparing reference %s against output %s" % (reference_csv, output_csv)) 24 | output = csv_util.get_exp(output_csv) 25 | reference = csv_util.get_exp(reference_csv) 26 | keys = ["epoch", "expl/num steps total", "eval/Average Returns", "trainer/Q1 Predictions Mean", ] 27 | csv_util.check_equal(reference, output, keys) 28 | 29 | if __name__ == "__main__": 30 | test_sac_online() 31 | --------------------------------------------------------------------------------