├── rlkit ├── __init__.py ├── torch │ ├── __init__.py │ ├── ddpg │ │ └── __init__.py │ ├── dqn │ │ ├── __init__.py │ │ ├── double_dqn.py │ │ └── dqn.py │ ├── her │ │ ├── __init__.py │ │ └── her.py │ ├── sac │ │ ├── __init__.py │ │ └── policies.py │ ├── td3 │ │ └── __init__.py │ ├── data_management │ │ ├── __init__.py │ │ └── normalizer.py │ ├── vae │ │ ├── vae_schedules.py │ │ └── vae_base.py │ ├── modules.py │ ├── torch_rl_algorithm.py │ ├── core.py │ ├── data.py │ ├── distributions.py │ ├── pytorch_util.py │ ├── networks.py │ └── skewfit │ │ └── video_gen.py ├── policies │ ├── __init__.py │ ├── simple.py │ ├── base.py │ └── argmax.py ├── samplers │ ├── __init__.py │ ├── data_collector │ │ ├── __init__.py │ │ ├── vae_env.py │ │ ├── base.py │ │ └── path_collector.py │ ├── util.py │ └── rollout_functions.py ├── data_management │ ├── __init__.py │ ├── path_builder.py │ ├── env_replay_buffer.py │ ├── replay_buffer.py │ ├── simple_replay_buffer.py │ ├── normalizer.py │ └── shared_obs_dict_replay_buffer.py ├── exploration_strategies │ ├── __init__.py │ ├── epsilon_greedy.py │ ├── gaussian_strategy.py │ ├── base.py │ ├── gaussian_and_epsilon_strategy.py │ └── ou_strategy.py ├── core │ ├── __init__.py │ ├── trainer.py │ ├── serializable.py │ ├── batch_rl_algorithm.py │ ├── online_rl_algorithm.py │ ├── eval_util.py │ └── rl_algorithm.py ├── launchers │ ├── __init__.py │ └── conf.py ├── envs │ ├── env_utils.py │ ├── mujoco_env.py │ ├── ant.py │ ├── goal_generation │ │ └── pickup_goal_dataset.py │ ├── assets │ │ ├── reacher_7dof.xml │ │ └── low_gear_ratio_ant.xml │ ├── mujoco_image_env.py │ └── wrappers.py └── util │ ├── ml_util.py │ ├── io.py │ └── video.py ├── docs ├── images │ ├── her_dqn.png │ ├── skewfit_door.png │ ├── skewfit_pickup.png │ ├── skewfit_pusher.png │ ├── FetchReach-v1_HER-TD3.png │ ├── her_td3_sawyer_reacher.png │ └── SawyerReachXYZEnv-v0_HER-TD3.png ├── RIG.md ├── SkewFit.md ├── goal_based_envs.md ├── TDMs.md └── HER.md ├── environment ├── docker │ ├── vendor │ │ ├── 10_nvidia.json │ │ └── Xdummy-entrypoint │ └── Dockerfile ├── mac-env.yml ├── linux-gpu-env.yml └── linux-cpu-env.yml ├── .gitignore ├── setup.py ├── LICENSE ├── scripts ├── run_policy.py ├── run_goal_conditioned_policy.py └── run_experiment_from_doodad.py └── examples ├── doodad ├── ec2_example.py └── gcp_example.py ├── dqn_and_double_dqn.py ├── ddpg.py ├── sac.py ├── her ├── her_dqn_gridworld.py ├── her_sac_gym_fetch_reach.py └── her_td3_multiworld_sawyer_reach.py ├── td3.py └── skewfit ├── sawyer_door.py └── sawyer_push.py /rlkit/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /rlkit/torch/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /rlkit/policies/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /rlkit/samplers/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /rlkit/torch/ddpg/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /rlkit/torch/dqn/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /rlkit/torch/her/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /rlkit/torch/sac/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /rlkit/torch/td3/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /rlkit/data_management/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /rlkit/exploration_strategies/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /rlkit/torch/data_management/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /docs/images/her_dqn.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mimoralea/rlkit/master/docs/images/her_dqn.png -------------------------------------------------------------------------------- /docs/images/skewfit_door.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mimoralea/rlkit/master/docs/images/skewfit_door.png -------------------------------------------------------------------------------- /docs/images/skewfit_pickup.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mimoralea/rlkit/master/docs/images/skewfit_pickup.png -------------------------------------------------------------------------------- /docs/images/skewfit_pusher.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mimoralea/rlkit/master/docs/images/skewfit_pusher.png -------------------------------------------------------------------------------- /docs/images/FetchReach-v1_HER-TD3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mimoralea/rlkit/master/docs/images/FetchReach-v1_HER-TD3.png -------------------------------------------------------------------------------- /docs/images/her_td3_sawyer_reacher.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mimoralea/rlkit/master/docs/images/her_td3_sawyer_reacher.png -------------------------------------------------------------------------------- /docs/images/SawyerReachXYZEnv-v0_HER-TD3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mimoralea/rlkit/master/docs/images/SawyerReachXYZEnv-v0_HER-TD3.png -------------------------------------------------------------------------------- /environment/docker/vendor/10_nvidia.json: -------------------------------------------------------------------------------- 1 | { 2 | "file_format_version" : "1.0.0", 3 | "ICD" : { 4 | "library_path" : "libEGL_nvidia.so.0" 5 | } 6 | } 7 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | */*/mjkey.txt 2 | **/.DS_STORE 3 | **/*.pyc 4 | **/*.swp 5 | rlkit/launchers/config.py 6 | rlkit/launchers/conf_private.py 7 | MANIFEST 8 | *.egg-info 9 | \.idea/ 10 | -------------------------------------------------------------------------------- /rlkit/core/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | General classes, functions, utilities that are used throughout rlkit. 3 | """ 4 | from rlkit.core.logging import logger 5 | 6 | __all__ = ['logger'] 7 | 8 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from distutils.core import setup 2 | from setuptools import find_packages 3 | 4 | setup( 5 | name='rlkit', 6 | version='0.2.1dev', 7 | packages=find_packages(), 8 | license='MIT License', 9 | long_description=open('README.md').read(), 10 | ) -------------------------------------------------------------------------------- /rlkit/core/trainer.py: -------------------------------------------------------------------------------- 1 | import abc 2 | 3 | 4 | class Trainer(object, metaclass=abc.ABCMeta): 5 | @abc.abstractmethod 6 | def train(self, data): 7 | pass 8 | 9 | def end_epoch(self, epoch): 10 | pass 11 | 12 | def get_snapshot(self): 13 | return {} 14 | 15 | def get_diagnostics(self): 16 | return {} 17 | -------------------------------------------------------------------------------- /rlkit/launchers/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | This module contains 'launchers', which are self-contained functions that take 3 | one dictionary and run a full experiment. The dictionary configures the 4 | experiment. 5 | 6 | It is important that the functions are completely self-contained (i.e. they 7 | import their own modules) so that they can be serialized. 8 | """ 9 | -------------------------------------------------------------------------------- /rlkit/policies/simple.py: -------------------------------------------------------------------------------- 1 | from rlkit.policies.base import Policy 2 | 3 | 4 | class RandomPolicy(Policy): 5 | """ 6 | Policy that always outputs zero. 7 | """ 8 | 9 | def __init__(self, action_space): 10 | self.action_space = action_space 11 | 12 | def get_action(self, obs): 13 | return self.action_space.sample(), {} 14 | -------------------------------------------------------------------------------- /rlkit/samplers/data_collector/__init__.py: -------------------------------------------------------------------------------- 1 | from rlkit.samplers.data_collector.base import ( 2 | DataCollector, 3 | PathCollector, 4 | StepCollector, 5 | ) 6 | from rlkit.samplers.data_collector.path_collector import ( 7 | MdpPathCollector, 8 | GoalConditionedPathCollector, 9 | ) 10 | from rlkit.samplers.data_collector.step_collector import ( 11 | GoalConditionedStepCollector 12 | ) 13 | -------------------------------------------------------------------------------- /rlkit/policies/base.py: -------------------------------------------------------------------------------- 1 | import abc 2 | 3 | 4 | class Policy(object, metaclass=abc.ABCMeta): 5 | """ 6 | General policy interface. 7 | """ 8 | @abc.abstractmethod 9 | def get_action(self, observation): 10 | """ 11 | 12 | :param observation: 13 | :return: action, debug_dictionary 14 | """ 15 | pass 16 | 17 | def reset(self): 18 | pass 19 | 20 | 21 | class ExplorationPolicy(Policy, metaclass=abc.ABCMeta): 22 | def set_num_steps_total(self, t): 23 | pass 24 | -------------------------------------------------------------------------------- /rlkit/policies/argmax.py: -------------------------------------------------------------------------------- 1 | """ 2 | Torch argmax policy 3 | """ 4 | import numpy as np 5 | from torch import nn 6 | 7 | import rlkit.torch.pytorch_util as ptu 8 | from rlkit.policies.base import Policy 9 | 10 | 11 | class ArgmaxDiscretePolicy(nn.Module, Policy): 12 | def __init__(self, qf): 13 | super().__init__() 14 | self.qf = qf 15 | 16 | def get_action(self, obs): 17 | obs = np.expand_dims(obs, axis=0) 18 | obs = ptu.from_numpy(obs).float() 19 | q_values = self.qf(obs).squeeze(0) 20 | q_values_np = ptu.get_numpy(q_values) 21 | return q_values_np.argmax(), {} 22 | -------------------------------------------------------------------------------- /rlkit/exploration_strategies/epsilon_greedy.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | from rlkit.exploration_strategies.base import RawExplorationStrategy 4 | 5 | 6 | class EpsilonGreedy(RawExplorationStrategy): 7 | """ 8 | Take a random discrete action with some probability. 9 | """ 10 | def __init__(self, action_space, prob_random_action=0.1): 11 | self.prob_random_action = prob_random_action 12 | self.action_space = action_space 13 | 14 | def get_action_from_raw_action(self, action, **kwargs): 15 | if random.random() <= self.prob_random_action: 16 | return self.action_space.sample() 17 | return action 18 | -------------------------------------------------------------------------------- /docs/RIG.md: -------------------------------------------------------------------------------- 1 | # Reinforcement Learning with Imagined Goals 2 | Implementation of reinforcement learning with imagined goals (RIG) 3 | To find out 4 | more, see any of the following links: 5 | * arXiv: https://arxiv.org/abs/1807.04742 6 | * Website: https://sites.google.com/site/visualrlwithimaginedgoals/ 7 | * Blog Post: https://bair.berkeley.edu/blog/2018/09/06/rig/ 8 | 9 | To see the original implementation, checkout version `v0.1.2` of this repo. 10 | 11 | In versions `0.2+`, RIG is a special case of [Skew-Fit](SkewFit.md) with the 12 | power set to `0`. 13 | 14 | ## Goal-based environments and `ObsDictRelabelingBuffer` 15 | [See here.](goal_based_envs.md) 16 | -------------------------------------------------------------------------------- /environment/docker/vendor/Xdummy-entrypoint: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | import argparse 3 | import os 4 | import sys 5 | import subprocess 6 | 7 | parser = argparse.ArgumentParser() 8 | args, extra_args = parser.parse_known_args() 9 | subprocess.Popen(["nohup", "Xdummy"], stdout=open('/dev/null', 'w'), stderr=open('/dev/null', 'w')) 10 | os.environ['DISPLAY'] = ':0' 11 | if not extra_args: 12 | sys.argv = ['/bin/bash'] 13 | else: 14 | sys.argv = extra_args 15 | # Explicitly flush right before the exec since otherwise things might get 16 | # lost in Python's buffers around stdout/stderr (!). 17 | sys.stdout.flush() 18 | sys.stderr.flush() 19 | os.execvpe(sys.argv[0], sys.argv, os.environ) 20 | 21 | -------------------------------------------------------------------------------- /environment/mac-env.yml: -------------------------------------------------------------------------------- 1 | name: rlkit 2 | channels: 3 | - kne # for pybox2d 4 | - pytorch 5 | - anaconda # for mkl 6 | dependencies: 7 | - cython 8 | - ipython # technically unnecessary 9 | - joblib=0.9.4 10 | - lockfile 11 | - mako=1.0.6=py35_0 12 | - matplotlib=2.0.2=np111py35_0 13 | - mkl=2018.0.2 # Need to add explicit dependence for pytorch 14 | - numba=0.35.0=np111py35_0 15 | - numpy=1.11.3 16 | - path.py=10.3.1=py35_0 17 | - pybox2d=2.3.1post2=py35_0 18 | - python=3.5.2 19 | - python-dateutil=2.6.1=py35_0 20 | - pytorch=0.4.1 21 | - scipy=1.0.1 22 | - pip: 23 | - cloudpickle==0.5.2 24 | - gym[all]==0.10.5 25 | - gitpython==2.1.7 26 | - gtimer==1.0.0b5 27 | - pygame==1.9.2 28 | - ipdb # technically unnecessary 29 | -------------------------------------------------------------------------------- /environment/linux-gpu-env.yml: -------------------------------------------------------------------------------- 1 | name: rlkit 2 | channels: 3 | - kne # for pybox2d 4 | - pytorch 5 | - anaconda # for mkl 6 | dependencies: 7 | - cython 8 | - ipython # technically unnecessary 9 | - joblib=0.9.4 10 | - lockfile 11 | - mako=1.0.6=py35_0 12 | - matplotlib=2.0.2=np111py35_0 13 | - mkl=2018.0.2 # Need to add explicit dependence for pytorch 14 | - numba=0.35.0=np111py35_0 15 | - numpy=1.11.3 16 | - path.py=10.3.1=py35_0 17 | - pybox2d=2.3.1post2=py35_0 18 | - python=3.5.2 19 | - python-dateutil=2.6.1=py35_0 20 | - pytorch=0.4.1 21 | - scipy=1.0.1 22 | - patchelf 23 | - pip: 24 | - cloudpickle==0.5.2 25 | - gym[all]==0.10.5 26 | - gitpython==2.1.7 27 | - gtimer==1.0.0b5 28 | - pygame==1.9.2 29 | - ipdb # technically unnecessary 30 | -------------------------------------------------------------------------------- /environment/linux-cpu-env.yml: -------------------------------------------------------------------------------- 1 | name: rlkit 2 | channels: 3 | - kne # for pybox2d 4 | - pytorch 5 | - anaconda # for mkl 6 | dependencies: 7 | - cython 8 | - ipython # technically unnecessary 9 | - joblib=0.9.4 10 | - lockfile 11 | - mako=1.0.6=py35_0 12 | - matplotlib=2.0.2=np111py35_0 13 | - mkl=2018.0.2 # Need to add explicit dependence for pytorch 14 | - numba=0.35.0=np111py35_0 15 | - numpy=1.11.3 16 | - path.py=10.3.1=py35_0 17 | - pybox2d=2.3.1post2=py35_0 18 | - python=3.5.2 19 | - python-dateutil=2.6.1=py35_0 20 | - pytorch-cpu=0.4.1 21 | - scipy=1.0.1 22 | - patchelf 23 | - pip: 24 | - cloudpickle==0.5.2 25 | - gym[all]==0.10.5 26 | - gitpython==2.1.7 27 | - gtimer==1.0.0b5 28 | - pygame==1.9.2 29 | - ipdb # technically unnecessary 30 | -------------------------------------------------------------------------------- /rlkit/envs/env_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from gym.spaces import Box, Discrete, Tuple 4 | 5 | ENV_ASSET_DIR = os.path.join(os.path.dirname(__file__), 'assets') 6 | 7 | 8 | def get_asset_full_path(file_name): 9 | return os.path.join(ENV_ASSET_DIR, file_name) 10 | 11 | 12 | def get_dim(space): 13 | if isinstance(space, Box): 14 | return space.low.size 15 | elif isinstance(space, Discrete): 16 | return space.n 17 | elif isinstance(space, Tuple): 18 | return sum(get_dim(subspace) for subspace in space.spaces) 19 | elif hasattr(space, 'flat_dim'): 20 | return space.flat_dim 21 | else: 22 | raise TypeError("Unknown space: {}".format(space)) 23 | 24 | 25 | def mode(env, mode_type): 26 | try: 27 | getattr(env, mode_type)() 28 | except AttributeError: 29 | pass 30 | -------------------------------------------------------------------------------- /rlkit/samplers/data_collector/vae_env.py: -------------------------------------------------------------------------------- 1 | from rlkit.envs.vae_wrapper import VAEWrappedEnv 2 | from rlkit.samplers.data_collector import GoalConditionedPathCollector 3 | 4 | 5 | class VAEWrappedEnvPathCollector(GoalConditionedPathCollector): 6 | def __init__( 7 | self, 8 | goal_sampling_mode, 9 | env: VAEWrappedEnv, 10 | policy, 11 | decode_goals=False, 12 | **kwargs 13 | ): 14 | super().__init__(env, policy, **kwargs) 15 | self._goal_sampling_mode = goal_sampling_mode 16 | self._decode_goals = decode_goals 17 | 18 | def collect_new_paths(self, *args, **kwargs): 19 | self._env.goal_sampling_mode = self._goal_sampling_mode 20 | self._env.decode_goals = self._decode_goals 21 | return super().collect_new_paths(*args, **kwargs) -------------------------------------------------------------------------------- /rlkit/samplers/data_collector/base.py: -------------------------------------------------------------------------------- 1 | import abc 2 | 3 | 4 | class DataCollector(object, metaclass=abc.ABCMeta): 5 | def end_epoch(self, epoch): 6 | pass 7 | 8 | def get_diagnostics(self): 9 | return {} 10 | 11 | def get_snapshot(self): 12 | return {} 13 | 14 | @abc.abstractmethod 15 | def get_epoch_paths(self): 16 | pass 17 | 18 | 19 | class PathCollector(DataCollector, metaclass=abc.ABCMeta): 20 | @abc.abstractmethod 21 | def collect_new_paths( 22 | self, 23 | max_path_length, 24 | num_steps, 25 | discard_incomplete_paths, 26 | ): 27 | pass 28 | 29 | 30 | class StepCollector(DataCollector, metaclass=abc.ABCMeta): 31 | @abc.abstractmethod 32 | def collect_new_steps( 33 | self, 34 | max_path_length, 35 | num_steps, 36 | discard_incomplete_paths, 37 | ): 38 | pass 39 | -------------------------------------------------------------------------------- /rlkit/torch/her/her.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from rlkit.torch.torch_rl_algorithm import TorchTrainer 4 | 5 | 6 | class HERTrainer(TorchTrainer): 7 | def __init__(self, base_trainer: TorchTrainer): 8 | super().__init__() 9 | self._base_trainer = base_trainer 10 | 11 | def train_from_torch(self, data): 12 | obs = data['observations'] 13 | next_obs = data['next_observations'] 14 | goals = data['resampled_goals'] 15 | data['observations'] = torch.cat((obs, goals), dim=1) 16 | data['next_observations'] = torch.cat((next_obs, goals), dim=1) 17 | self._base_trainer.train_from_torch(data) 18 | 19 | def get_diagnostics(self): 20 | return self._base_trainer.get_diagnostics() 21 | 22 | def end_epoch(self, epoch): 23 | self._base_trainer.end_epoch(epoch) 24 | 25 | @property 26 | def networks(self): 27 | return self._base_trainer.networks 28 | 29 | def get_snapshot(self): 30 | return self._base_trainer.get_snapshot() 31 | -------------------------------------------------------------------------------- /docs/SkewFit.md: -------------------------------------------------------------------------------- 1 | # Skew-Fit 2 | Requires [multiworld](https://github.com/vitchyr/multiworld) to be installed: 3 | ``` 4 | pip install git+https://github.com/vitchyr/multiworld.git@28ee206f60a45690d484737466b558abdef191ea 5 | ``` 6 | 7 | Implementation of Skew-Fit. For more information: 8 | - [Videos](https://sites.google.com/view/skew-fit) 9 | - [arXiv](https://arxiv.org/abs/1903.03698) 10 | 11 | Here are the results you should expect from each script. 12 | These plots are generated with [viskit](https://github.com/vitchyr/viskit) 13 | with smoothing on. 14 | 15 | Note that [RIG](RIG.md) is a special-case of Skew-Fit with `power=0`. 16 | 17 | 18 | [examples/skewfit/sawyer_door.py](../examples/skewfit/sawyer_door.py). 1 Seed: 19 | ![Skew-Fit Sawyer Door results](images/skewfit_door.png) 20 | 21 | [examples/skewfit/sawyer_pickup.py](../examples/skewfit/sawyer_pickup.py). 3 Seeds: 22 | ![Skew-Fit Sawyer Pickup results](images/skewfit_pickup.png) 23 | 24 | [examples/skewfit/sawyer_pusher.py](../examples/skewfit/sawyer_pusher.py). 9 Seeds: 25 | ![Skew-Fit Sawyer Pusher results](images/skewfit_pusher.png) 26 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Vitchyr Pong 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /rlkit/exploration_strategies/gaussian_strategy.py: -------------------------------------------------------------------------------- 1 | from rlkit.exploration_strategies.base import RawExplorationStrategy 2 | import numpy as np 3 | 4 | 5 | class GaussianStrategy(RawExplorationStrategy): 6 | """ 7 | This strategy adds Gaussian noise to the action taken by the deterministic policy. 8 | 9 | Based on the rllab implementation. 10 | """ 11 | def __init__(self, action_space, max_sigma=1.0, min_sigma=None, 12 | decay_period=1000000): 13 | assert len(action_space.shape) == 1 14 | self._max_sigma = max_sigma 15 | if min_sigma is None: 16 | min_sigma = max_sigma 17 | self._min_sigma = min_sigma 18 | self._decay_period = decay_period 19 | self._action_space = action_space 20 | 21 | def get_action_from_raw_action(self, action, t=None, **kwargs): 22 | sigma = ( 23 | self._max_sigma - (self._max_sigma - self._min_sigma) * 24 | min(1.0, t * 1.0 / self._decay_period) 25 | ) 26 | return np.clip( 27 | action + np.random.normal(size=len(action)) * sigma, 28 | self._action_space.low, 29 | self._action_space.high, 30 | ) 31 | -------------------------------------------------------------------------------- /rlkit/torch/vae/vae_schedules.py: -------------------------------------------------------------------------------- 1 | def always_train(epoch): 2 | return True, 300 3 | 4 | 5 | def custom_schedule(epoch): 6 | if epoch < 10: 7 | return True, 1000 8 | elif epoch < 300: 9 | return True, 200 10 | else: 11 | return epoch % 3 == 0, 200 12 | 13 | 14 | def custom_schedule_2(epoch): 15 | if epoch < 10: 16 | return True, 1000 17 | elif epoch < 100: 18 | return True, 200 19 | else: 20 | return epoch % 2 == 0, 200 21 | 22 | 23 | def every_other(epoch): 24 | return epoch % 2 == 0, 400 25 | 26 | 27 | def every_three(epoch): 28 | return epoch % 3 == 0, 600 29 | 30 | 31 | def every_three_a_lot(epoch): 32 | return epoch % 3 == 0, 1200 33 | 34 | 35 | def every_six(epoch): 36 | return epoch % 6 == 0, 1200 37 | 38 | 39 | def every_six_less(epoch): 40 | return epoch % 6 == 0, 600 41 | 42 | 43 | def every_six_much_less(epoch): 44 | return epoch % 6 == 0, 300 45 | 46 | 47 | def every_ten(epoch): 48 | return epoch % 10 == 0 or epoch == 5, 1000 49 | 50 | 51 | def every_twenty(epoch): 52 | return epoch % 10 == 0 or epoch == 5 or epoch == 10, 1000 53 | 54 | 55 | def never_train(epoch): 56 | return False, 0 57 | -------------------------------------------------------------------------------- /scripts/run_policy.py: -------------------------------------------------------------------------------- 1 | from rlkit.samplers.rollout_functions import rollout 2 | from rlkit.torch.pytorch_util import set_gpu_mode 3 | import argparse 4 | import torch 5 | import uuid 6 | from rlkit.core import logger 7 | 8 | filename = str(uuid.uuid4()) 9 | 10 | 11 | def simulate_policy(args): 12 | data = torch.load(args.file) 13 | policy = data['evaluation/policy'] 14 | env = data['evaluation/env'] 15 | print("Policy loaded") 16 | if args.gpu: 17 | set_gpu_mode(True) 18 | policy.cuda() 19 | while True: 20 | path = rollout( 21 | env, 22 | policy, 23 | max_path_length=args.H, 24 | render=True, 25 | ) 26 | if hasattr(env, "log_diagnostics"): 27 | env.log_diagnostics([path]) 28 | logger.dump_tabular() 29 | 30 | 31 | if __name__ == "__main__": 32 | parser = argparse.ArgumentParser() 33 | parser.add_argument('file', type=str, 34 | help='path to the snapshot file') 35 | parser.add_argument('--H', type=int, default=300, 36 | help='Max length of rollout') 37 | parser.add_argument('--gpu', action='store_true') 38 | args = parser.parse_args() 39 | 40 | simulate_policy(args) 41 | -------------------------------------------------------------------------------- /rlkit/exploration_strategies/base.py: -------------------------------------------------------------------------------- 1 | import abc 2 | 3 | from rlkit.policies.base import ExplorationPolicy 4 | 5 | 6 | class ExplorationStrategy(object, metaclass=abc.ABCMeta): 7 | @abc.abstractmethod 8 | def get_action(self, t, observation, policy, **kwargs): 9 | pass 10 | 11 | def reset(self): 12 | pass 13 | 14 | 15 | class RawExplorationStrategy(ExplorationStrategy, metaclass=abc.ABCMeta): 16 | @abc.abstractmethod 17 | def get_action_from_raw_action(self, action, **kwargs): 18 | pass 19 | 20 | def get_action(self, t, policy, *args, **kwargs): 21 | action, agent_info = policy.get_action(*args, **kwargs) 22 | return self.get_action_from_raw_action(action, t=t), agent_info 23 | 24 | def reset(self): 25 | pass 26 | 27 | 28 | class PolicyWrappedWithExplorationStrategy(ExplorationPolicy): 29 | def __init__( 30 | self, 31 | exploration_strategy: ExplorationStrategy, 32 | policy, 33 | ): 34 | self.es = exploration_strategy 35 | self.policy = policy 36 | self.t = 0 37 | 38 | def set_num_steps_total(self, t): 39 | self.t = t 40 | 41 | def get_action(self, *args, **kwargs): 42 | return self.es.get_action(self.t, self.policy, *args, **kwargs) 43 | 44 | def reset(self): 45 | self.es.reset() 46 | self.policy.reset() 47 | -------------------------------------------------------------------------------- /rlkit/exploration_strategies/gaussian_and_epsilon_strategy.py: -------------------------------------------------------------------------------- 1 | import random 2 | from rlkit.exploration_strategies.base import RawExplorationStrategy 3 | import numpy as np 4 | 5 | 6 | class GaussianAndEpislonStrategy(RawExplorationStrategy): 7 | """ 8 | With probability epsilon, take a completely random action. 9 | with probability 1-epsilon, add Gaussian noise to the action taken by a 10 | deterministic policy. 11 | """ 12 | def __init__(self, action_space, epsilon, max_sigma=1.0, min_sigma=None, 13 | decay_period=1000000): 14 | assert len(action_space.shape) == 1 15 | if min_sigma is None: 16 | min_sigma = max_sigma 17 | self._max_sigma = max_sigma 18 | self._epsilon = epsilon 19 | self._min_sigma = min_sigma 20 | self._decay_period = decay_period 21 | self._action_space = action_space 22 | 23 | def get_action_from_raw_action(self, action, t=None, **kwargs): 24 | if random.random() < self._epsilon: 25 | return self._action_space.sample() 26 | else: 27 | sigma = self._max_sigma - (self._max_sigma - self._min_sigma) * min(1.0, t * 1.0 / self._decay_period) 28 | return np.clip( 29 | action + np.random.normal(size=len(action)) * sigma, 30 | self._action_space.low, 31 | self._action_space.high, 32 | ) 33 | -------------------------------------------------------------------------------- /rlkit/torch/modules.py: -------------------------------------------------------------------------------- 1 | """ 2 | Contain some self-contained modules. 3 | """ 4 | import torch 5 | import torch.nn as nn 6 | 7 | 8 | class HuberLoss(nn.Module): 9 | def __init__(self, delta=1): 10 | super().__init__() 11 | self.huber_loss_delta1 = nn.SmoothL1Loss() 12 | self.delta = delta 13 | 14 | def forward(self, x, x_hat): 15 | loss = self.huber_loss_delta1(x / self.delta, x_hat / self.delta) 16 | return loss * self.delta * self.delta 17 | 18 | 19 | class LayerNorm(nn.Module): 20 | """ 21 | Simple 1D LayerNorm. 22 | """ 23 | 24 | def __init__(self, features, center=True, scale=False, eps=1e-6): 25 | super().__init__() 26 | self.center = center 27 | self.scale = scale 28 | self.eps = eps 29 | if self.scale: 30 | self.scale_param = nn.Parameter(torch.ones(features)) 31 | else: 32 | self.scale_param = None 33 | if self.center: 34 | self.center_param = nn.Parameter(torch.zeros(features)) 35 | else: 36 | self.center_param = None 37 | 38 | def forward(self, x): 39 | mean = x.mean(-1, keepdim=True) 40 | std = x.std(-1, keepdim=True) 41 | output = (x - mean) / (std + self.eps) 42 | if self.scale: 43 | output = output * self.scale_param 44 | if self.center: 45 | output = output + self.center_param 46 | return output 47 | -------------------------------------------------------------------------------- /docs/goal_based_envs.md: -------------------------------------------------------------------------------- 1 | # Goal-based environments and `ObsDictRelabelingBuffer` 2 | Some algorithms, like HER, are for goal-conditioned environments, like 3 | the [OpenAI Gym GoalEnv](https://blog.openai.com/ingredients-for-robotics-research/) 4 | or the [multiworld MultitaskEnv](https://github.com/vitchyr/multiworld/) 5 | environments. 6 | 7 | These environments are different from normal gym environments in that they 8 | return dictionaries for observations, like so: the environments work like this: 9 | 10 | ``` 11 | env = CarEnv() 12 | obs = env.reset() 13 | next_obs, reward, done, info = env.step(action) 14 | print(obs) 15 | 16 | # Output: 17 | # { 18 | # 'observation': ..., 19 | # 'desired_goal': ..., 20 | # 'achieved_goal': ..., 21 | # } 22 | ``` 23 | The `GoalEnv` environments also have a function with signature 24 | ``` 25 | def compute_rewards (achieved_goal, desired_goal): 26 | # achieved_goal and desired_goal are vectors 27 | ``` 28 | while the `MultitaskEnv` has a signature like 29 | ``` 30 | def compute_rewards (observation, action, next_observation): 31 | # observation and next_observations are dictionaries 32 | ``` 33 | To learn more about these environments, check out the URLs above. 34 | This means that normal RL algorithms won't even "type check" with these 35 | environments. 36 | 37 | `ObsDictRelabelingBuffer` perform hindsight experience replay with 38 | either types of environments and works by saving specific values in the 39 | observation dictionary. 40 | 41 | -------------------------------------------------------------------------------- /rlkit/data_management/path_builder.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | class PathBuilder(dict): 5 | """ 6 | Usage: 7 | ``` 8 | path_builder = PathBuilder() 9 | path.add_sample( 10 | observations=1, 11 | actions=2, 12 | next_observations=3, 13 | ... 14 | ) 15 | path.add_sample( 16 | observations=4, 17 | actions=5, 18 | next_observations=6, 19 | ... 20 | ) 21 | 22 | path = path_builder.get_all_stacked() 23 | 24 | path['observations'] 25 | # output: [1, 4] 26 | path['actions'] 27 | # output: [2, 5] 28 | ``` 29 | 30 | Note that the key should be "actions" and not "action" since the 31 | resulting dictionary will have those keys. 32 | """ 33 | 34 | def __init__(self): 35 | super().__init__() 36 | self._path_length = 0 37 | 38 | def add_all(self, **key_to_value): 39 | for k, v in key_to_value.items(): 40 | if k not in self: 41 | self[k] = [v] 42 | else: 43 | self[k].append(v) 44 | self._path_length += 1 45 | 46 | def get_all_stacked(self): 47 | output_dict = dict() 48 | for k, v in self.items(): 49 | output_dict[k] = stack_list(v) 50 | return output_dict 51 | 52 | def __len__(self): 53 | return self._path_length 54 | 55 | 56 | def stack_list(lst): 57 | if isinstance(lst[0], dict): 58 | return lst 59 | else: 60 | return np.array(lst) 61 | -------------------------------------------------------------------------------- /rlkit/torch/torch_rl_algorithm.py: -------------------------------------------------------------------------------- 1 | import abc 2 | from collections import OrderedDict 3 | 4 | from typing import Iterable 5 | from torch import nn as nn 6 | 7 | from rlkit.core.batch_rl_algorithm import BatchRLAlgorithm 8 | from rlkit.core.online_rl_algorithm import OnlineRLAlgorithm 9 | from rlkit.core.trainer import Trainer 10 | from rlkit.torch.core import np_to_pytorch_batch 11 | 12 | 13 | class TorchOnlineRLAlgorithm(OnlineRLAlgorithm): 14 | def to(self, device): 15 | for net in self.trainer.networks: 16 | net.to(device) 17 | 18 | def training_mode(self, mode): 19 | for net in self.trainer.networks: 20 | net.train(mode) 21 | 22 | 23 | class TorchBatchRLAlgorithm(BatchRLAlgorithm): 24 | def to(self, device): 25 | for net in self.trainer.networks: 26 | net.to(device) 27 | 28 | def training_mode(self, mode): 29 | for net in self.trainer.networks: 30 | net.train(mode) 31 | 32 | 33 | class TorchTrainer(Trainer, metaclass=abc.ABCMeta): 34 | def __init__(self): 35 | self._num_train_steps = 0 36 | 37 | def train(self, np_batch): 38 | self._num_train_steps += 1 39 | batch = np_to_pytorch_batch(np_batch) 40 | self.train_from_torch(batch) 41 | 42 | def get_diagnostics(self): 43 | return OrderedDict([ 44 | ('num train calls', self._num_train_steps), 45 | ]) 46 | 47 | @abc.abstractmethod 48 | def train_from_torch(self, batch): 49 | pass 50 | 51 | @property 52 | @abc.abstractmethod 53 | def networks(self) -> Iterable[nn.Module]: 54 | pass 55 | -------------------------------------------------------------------------------- /rlkit/data_management/env_replay_buffer.py: -------------------------------------------------------------------------------- 1 | from gym.spaces import Discrete 2 | 3 | from rlkit.data_management.simple_replay_buffer import SimpleReplayBuffer 4 | from rlkit.envs.env_utils import get_dim 5 | import numpy as np 6 | 7 | 8 | class EnvReplayBuffer(SimpleReplayBuffer): 9 | def __init__( 10 | self, 11 | max_replay_buffer_size, 12 | env, 13 | env_info_sizes=None 14 | ): 15 | """ 16 | :param max_replay_buffer_size: 17 | :param env: 18 | """ 19 | self.env = env 20 | self._ob_space = env.observation_space 21 | self._action_space = env.action_space 22 | 23 | if env_info_sizes is None: 24 | if hasattr(env, 'info_sizes'): 25 | env_info_sizes = env.info_sizes 26 | else: 27 | env_info_sizes = dict() 28 | 29 | super().__init__( 30 | max_replay_buffer_size=max_replay_buffer_size, 31 | observation_dim=get_dim(self._ob_space), 32 | action_dim=get_dim(self._action_space), 33 | env_info_sizes=env_info_sizes 34 | ) 35 | 36 | def add_sample(self, observation, action, reward, terminal, 37 | next_observation, **kwargs): 38 | if isinstance(self._action_space, Discrete): 39 | new_action = np.zeros(self._action_dim) 40 | new_action[action] = 1 41 | else: 42 | new_action = action 43 | return super().add_sample( 44 | observation=observation, 45 | action=new_action, 46 | reward=reward, 47 | next_observation=next_observation, 48 | terminal=terminal, 49 | **kwargs 50 | ) 51 | -------------------------------------------------------------------------------- /examples/doodad/ec2_example.py: -------------------------------------------------------------------------------- 1 | """ 2 | Example of running stuff on EC2 3 | """ 4 | import time 5 | 6 | from rlkit.core import logger 7 | from rlkit.launchers.launcher_util import run_experiment 8 | from datetime import datetime 9 | from pytz import timezone 10 | import pytz 11 | 12 | 13 | def example(variant): 14 | import torch 15 | logger.log(torch.__version__) 16 | date_format = '%m/%d/%Y %H:%M:%S %Z' 17 | date = datetime.now(tz=pytz.utc) 18 | logger.log("start") 19 | logger.log('Current date & time is: {}'.format(date.strftime(date_format))) 20 | if torch.cuda.is_available(): 21 | x = torch.randn(3) 22 | logger.log(str(x.to(ptu.device))) 23 | 24 | date = date.astimezone(timezone('US/Pacific')) 25 | logger.log('Local date & time is: {}'.format(date.strftime(date_format))) 26 | for i in range(variant['num_seconds']): 27 | logger.log("Tick, {}".format(i)) 28 | time.sleep(1) 29 | logger.log("end") 30 | logger.log('Local date & time is: {}'.format(date.strftime(date_format))) 31 | 32 | logger.log("start mujoco") 33 | from gym.envs.mujoco import HalfCheetahEnv 34 | e = HalfCheetahEnv() 35 | img = e.sim.render(32, 32) 36 | logger.log(str(sum(img))) 37 | logger.log("end mujocoy") 38 | 39 | 40 | if __name__ == "__main__": 41 | # noinspection PyTypeChecker 42 | date_format = '%m/%d/%Y %H:%M:%S %Z' 43 | date = datetime.now(tz=pytz.utc) 44 | logger.log("start") 45 | variant = dict( 46 | num_seconds=10, 47 | launch_time=str(date.strftime(date_format)), 48 | ) 49 | run_experiment( 50 | example, 51 | exp_prefix="ec2-test", 52 | mode='ec2', 53 | variant=variant, 54 | # use_gpu=True, # GPUs are much more expensive! 55 | ) 56 | -------------------------------------------------------------------------------- /rlkit/torch/core.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | from rlkit.torch import pytorch_util as ptu 5 | 6 | 7 | def eval_np(module, *args, **kwargs): 8 | """ 9 | Eval this module with a numpy interface 10 | 11 | Same as a call to __call__ except all Variable input/outputs are 12 | replaced with numpy equivalents. 13 | 14 | Assumes the output is either a single object or a tuple of objects. 15 | """ 16 | torch_args = tuple(torch_ify(x) for x in args) 17 | torch_kwargs = {k: torch_ify(v) for k, v in kwargs.items()} 18 | outputs = module(*torch_args, **torch_kwargs) 19 | if isinstance(outputs, tuple): 20 | return tuple(np_ify(x) for x in outputs) 21 | else: 22 | return np_ify(outputs) 23 | 24 | 25 | def torch_ify(np_array_or_other): 26 | if isinstance(np_array_or_other, np.ndarray): 27 | return ptu.from_numpy(np_array_or_other) 28 | else: 29 | return np_array_or_other 30 | 31 | 32 | def np_ify(tensor_or_other): 33 | if isinstance(tensor_or_other, torch.autograd.Variable): 34 | return ptu.get_numpy(tensor_or_other) 35 | else: 36 | return tensor_or_other 37 | 38 | 39 | def _elem_or_tuple_to_variable(elem_or_tuple): 40 | if isinstance(elem_or_tuple, tuple): 41 | return tuple( 42 | _elem_or_tuple_to_variable(e) for e in elem_or_tuple 43 | ) 44 | return ptu.from_numpy(elem_or_tuple).float() 45 | 46 | 47 | def _filter_batch(np_batch): 48 | for k, v in np_batch.items(): 49 | if v.dtype == np.bool: 50 | yield k, v.astype(int) 51 | else: 52 | yield k, v 53 | 54 | 55 | def np_to_pytorch_batch(np_batch): 56 | return { 57 | k: _elem_or_tuple_to_variable(x) 58 | for k, x in _filter_batch(np_batch) 59 | if x.dtype != np.dtype('O') # ignore object (e.g. dictionaries) 60 | } 61 | 62 | -------------------------------------------------------------------------------- /docs/TDMs.md: -------------------------------------------------------------------------------- 1 | # Temporal Difference Models (TDMs) 2 | The TDM implementation is a bit different from the other algorithms. One reason for this is that the goals and rewards are retroactively relabelled. Some notable implementation details: 3 | - The networks (policy and QF) take in the goal and tau (the number of time steps left). 4 | - The algorithm relabels the terminal and rewards, so the terminal/reward from the environment are ignored completely. 5 | - TdmNormalizer is used to normalize the observations/states. If you want, you can totally ignore it and set `num_pretrain_path=0`. 6 | - The environments need to be [MultitaskEnv](../rlkit/torch/tdm/envs/multitask_env), meaning standard gym environments won't work out of the box. See below for details. 7 | 8 | The example scripts have tuned hyperparameters. Specifically, the following hyperparameters are tuned, as they seem to be the most important ones to tune: 9 | - `num_updates_per_env_step` 10 | - `reward_scale` 11 | - `max_tau` 12 | 13 | 14 | ## Creating your own environment for TDM 15 | A [MultitaskEnv](../envs/multitask_env.py) instances needs to implement 3 functions: 16 | 17 | ```python 18 | def goal_dim(self) -> int: 19 | """ 20 | :return: int, dimension of goal vector 21 | """ 22 | pass 23 | 24 | @abc.abstractmethod 25 | def sample_goals(self, batch_size): 26 | pass 27 | 28 | @abc.abstractmethod 29 | def convert_obs_to_goals(self, obs): 30 | pass 31 | ``` 32 | 33 | If you want to see how to make an existing environment multitask, see [GoalXVelHalfCheetah](../envs/half_cheetah_env.py), which builds off of Gym's HalfCheetah environments. 34 | 35 | Another useful example might be [GoalXYPosAnt](../envs/ant_env.py), which builds off a custom environment. 36 | 37 | One important thing is that the environment should *not* include the goal as part of the state, since the goal will be separately given to the networks. 38 | -------------------------------------------------------------------------------- /rlkit/torch/dqn/double_dqn.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | import rlkit.torch.pytorch_util as ptu 5 | from rlkit.core.eval_util import create_stats_ordered_dict 6 | from rlkit.torch.dqn.dqn import DQNTrainer 7 | 8 | 9 | class DoubleDQNTrainer(DQNTrainer): 10 | def train_from_torch(self, batch): 11 | rewards = batch['rewards'] 12 | terminals = batch['terminals'] 13 | obs = batch['observations'] 14 | actions = batch['actions'] 15 | next_obs = batch['next_observations'] 16 | 17 | """ 18 | Compute loss 19 | """ 20 | 21 | best_action_idxs = self.qf(next_obs).max( 22 | 1, keepdim=True 23 | )[1] 24 | target_q_values = self.target_qf(next_obs).gather( 25 | 1, best_action_idxs 26 | ).detach() 27 | y_target = rewards + (1. - terminals) * self.discount * target_q_values 28 | y_target = y_target.detach() 29 | # actions is a one-hot vector 30 | y_pred = torch.sum(self.qf(obs) * actions, dim=1, keepdim=True) 31 | qf_loss = self.qf_criterion(y_pred, y_target) 32 | 33 | """ 34 | Update networks 35 | """ 36 | self.qf_optimizer.zero_grad() 37 | qf_loss.backward() 38 | self.qf_optimizer.step() 39 | 40 | """ 41 | Soft target network updates 42 | """ 43 | if self._n_train_steps_total % self.target_update_period == 0: 44 | ptu.soft_update_from_to( 45 | self.qf, self.target_qf, self.soft_target_tau 46 | ) 47 | 48 | """ 49 | Save some statistics for eval using just one batch. 50 | """ 51 | if self._need_to_update_eval_statistics: 52 | self._need_to_update_eval_statistics = False 53 | self.eval_statistics['QF Loss'] = np.mean(ptu.get_numpy(qf_loss)) 54 | self.eval_statistics.update(create_stats_ordered_dict( 55 | 'Y Predictions', 56 | ptu.get_numpy(y_pred), 57 | )) 58 | -------------------------------------------------------------------------------- /rlkit/exploration_strategies/ou_strategy.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import numpy.random as nr 3 | 4 | from rlkit.exploration_strategies.base import RawExplorationStrategy 5 | 6 | 7 | class OUStrategy(RawExplorationStrategy): 8 | """ 9 | This strategy implements the Ornstein-Uhlenbeck process, which adds 10 | time-correlated noise to the actions taken by the deterministic policy. 11 | The OU process satisfies the following stochastic differential equation: 12 | dxt = theta*(mu - xt)*dt + sigma*dWt 13 | where Wt denotes the Wiener process 14 | 15 | Based on the rllab implementation. 16 | """ 17 | 18 | def __init__( 19 | self, 20 | action_space, 21 | mu=0, 22 | theta=0.15, 23 | max_sigma=0.3, 24 | min_sigma=None, 25 | decay_period=100000, 26 | ): 27 | if min_sigma is None: 28 | min_sigma = max_sigma 29 | self.mu = mu 30 | self.theta = theta 31 | self.sigma = max_sigma 32 | self._max_sigma = max_sigma 33 | if min_sigma is None: 34 | min_sigma = max_sigma 35 | self._min_sigma = min_sigma 36 | self._decay_period = decay_period 37 | self.dim = np.prod(action_space.low.shape) 38 | self.low = action_space.low 39 | self.high = action_space.high 40 | self.state = np.ones(self.dim) * self.mu 41 | self.reset() 42 | 43 | def reset(self): 44 | self.state = np.ones(self.dim) * self.mu 45 | 46 | def evolve_state(self): 47 | x = self.state 48 | dx = self.theta * (self.mu - x) + self.sigma * nr.randn(len(x)) 49 | self.state = x + dx 50 | return self.state 51 | 52 | def get_action_from_raw_action(self, action, t=0, **kwargs): 53 | ou_state = self.evolve_state() 54 | self.sigma = ( 55 | self._max_sigma 56 | - (self._max_sigma - self._min_sigma) 57 | * min(1.0, t * 1.0 / self._decay_period) 58 | ) 59 | return np.clip(action + ou_state, self.low, self.high) 60 | -------------------------------------------------------------------------------- /rlkit/envs/mujoco_env.py: -------------------------------------------------------------------------------- 1 | import os 2 | from os import path 3 | 4 | import mujoco_py 5 | import numpy as np 6 | from gym.envs.mujoco import mujoco_env 7 | 8 | from rlkit.core.serializable import Serializable 9 | 10 | ENV_ASSET_DIR = os.path.join(os.path.dirname(__file__), 'assets') 11 | 12 | 13 | class MujocoEnv(mujoco_env.MujocoEnv, Serializable): 14 | """ 15 | My own wrapper around MujocoEnv. 16 | 17 | The caller needs to declare 18 | """ 19 | def __init__( 20 | self, 21 | model_path, 22 | frame_skip=1, 23 | model_path_is_local=True, 24 | automatically_set_obs_and_action_space=False, 25 | ): 26 | if model_path_is_local: 27 | model_path = get_asset_xml(model_path) 28 | if automatically_set_obs_and_action_space: 29 | mujoco_env.MujocoEnv.__init__(self, model_path, frame_skip) 30 | else: 31 | """ 32 | Code below is copy/pasted from MujocoEnv's __init__ function. 33 | """ 34 | if model_path.startswith("/"): 35 | fullpath = model_path 36 | else: 37 | fullpath = os.path.join(os.path.dirname(__file__), "assets", model_path) 38 | if not path.exists(fullpath): 39 | raise IOError("File %s does not exist" % fullpath) 40 | self.frame_skip = frame_skip 41 | self.model = mujoco_py.MjModel(fullpath) 42 | self.data = self.model.data 43 | self.viewer = None 44 | 45 | self.metadata = { 46 | 'render.modes': ['human', 'rgb_array'], 47 | 'video.frames_per_second': int(np.round(1.0 / self.dt)) 48 | } 49 | 50 | self.init_qpos = self.model.data.qpos.ravel().copy() 51 | self.init_qvel = self.model.data.qvel.ravel().copy() 52 | self._seed() 53 | 54 | def init_serialization(self, locals): 55 | Serializable.quick_init(self, locals) 56 | 57 | def log_diagnostics(self, paths): 58 | pass 59 | 60 | 61 | def get_asset_xml(xml_name): 62 | return os.path.join(ENV_ASSET_DIR, xml_name) 63 | -------------------------------------------------------------------------------- /examples/doodad/gcp_example.py: -------------------------------------------------------------------------------- 1 | """ 2 | Example of running stuff on GCP 3 | """ 4 | import time 5 | 6 | from rlkit.core import logger 7 | from rlkit.launchers.launcher_util import run_experiment 8 | from datetime import datetime 9 | from pytz import timezone 10 | import pytz 11 | 12 | 13 | def example(variant): 14 | import torch 15 | import rlkit.torch.pytorch_util as ptu 16 | print("Starting") 17 | logger.log(torch.__version__) 18 | date_format = '%m/%d/%Y %H:%M:%S %Z' 19 | date = datetime.now(tz=pytz.utc) 20 | logger.log("start") 21 | logger.log('Current date & time is: {}'.format(date.strftime(date_format))) 22 | logger.log("Cuda available: {}".format(torch.cuda.is_available())) 23 | if torch.cuda.is_available(): 24 | x = torch.randn(3) 25 | logger.log(str(x.to(ptu.device))) 26 | 27 | date = date.astimezone(timezone('US/Pacific')) 28 | logger.log('Local date & time is: {}'.format(date.strftime(date_format))) 29 | for i in range(variant['num_seconds']): 30 | logger.log("Tick, {}".format(i)) 31 | time.sleep(1) 32 | logger.log("end") 33 | logger.log('Local date & time is: {}'.format(date.strftime(date_format))) 34 | 35 | logger.log("start mujoco") 36 | from gym.envs.mujoco import HalfCheetahEnv 37 | e = HalfCheetahEnv() 38 | img = e.sim.render(32, 32) 39 | logger.log(str(sum(img))) 40 | logger.log("end mujoco") 41 | 42 | logger.record_tabular('Epoch', 1) 43 | logger.dump_tabular() 44 | logger.record_tabular('Epoch', 2) 45 | logger.dump_tabular() 46 | logger.record_tabular('Epoch', 3) 47 | logger.dump_tabular() 48 | print("Done") 49 | 50 | 51 | if __name__ == "__main__": 52 | # noinspection PyTypeChecker 53 | date_format = '%m/%d/%Y %H:%M:%S %Z' 54 | date = datetime.now(tz=pytz.utc) 55 | logger.log("start") 56 | variant = dict( 57 | num_seconds=10, 58 | launch_time=str(date.strftime(date_format)), 59 | ) 60 | run_experiment( 61 | example, 62 | exp_prefix="gcp-test", 63 | mode='gcp', 64 | variant=variant, 65 | # use_gpu=True, # GPUs are much more expensive! 66 | ) 67 | -------------------------------------------------------------------------------- /rlkit/envs/ant.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from rlkit.envs.mujoco_env import MujocoEnv 4 | 5 | 6 | class AntEnv(MujocoEnv): 7 | def __init__(self, use_low_gear_ratio=True): 8 | self.init_serialization(locals()) 9 | if use_low_gear_ratio: 10 | xml_path = 'low_gear_ratio_ant.xml' 11 | else: 12 | xml_path = 'normal_gear_ratio_ant.xml' 13 | super().__init__( 14 | xml_path, 15 | frame_skip=5, 16 | automatically_set_obs_and_action_space=True, 17 | ) 18 | 19 | def step(self, a): 20 | torso_xyz_before = self.get_body_com("torso") 21 | self.do_simulation(a, self.frame_skip) 22 | torso_xyz_after = self.get_body_com("torso") 23 | torso_velocity = torso_xyz_after - torso_xyz_before 24 | forward_reward = torso_velocity[0]/self.dt 25 | ctrl_cost = .5 * np.square(a).sum() 26 | contact_cost = 0.5 * 1e-3 * np.sum( 27 | np.square(np.clip(self.sim.data.cfrc_ext, -1, 1))) 28 | survive_reward = 1.0 29 | reward = forward_reward - ctrl_cost - contact_cost + survive_reward 30 | state = self.state_vector() 31 | notdone = np.isfinite(state).all() \ 32 | and state[2] >= 0.2 and state[2] <= 1.0 33 | done = not notdone 34 | ob = self._get_obs() 35 | return ob, reward, done, dict( 36 | reward_forward=forward_reward, 37 | reward_ctrl=-ctrl_cost, 38 | reward_contact=-contact_cost, 39 | reward_survive=survive_reward, 40 | torso_velocity=torso_velocity, 41 | ) 42 | 43 | def _get_obs(self): 44 | return np.concatenate([ 45 | self.sim.data.qpos.flat[2:], 46 | self.sim.data.qvel.flat, 47 | ]) 48 | 49 | def reset_model(self): 50 | qpos = self.init_qpos + self.np_random.uniform(size=self.model.nq, low=-.1, high=.1) 51 | qvel = self.init_qvel + self.np_random.randn(self.model.nv) * .1 52 | self.set_state(qpos, qvel) 53 | return self._get_obs() 54 | 55 | def viewer_setup(self): 56 | self.viewer.cam.distance = self.model.stat.extent * 0.5 57 | -------------------------------------------------------------------------------- /rlkit/util/ml_util.py: -------------------------------------------------------------------------------- 1 | """ 2 | General utility functions for machine learning. 3 | """ 4 | import abc 5 | import math 6 | import numpy as np 7 | 8 | 9 | class ScalarSchedule(object, metaclass=abc.ABCMeta): 10 | @abc.abstractmethod 11 | def get_value(self, t): 12 | pass 13 | 14 | 15 | class ConstantSchedule(ScalarSchedule): 16 | def __init__(self, value): 17 | self._value = value 18 | 19 | def get_value(self, t): 20 | return self._value 21 | 22 | 23 | class LinearSchedule(ScalarSchedule): 24 | """ 25 | Linearly interpolate and then stop at a final value. 26 | """ 27 | def __init__( 28 | self, 29 | init_value, 30 | final_value, 31 | ramp_duration, 32 | ): 33 | self._init_value = init_value 34 | self._final_value = final_value 35 | self._ramp_duration = ramp_duration 36 | 37 | def get_value(self, t): 38 | return ( 39 | self._init_value 40 | + (self._final_value - self._init_value) 41 | * min(1.0, t * 1.0 / self._ramp_duration) 42 | ) 43 | 44 | 45 | class IntLinearSchedule(LinearSchedule): 46 | """ 47 | Same as RampUpSchedule but round output to an int 48 | """ 49 | def get_value(self, t): 50 | return int(super().get_value(t)) 51 | 52 | 53 | class PiecewiseLinearSchedule(ScalarSchedule): 54 | """ 55 | Given a list of (x, t) value-time pairs, return value x at time t, 56 | and linearly interpolate between the two 57 | """ 58 | def __init__( 59 | self, 60 | x_values, 61 | y_values, 62 | ): 63 | self._x_values = x_values 64 | self._y_values = y_values 65 | 66 | def get_value(self, t): 67 | return np.interp(t, self._x_values, self._y_values) 68 | 69 | 70 | class IntPiecewiseLinearSchedule(PiecewiseLinearSchedule): 71 | def get_value(self, t): 72 | return int(super().get_value(t)) 73 | 74 | 75 | def none_to_infty(bounds): 76 | if bounds is None: 77 | bounds = -math.inf, math.inf 78 | lb, ub = bounds 79 | if lb is None: 80 | lb = -math.inf 81 | if ub is None: 82 | ub = math.inf 83 | return lb, ub 84 | -------------------------------------------------------------------------------- /scripts/run_goal_conditioned_policy.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import pickle 3 | 4 | from rlkit.core import logger 5 | from rlkit.samplers.rollout_functions import multitask_rollout 6 | from rlkit.torch import pytorch_util as ptu 7 | from rlkit.envs.vae_wrapper import VAEWrappedEnv 8 | 9 | 10 | def simulate_policy(args): 11 | data = pickle.load(open(args.file, "rb")) 12 | policy = data['evaluation/policy'] 13 | env = data['evaluation/env'] 14 | print("Policy and environment loaded") 15 | if args.gpu: 16 | ptu.set_gpu_mode(True) 17 | policy.to(ptu.device) 18 | if isinstance(env, VAEWrappedEnv) and hasattr(env, 'mode'): 19 | env.mode(args.mode) 20 | if args.enable_render or hasattr(env, 'enable_render'): 21 | # some environments need to be reconfigured for visualization 22 | env.enable_render() 23 | paths = [] 24 | while True: 25 | paths.append(multitask_rollout( 26 | env, 27 | policy, 28 | max_path_length=args.H, 29 | render=not args.hide, 30 | observation_key='observation', 31 | desired_goal_key='desired_goal', 32 | )) 33 | if hasattr(env, "log_diagnostics"): 34 | env.log_diagnostics(paths) 35 | if hasattr(env, "get_diagnostics"): 36 | for k, v in env.get_diagnostics(paths).items(): 37 | logger.record_tabular(k, v) 38 | logger.dump_tabular() 39 | 40 | 41 | if __name__ == "__main__": 42 | 43 | parser = argparse.ArgumentParser() 44 | parser.add_argument('file', type=str, 45 | help='path to the snapshot file') 46 | parser.add_argument('--H', type=int, default=300, 47 | help='Max length of rollout') 48 | parser.add_argument('--speedup', type=float, default=10, 49 | help='Speedup') 50 | parser.add_argument('--mode', default='video_env', type=str, 51 | help='env mode') 52 | parser.add_argument('--gpu', action='store_true') 53 | parser.add_argument('--enable_render', action='store_true') 54 | parser.add_argument('--hide', action='store_true') 55 | args = parser.parse_args() 56 | 57 | simulate_policy(args) 58 | -------------------------------------------------------------------------------- /rlkit/core/serializable.py: -------------------------------------------------------------------------------- 1 | """ 2 | Based on rllab's serializable.py file 3 | 4 | https://github.com/rll/rllab 5 | """ 6 | 7 | import inspect 8 | import sys 9 | 10 | 11 | class Serializable(object): 12 | 13 | def __init__(self, *args, **kwargs): 14 | self.__args = args 15 | self.__kwargs = kwargs 16 | 17 | def quick_init(self, locals_): 18 | if getattr(self, "_serializable_initialized", False): 19 | return 20 | if sys.version_info >= (3, 0): 21 | spec = inspect.getfullargspec(self.__init__) 22 | # Exclude the first "self" parameter 23 | if spec.varkw: 24 | kwargs = locals_[spec.varkw].copy() 25 | else: 26 | kwargs = dict() 27 | if spec.kwonlyargs: 28 | for key in spec.kwonlyargs: 29 | kwargs[key] = locals_[key] 30 | else: 31 | spec = inspect.getargspec(self.__init__) 32 | if spec.keywords: 33 | kwargs = locals_[spec.keywords] 34 | else: 35 | kwargs = dict() 36 | if spec.varargs: 37 | varargs = locals_[spec.varargs] 38 | else: 39 | varargs = tuple() 40 | in_order_args = [locals_[arg] for arg in spec.args][1:] 41 | self.__args = tuple(in_order_args) + varargs 42 | self.__kwargs = kwargs 43 | setattr(self, "_serializable_initialized", True) 44 | 45 | def __getstate__(self): 46 | return {"__args": self.__args, "__kwargs": self.__kwargs} 47 | 48 | def __setstate__(self, d): 49 | # convert all __args to keyword-based arguments 50 | if sys.version_info >= (3, 0): 51 | spec = inspect.getfullargspec(self.__init__) 52 | else: 53 | spec = inspect.getargspec(self.__init__) 54 | in_order_args = spec.args[1:] 55 | out = type(self)(**dict(zip(in_order_args, d["__args"]), **d["__kwargs"])) 56 | self.__dict__.update(out.__dict__) 57 | 58 | @classmethod 59 | def clone(cls, obj, **kwargs): 60 | assert isinstance(obj, Serializable) 61 | d = obj.__getstate__() 62 | d["__kwargs"] = dict(d["__kwargs"], **kwargs) 63 | out = type(obj).__new__(type(obj)) 64 | out.__setstate__(d) 65 | return out 66 | -------------------------------------------------------------------------------- /docs/HER.md: -------------------------------------------------------------------------------- 1 | # Hindsight Experience Replay 2 | Some notes on the implementation of 3 | [Hindsight Experience Replay](https://arxiv.org/abs/1707.01495). 4 | ## Expected Results 5 | If you run the [Fetch example](../examples/her/her_td3_gym_fetch_reach.py), then 6 | you should get results like this: 7 | ![Fetch HER results](images/FetchReach-v1_HER-TD3.png) 8 | 9 | If you run the [GridWorld example](../examples/her/her_dqn_gridworld.py) 10 | , then you should get results like this: 11 | ![HER Gridworld results](images/her_dqn.png) 12 | 13 | Note that these examples use HER combined with DQN and SAC, and not DDPG. 14 | 15 | These plots are generated using [viskit](https://github.com/vitchyr/viskit). 16 | 17 | ## Goal-based environments and `ObsDictRelabelingBuffer` 18 | [See here.](goal_based_envs.md) 19 | 20 | ## Implementation Difference 21 | This HER implemention is slightly different from the one presented in the paper. 22 | Rather than relabeling goals when saving data to the replay buffer, the goals 23 | are relabeled when sampling from the replay buffer. 24 | 25 | 26 | In other words, HER in the paper does this: 27 | 28 | Data collection 29 | 1. Sample $(s, a, r, s', g) ~ \text\{ENV}$. 30 | 2. Save $(s, a, r, s', g)$ into replay buffer $\mathcal B$. 31 | For i = 1, ..., K: 32 | Sample $g_i$ using the future strategy. 33 | Recompute rewards $r_i = f(s', g_i)$. 34 | Save $(s, a, r_i, s', g_)$ into replay buffer $\mathcal B$. 35 | Train time 36 | 1. Sample $(s, a, r, s', g)$ from replay buffer 37 | 2. Train Q function $(s, a, r, s', g)$ 38 | 39 | The implementation here does: 40 | 41 | Data collection 42 | 1. Sample $(s, a, r, s', g) ~ \text\{ENV}$. 43 | 2. Save $(s, a, r, s', g)$ into replay buffer $\mathcal B$. 44 | Train time 45 | 1. Sample $(s, a, r, s', g)$ from replay buffer 46 | 2a. With probability 1/(K+1): 47 | Train Q function $(s, a, r, s', g)$ 48 | 2b. With probability 1 - 1/(K+1): 49 | Sample $g'$ using the future strategy. 50 | Recompute rewards $r' = f(s', g')$. 51 | Train Q function on $(s, a, r', s', g')$ 52 | 53 | Both implementations effective do the same thing: with probability 1/(K+1), 54 | you train the policy on the goal used during rollout. Otherwise, train the 55 | policy on a resampled goal. 56 | 57 | -------------------------------------------------------------------------------- /rlkit/torch/data.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch.utils.data import Dataset, Sampler 4 | 5 | # TODO: move this to more reasonable place 6 | from rlkit.data_management.obs_dict_replay_buffer import normalize_image 7 | 8 | 9 | class ImageDataset(Dataset): 10 | 11 | def __init__(self, images, should_normalize=True): 12 | super().__init__() 13 | self.dataset = images 14 | self.dataset_len = len(self.dataset) 15 | assert should_normalize == (images.dtype == np.uint8) 16 | self.should_normalize = should_normalize 17 | 18 | def __len__(self): 19 | return self.dataset_len 20 | 21 | def __getitem__(self, idxs): 22 | samples = self.dataset[idxs, :] 23 | if self.should_normalize: 24 | samples = normalize_image(samples) 25 | return np.float32(samples) 26 | 27 | 28 | class InfiniteRandomSampler(Sampler): 29 | 30 | def __init__(self, data_source): 31 | self.data_source = data_source 32 | self.iter = iter(torch.randperm(len(self.data_source)).tolist()) 33 | 34 | def __iter__(self): 35 | return self 36 | 37 | def __next__(self): 38 | try: 39 | idx = next(self.iter) 40 | except StopIteration: 41 | self.iter = iter(torch.randperm(len(self.data_source)).tolist()) 42 | idx = next(self.iter) 43 | return idx 44 | 45 | def __len__(self): 46 | return 2 ** 62 47 | 48 | 49 | class InfiniteWeightedRandomSampler(Sampler): 50 | 51 | def __init__(self, data_source, weights): 52 | assert len(data_source) == len(weights) 53 | assert len(weights.shape) == 1 54 | self.data_source = data_source 55 | # Always use CPU 56 | self._weights = torch.from_numpy(weights) 57 | self.iter = self._create_iterator() 58 | 59 | def update_weights(self, weights): 60 | self._weights = weights 61 | self.iter = self._create_iterator() 62 | 63 | def _create_iterator(self): 64 | return iter( 65 | torch.multinomial( 66 | self._weights, len(self._weights), replacement=True 67 | ).tolist() 68 | ) 69 | 70 | def __iter__(self): 71 | return self 72 | 73 | def __next__(self): 74 | try: 75 | idx = next(self.iter) 76 | except StopIteration: 77 | self.iter = self._create_iterator() 78 | idx = next(self.iter) 79 | return idx 80 | 81 | def __len__(self): 82 | return 2 ** 62 83 | -------------------------------------------------------------------------------- /rlkit/torch/distributions.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.distributions import Distribution, Normal 3 | import rlkit.torch.pytorch_util as ptu 4 | 5 | 6 | class TanhNormal(Distribution): 7 | """ 8 | Represent distribution of X where 9 | X ~ tanh(Z) 10 | Z ~ N(mean, std) 11 | 12 | Note: this is not very numerically stable. 13 | """ 14 | def __init__(self, normal_mean, normal_std, epsilon=1e-6): 15 | """ 16 | :param normal_mean: Mean of the normal distribution 17 | :param normal_std: Std of the normal distribution 18 | :param epsilon: Numerical stability epsilon when computing log-prob. 19 | """ 20 | self.normal_mean = normal_mean 21 | self.normal_std = normal_std 22 | self.normal = Normal(normal_mean, normal_std) 23 | self.epsilon = epsilon 24 | 25 | def sample_n(self, n, return_pre_tanh_value=False): 26 | z = self.normal.sample_n(n) 27 | if return_pre_tanh_value: 28 | return torch.tanh(z), z 29 | else: 30 | return torch.tanh(z) 31 | 32 | def log_prob(self, value, pre_tanh_value=None): 33 | """ 34 | 35 | :param value: some value, x 36 | :param pre_tanh_value: arctanh(x) 37 | :return: 38 | """ 39 | if pre_tanh_value is None: 40 | pre_tanh_value = torch.log( 41 | (1+value) / (1-value) 42 | ) / 2 43 | return self.normal.log_prob(pre_tanh_value) - torch.log( 44 | 1 - value * value + self.epsilon 45 | ) 46 | 47 | def sample(self, return_pretanh_value=False): 48 | """ 49 | Gradients will and should *not* pass through this operation. 50 | 51 | See https://github.com/pytorch/pytorch/issues/4620 for discussion. 52 | """ 53 | z = self.normal.sample().detach() 54 | 55 | if return_pretanh_value: 56 | return torch.tanh(z), z 57 | else: 58 | return torch.tanh(z) 59 | 60 | def rsample(self, return_pretanh_value=False): 61 | """ 62 | Sampling in the reparameterization case. 63 | """ 64 | z = ( 65 | self.normal_mean + 66 | self.normal_std * 67 | Normal( 68 | ptu.zeros(self.normal_mean.size()), 69 | ptu.ones(self.normal_std.size()) 70 | ).sample() 71 | ) 72 | z.requires_grad_() 73 | 74 | if return_pretanh_value: 75 | return torch.tanh(z), z 76 | else: 77 | return torch.tanh(z) 78 | -------------------------------------------------------------------------------- /rlkit/torch/data_management/normalizer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import rlkit.torch.pytorch_util as ptu 3 | import numpy as np 4 | 5 | from rlkit.data_management.normalizer import Normalizer, FixedNormalizer 6 | 7 | 8 | class TorchNormalizer(Normalizer): 9 | """ 10 | Update with np array, but de/normalize pytorch Tensors. 11 | """ 12 | def normalize(self, v, clip_range=None): 13 | if not self.synchronized: 14 | self.synchronize() 15 | if clip_range is None: 16 | clip_range = self.default_clip_range 17 | mean = ptu.from_numpy(self.mean) 18 | std = ptu.from_numpy(self.std) 19 | if v.dim() == 2: 20 | # Unsqueeze along the batch use automatic broadcasting 21 | mean = mean.unsqueeze(0) 22 | std = std.unsqueeze(0) 23 | return torch.clamp((v - mean) / std, -clip_range, clip_range) 24 | 25 | def denormalize(self, v): 26 | if not self.synchronized: 27 | self.synchronize() 28 | mean = ptu.from_numpy(self.mean) 29 | std = ptu.from_numpy(self.std) 30 | if v.dim() == 2: 31 | mean = mean.unsqueeze(0) 32 | std = std.unsqueeze(0) 33 | return mean + v * std 34 | 35 | 36 | class TorchFixedNormalizer(FixedNormalizer): 37 | def normalize(self, v, clip_range=None): 38 | if clip_range is None: 39 | clip_range = self.default_clip_range 40 | mean = ptu.from_numpy(self.mean) 41 | std = ptu.from_numpy(self.std) 42 | if v.dim() == 2: 43 | # Unsqueeze along the batch use automatic broadcasting 44 | mean = mean.unsqueeze(0) 45 | std = std.unsqueeze(0) 46 | return torch.clamp((v - mean) / std, -clip_range, clip_range) 47 | 48 | def normalize_scale(self, v): 49 | """ 50 | Only normalize the scale. Do not subtract the mean. 51 | """ 52 | std = ptu.from_numpy(self.std) 53 | if v.dim() == 2: 54 | std = std.unsqueeze(0) 55 | return v / std 56 | 57 | def denormalize(self, v): 58 | mean = ptu.from_numpy(self.mean) 59 | std = ptu.from_numpy(self.std) 60 | if v.dim() == 2: 61 | mean = mean.unsqueeze(0) 62 | std = std.unsqueeze(0) 63 | return mean + v * std 64 | 65 | def denormalize_scale(self, v): 66 | """ 67 | Only denormalize the scale. Do not add the mean. 68 | """ 69 | std = ptu.from_numpy(self.std) 70 | if v.dim() == 2: 71 | std = std.unsqueeze(0) 72 | return v * std 73 | -------------------------------------------------------------------------------- /scripts/run_experiment_from_doodad.py: -------------------------------------------------------------------------------- 1 | import doodad as dd 2 | import torch.multiprocessing as mp 3 | 4 | from rlkit.launchers.launcher_util import run_experiment_here 5 | 6 | if __name__ == "__main__": 7 | import matplotlib 8 | matplotlib.use('agg') 9 | 10 | mp.set_start_method('forkserver') 11 | args_dict = dd.get_args() 12 | method_call = args_dict['method_call'] 13 | run_experiment_kwargs = args_dict['run_experiment_kwargs'] 14 | output_dir = args_dict['output_dir'] 15 | run_mode = args_dict.get('mode', None) 16 | if run_mode and run_mode in ['slurm_singularity', 'sss']: 17 | import os 18 | run_experiment_kwargs['variant']['slurm-job-id'] = os.environ.get( 19 | 'SLURM_JOB_ID', None 20 | ) 21 | if run_mode and (run_mode == 'ec2' or run_mode == 'gcp'): 22 | if run_mode == 'ec2': 23 | try: 24 | import urllib.request 25 | instance_id = urllib.request.urlopen( 26 | 'http://169.254.169.254/latest/meta-data/instance-id' 27 | ).read().decode() 28 | run_experiment_kwargs['variant']['EC2_instance_id'] = instance_id 29 | except Exception as e: 30 | print("Could not get AWS instance ID. Error was...") 31 | print(e) 32 | if run_mode == 'gcp': 33 | try: 34 | import urllib.request 35 | request = urllib.request.Request( 36 | "http://metadata/computeMetadata/v1/instance/name", 37 | ) 38 | # See this URL for why we need this header: 39 | # https://cloud.google.com/compute/docs/storing-retrieving-metadata 40 | request.add_header("Metadata-Flavor", "Google") 41 | instance_name = urllib.request.urlopen(request).read().decode() 42 | run_experiment_kwargs['variant']['GCP_instance_name'] = ( 43 | instance_name 44 | ) 45 | except Exception as e: 46 | print("Could not get GCP instance name. Error was...") 47 | print(e) 48 | # Do this in case base_log_dir was already set 49 | run_experiment_kwargs['base_log_dir'] = output_dir 50 | run_experiment_here( 51 | method_call, 52 | include_exp_prefix_sub_dir=False, 53 | **run_experiment_kwargs 54 | ) 55 | else: 56 | run_experiment_here( 57 | method_call, 58 | log_dir=output_dir, 59 | **run_experiment_kwargs 60 | ) 61 | -------------------------------------------------------------------------------- /rlkit/data_management/replay_buffer.py: -------------------------------------------------------------------------------- 1 | import abc 2 | 3 | 4 | class ReplayBuffer(object, metaclass=abc.ABCMeta): 5 | """ 6 | A class used to save and replay data. 7 | """ 8 | 9 | @abc.abstractmethod 10 | def add_sample(self, observation, action, reward, next_observation, 11 | terminal, **kwargs): 12 | """ 13 | Add a transition tuple. 14 | """ 15 | pass 16 | 17 | @abc.abstractmethod 18 | def terminate_episode(self): 19 | """ 20 | Let the replay buffer know that the episode has terminated in case some 21 | special book-keeping has to happen. 22 | :return: 23 | """ 24 | pass 25 | 26 | @abc.abstractmethod 27 | def num_steps_can_sample(self, **kwargs): 28 | """ 29 | :return: # of unique items that can be sampled. 30 | """ 31 | pass 32 | 33 | def add_path(self, path): 34 | """ 35 | Add a path to the replay buffer. 36 | 37 | This default implementation naively goes through every step, but you 38 | may want to optimize this. 39 | 40 | NOTE: You should NOT call "terminate_episode" after calling add_path. 41 | It's assumed that this function handles the episode termination. 42 | 43 | :param path: Dict like one outputted by rlkit.samplers.util.rollout 44 | """ 45 | for i, ( 46 | obs, 47 | action, 48 | reward, 49 | next_obs, 50 | terminal, 51 | agent_info, 52 | env_info 53 | ) in enumerate(zip( 54 | path["observations"], 55 | path["actions"], 56 | path["rewards"], 57 | path["next_observations"], 58 | path["terminals"], 59 | path["agent_infos"], 60 | path["env_infos"], 61 | )): 62 | self.add_sample( 63 | observation=obs, 64 | action=action, 65 | reward=reward, 66 | next_observation=next_obs, 67 | terminal=terminal, 68 | agent_info=agent_info, 69 | env_info=env_info, 70 | ) 71 | self.terminate_episode() 72 | 73 | def add_paths(self, paths): 74 | for path in paths: 75 | self.add_path(path) 76 | 77 | @abc.abstractmethod 78 | def random_batch(self, batch_size): 79 | """ 80 | Return a batch of size `batch_size`. 81 | :param batch_size: 82 | :return: 83 | """ 84 | pass 85 | 86 | def get_diagnostics(self): 87 | return {} 88 | 89 | def get_snapshot(self): 90 | return {} 91 | 92 | def end_epoch(self, epoch): 93 | return 94 | 95 | -------------------------------------------------------------------------------- /rlkit/util/io.py: -------------------------------------------------------------------------------- 1 | import joblib 2 | import numpy as np 3 | import pickle 4 | 5 | import boto3 6 | 7 | from rlkit.launchers.conf import LOCAL_LOG_DIR, AWS_S3_PATH 8 | import os 9 | 10 | PICKLE = 'pickle' 11 | NUMPY = 'numpy' 12 | JOBLIB = 'joblib' 13 | 14 | 15 | def local_path_from_s3_or_local_path(filename): 16 | relative_filename = os.path.join(LOCAL_LOG_DIR, filename) 17 | if os.path.isfile(filename): 18 | return filename 19 | elif os.path.isfile(relative_filename): 20 | return relative_filename 21 | else: 22 | return sync_down(filename) 23 | 24 | 25 | def sync_down(path, check_exists=True): 26 | is_docker = os.path.isfile("/.dockerenv") 27 | if is_docker: 28 | local_path = "/tmp/%s" % (path) 29 | else: 30 | local_path = "%s/%s" % (LOCAL_LOG_DIR, path) 31 | 32 | if check_exists and os.path.isfile(local_path): 33 | return local_path 34 | 35 | local_dir = os.path.dirname(local_path) 36 | os.makedirs(local_dir, exist_ok=True) 37 | 38 | if is_docker: 39 | from doodad.ec2.autoconfig import AUTOCONFIG 40 | os.environ["AWS_ACCESS_KEY_ID"] = AUTOCONFIG.aws_access_key() 41 | os.environ["AWS_SECRET_ACCESS_KEY"] = AUTOCONFIG.aws_access_secret() 42 | 43 | full_s3_path = os.path.join(AWS_S3_PATH, path) 44 | bucket_name, bucket_relative_path = split_s3_full_path(full_s3_path) 45 | try: 46 | bucket = boto3.resource('s3').Bucket(bucket_name) 47 | bucket.download_file(bucket_relative_path, local_path) 48 | except Exception as e: 49 | local_path = None 50 | print("Failed to sync! path: ", path) 51 | print("Exception: ", e) 52 | return local_path 53 | 54 | 55 | def split_s3_full_path(s3_path): 56 | """ 57 | Split "s3://foo/bar/baz" into "foo" and "bar/baz" 58 | """ 59 | bucket_name_and_directories = s3_path.split('//')[1] 60 | bucket_name, *directories = bucket_name_and_directories.split('/') 61 | directory_path = '/'.join(directories) 62 | return bucket_name, directory_path 63 | 64 | 65 | def load_local_or_remote_file(filepath, file_type=None): 66 | local_path = local_path_from_s3_or_local_path(filepath) 67 | if file_type is None: 68 | extension = local_path.split('.')[-1] 69 | if extension == 'npy': 70 | file_type = NUMPY 71 | else: 72 | file_type = PICKLE 73 | else: 74 | file_type = PICKLE 75 | if file_type == NUMPY: 76 | object = np.load(open(local_path, "rb")) 77 | elif file_type == JOBLIB: 78 | object = joblib.load(local_path) 79 | else: 80 | object = pickle.load(open(local_path, "rb")) 81 | print("loaded", local_path) 82 | return object 83 | 84 | 85 | if __name__ == "__main__": 86 | p = sync_down("ashvin/vae/new-point2d/run0/id1/params.pkl") 87 | print("got", p) -------------------------------------------------------------------------------- /examples/dqn_and_double_dqn.py: -------------------------------------------------------------------------------- 1 | """ 2 | Run DQN on grid world. 3 | """ 4 | 5 | import gym 6 | from torch import nn as nn 7 | 8 | from rlkit.exploration_strategies.base import \ 9 | PolicyWrappedWithExplorationStrategy 10 | from rlkit.exploration_strategies.epsilon_greedy import EpsilonGreedy 11 | from rlkit.policies.argmax import ArgmaxDiscretePolicy 12 | from rlkit.torch.dqn.dqn import DQNTrainer 13 | from rlkit.torch.networks import Mlp 14 | import rlkit.torch.pytorch_util as ptu 15 | from rlkit.data_management.env_replay_buffer import EnvReplayBuffer 16 | from rlkit.launchers.launcher_util import setup_logger 17 | from rlkit.samplers.data_collector import MdpPathCollector 18 | from rlkit.torch.torch_rl_algorithm import TorchBatchRLAlgorithm 19 | 20 | 21 | def experiment(variant): 22 | expl_env = gym.make('CartPole-v0') 23 | eval_env = gym.make('CartPole-v0') 24 | obs_dim = expl_env.observation_space.low.size 25 | action_dim = eval_env.action_space.n 26 | 27 | qf = Mlp( 28 | hidden_sizes=[32, 32], 29 | input_size=obs_dim, 30 | output_size=action_dim, 31 | ) 32 | target_qf = Mlp( 33 | hidden_sizes=[32, 32], 34 | input_size=obs_dim, 35 | output_size=action_dim, 36 | ) 37 | qf_criterion = nn.MSELoss() 38 | eval_policy = ArgmaxDiscretePolicy(qf) 39 | expl_policy = PolicyWrappedWithExplorationStrategy( 40 | EpsilonGreedy(expl_env.action_space), 41 | eval_policy, 42 | ) 43 | eval_path_collector = MdpPathCollector( 44 | eval_env, 45 | eval_policy, 46 | ) 47 | expl_path_collector = MdpPathCollector( 48 | expl_env, 49 | expl_policy, 50 | ) 51 | trainer = DQNTrainer( 52 | qf=qf, 53 | target_qf=target_qf, 54 | qf_criterion=qf_criterion, 55 | **variant['trainer_kwargs'] 56 | ) 57 | replay_buffer = EnvReplayBuffer( 58 | variant['replay_buffer_size'], 59 | expl_env, 60 | ) 61 | algorithm = TorchBatchRLAlgorithm( 62 | trainer=trainer, 63 | exploration_env=expl_env, 64 | evaluation_env=eval_env, 65 | exploration_data_collector=expl_path_collector, 66 | evaluation_data_collector=eval_path_collector, 67 | replay_buffer=replay_buffer, 68 | **variant['algorithm_kwargs'] 69 | ) 70 | algorithm.to(ptu.device) 71 | algorithm.train() 72 | 73 | 74 | 75 | 76 | if __name__ == "__main__": 77 | # noinspection PyTypeChecker 78 | variant = dict( 79 | algorithm="SAC", 80 | version="normal", 81 | layer_size=256, 82 | replay_buffer_size=int(1E6), 83 | algorithm_kwargs=dict( 84 | num_epochs=3000, 85 | num_eval_steps_per_epoch=5000, 86 | num_trains_per_train_loop=1000, 87 | num_expl_steps_per_train_loop=1000, 88 | min_num_steps_before_training=1000, 89 | max_path_length=1000, 90 | batch_size=256, 91 | ), 92 | trainer_kwargs=dict( 93 | discount=0.99, 94 | learning_rate=3E-4, 95 | ), 96 | ) 97 | setup_logger('name-of-experiment', variant=variant) 98 | # ptu.set_gpu_mode(True) # optionally set the GPU (default=False) 99 | experiment(variant) 100 | -------------------------------------------------------------------------------- /rlkit/core/batch_rl_algorithm.py: -------------------------------------------------------------------------------- 1 | import abc 2 | 3 | import gtimer as gt 4 | from rlkit.core.rl_algorithm import BaseRLAlgorithm 5 | from rlkit.data_management.replay_buffer import ReplayBuffer 6 | from rlkit.samplers.data_collector import PathCollector 7 | 8 | 9 | class BatchRLAlgorithm(BaseRLAlgorithm, metaclass=abc.ABCMeta): 10 | def __init__( 11 | self, 12 | trainer, 13 | exploration_env, 14 | evaluation_env, 15 | exploration_data_collector: PathCollector, 16 | evaluation_data_collector: PathCollector, 17 | replay_buffer: ReplayBuffer, 18 | batch_size, 19 | max_path_length, 20 | num_epochs, 21 | num_eval_steps_per_epoch, 22 | num_expl_steps_per_train_loop, 23 | num_trains_per_train_loop, 24 | num_train_loops_per_epoch=1, 25 | min_num_steps_before_training=0, 26 | ): 27 | super().__init__( 28 | trainer, 29 | exploration_env, 30 | evaluation_env, 31 | exploration_data_collector, 32 | evaluation_data_collector, 33 | replay_buffer, 34 | ) 35 | self.batch_size = batch_size 36 | self.max_path_length = max_path_length 37 | self.num_epochs = num_epochs 38 | self.num_eval_steps_per_epoch = num_eval_steps_per_epoch 39 | self.num_trains_per_train_loop = num_trains_per_train_loop 40 | self.num_train_loops_per_epoch = num_train_loops_per_epoch 41 | self.num_expl_steps_per_train_loop = num_expl_steps_per_train_loop 42 | self.min_num_steps_before_training = min_num_steps_before_training 43 | 44 | def _train(self): 45 | if self.min_num_steps_before_training > 0: 46 | init_expl_paths = self.expl_data_collector.collect_new_paths( 47 | self.max_path_length, 48 | self.min_num_steps_before_training, 49 | discard_incomplete_paths=False, 50 | ) 51 | self.replay_buffer.add_paths(init_expl_paths) 52 | self.expl_data_collector.end_epoch(-1) 53 | 54 | for epoch in gt.timed_for( 55 | range(self._start_epoch, self.num_epochs), 56 | save_itrs=True, 57 | ): 58 | self.eval_data_collector.collect_new_paths( 59 | self.max_path_length, 60 | self.num_eval_steps_per_epoch, 61 | discard_incomplete_paths=True, 62 | ) 63 | gt.stamp('evaluation sampling') 64 | 65 | for _ in range(self.num_train_loops_per_epoch): 66 | new_expl_paths = self.expl_data_collector.collect_new_paths( 67 | self.max_path_length, 68 | self.num_expl_steps_per_train_loop, 69 | discard_incomplete_paths=False, 70 | ) 71 | gt.stamp('exploration sampling', unique=False) 72 | 73 | self.replay_buffer.add_paths(new_expl_paths) 74 | gt.stamp('data storing', unique=False) 75 | 76 | self.training_mode(True) 77 | for _ in range(self.num_trains_per_train_loop): 78 | train_data = self.replay_buffer.random_batch( 79 | self.batch_size) 80 | self.trainer.train(train_data) 81 | gt.stamp('training', unique=False) 82 | self.training_mode(False) 83 | 84 | self._end_epoch(epoch) 85 | -------------------------------------------------------------------------------- /rlkit/torch/dqn/dqn.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | 3 | import numpy as np 4 | import torch 5 | import torch.optim as optim 6 | from torch import nn as nn 7 | 8 | import rlkit.torch.pytorch_util as ptu 9 | from rlkit.core.eval_util import create_stats_ordered_dict 10 | from rlkit.torch.torch_rl_algorithm import TorchTrainer 11 | 12 | 13 | class DQNTrainer(TorchTrainer): 14 | def __init__( 15 | self, 16 | qf, 17 | target_qf, 18 | learning_rate=1e-3, 19 | soft_target_tau=1e-3, 20 | target_update_period=1, 21 | qf_criterion=None, 22 | 23 | discount=0.99, 24 | reward_scale=1.0, 25 | ): 26 | super().__init__() 27 | self.qf = qf 28 | self.target_qf = target_qf 29 | self.learning_rate = learning_rate 30 | self.soft_target_tau = soft_target_tau 31 | self.target_update_period = target_update_period 32 | self.qf_optimizer = optim.Adam( 33 | self.qf.parameters(), 34 | lr=self.learning_rate, 35 | ) 36 | self.discount = discount 37 | self.reward_scale = reward_scale 38 | self.qf_criterion = qf_criterion or nn.MSELoss() 39 | self.eval_statistics = OrderedDict() 40 | self._n_train_steps_total = 0 41 | self._need_to_update_eval_statistics = True 42 | 43 | def train_from_torch(self, batch): 44 | rewards = batch['rewards'] * self.reward_scale 45 | terminals = batch['terminals'] 46 | obs = batch['observations'] 47 | actions = batch['actions'] 48 | next_obs = batch['next_observations'] 49 | 50 | """ 51 | Compute loss 52 | """ 53 | 54 | target_q_values = self.target_qf(next_obs).detach().max( 55 | 1, keepdim=True 56 | )[0] 57 | y_target = rewards + (1. - terminals) * self.discount * target_q_values 58 | y_target = y_target.detach() 59 | # actions is a one-hot vector 60 | y_pred = torch.sum(self.qf(obs) * actions, dim=1, keepdim=True) 61 | qf_loss = self.qf_criterion(y_pred, y_target) 62 | 63 | """ 64 | Soft target network updates 65 | """ 66 | self.qf_optimizer.zero_grad() 67 | qf_loss.backward() 68 | self.qf_optimizer.step() 69 | 70 | """ 71 | Soft Updates 72 | """ 73 | if self._n_train_steps_total % self.target_update_period == 0: 74 | ptu.soft_update_from_to( 75 | self.qf, self.target_qf, self.soft_target_tau 76 | ) 77 | 78 | """ 79 | Save some statistics for eval using just one batch. 80 | """ 81 | if self._need_to_update_eval_statistics: 82 | self._need_to_update_eval_statistics = False 83 | self.eval_statistics['QF Loss'] = np.mean(ptu.get_numpy(qf_loss)) 84 | self.eval_statistics.update(create_stats_ordered_dict( 85 | 'Y Predictions', 86 | ptu.get_numpy(y_pred), 87 | )) 88 | 89 | def get_diagnostics(self): 90 | return self.eval_statistics 91 | 92 | def end_epoch(self, epoch): 93 | self._need_to_update_eval_statistics = True 94 | 95 | @property 96 | def networks(self): 97 | return [ 98 | self.qf, 99 | self.target_qf, 100 | ] 101 | 102 | def get_snapshot(self): 103 | return dict( 104 | qf=self.qf, 105 | target_qf=self.target_qf, 106 | ) 107 | -------------------------------------------------------------------------------- /examples/ddpg.py: -------------------------------------------------------------------------------- 1 | """ 2 | Example of running PyTorch implementation of DDPG on HalfCheetah. 3 | """ 4 | import copy 5 | 6 | from gym.envs.mujoco import HalfCheetahEnv 7 | 8 | from rlkit.data_management.env_replay_buffer import EnvReplayBuffer 9 | from rlkit.envs.wrappers import NormalizedBoxEnv 10 | from rlkit.exploration_strategies.base import ( 11 | PolicyWrappedWithExplorationStrategy 12 | ) 13 | from rlkit.exploration_strategies.ou_strategy import OUStrategy 14 | from rlkit.launchers.launcher_util import setup_logger 15 | from rlkit.samplers.data_collector import MdpPathCollector 16 | from rlkit.torch.networks import FlattenMlp, TanhMlpPolicy 17 | from rlkit.torch.ddpg.ddpg import DDPGTrainer 18 | import rlkit.torch.pytorch_util as ptu 19 | from rlkit.torch.torch_rl_algorithm import TorchBatchRLAlgorithm 20 | 21 | 22 | def experiment(variant): 23 | eval_env = NormalizedBoxEnv(HalfCheetahEnv()) 24 | expl_env = NormalizedBoxEnv(HalfCheetahEnv()) 25 | # Or for a specific version: 26 | # import gym 27 | # env = NormalizedBoxEnv(gym.make('HalfCheetah-v1')) 28 | obs_dim = eval_env.observation_space.low.size 29 | action_dim = eval_env.action_space.low.size 30 | qf = FlattenMlp( 31 | input_size=obs_dim + action_dim, 32 | output_size=1, 33 | **variant['qf_kwargs'] 34 | ) 35 | policy = TanhMlpPolicy( 36 | input_size=obs_dim, 37 | output_size=action_dim, 38 | **variant['policy_kwargs'] 39 | ) 40 | target_qf = copy.deepcopy(qf) 41 | target_policy = copy.deepcopy(policy) 42 | eval_path_collector = MdpPathCollector(eval_env, policy) 43 | exploration_policy = PolicyWrappedWithExplorationStrategy( 44 | exploration_strategy=OUStrategy(action_space=expl_env.action_space), 45 | policy=policy, 46 | ) 47 | expl_path_collector = MdpPathCollector(expl_env, exploration_policy) 48 | replay_buffer = EnvReplayBuffer(variant['replay_buffer_size'], expl_env) 49 | trainer = DDPGTrainer( 50 | qf=qf, 51 | target_qf=target_qf, 52 | policy=policy, 53 | target_policy=target_policy, 54 | **variant['trainer_kwargs'] 55 | ) 56 | algorithm = TorchBatchRLAlgorithm( 57 | trainer=trainer, 58 | exploration_env=expl_env, 59 | evaluation_env=eval_env, 60 | exploration_data_collector=expl_path_collector, 61 | evaluation_data_collector=eval_path_collector, 62 | replay_buffer=replay_buffer, 63 | **variant['algorithm_kwargs'] 64 | ) 65 | algorithm.to(ptu.device) 66 | algorithm.train() 67 | 68 | 69 | if __name__ == "__main__": 70 | # noinspection PyTypeChecker 71 | variant = dict( 72 | algorithm_kwargs=dict( 73 | num_epochs=1000, 74 | num_eval_steps_per_epoch=1000, 75 | num_trains_per_train_loop=1000, 76 | num_expl_steps_per_train_loop=1000, 77 | min_num_steps_before_training=10000, 78 | max_path_length=1000, 79 | batch_size=128, 80 | ), 81 | trainer_kwargs=dict( 82 | use_soft_update=True, 83 | tau=1e-2, 84 | discount=0.99, 85 | qf_learning_rate=1e-3, 86 | policy_learning_rate=1e-4, 87 | ), 88 | qf_kwargs=dict( 89 | hidden_sizes=[400, 300], 90 | ), 91 | policy_kwargs=dict( 92 | hidden_sizes=[400, 300], 93 | ), 94 | replay_buffer_size=int(1E6), 95 | ) 96 | # ptu.set_gpu_mode(True) # optionally set the GPU (default=False) 97 | setup_logger('name-of-experiment', variant=variant) 98 | experiment(variant) 99 | -------------------------------------------------------------------------------- /rlkit/torch/pytorch_util.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | def soft_update_from_to(source, target, tau): 6 | for target_param, param in zip(target.parameters(), source.parameters()): 7 | target_param.data.copy_( 8 | target_param.data * (1.0 - tau) + param.data * tau 9 | ) 10 | 11 | 12 | def copy_model_params_from_to(source, target): 13 | for target_param, param in zip(target.parameters(), source.parameters()): 14 | target_param.data.copy_(param.data) 15 | 16 | 17 | def fanin_init(tensor): 18 | size = tensor.size() 19 | if len(size) == 2: 20 | fan_in = size[0] 21 | elif len(size) > 2: 22 | fan_in = np.prod(size[1:]) 23 | else: 24 | raise Exception("Shape must be have dimension at least 2.") 25 | bound = 1. / np.sqrt(fan_in) 26 | return tensor.data.uniform_(-bound, bound) 27 | 28 | 29 | def fanin_init_weights_like(tensor): 30 | size = tensor.size() 31 | if len(size) == 2: 32 | fan_in = size[0] 33 | elif len(size) > 2: 34 | fan_in = np.prod(size[1:]) 35 | else: 36 | raise Exception("Shape must be have dimension at least 2.") 37 | bound = 1. / np.sqrt(fan_in) 38 | new_tensor = FloatTensor(tensor.size()) 39 | new_tensor.uniform_(-bound, bound) 40 | return new_tensor 41 | 42 | 43 | """ 44 | GPU wrappers 45 | """ 46 | 47 | _use_gpu = False 48 | device = None 49 | _gpu_id = 0 50 | 51 | 52 | def set_gpu_mode(mode, gpu_id=0): 53 | global _use_gpu 54 | global device 55 | global _gpu_id 56 | _gpu_id = gpu_id 57 | _use_gpu = mode 58 | device = torch.device("cuda:" + str(gpu_id) if _use_gpu else "cpu") 59 | 60 | 61 | def gpu_enabled(): 62 | return _use_gpu 63 | 64 | 65 | def set_device(gpu_id): 66 | torch.cuda.set_device(gpu_id) 67 | 68 | 69 | # noinspection PyPep8Naming 70 | def FloatTensor(*args, torch_device=None, **kwargs): 71 | if torch_device is None: 72 | torch_device = device 73 | return torch.FloatTensor(*args, **kwargs, device=torch_device) 74 | 75 | 76 | def from_numpy(*args, **kwargs): 77 | return torch.from_numpy(*args, **kwargs).float().to(device) 78 | 79 | 80 | def get_numpy(tensor): 81 | return tensor.to('cpu').detach().numpy() 82 | 83 | 84 | def zeros(*sizes, torch_device=None, **kwargs): 85 | if torch_device is None: 86 | torch_device = device 87 | return torch.zeros(*sizes, **kwargs, device=torch_device) 88 | 89 | 90 | def ones(*sizes, torch_device=None, **kwargs): 91 | if torch_device is None: 92 | torch_device = device 93 | return torch.ones(*sizes, **kwargs, device=torch_device) 94 | 95 | 96 | def ones_like(*args, torch_device=None, **kwargs): 97 | if torch_device is None: 98 | torch_device = device 99 | return torch.ones_like(*args, **kwargs, device=torch_device) 100 | 101 | 102 | def randn(*args, torch_device=None, **kwargs): 103 | if torch_device is None: 104 | torch_device = device 105 | return torch.randn(*args, **kwargs, device=torch_device) 106 | 107 | 108 | def zeros_like(*args, torch_device=None, **kwargs): 109 | if torch_device is None: 110 | torch_device = device 111 | return torch.zeros_like(*args, **kwargs, device=torch_device) 112 | 113 | 114 | def tensor(*args, torch_device=None, **kwargs): 115 | if torch_device is None: 116 | torch_device = device 117 | return torch.tensor(*args, **kwargs, device=torch_device) 118 | 119 | 120 | def normal(*args, **kwargs): 121 | return torch.normal(*args, **kwargs).to(device) 122 | -------------------------------------------------------------------------------- /rlkit/data_management/simple_replay_buffer.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | 3 | import numpy as np 4 | 5 | from rlkit.data_management.replay_buffer import ReplayBuffer 6 | 7 | 8 | class SimpleReplayBuffer(ReplayBuffer): 9 | 10 | def __init__( 11 | self, 12 | max_replay_buffer_size, 13 | observation_dim, 14 | action_dim, 15 | env_info_sizes, 16 | ): 17 | self._observation_dim = observation_dim 18 | self._action_dim = action_dim 19 | self._max_replay_buffer_size = max_replay_buffer_size 20 | self._observations = np.zeros((max_replay_buffer_size, observation_dim)) 21 | # It's a bit memory inefficient to save the observations twice, 22 | # but it makes the code *much* easier since you no longer have to 23 | # worry about termination conditions. 24 | self._next_obs = np.zeros((max_replay_buffer_size, observation_dim)) 25 | self._actions = np.zeros((max_replay_buffer_size, action_dim)) 26 | # Make everything a 2D np array to make it easier for other code to 27 | # reason about the shape of the data 28 | self._rewards = np.zeros((max_replay_buffer_size, 1)) 29 | # self._terminals[i] = a terminal was received at time i 30 | self._terminals = np.zeros((max_replay_buffer_size, 1), dtype='uint8') 31 | # Define self._env_infos[key][i] to be the return value of env_info[key] 32 | # at time i 33 | self._env_infos = {} 34 | for key, size in env_info_sizes.items(): 35 | self._env_infos[key] = np.zeros((max_replay_buffer_size, size)) 36 | self._env_info_keys = env_info_sizes.keys() 37 | 38 | self._top = 0 39 | self._size = 0 40 | 41 | def add_sample(self, observation, action, reward, next_observation, 42 | terminal, env_info, **kwargs): 43 | self._observations[self._top] = observation 44 | self._actions[self._top] = action 45 | self._rewards[self._top] = reward 46 | self._terminals[self._top] = terminal 47 | self._next_obs[self._top] = next_observation 48 | 49 | for key in self._env_info_keys: 50 | self._env_infos[key][self._top] = env_info[key] 51 | self._advance() 52 | 53 | def terminate_episode(self): 54 | pass 55 | 56 | def _advance(self): 57 | self._top = (self._top + 1) % self._max_replay_buffer_size 58 | if self._size < self._max_replay_buffer_size: 59 | self._size += 1 60 | 61 | def random_batch(self, batch_size): 62 | indices = np.random.randint(0, self._size, batch_size) 63 | batch = dict( 64 | observations=self._observations[indices], 65 | actions=self._actions[indices], 66 | rewards=self._rewards[indices], 67 | terminals=self._terminals[indices], 68 | next_observations=self._next_obs[indices], 69 | ) 70 | for key in self._env_info_keys: 71 | assert key not in batch.keys() 72 | batch[key] = self._env_infos[key][indices] 73 | return batch 74 | 75 | def rebuild_env_info_dict(self, idx): 76 | return { 77 | key: self._env_infos[key][idx] 78 | for key in self._env_info_keys 79 | } 80 | 81 | def batch_env_info_dict(self, indices): 82 | return { 83 | key: self._env_infos[key][indices] 84 | for key in self._env_info_keys 85 | } 86 | 87 | def num_steps_can_sample(self): 88 | return self._size 89 | 90 | def get_diagnostics(self): 91 | return OrderedDict([ 92 | ('size', self._size) 93 | ]) 94 | -------------------------------------------------------------------------------- /examples/sac.py: -------------------------------------------------------------------------------- 1 | from gym.envs.mujoco import HalfCheetahEnv 2 | 3 | import rlkit.torch.pytorch_util as ptu 4 | from rlkit.data_management.env_replay_buffer import EnvReplayBuffer 5 | from rlkit.envs.wrappers import NormalizedBoxEnv 6 | from rlkit.launchers.launcher_util import setup_logger 7 | from rlkit.samplers.data_collector import MdpPathCollector 8 | from rlkit.torch.sac.policies import TanhGaussianPolicy, MakeDeterministic 9 | from rlkit.torch.sac.sac import SACTrainer 10 | from rlkit.torch.networks import FlattenMlp 11 | from rlkit.torch.torch_rl_algorithm import TorchBatchRLAlgorithm 12 | 13 | 14 | def experiment(variant): 15 | expl_env = NormalizedBoxEnv(HalfCheetahEnv()) 16 | eval_env = NormalizedBoxEnv(HalfCheetahEnv()) 17 | obs_dim = expl_env.observation_space.low.size 18 | action_dim = eval_env.action_space.low.size 19 | 20 | M = variant['layer_size'] 21 | qf1 = FlattenMlp( 22 | input_size=obs_dim + action_dim, 23 | output_size=1, 24 | hidden_sizes=[M, M], 25 | ) 26 | qf2 = FlattenMlp( 27 | input_size=obs_dim + action_dim, 28 | output_size=1, 29 | hidden_sizes=[M, M], 30 | ) 31 | target_qf1 = FlattenMlp( 32 | input_size=obs_dim + action_dim, 33 | output_size=1, 34 | hidden_sizes=[M, M], 35 | ) 36 | target_qf2 = FlattenMlp( 37 | input_size=obs_dim + action_dim, 38 | output_size=1, 39 | hidden_sizes=[M, M], 40 | ) 41 | policy = TanhGaussianPolicy( 42 | obs_dim=obs_dim, 43 | action_dim=action_dim, 44 | hidden_sizes=[M, M], 45 | ) 46 | eval_policy = MakeDeterministic(policy) 47 | eval_path_collector = MdpPathCollector( 48 | eval_env, 49 | eval_policy, 50 | ) 51 | expl_path_collector = MdpPathCollector( 52 | expl_env, 53 | policy, 54 | ) 55 | replay_buffer = EnvReplayBuffer( 56 | variant['replay_buffer_size'], 57 | expl_env, 58 | ) 59 | trainer = SACTrainer( 60 | env=eval_env, 61 | policy=policy, 62 | qf1=qf1, 63 | qf2=qf2, 64 | target_qf1=target_qf1, 65 | target_qf2=target_qf2, 66 | **variant['trainer_kwargs'] 67 | ) 68 | algorithm = TorchBatchRLAlgorithm( 69 | trainer=trainer, 70 | exploration_env=expl_env, 71 | evaluation_env=eval_env, 72 | exploration_data_collector=expl_path_collector, 73 | evaluation_data_collector=eval_path_collector, 74 | replay_buffer=replay_buffer, 75 | **variant['algorithm_kwargs'] 76 | ) 77 | algorithm.to(ptu.device) 78 | algorithm.train() 79 | 80 | 81 | 82 | 83 | if __name__ == "__main__": 84 | # noinspection PyTypeChecker 85 | variant = dict( 86 | algorithm="SAC", 87 | version="normal", 88 | layer_size=256, 89 | replay_buffer_size=int(1E6), 90 | algorithm_kwargs=dict( 91 | num_epochs=3000, 92 | num_eval_steps_per_epoch=5000, 93 | num_trains_per_train_loop=1000, 94 | num_expl_steps_per_train_loop=1000, 95 | min_num_steps_before_training=1000, 96 | max_path_length=1000, 97 | batch_size=256, 98 | ), 99 | trainer_kwargs=dict( 100 | discount=0.99, 101 | soft_target_tau=5e-3, 102 | target_update_period=1, 103 | policy_lr=3E-4, 104 | qf_lr=3E-4, 105 | reward_scale=1, 106 | use_automatic_entropy_tuning=True, 107 | ), 108 | ) 109 | setup_logger('name-of-experiment', variant=variant) 110 | # ptu.set_gpu_mode(True) # optionally set the GPU (default=False) 111 | experiment(variant) 112 | -------------------------------------------------------------------------------- /environment/docker/Dockerfile: -------------------------------------------------------------------------------- 1 | # We need the CUDA base dockerfile to enable GPU rendering 2 | # on hosts with GPUs. 3 | # The image below is a pinned version of nvidia/cuda:9.1-cudnn7-devel-ubuntu16.04 (from Jan 2018) 4 | # If updating the base image, be sure to test on GPU since it has broken in the past. 5 | FROM nvidia/cuda:9.0-cudnn7-devel-ubuntu16.04 6 | 7 | 8 | RUN apt-get update -q \ 9 | && DEBIAN_FRONTEND=noninteractive apt-get install -y \ 10 | curl \ 11 | git \ 12 | libgl1-mesa-dev \ 13 | libgl1-mesa-glx \ 14 | libglew-dev \ 15 | libosmesa6-dev \ 16 | software-properties-common \ 17 | net-tools \ 18 | unzip \ 19 | vim \ 20 | virtualenv \ 21 | wget \ 22 | xpra \ 23 | xserver-xorg-dev \ 24 | && apt-get clean \ 25 | && rm -rf /var/lib/apt/lists/* 26 | 27 | RUN DEBIAN_FRONTEND=noninteractive add-apt-repository --yes ppa:deadsnakes/ppa && apt-get update 28 | RUN DEBIAN_FRONTEND=noninteractive apt-get install --yes python3.5-dev python3.5 python3-pip 29 | RUN virtualenv --python=python3.5 env 30 | 31 | RUN rm /usr/bin/python 32 | RUN ln -s /env/bin/python3.5 /usr/bin/python 33 | RUN ln -s /env/bin/pip3.5 /usr/bin/pip 34 | RUN ln -s /env/bin/pytest /usr/bin/pytest 35 | 36 | RUN curl -o /usr/local/bin/patchelf https://s3-us-west-2.amazonaws.com/openai-sci-artifacts/manual-builds/patchelf_0.9_amd64.elf \ 37 | && chmod +x /usr/local/bin/patchelf 38 | 39 | ENV LANG C.UTF-8 40 | 41 | RUN mkdir -p /root/.mujoco \ 42 | && wget https://www.roboti.us/download/mjpro150_linux.zip -O mujoco.zip \ 43 | && unzip mujoco.zip -d /root/.mujoco \ 44 | && rm mujoco.zip 45 | COPY ./mjkey.txt /root/.mujoco/ 46 | ENV LD_LIBRARY_PATH /root/.mujoco/mjpro150/bin:${LD_LIBRARY_PATH} 47 | ENV LD_LIBRARY_PATH /usr/local/nvidia/lib64:${LD_LIBRARY_PATH} 48 | 49 | COPY vendor/Xdummy /usr/local/bin/Xdummy 50 | RUN chmod +x /usr/local/bin/Xdummy 51 | 52 | # Workaround for https://bugs.launchpad.net/ubuntu/+source/nvidia-graphics-drivers-375/+bug/1674677 53 | COPY ./vendor/10_nvidia.json /usr/share/glvnd/egl_vendor.d/10_nvidia.json 54 | 55 | RUN apt-get update && apt-get install -y libav-tools 56 | 57 | # For some reason this works despite an error showing up... 58 | RUN DEBIAN_FRONTEND=noninteractive apt-get -qy install nvidia-384; exit 0 59 | ENV LD_LIBRARY_PATH ${LD_LIBRARY_PATH}:/usr/lib/nvidia-384 60 | 61 | RUN mkdir /root/code 62 | WORKDIR /root/code 63 | 64 | WORKDIR /mujoco_py 65 | 66 | # For atari-py 67 | RUN apt-get install -y zlib1g-dev swig cmake 68 | 69 | # Previous versions installed from a requirements.txt, but direct pip 70 | # install seems cleaner 71 | RUN pip install glfw>=1.4.0 72 | RUN pip install numpy>=1.11 73 | RUN pip install Cython>=0.27.2 74 | RUN pip install imageio>=2.1.2 75 | RUN pip install cffi>=1.10 76 | RUN pip install imagehash>=3.4 77 | RUN pip install ipdb 78 | RUN pip install Pillow>=4.0.0 79 | RUN pip install pycparser>=2.17.0 80 | RUN pip install pytest>=3.0.5 81 | RUN pip install pytest-instafail==0.3.0 82 | RUN pip install scipy>=0.18.0 83 | RUN pip install sphinx 84 | RUN pip install sphinx_rtd_theme 85 | RUN pip install numpydoc 86 | RUN pip install cloudpickle==0.5.2 87 | RUN pip install cached-property==1.3.1 88 | RUN pip install gym[all]==0.10.5 89 | RUN pip install gitpython==2.1.7 90 | RUN pip install gtimer==1.0.0b5 91 | RUN pip install awscli==1.11.179 92 | RUN pip install boto3==1.4.8 93 | RUN pip install ray==0.2.2 94 | RUN pip install path.py==10.3.1 95 | RUN pip install http://download.pytorch.org/whl/cu90/torch-0.4.1-cp35-cp35m-linux_x86_64.whl 96 | RUN pip install joblib==0.9.4 97 | RUN pip install opencv-python==3.4.0.12 98 | RUN pip install torchvision==0.2.0 99 | RUN pip install sk-video==1.1.10 100 | RUN pip install git+https://github.com/vitchyr/multiworld.git 101 | -------------------------------------------------------------------------------- /rlkit/data_management/normalizer.py: -------------------------------------------------------------------------------- 1 | """ 2 | Based on code from Marcin Andrychowicz 3 | """ 4 | import numpy as np 5 | 6 | 7 | class Normalizer(object): 8 | def __init__( 9 | self, 10 | size, 11 | eps=1e-8, 12 | default_clip_range=np.inf, 13 | mean=0, 14 | std=1, 15 | ): 16 | self.size = size 17 | self.eps = eps 18 | self.default_clip_range = default_clip_range 19 | self.sum = np.zeros(self.size, np.float32) 20 | self.sumsq = np.zeros(self.size, np.float32) 21 | self.count = np.ones(1, np.float32) 22 | self.mean = mean + np.zeros(self.size, np.float32) 23 | self.std = std * np.ones(self.size, np.float32) 24 | self.synchronized = True 25 | 26 | def update(self, v): 27 | if v.ndim == 1: 28 | v = np.expand_dims(v, 0) 29 | assert v.ndim == 2 30 | assert v.shape[1] == self.size 31 | self.sum += v.sum(axis=0) 32 | self.sumsq += (np.square(v)).sum(axis=0) 33 | self.count[0] += v.shape[0] 34 | self.synchronized = False 35 | 36 | def normalize(self, v, clip_range=None): 37 | if not self.synchronized: 38 | self.synchronize() 39 | if clip_range is None: 40 | clip_range = self.default_clip_range 41 | mean, std = self.mean, self.std 42 | if v.ndim == 2: 43 | mean = mean.reshape(1, -1) 44 | std = std.reshape(1, -1) 45 | return np.clip((v - mean) / std, -clip_range, clip_range) 46 | 47 | def denormalize(self, v): 48 | if not self.synchronized: 49 | self.synchronize() 50 | mean, std = self.mean, self.std 51 | if v.ndim == 2: 52 | mean = mean.reshape(1, -1) 53 | std = std.reshape(1, -1) 54 | return mean + v * std 55 | 56 | def synchronize(self): 57 | self.mean[...] = self.sum / self.count[0] 58 | self.std[...] = np.sqrt( 59 | np.maximum( 60 | np.square(self.eps), 61 | self.sumsq / self.count[0] - np.square(self.mean) 62 | ) 63 | ) 64 | self.synchronized = True 65 | 66 | 67 | class IdentityNormalizer(object): 68 | def __init__(self, *args, **kwargs): 69 | pass 70 | 71 | def update(self, v): 72 | pass 73 | 74 | def normalize(self, v, clip_range=None): 75 | return v 76 | 77 | def denormalize(self, v): 78 | return v 79 | 80 | 81 | class FixedNormalizer(object): 82 | def __init__( 83 | self, 84 | size, 85 | default_clip_range=np.inf, 86 | mean=0, 87 | std=1, 88 | eps=1e-8, 89 | ): 90 | assert std > 0 91 | std = std + eps 92 | self.size = size 93 | self.default_clip_range = default_clip_range 94 | self.mean = mean + np.zeros(self.size, np.float32) 95 | self.std = std + np.zeros(self.size, np.float32) 96 | self.eps = eps 97 | 98 | def set_mean(self, mean): 99 | self.mean = mean + np.zeros(self.size, np.float32) 100 | 101 | def set_std(self, std): 102 | std = std + self.eps 103 | self.std = std + np.zeros(self.size, np.float32) 104 | 105 | def normalize(self, v, clip_range=None): 106 | if clip_range is None: 107 | clip_range = self.default_clip_range 108 | mean, std = self.mean, self.std 109 | if v.ndim == 2: 110 | mean = mean.reshape(1, -1) 111 | std = std.reshape(1, -1) 112 | return np.clip((v - mean) / std, -clip_range, clip_range) 113 | 114 | def denormalize(self, v): 115 | mean, std = self.mean, self.std 116 | if v.ndim == 2: 117 | mean = mean.reshape(1, -1) 118 | std = std.reshape(1, -1) 119 | return mean + v * std 120 | 121 | def copy_stats(self, other): 122 | self.set_mean(other.mean) 123 | self.set_std(other.std) 124 | -------------------------------------------------------------------------------- /rlkit/torch/networks.py: -------------------------------------------------------------------------------- 1 | """ 2 | General networks for pytorch. 3 | 4 | Algorithm-specific networks should go else-where. 5 | """ 6 | import torch 7 | from torch import nn as nn 8 | from torch.nn import functional as F 9 | 10 | from rlkit.policies.base import Policy 11 | from rlkit.torch import pytorch_util as ptu 12 | from rlkit.torch.core import eval_np 13 | from rlkit.torch.data_management.normalizer import TorchFixedNormalizer 14 | from rlkit.torch.modules import LayerNorm 15 | 16 | 17 | def identity(x): 18 | return x 19 | 20 | 21 | class Mlp(nn.Module): 22 | def __init__( 23 | self, 24 | hidden_sizes, 25 | output_size, 26 | input_size, 27 | init_w=3e-3, 28 | hidden_activation=F.relu, 29 | output_activation=identity, 30 | hidden_init=ptu.fanin_init, 31 | b_init_value=0.1, 32 | layer_norm=False, 33 | layer_norm_kwargs=None, 34 | ): 35 | super().__init__() 36 | 37 | if layer_norm_kwargs is None: 38 | layer_norm_kwargs = dict() 39 | 40 | self.input_size = input_size 41 | self.output_size = output_size 42 | self.hidden_activation = hidden_activation 43 | self.output_activation = output_activation 44 | self.layer_norm = layer_norm 45 | self.fcs = [] 46 | self.layer_norms = [] 47 | in_size = input_size 48 | 49 | for i, next_size in enumerate(hidden_sizes): 50 | fc = nn.Linear(in_size, next_size) 51 | in_size = next_size 52 | hidden_init(fc.weight) 53 | fc.bias.data.fill_(b_init_value) 54 | self.__setattr__("fc{}".format(i), fc) 55 | self.fcs.append(fc) 56 | 57 | if self.layer_norm: 58 | ln = LayerNorm(next_size) 59 | self.__setattr__("layer_norm{}".format(i), ln) 60 | self.layer_norms.append(ln) 61 | 62 | self.last_fc = nn.Linear(in_size, output_size) 63 | self.last_fc.weight.data.uniform_(-init_w, init_w) 64 | self.last_fc.bias.data.uniform_(-init_w, init_w) 65 | 66 | def forward(self, input, return_preactivations=False): 67 | h = input 68 | for i, fc in enumerate(self.fcs): 69 | h = fc(h) 70 | if self.layer_norm and i < len(self.fcs) - 1: 71 | h = self.layer_norms[i](h) 72 | h = self.hidden_activation(h) 73 | preactivation = self.last_fc(h) 74 | output = self.output_activation(preactivation) 75 | if return_preactivations: 76 | return output, preactivation 77 | else: 78 | return output 79 | 80 | 81 | class FlattenMlp(Mlp): 82 | """ 83 | Flatten inputs along dimension 1 and then pass through MLP. 84 | """ 85 | 86 | def forward(self, *inputs, **kwargs): 87 | flat_inputs = torch.cat(inputs, dim=1) 88 | return super().forward(flat_inputs, **kwargs) 89 | 90 | 91 | class MlpPolicy(Mlp, Policy): 92 | """ 93 | A simpler interface for creating policies. 94 | """ 95 | 96 | def __init__( 97 | self, 98 | *args, 99 | obs_normalizer: TorchFixedNormalizer = None, 100 | **kwargs 101 | ): 102 | super().__init__(*args, **kwargs) 103 | self.obs_normalizer = obs_normalizer 104 | 105 | def forward(self, obs, **kwargs): 106 | if self.obs_normalizer: 107 | obs = self.obs_normalizer.normalize(obs) 108 | return super().forward(obs, **kwargs) 109 | 110 | def get_action(self, obs_np): 111 | actions = self.get_actions(obs_np[None]) 112 | return actions[0, :], {} 113 | 114 | def get_actions(self, obs): 115 | return eval_np(self, obs) 116 | 117 | 118 | class TanhMlpPolicy(MlpPolicy): 119 | """ 120 | A helper class since most policies have a tanh output activation. 121 | """ 122 | def __init__(self, *args, **kwargs): 123 | super().__init__(*args, output_activation=torch.tanh, **kwargs) 124 | -------------------------------------------------------------------------------- /rlkit/core/online_rl_algorithm.py: -------------------------------------------------------------------------------- 1 | import abc 2 | 3 | import gtimer as gt 4 | from rlkit.core.rl_algorithm import BaseRLAlgorithm 5 | from rlkit.data_management.replay_buffer import ReplayBuffer 6 | from rlkit.samplers.data_collector import ( 7 | PathCollector, 8 | StepCollector, 9 | ) 10 | 11 | 12 | class OnlineRLAlgorithm(BaseRLAlgorithm, metaclass=abc.ABCMeta): 13 | def __init__( 14 | self, 15 | trainer, 16 | exploration_env, 17 | evaluation_env, 18 | exploration_data_collector: StepCollector, 19 | evaluation_data_collector: PathCollector, 20 | replay_buffer: ReplayBuffer, 21 | batch_size, 22 | max_path_length, 23 | num_epochs, 24 | num_eval_steps_per_epoch, 25 | num_expl_steps_per_train_loop, 26 | num_trains_per_train_loop, 27 | num_train_loops_per_epoch=1, 28 | min_num_steps_before_training=0, 29 | ): 30 | super().__init__( 31 | trainer, 32 | exploration_env, 33 | evaluation_env, 34 | exploration_data_collector, 35 | evaluation_data_collector, 36 | replay_buffer, 37 | ) 38 | self.batch_size = batch_size 39 | self.max_path_length = max_path_length 40 | self.num_epochs = num_epochs 41 | self.num_eval_steps_per_epoch = num_eval_steps_per_epoch 42 | self.num_trains_per_train_loop = num_trains_per_train_loop 43 | self.num_train_loops_per_epoch = num_train_loops_per_epoch 44 | self.num_expl_steps_per_train_loop = num_expl_steps_per_train_loop 45 | self.min_num_steps_before_training = min_num_steps_before_training 46 | 47 | assert self.num_trains_per_train_loop >= self.num_expl_steps_per_train_loop, \ 48 | 'Online training presumes num_trains_per_train_loop >= num_expl_steps_per_train_loop' 49 | 50 | def _train(self): 51 | self.training_mode(False) 52 | if self.min_num_steps_before_training > 0: 53 | self.expl_data_collector.collect_new_steps( 54 | self.max_path_length, 55 | self.min_num_steps_before_training, 56 | discard_incomplete_paths=False, 57 | ) 58 | init_expl_paths = self.expl_data_collector.get_epoch_paths() 59 | self.replay_buffer.add_paths(init_expl_paths) 60 | self.expl_data_collector.end_epoch(-1) 61 | 62 | gt.stamp('initial exploration', unique=True) 63 | 64 | num_trains_per_expl_step = self.num_trains_per_train_loop // self.num_expl_steps_per_train_loop 65 | for epoch in gt.timed_for( 66 | range(self._start_epoch, self.num_epochs), 67 | save_itrs=True, 68 | ): 69 | self.eval_data_collector.collect_new_paths( 70 | self.max_path_length, 71 | self.num_eval_steps_per_epoch, 72 | discard_incomplete_paths=True, 73 | ) 74 | gt.stamp('evaluation sampling') 75 | 76 | for _ in range(self.num_train_loops_per_epoch): 77 | for _ in range(self.num_expl_steps_per_train_loop): 78 | self.expl_data_collector.collect_new_steps( 79 | self.max_path_length, 80 | 1, # num steps 81 | discard_incomplete_paths=False, 82 | ) 83 | gt.stamp('exploration sampling', unique=False) 84 | 85 | self.training_mode(True) 86 | for _ in range(num_trains_per_expl_step): 87 | train_data = self.replay_buffer.random_batch( 88 | self.batch_size) 89 | self.trainer.train(train_data) 90 | gt.stamp('training', unique=False) 91 | self.training_mode(False) 92 | 93 | new_expl_paths = self.expl_data_collector.get_epoch_paths() 94 | self.replay_buffer.add_paths(new_expl_paths) 95 | gt.stamp('data storing', unique=False) 96 | 97 | self._end_epoch(epoch) 98 | -------------------------------------------------------------------------------- /rlkit/samplers/util.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def rollout(env, agent, max_path_length=np.inf, render=False): 5 | """ 6 | The following value for the following keys will be a 2D array, with the 7 | first dimension corresponding to the time dimension. 8 | - observations 9 | - actions 10 | - rewards 11 | - next_observations 12 | - terminals 13 | 14 | The next two elements will be lists of dictionaries, with the index into 15 | the list being the index into the time 16 | - agent_infos 17 | - env_infos 18 | 19 | :param env: 20 | :param agent: 21 | :param max_path_length: 22 | :param render: 23 | :return: 24 | """ 25 | observations = [] 26 | actions = [] 27 | rewards = [] 28 | terminals = [] 29 | agent_infos = [] 30 | env_infos = [] 31 | o = env.reset() 32 | next_o = None 33 | path_length = 0 34 | if render: 35 | env.render() 36 | while path_length < max_path_length: 37 | a, agent_info = agent.get_action(o) 38 | next_o, r, d, env_info = env.step(a) 39 | observations.append(o) 40 | rewards.append(r) 41 | terminals.append(d) 42 | actions.append(a) 43 | agent_infos.append(agent_info) 44 | env_infos.append(env_info) 45 | path_length += 1 46 | if d: 47 | break 48 | o = next_o 49 | if render: 50 | env.render() 51 | 52 | actions = np.array(actions) 53 | if len(actions.shape) == 1: 54 | actions = np.expand_dims(actions, 1) 55 | observations = np.array(observations) 56 | if len(observations.shape) == 1: 57 | observations = np.expand_dims(observations, 1) 58 | next_o = np.array([next_o]) 59 | next_observations = np.vstack( 60 | ( 61 | observations[1:, :], 62 | np.expand_dims(next_o, 0) 63 | ) 64 | ) 65 | return dict( 66 | observations=observations, 67 | actions=actions, 68 | rewards=np.array(rewards).reshape(-1, 1), 69 | next_observations=next_observations, 70 | terminals=np.array(terminals).reshape(-1, 1), 71 | agent_infos=agent_infos, 72 | env_infos=env_infos, 73 | ) 74 | 75 | 76 | def split_paths(paths): 77 | """ 78 | Stack multiples obs/actions/etc. from different paths 79 | :param paths: List of paths, where one path is something returned from 80 | the rollout functino above. 81 | :return: Tuple. Every element will have shape batch_size X DIM, including 82 | the rewards and terminal flags. 83 | """ 84 | rewards = [path["rewards"].reshape(-1, 1) for path in paths] 85 | terminals = [path["terminals"].reshape(-1, 1) for path in paths] 86 | actions = [path["actions"] for path in paths] 87 | obs = [path["observations"] for path in paths] 88 | next_obs = [path["next_observations"] for path in paths] 89 | rewards = np.vstack(rewards) 90 | terminals = np.vstack(terminals) 91 | obs = np.vstack(obs) 92 | actions = np.vstack(actions) 93 | next_obs = np.vstack(next_obs) 94 | assert len(rewards.shape) == 2 95 | assert len(terminals.shape) == 2 96 | assert len(obs.shape) == 2 97 | assert len(actions.shape) == 2 98 | assert len(next_obs.shape) == 2 99 | return rewards, terminals, obs, actions, next_obs 100 | 101 | 102 | def split_paths_to_dict(paths): 103 | rewards, terminals, obs, actions, next_obs = split_paths(paths) 104 | return dict( 105 | rewards=rewards, 106 | terminals=terminals, 107 | observations=obs, 108 | actions=actions, 109 | next_observations=next_obs, 110 | ) 111 | 112 | 113 | def get_stat_in_paths(paths, dict_name, scalar_name): 114 | if len(paths) == 0: 115 | return np.array([[]]) 116 | 117 | if type(paths[0][dict_name]) == dict: 118 | # Support rllab interface 119 | return [path[dict_name][scalar_name] for path in paths] 120 | 121 | return [ 122 | [info[scalar_name] for info in path[dict_name]] 123 | for path in paths 124 | ] -------------------------------------------------------------------------------- /rlkit/core/eval_util.py: -------------------------------------------------------------------------------- 1 | """ 2 | Common evaluation utilities. 3 | """ 4 | 5 | from collections import OrderedDict 6 | from numbers import Number 7 | 8 | import numpy as np 9 | 10 | import rlkit.pythonplusplus as ppp 11 | 12 | 13 | def get_generic_path_information(paths, stat_prefix=''): 14 | """ 15 | Get an OrderedDict with a bunch of statistic names and values. 16 | """ 17 | statistics = OrderedDict() 18 | returns = [sum(path["rewards"]) for path in paths] 19 | 20 | rewards = np.vstack([path["rewards"] for path in paths]) 21 | statistics.update(create_stats_ordered_dict('Rewards', rewards, 22 | stat_prefix=stat_prefix)) 23 | statistics.update(create_stats_ordered_dict('Returns', returns, 24 | stat_prefix=stat_prefix)) 25 | actions = [path["actions"] for path in paths] 26 | if len(actions[0].shape) == 1: 27 | actions = np.hstack([path["actions"] for path in paths]) 28 | else: 29 | actions = np.vstack([path["actions"] for path in paths]) 30 | statistics.update(create_stats_ordered_dict( 31 | 'Actions', actions, stat_prefix=stat_prefix 32 | )) 33 | statistics['Num Paths'] = len(paths) 34 | statistics[stat_prefix + 'Average Returns'] = get_average_returns(paths) 35 | 36 | for info_key in ['env_infos', 'agent_infos']: 37 | if info_key in paths[0]: 38 | all_env_infos = [ 39 | ppp.list_of_dicts__to__dict_of_lists(p[info_key]) 40 | for p in paths 41 | ] 42 | for k in all_env_infos[0].keys(): 43 | final_ks = np.array([info[k][-1] for info in all_env_infos]) 44 | first_ks = np.array([info[k][0] for info in all_env_infos]) 45 | all_ks = np.concatenate([info[k] for info in all_env_infos]) 46 | statistics.update(create_stats_ordered_dict( 47 | stat_prefix + k, 48 | final_ks, 49 | stat_prefix='{}/final/'.format(info_key), 50 | )) 51 | statistics.update(create_stats_ordered_dict( 52 | stat_prefix + k, 53 | first_ks, 54 | stat_prefix='{}/initial/'.format(info_key), 55 | )) 56 | statistics.update(create_stats_ordered_dict( 57 | stat_prefix + k, 58 | all_ks, 59 | stat_prefix='{}/'.format(info_key), 60 | )) 61 | 62 | return statistics 63 | 64 | 65 | def get_average_returns(paths): 66 | returns = [sum(path["rewards"]) for path in paths] 67 | return np.mean(returns) 68 | 69 | 70 | def create_stats_ordered_dict( 71 | name, 72 | data, 73 | stat_prefix=None, 74 | always_show_all_stats=True, 75 | exclude_max_min=False, 76 | ): 77 | if stat_prefix is not None: 78 | name = "{}{}".format(stat_prefix, name) 79 | if isinstance(data, Number): 80 | return OrderedDict({name: data}) 81 | 82 | if len(data) == 0: 83 | return OrderedDict() 84 | 85 | if isinstance(data, tuple): 86 | ordered_dict = OrderedDict() 87 | for number, d in enumerate(data): 88 | sub_dict = create_stats_ordered_dict( 89 | "{0}_{1}".format(name, number), 90 | d, 91 | ) 92 | ordered_dict.update(sub_dict) 93 | return ordered_dict 94 | 95 | if isinstance(data, list): 96 | try: 97 | iter(data[0]) 98 | except TypeError: 99 | pass 100 | else: 101 | data = np.concatenate(data) 102 | 103 | if (isinstance(data, np.ndarray) and data.size == 1 104 | and not always_show_all_stats): 105 | return OrderedDict({name: float(data)}) 106 | 107 | stats = OrderedDict([ 108 | (name + ' Mean', np.mean(data)), 109 | (name + ' Std', np.std(data)), 110 | ]) 111 | if not exclude_max_min: 112 | stats[name + ' Max'] = np.max(data) 113 | stats[name + ' Min'] = np.min(data) 114 | return stats 115 | -------------------------------------------------------------------------------- /examples/her/her_dqn_gridworld.py: -------------------------------------------------------------------------------- 1 | """ 2 | This should results in an average return of -20 by the end of training. 3 | 4 | Usually hits -30 around epoch 50. 5 | Note that one epoch = 5k steps, so 200 epochs = 1 million steps. 6 | """ 7 | import gym 8 | 9 | import rlkit.torch.pytorch_util as ptu 10 | from rlkit.data_management.obs_dict_replay_buffer import ObsDictRelabelingBuffer 11 | from rlkit.exploration_strategies.base import \ 12 | PolicyWrappedWithExplorationStrategy 13 | from rlkit.exploration_strategies.epsilon_greedy import EpsilonGreedy 14 | from rlkit.launchers.launcher_util import setup_logger 15 | from rlkit.samplers.data_collector import GoalConditionedPathCollector 16 | 17 | from rlkit.policies.argmax import ArgmaxDiscretePolicy 18 | from rlkit.torch.dqn.dqn import DQNTrainer 19 | from rlkit.torch.her.her import HERTrainer 20 | from rlkit.torch.networks import FlattenMlp 21 | from rlkit.torch.torch_rl_algorithm import TorchBatchRLAlgorithm 22 | 23 | try: 24 | import multiworld.envs.gridworlds 25 | except ImportError as e: 26 | print("To run this example, you need to install `multiworld`. See " 27 | "https://github.com/vitchyr/multiworld.") 28 | raise e 29 | 30 | 31 | def experiment(variant): 32 | expl_env = gym.make('GoalGridworld-v0') 33 | eval_env = gym.make('GoalGridworld-v0') 34 | 35 | obs_dim = expl_env.observation_space.spaces['observation'].low.size 36 | goal_dim = expl_env.observation_space.spaces['desired_goal'].low.size 37 | action_dim = expl_env.action_space.n 38 | qf = FlattenMlp( 39 | input_size=obs_dim + goal_dim, 40 | output_size=action_dim, 41 | hidden_sizes=[400, 300], 42 | ) 43 | target_qf = FlattenMlp( 44 | input_size=obs_dim + goal_dim, 45 | output_size=action_dim, 46 | hidden_sizes=[400, 300], 47 | ) 48 | eval_policy = ArgmaxDiscretePolicy(qf) 49 | exploration_strategy = EpsilonGreedy( 50 | action_space=expl_env.action_space, 51 | ) 52 | expl_policy = PolicyWrappedWithExplorationStrategy( 53 | exploration_strategy=exploration_strategy, 54 | policy=eval_policy, 55 | ) 56 | 57 | replay_buffer = ObsDictRelabelingBuffer( 58 | env=eval_env, 59 | **variant['replay_buffer_kwargs'] 60 | ) 61 | observation_key = 'observation' 62 | desired_goal_key = 'desired_goal' 63 | eval_path_collector = GoalConditionedPathCollector( 64 | eval_env, 65 | eval_policy, 66 | observation_key=observation_key, 67 | desired_goal_key=desired_goal_key, 68 | ) 69 | expl_path_collector = GoalConditionedPathCollector( 70 | expl_env, 71 | expl_policy, 72 | observation_key=observation_key, 73 | desired_goal_key=desired_goal_key, 74 | ) 75 | trainer = DQNTrainer( 76 | qf=qf, 77 | target_qf=target_qf, 78 | **variant['trainer_kwargs'] 79 | ) 80 | trainer = HERTrainer(trainer) 81 | algorithm = TorchBatchRLAlgorithm( 82 | trainer=trainer, 83 | exploration_env=expl_env, 84 | evaluation_env=eval_env, 85 | exploration_data_collector=expl_path_collector, 86 | evaluation_data_collector=eval_path_collector, 87 | replay_buffer=replay_buffer, 88 | **variant['algo_kwargs'] 89 | ) 90 | algorithm.to(ptu.device) 91 | algorithm.train() 92 | 93 | 94 | if __name__ == "__main__": 95 | variant = dict( 96 | algo_kwargs=dict( 97 | num_epochs=100, 98 | max_path_length=50, 99 | num_eval_steps_per_epoch=1000, 100 | num_expl_steps_per_train_loop=1000, 101 | num_trains_per_train_loop=1000, 102 | min_num_steps_before_training=1000, 103 | batch_size=128, 104 | ), 105 | trainer_kwargs=dict( 106 | discount=0.99, 107 | ), 108 | replay_buffer_kwargs=dict( 109 | max_size=100000, 110 | fraction_goals_rollout_goals=0.2, # equal to k = 4 in HER paper 111 | fraction_goals_env_goals=0.0, 112 | ), 113 | ) 114 | setup_logger('her-dqn-gridworld-experiment', variant=variant) 115 | experiment(variant) 116 | -------------------------------------------------------------------------------- /rlkit/util/video.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as osp 3 | import time 4 | 5 | import numpy as np 6 | import scipy.misc 7 | import skvideo.io 8 | 9 | from rlkit.envs.vae_wrapper import VAEWrappedEnv 10 | 11 | 12 | def dump_video( 13 | env, 14 | policy, 15 | filename, 16 | rollout_function, 17 | rows=3, 18 | columns=6, 19 | pad_length=0, 20 | pad_color=255, 21 | do_timer=True, 22 | horizon=100, 23 | dirname_to_save_images=None, 24 | subdirname="rollouts", 25 | imsize=84, 26 | num_channels=3, 27 | ): 28 | frames = [] 29 | H = 3 * imsize 30 | W = imsize 31 | N = rows * columns 32 | for i in range(N): 33 | start = time.time() 34 | path = rollout_function( 35 | env, 36 | policy, 37 | max_path_length=horizon, 38 | render=False, 39 | ) 40 | is_vae_env = isinstance(env, VAEWrappedEnv) 41 | l = [] 42 | for d in path['full_observations']: 43 | if is_vae_env: 44 | recon = np.clip(env._reconstruct_img(d['image_observation']), 0, 45 | 1) 46 | else: 47 | recon = d['image_observation'] 48 | l.append( 49 | get_image( 50 | d['image_desired_goal'], 51 | d['image_observation'], 52 | recon, 53 | pad_length=pad_length, 54 | pad_color=pad_color, 55 | imsize=imsize, 56 | ) 57 | ) 58 | frames += l 59 | 60 | if dirname_to_save_images: 61 | rollout_dir = osp.join(dirname_to_save_images, subdirname, str(i)) 62 | os.makedirs(rollout_dir, exist_ok=True) 63 | rollout_frames = frames[-101:] 64 | goal_img = np.flip(rollout_frames[0][:imsize, :imsize, :], 0) 65 | scipy.misc.imsave(rollout_dir + "/goal.png", goal_img) 66 | goal_img = np.flip(rollout_frames[1][:imsize, :imsize, :], 0) 67 | scipy.misc.imsave(rollout_dir + "/z_goal.png", goal_img) 68 | for j in range(0, 101, 1): 69 | img = np.flip(rollout_frames[j][imsize:, :imsize, :], 0) 70 | scipy.misc.imsave(rollout_dir + "/" + str(j) + ".png", img) 71 | if do_timer: 72 | print(i, time.time() - start) 73 | 74 | frames = np.array(frames, dtype=np.uint8) 75 | path_length = frames.size // ( 76 | N * (H + 2 * pad_length) * (W + 2 * pad_length) * num_channels 77 | ) 78 | frames = np.array(frames, dtype=np.uint8).reshape( 79 | (N, path_length, H + 2 * pad_length, W + 2 * pad_length, num_channels) 80 | ) 81 | f1 = [] 82 | for k1 in range(columns): 83 | f2 = [] 84 | for k2 in range(rows): 85 | k = k1 * rows + k2 86 | f2.append(frames[k:k + 1, :, :, :, :].reshape( 87 | (path_length, H + 2 * pad_length, W + 2 * pad_length, 88 | num_channels) 89 | )) 90 | f1.append(np.concatenate(f2, axis=1)) 91 | outputdata = np.concatenate(f1, axis=2) 92 | skvideo.io.vwrite(filename, outputdata) 93 | print("Saved video to ", filename) 94 | 95 | 96 | def get_image(goal, obs, recon_obs, imsize=84, pad_length=1, pad_color=255): 97 | if len(goal.shape) == 1: 98 | goal = goal.reshape(-1, imsize, imsize).transpose() 99 | obs = obs.reshape(-1, imsize, imsize).transpose() 100 | recon_obs = recon_obs.reshape(-1, imsize, imsize).transpose() 101 | img = np.concatenate((goal, obs, recon_obs)) 102 | img = np.uint8(255 * img) 103 | if pad_length > 0: 104 | img = add_border(img, pad_length, pad_color) 105 | return img 106 | 107 | 108 | def add_border(img, pad_length, pad_color, imsize=84): 109 | H = 3 * imsize 110 | W = imsize 111 | img = img.reshape((3 * imsize, imsize, -1)) 112 | img2 = np.ones((H + 2 * pad_length, W + 2 * pad_length, img.shape[2]), 113 | dtype=np.uint8) * pad_color 114 | img2[pad_length:-pad_length, pad_length:-pad_length, :] = img 115 | return img2 116 | -------------------------------------------------------------------------------- /rlkit/torch/skewfit/video_gen.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as osp 3 | 4 | import uuid 5 | from rlkit.envs.vae_wrapper import VAEWrappedEnv 6 | 7 | filename = str(uuid.uuid4()) 8 | 9 | import skvideo.io 10 | import numpy as np 11 | import time 12 | 13 | import scipy.misc 14 | 15 | def add_border(img, pad_length, pad_color, imsize=84): 16 | H = 3*imsize 17 | W = imsize 18 | img = img.reshape((3*imsize, imsize, -1)) 19 | img2 = np.ones((H + 2 * pad_length, W + 2 * pad_length, img.shape[2]), dtype=np.uint8) * pad_color 20 | img2[pad_length:-pad_length, pad_length:-pad_length, :] = img 21 | return img2 22 | 23 | 24 | def get_image(goal, obs, recon_obs, imsize=84, pad_length=1, pad_color=255): 25 | if len(goal.shape) == 1: 26 | goal = goal.reshape(-1, imsize, imsize).transpose(2, 1, 0) 27 | obs = obs.reshape(-1, imsize, imsize).transpose(2,1,0) 28 | recon_obs = recon_obs.reshape(-1, imsize, imsize).transpose(2,1,0) 29 | img = np.concatenate((goal, obs, recon_obs)) 30 | img = np.uint8(255 * img) 31 | if pad_length > 0: 32 | img = add_border(img, pad_length, pad_color) 33 | return img 34 | 35 | 36 | def dump_video( 37 | env, 38 | policy, 39 | filename, 40 | rollout_function, 41 | rows=3, 42 | columns=6, 43 | pad_length=0, 44 | pad_color=255, 45 | do_timer=True, 46 | horizon=100, 47 | dirname_to_save_images=None, 48 | subdirname="rollouts", 49 | imsize=84, 50 | ): 51 | # num_channels = env.vae.input_channels 52 | num_channels = 1 if env.grayscale else 3 53 | frames = [] 54 | H = 3*imsize 55 | W=imsize 56 | N = rows * columns 57 | for i in range(N): 58 | start = time.time() 59 | path = rollout_function( 60 | env, 61 | policy, 62 | max_path_length=horizon, 63 | render=False, 64 | ) 65 | is_vae_env = isinstance(env, VAEWrappedEnv) 66 | l = [] 67 | for d in path['full_observations']: 68 | if is_vae_env: 69 | recon = np.clip(env._reconstruct_img(d['image_observation']), 0, 1) 70 | else: 71 | recon = d['image_observation'] 72 | l.append( 73 | get_image( 74 | d['image_desired_goal'], 75 | d['image_observation'], 76 | recon, 77 | pad_length=pad_length, 78 | pad_color=pad_color, 79 | imsize=imsize, 80 | ) 81 | ) 82 | frames += l 83 | 84 | if dirname_to_save_images: 85 | rollout_dir = osp.join(dirname_to_save_images, subdirname, str(i)) 86 | os.makedirs(rollout_dir, exist_ok=True) 87 | rollout_frames = frames[-101:] 88 | goal_img = np.flip(rollout_frames[0][:imsize, :imsize, :], 0) 89 | scipy.misc.imsave(rollout_dir+"/goal.png", goal_img) 90 | goal_img = np.flip(rollout_frames[1][:imsize, :imsize, :], 0) 91 | scipy.misc.imsave(rollout_dir+"/z_goal.png", goal_img) 92 | for j in range(0, 101, 1): 93 | img = np.flip(rollout_frames[j][imsize:, :imsize, :], 0) 94 | scipy.misc.imsave(rollout_dir+"/"+str(j)+".png", img) 95 | if do_timer: 96 | print(i, time.time() - start) 97 | 98 | frames = np.array(frames, dtype=np.uint8) 99 | path_length = frames.size // ( 100 | N * (H + 2*pad_length) * (W + 2*pad_length) * num_channels 101 | ) 102 | frames = np.array(frames, dtype=np.uint8).reshape( 103 | (N, path_length, H + 2 * pad_length, W + 2 * pad_length, num_channels) 104 | ) 105 | f1 = [] 106 | for k1 in range(columns): 107 | f2 = [] 108 | for k2 in range(rows): 109 | k = k1 * rows + k2 110 | f2.append(frames[k:k+1, :, :, :, :].reshape( 111 | (path_length, H + 2 * pad_length, W + 2 * pad_length, num_channels) 112 | )) 113 | f1.append(np.concatenate(f2, axis=1)) 114 | outputdata = np.concatenate(f1, axis=2) 115 | skvideo.io.vwrite(filename, outputdata) 116 | print("Saved video to ", filename) 117 | -------------------------------------------------------------------------------- /examples/td3.py: -------------------------------------------------------------------------------- 1 | """ 2 | This should results in an average return of ~3000 by the end of training. 3 | 4 | Usually hits 3000 around epoch 80-100. Within a see, the performance will be 5 | a bit noisy from one epoch to the next (occasionally dips dow to ~2000). 6 | 7 | Note that one epoch = 5k steps, so 200 epochs = 1 million steps. 8 | """ 9 | from gym.envs.mujoco import HalfCheetahEnv 10 | 11 | import rlkit.torch.pytorch_util as ptu 12 | from rlkit.data_management.env_replay_buffer import EnvReplayBuffer 13 | from rlkit.envs.wrappers import NormalizedBoxEnv 14 | from rlkit.exploration_strategies.base import \ 15 | PolicyWrappedWithExplorationStrategy 16 | from rlkit.exploration_strategies.gaussian_strategy import GaussianStrategy 17 | from rlkit.launchers.launcher_util import setup_logger 18 | from rlkit.samplers.data_collector import MdpPathCollector 19 | from rlkit.torch.networks import FlattenMlp, TanhMlpPolicy 20 | from rlkit.torch.td3.td3 import TD3Trainer 21 | from rlkit.torch.torch_rl_algorithm import TorchBatchRLAlgorithm 22 | 23 | 24 | def experiment(variant): 25 | expl_env = NormalizedBoxEnv(HalfCheetahEnv()) 26 | eval_env = NormalizedBoxEnv(HalfCheetahEnv()) 27 | obs_dim = expl_env.observation_space.low.size 28 | action_dim = expl_env.action_space.low.size 29 | qf1 = FlattenMlp( 30 | input_size=obs_dim + action_dim, 31 | output_size=1, 32 | **variant['qf_kwargs'] 33 | ) 34 | qf2 = FlattenMlp( 35 | input_size=obs_dim + action_dim, 36 | output_size=1, 37 | **variant['qf_kwargs'] 38 | ) 39 | target_qf1 = FlattenMlp( 40 | input_size=obs_dim + action_dim, 41 | output_size=1, 42 | **variant['qf_kwargs'] 43 | ) 44 | target_qf2 = FlattenMlp( 45 | input_size=obs_dim + action_dim, 46 | output_size=1, 47 | **variant['qf_kwargs'] 48 | ) 49 | policy = TanhMlpPolicy( 50 | input_size=obs_dim, 51 | output_size=action_dim, 52 | **variant['policy_kwargs'] 53 | ) 54 | target_policy = TanhMlpPolicy( 55 | input_size=obs_dim, 56 | output_size=action_dim, 57 | **variant['policy_kwargs'] 58 | ) 59 | es = GaussianStrategy( 60 | action_space=expl_env.action_space, 61 | max_sigma=0.1, 62 | min_sigma=0.1, # Constant sigma 63 | ) 64 | exploration_policy = PolicyWrappedWithExplorationStrategy( 65 | exploration_strategy=es, 66 | policy=policy, 67 | ) 68 | eval_path_collector = MdpPathCollector( 69 | eval_env, 70 | policy, 71 | ) 72 | expl_path_collector = MdpPathCollector( 73 | expl_env, 74 | exploration_policy, 75 | ) 76 | replay_buffer = EnvReplayBuffer( 77 | variant['replay_buffer_size'], 78 | expl_env, 79 | ) 80 | trainer = TD3Trainer( 81 | policy=policy, 82 | qf1=qf1, 83 | qf2=qf2, 84 | target_qf1=target_qf1, 85 | target_qf2=target_qf2, 86 | target_policy=target_policy, 87 | **variant['trainer_kwargs'] 88 | ) 89 | algorithm = TorchBatchRLAlgorithm( 90 | trainer=trainer, 91 | exploration_env=expl_env, 92 | evaluation_env=eval_env, 93 | exploration_data_collector=expl_path_collector, 94 | evaluation_data_collector=eval_path_collector, 95 | replay_buffer=replay_buffer, 96 | **variant['algorithm_kwargs'] 97 | ) 98 | algorithm.to(ptu.device) 99 | algorithm.train() 100 | 101 | 102 | if __name__ == "__main__": 103 | variant = dict( 104 | algorithm_kwargs=dict( 105 | num_epochs=3000, 106 | num_eval_steps_per_epoch=5000, 107 | num_trains_per_train_loop=1000, 108 | num_expl_steps_per_train_loop=1000, 109 | min_num_steps_before_training=1000, 110 | max_path_length=1000, 111 | batch_size=256, 112 | ), 113 | trainer_kwargs=dict( 114 | discount=0.99, 115 | ), 116 | qf_kwargs=dict( 117 | hidden_sizes=[400, 300], 118 | ), 119 | policy_kwargs=dict( 120 | hidden_sizes=[400, 300], 121 | ), 122 | replay_buffer_size=int(1E6), 123 | ) 124 | # ptu.set_gpu_mode(True) # optionally set the GPU (default=False) 125 | setup_logger('rlkit-post-refactor-td3-half-cheetah', variant=variant) 126 | experiment(variant) 127 | -------------------------------------------------------------------------------- /rlkit/torch/vae/vae_base.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import abc 4 | from torch.distributions import Normal 5 | from torch.nn import functional as F 6 | from rlkit.torch import pytorch_util as ptu 7 | 8 | 9 | class VAEBase(torch.nn.Module, metaclass=abc.ABCMeta): 10 | def __init__( 11 | self, 12 | representation_size, 13 | ): 14 | super().__init__() 15 | self.representation_size = representation_size 16 | 17 | @abc.abstractmethod 18 | def encode(self, input): 19 | """ 20 | :param input: 21 | :return: latent_distribution_params 22 | """ 23 | raise NotImplementedError() 24 | 25 | @abc.abstractmethod 26 | def rsample(self, latent_distribution_params): 27 | """ 28 | 29 | :param latent_distribution_params: 30 | :return: latents 31 | """ 32 | raise NotImplementedError() 33 | 34 | @abc.abstractmethod 35 | def reparameterize(self, latent_distribution_params): 36 | """ 37 | 38 | :param latent_distribution_params: 39 | :return: latents 40 | """ 41 | raise NotImplementedError() 42 | 43 | @abc.abstractmethod 44 | def decode(self, latents): 45 | """ 46 | :param latents: 47 | :return: reconstruction, obs_distribution_params 48 | """ 49 | raise NotImplementedError() 50 | 51 | @abc.abstractmethod 52 | def logprob(self, inputs, obs_distribution_params): 53 | """ 54 | :param inputs: 55 | :param obs_distribution_params: 56 | :return: log probability of input under decoder 57 | """ 58 | raise NotImplementedError() 59 | 60 | @abc.abstractmethod 61 | def kl_divergence(self, latent_distribution_params): 62 | """ 63 | :param latent_distribution_params: 64 | :return: kl div between latent_distribution_params and prior on latent space 65 | """ 66 | raise NotImplementedError() 67 | 68 | @abc.abstractmethod 69 | def get_encoding_from_latent_distribution_params(self, latent_distribution_params): 70 | """ 71 | 72 | :param latent_distribution_params: 73 | :return: get latents from latent distribution params 74 | """ 75 | raise NotImplementedError() 76 | 77 | def forward(self, input): 78 | """ 79 | :param input: 80 | :return: reconstructed input, obs_distribution_params, latent_distribution_params 81 | """ 82 | latent_distribution_params = self.encode(input) 83 | latents = self.reparameterize(latent_distribution_params) 84 | reconstructions, obs_distribution_params = self.decode(latents) 85 | return reconstructions, obs_distribution_params, latent_distribution_params 86 | 87 | 88 | class GaussianLatentVAE(VAEBase): 89 | def __init__( 90 | self, 91 | representation_size, 92 | ): 93 | super().__init__(representation_size) 94 | self.dist_mu = np.zeros(self.representation_size) 95 | self.dist_std = np.ones(self.representation_size) 96 | 97 | def rsample(self, latent_distribution_params): 98 | mu, logvar = latent_distribution_params 99 | stds = (0.5 * logvar).exp() 100 | epsilon = ptu.randn(*mu.size()) 101 | latents = epsilon * stds + mu 102 | return latents 103 | 104 | def reparameterize(self, latent_distribution_params): 105 | if self.training: 106 | return self.rsample(latent_distribution_params) 107 | else: 108 | return latent_distribution_params[0] 109 | 110 | def kl_divergence(self, latent_distribution_params): 111 | mu, logvar = latent_distribution_params 112 | return - 0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp(), dim=1).mean() 113 | 114 | def get_encoding_from_latent_distribution_params(self, latent_distribution_params): 115 | return latent_distribution_params[0].cpu() 116 | 117 | 118 | def compute_bernoulli_log_prob(x, reconstruction_of_x): 119 | return -1 * F.binary_cross_entropy( 120 | reconstruction_of_x, 121 | x, 122 | reduction='elementwise_mean' 123 | ) 124 | 125 | 126 | def compute_gaussian_log_prob(input, dec_mu, dec_var): 127 | decoder_dist = Normal(dec_mu, dec_var.pow(0.5)) 128 | log_probs = decoder_dist.log_prob(input) 129 | vals = log_probs.sum(dim=1, keepdim=True) 130 | return vals.mean() 131 | 132 | -------------------------------------------------------------------------------- /examples/her/her_sac_gym_fetch_reach.py: -------------------------------------------------------------------------------- 1 | import gym 2 | 3 | import rlkit.torch.pytorch_util as ptu 4 | from rlkit.data_management.obs_dict_replay_buffer import ObsDictRelabelingBuffer 5 | from rlkit.launchers.launcher_util import setup_logger 6 | from rlkit.samplers.data_collector import GoalConditionedPathCollector 7 | from rlkit.torch.her.her import HERTrainer 8 | from rlkit.torch.networks import FlattenMlp 9 | from rlkit.torch.sac.policies import MakeDeterministic, TanhGaussianPolicy 10 | from rlkit.torch.sac.sac import SACTrainer 11 | from rlkit.torch.torch_rl_algorithm import TorchBatchRLAlgorithm 12 | 13 | 14 | def experiment(variant): 15 | eval_env = gym.make('FetchReach-v1') 16 | expl_env = gym.make('FetchReach-v1') 17 | 18 | observation_key = 'observation' 19 | desired_goal_key = 'desired_goal' 20 | 21 | achieved_goal_key = desired_goal_key.replace("desired", "achieved") 22 | replay_buffer = ObsDictRelabelingBuffer( 23 | env=eval_env, 24 | observation_key=observation_key, 25 | desired_goal_key=desired_goal_key, 26 | achieved_goal_key=achieved_goal_key, 27 | **variant['replay_buffer_kwargs'] 28 | ) 29 | obs_dim = eval_env.observation_space.spaces['observation'].low.size 30 | action_dim = eval_env.action_space.low.size 31 | goal_dim = eval_env.observation_space.spaces['desired_goal'].low.size 32 | qf1 = FlattenMlp( 33 | input_size=obs_dim + action_dim + goal_dim, 34 | output_size=1, 35 | **variant['qf_kwargs'] 36 | ) 37 | qf2 = FlattenMlp( 38 | input_size=obs_dim + action_dim + goal_dim, 39 | output_size=1, 40 | **variant['qf_kwargs'] 41 | ) 42 | target_qf1 = FlattenMlp( 43 | input_size=obs_dim + action_dim + goal_dim, 44 | output_size=1, 45 | **variant['qf_kwargs'] 46 | ) 47 | target_qf2 = FlattenMlp( 48 | input_size=obs_dim + action_dim + goal_dim, 49 | output_size=1, 50 | **variant['qf_kwargs'] 51 | ) 52 | policy = TanhGaussianPolicy( 53 | obs_dim=obs_dim + goal_dim, 54 | action_dim=action_dim, 55 | **variant['policy_kwargs'] 56 | ) 57 | eval_policy = MakeDeterministic(policy) 58 | trainer = SACTrainer( 59 | env=eval_env, 60 | policy=policy, 61 | qf1=qf1, 62 | qf2=qf2, 63 | target_qf1=target_qf1, 64 | target_qf2=target_qf2, 65 | **variant['sac_trainer_kwargs'] 66 | ) 67 | trainer = HERTrainer(trainer) 68 | eval_path_collector = GoalConditionedPathCollector( 69 | eval_env, 70 | eval_policy, 71 | observation_key=observation_key, 72 | desired_goal_key=desired_goal_key, 73 | ) 74 | expl_path_collector = GoalConditionedPathCollector( 75 | expl_env, 76 | policy, 77 | observation_key=observation_key, 78 | desired_goal_key=desired_goal_key, 79 | ) 80 | algorithm = TorchBatchRLAlgorithm( 81 | trainer=trainer, 82 | exploration_env=expl_env, 83 | evaluation_env=eval_env, 84 | exploration_data_collector=expl_path_collector, 85 | evaluation_data_collector=eval_path_collector, 86 | replay_buffer=replay_buffer, 87 | **variant['algo_kwargs'] 88 | ) 89 | algorithm.to(ptu.device) 90 | algorithm.train() 91 | 92 | 93 | 94 | if __name__ == "__main__": 95 | variant = dict( 96 | algorithm='HER-SAC', 97 | version='normal', 98 | algo_kwargs=dict( 99 | batch_size=128, 100 | num_epochs=100, 101 | num_eval_steps_per_epoch=5000, 102 | num_expl_steps_per_train_loop=1000, 103 | num_trains_per_train_loop=1000, 104 | min_num_steps_before_training=1000, 105 | max_path_length=50, 106 | ), 107 | sac_trainer_kwargs=dict( 108 | discount=0.99, 109 | soft_target_tau=5e-3, 110 | target_update_period=1, 111 | policy_lr=3E-4, 112 | qf_lr=3E-4, 113 | reward_scale=1, 114 | use_automatic_entropy_tuning=True, 115 | ), 116 | replay_buffer_kwargs=dict( 117 | max_size=int(1E6), 118 | fraction_goals_rollout_goals=0.2, # equal to k = 4 in HER paper 119 | fraction_goals_env_goals=0, 120 | ), 121 | qf_kwargs=dict( 122 | hidden_sizes=[400, 300], 123 | ), 124 | policy_kwargs=dict( 125 | hidden_sizes=[400, 300], 126 | ), 127 | ) 128 | setup_logger('her-sac-fetch-experiment', variant=variant) 129 | experiment(variant) 130 | -------------------------------------------------------------------------------- /rlkit/launchers/conf.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copy this file to config.py and modify as needed. 3 | """ 4 | import os 5 | from os.path import join 6 | import rlkit 7 | 8 | """ 9 | `doodad.mount.MountLocal` by default ignores directories called "data" 10 | If you're going to rename this directory and use EC2, then change 11 | `doodad.mount.MountLocal.filter_dir` 12 | """ 13 | # The directory of the project, not source 14 | rlkit_project_dir = join(os.path.dirname(rlkit.__file__), os.pardir) 15 | LOCAL_LOG_DIR = join(rlkit_project_dir, 'data') 16 | 17 | """ 18 | ******************************************************************************** 19 | ******************************************************************************** 20 | ******************************************************************************** 21 | 22 | You probably don't need to set all of the configurations below this line, 23 | unless you use AWS, GCP, Slurm, and/or Slurm on a remote server. I recommend 24 | ignoring most of these things and only using them on an as-needed basis. 25 | 26 | ******************************************************************************** 27 | ******************************************************************************** 28 | ******************************************************************************** 29 | """ 30 | 31 | """ 32 | General doodad settings. 33 | """ 34 | CODE_DIRS_TO_MOUNT = [ 35 | rlkit_project_dir, 36 | # '/home/user/python/module/one', Add more paths as needed 37 | ] 38 | 39 | HOME = os.getenv('HOME') if os.getenv('HOME') is not None else os.getenv("USERPROFILE") 40 | 41 | DIR_AND_MOUNT_POINT_MAPPINGS = [ 42 | dict( 43 | local_dir=join(HOME, '.mujoco/'), 44 | mount_point='/root/.mujoco', 45 | ), 46 | ] 47 | RUN_DOODAD_EXPERIMENT_SCRIPT_PATH = ( 48 | join(rlkit_project_dir, 'scripts', 'run_experiment_from_doodad.py') 49 | # '/home/user/path/to/rlkit/scripts/run_experiment_from_doodad.py' 50 | ) 51 | """ 52 | AWS Settings 53 | """ 54 | # If not set, default will be chosen by doodad 55 | # AWS_S3_PATH = 's3://bucket/directory 56 | 57 | # The docker image is looked up on dockerhub.com. 58 | DOODAD_DOCKER_IMAGE = "TODO" 59 | INSTANCE_TYPE = 'c4.large' 60 | SPOT_PRICE = 0.03 61 | 62 | GPU_DOODAD_DOCKER_IMAGE = 'TODO' 63 | GPU_INSTANCE_TYPE = 'g2.2xlarge' 64 | GPU_SPOT_PRICE = 0.5 65 | 66 | # You can use AMI images with the docker images already installed. 67 | REGION_TO_GPU_AWS_IMAGE_ID = { 68 | 'us-west-1': "TODO", 69 | 'us-east-1': "TODO", 70 | } 71 | 72 | REGION_TO_GPU_AWS_AVAIL_ZONE = { 73 | 'us-east-1': "us-east-1b", 74 | } 75 | 76 | # This really shouldn't matter and in theory could be whatever 77 | OUTPUT_DIR_FOR_DOODAD_TARGET = '/tmp/doodad-output/' 78 | 79 | 80 | """ 81 | Slurm Settings 82 | """ 83 | SINGULARITY_IMAGE = '/home/PATH/TO/IMAGE.img' 84 | # This assumes you saved mujoco to $HOME/.mujoco 85 | SINGULARITY_PRE_CMDS = [ 86 | 'export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$HOME/.mujoco/mjpro150/bin' 87 | ] 88 | SLURM_CPU_CONFIG = dict( 89 | account_name='TODO', 90 | partition='savio', 91 | nodes=1, 92 | n_tasks=1, 93 | n_gpus=1, 94 | ) 95 | SLURM_GPU_CONFIG = dict( 96 | account_name='TODO', 97 | partition='savio2_1080ti', 98 | nodes=1, 99 | n_tasks=1, 100 | n_gpus=1, 101 | ) 102 | 103 | 104 | """ 105 | Slurm Script Settings 106 | 107 | These are basically the same settings as above, but for the remote machine 108 | where you will be running the generated script. 109 | """ 110 | SSS_CODE_DIRS_TO_MOUNT = [ 111 | ] 112 | SSS_DIR_AND_MOUNT_POINT_MAPPINGS = [ 113 | dict( 114 | local_dir='/global/home/users/USERNAME/.mujoco', 115 | mount_point='/root/.mujoco', 116 | ), 117 | ] 118 | SSS_LOG_DIR = '/global/scratch/USERNAME/doodad-log' 119 | 120 | SSS_IMAGE = '/global/scratch/USERNAME/TODO.img' 121 | SSS_RUN_DOODAD_EXPERIMENT_SCRIPT_PATH = ( 122 | '/global/home/users/USERNAME/path/to/rlkit/scripts' 123 | '/run_experiment_from_doodad.py' 124 | ) 125 | SSS_PRE_CMDS = [ 126 | 'export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/global/home/users/USERNAME' 127 | '/.mujoco/mjpro150/bin' 128 | ] 129 | 130 | """ 131 | GCP Settings 132 | """ 133 | GCP_IMAGE_NAME = 'TODO' 134 | GCP_GPU_IMAGE_NAME = 'TODO' 135 | GCP_BUCKET_NAME = 'TODO' 136 | 137 | GCP_DEFAULT_KWARGS = dict( 138 | zone='us-west2-c', 139 | instance_type='n1-standard-4', 140 | image_project='TODO', 141 | terminate=True, 142 | preemptible=True, 143 | gpu_kwargs=dict( 144 | gpu_model='nvidia-tesla-p4', 145 | num_gpu=1, 146 | ) 147 | ) 148 | 149 | try: 150 | from rlkit.launchers.conf_private import * 151 | except ImportError: 152 | print("No personal conf_private.py found.") 153 | -------------------------------------------------------------------------------- /rlkit/samplers/rollout_functions.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def multitask_rollout( 5 | env, 6 | agent, 7 | max_path_length=np.inf, 8 | render=False, 9 | render_kwargs=None, 10 | observation_key=None, 11 | desired_goal_key=None, 12 | get_action_kwargs=None, 13 | return_dict_obs=False, 14 | ): 15 | if render_kwargs is None: 16 | render_kwargs = {} 17 | if get_action_kwargs is None: 18 | get_action_kwargs = {} 19 | dict_obs = [] 20 | dict_next_obs = [] 21 | observations = [] 22 | actions = [] 23 | rewards = [] 24 | terminals = [] 25 | agent_infos = [] 26 | env_infos = [] 27 | next_observations = [] 28 | path_length = 0 29 | agent.reset() 30 | o = env.reset() 31 | if render: 32 | env.render(**render_kwargs) 33 | goal = o[desired_goal_key] 34 | while path_length < max_path_length: 35 | dict_obs.append(o) 36 | if observation_key: 37 | o = o[observation_key] 38 | new_obs = np.hstack((o, goal)) 39 | a, agent_info = agent.get_action(new_obs, **get_action_kwargs) 40 | next_o, r, d, env_info = env.step(a) 41 | if render: 42 | env.render(**render_kwargs) 43 | observations.append(o) 44 | rewards.append(r) 45 | terminals.append(d) 46 | actions.append(a) 47 | next_observations.append(next_o) 48 | dict_next_obs.append(next_o) 49 | agent_infos.append(agent_info) 50 | env_infos.append(env_info) 51 | path_length += 1 52 | if d: 53 | break 54 | o = next_o 55 | actions = np.array(actions) 56 | if len(actions.shape) == 1: 57 | actions = np.expand_dims(actions, 1) 58 | observations = np.array(observations) 59 | next_observations = np.array(next_observations) 60 | if return_dict_obs: 61 | observations = dict_obs 62 | next_observations = dict_next_obs 63 | return dict( 64 | observations=observations, 65 | actions=actions, 66 | rewards=np.array(rewards).reshape(-1, 1), 67 | next_observations=next_observations, 68 | terminals=np.array(terminals).reshape(-1, 1), 69 | agent_infos=agent_infos, 70 | env_infos=env_infos, 71 | goals=np.repeat(goal[None], path_length, 0), 72 | full_observations=dict_obs, 73 | ) 74 | 75 | 76 | def rollout( 77 | env, 78 | agent, 79 | max_path_length=np.inf, 80 | render=False, 81 | render_kwargs=None, 82 | ): 83 | """ 84 | The following value for the following keys will be a 2D array, with the 85 | first dimension corresponding to the time dimension. 86 | - observations 87 | - actions 88 | - rewards 89 | - next_observations 90 | - terminals 91 | 92 | The next two elements will be lists of dictionaries, with the index into 93 | the list being the index into the time 94 | - agent_infos 95 | - env_infos 96 | """ 97 | if render_kwargs is None: 98 | render_kwargs = {} 99 | observations = [] 100 | actions = [] 101 | rewards = [] 102 | terminals = [] 103 | agent_infos = [] 104 | env_infos = [] 105 | o = env.reset() 106 | agent.reset() 107 | next_o = None 108 | path_length = 0 109 | if render: 110 | env.render(**render_kwargs) 111 | while path_length < max_path_length: 112 | a, agent_info = agent.get_action(o) 113 | next_o, r, d, env_info = env.step(a) 114 | observations.append(o) 115 | rewards.append(r) 116 | terminals.append(d) 117 | actions.append(a) 118 | agent_infos.append(agent_info) 119 | env_infos.append(env_info) 120 | path_length += 1 121 | if d: 122 | break 123 | o = next_o 124 | if render: 125 | env.render(**render_kwargs) 126 | 127 | actions = np.array(actions) 128 | if len(actions.shape) == 1: 129 | actions = np.expand_dims(actions, 1) 130 | observations = np.array(observations) 131 | if len(observations.shape) == 1: 132 | observations = np.expand_dims(observations, 1) 133 | next_o = np.array([next_o]) 134 | next_observations = np.vstack( 135 | ( 136 | observations[1:, :], 137 | np.expand_dims(next_o, 0) 138 | ) 139 | ) 140 | return dict( 141 | observations=observations, 142 | actions=actions, 143 | rewards=np.array(rewards).reshape(-1, 1), 144 | next_observations=next_observations, 145 | terminals=np.array(terminals).reshape(-1, 1), 146 | agent_infos=agent_infos, 147 | env_infos=env_infos, 148 | ) 149 | -------------------------------------------------------------------------------- /rlkit/torch/sac/policies.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch import nn as nn 4 | 5 | from rlkit.policies.base import ExplorationPolicy, Policy 6 | from rlkit.torch.core import eval_np 7 | from rlkit.torch.distributions import TanhNormal 8 | from rlkit.torch.networks import Mlp 9 | 10 | 11 | LOG_SIG_MAX = 2 12 | LOG_SIG_MIN = -20 13 | 14 | 15 | class TanhGaussianPolicy(Mlp, ExplorationPolicy): 16 | """ 17 | Usage: 18 | 19 | ``` 20 | policy = TanhGaussianPolicy(...) 21 | action, mean, log_std, _ = policy(obs) 22 | action, mean, log_std, _ = policy(obs, deterministic=True) 23 | action, mean, log_std, log_prob = policy(obs, return_log_prob=True) 24 | ``` 25 | 26 | Here, mean and log_std are the mean and log_std of the Gaussian that is 27 | sampled from. 28 | 29 | If deterministic is True, action = tanh(mean). 30 | If return_log_prob is False (default), log_prob = None 31 | This is done because computing the log_prob can be a bit expensive. 32 | """ 33 | def __init__( 34 | self, 35 | hidden_sizes, 36 | obs_dim, 37 | action_dim, 38 | std=None, 39 | init_w=1e-3, 40 | **kwargs 41 | ): 42 | super().__init__( 43 | hidden_sizes, 44 | input_size=obs_dim, 45 | output_size=action_dim, 46 | init_w=init_w, 47 | **kwargs 48 | ) 49 | self.log_std = None 50 | self.std = std 51 | if std is None: 52 | last_hidden_size = obs_dim 53 | if len(hidden_sizes) > 0: 54 | last_hidden_size = hidden_sizes[-1] 55 | self.last_fc_log_std = nn.Linear(last_hidden_size, action_dim) 56 | self.last_fc_log_std.weight.data.uniform_(-init_w, init_w) 57 | self.last_fc_log_std.bias.data.uniform_(-init_w, init_w) 58 | else: 59 | self.log_std = np.log(std) 60 | assert LOG_SIG_MIN <= self.log_std <= LOG_SIG_MAX 61 | 62 | def get_action(self, obs_np, deterministic=False): 63 | actions = self.get_actions(obs_np[None], deterministic=deterministic) 64 | return actions[0, :], {} 65 | 66 | def get_actions(self, obs_np, deterministic=False): 67 | return eval_np(self, obs_np, deterministic=deterministic)[0] 68 | 69 | def forward( 70 | self, 71 | obs, 72 | reparameterize=True, 73 | deterministic=False, 74 | return_log_prob=False, 75 | ): 76 | """ 77 | :param obs: Observation 78 | :param deterministic: If True, do not sample 79 | :param return_log_prob: If True, return a sample and its log probability 80 | """ 81 | h = obs 82 | for i, fc in enumerate(self.fcs): 83 | h = self.hidden_activation(fc(h)) 84 | mean = self.last_fc(h) 85 | if self.std is None: 86 | log_std = self.last_fc_log_std(h) 87 | log_std = torch.clamp(log_std, LOG_SIG_MIN, LOG_SIG_MAX) 88 | std = torch.exp(log_std) 89 | else: 90 | std = self.std 91 | log_std = self.log_std 92 | 93 | log_prob = None 94 | entropy = None 95 | mean_action_log_prob = None 96 | pre_tanh_value = None 97 | if deterministic: 98 | action = torch.tanh(mean) 99 | else: 100 | tanh_normal = TanhNormal(mean, std) 101 | if return_log_prob: 102 | if reparameterize is True: 103 | action, pre_tanh_value = tanh_normal.rsample( 104 | return_pretanh_value=True 105 | ) 106 | else: 107 | action, pre_tanh_value = tanh_normal.sample( 108 | return_pretanh_value=True 109 | ) 110 | log_prob = tanh_normal.log_prob( 111 | action, 112 | pre_tanh_value=pre_tanh_value 113 | ) 114 | log_prob = log_prob.sum(dim=1, keepdim=True) 115 | else: 116 | if reparameterize is True: 117 | action = tanh_normal.rsample() 118 | else: 119 | action = tanh_normal.sample() 120 | 121 | return ( 122 | action, mean, log_std, log_prob, entropy, std, 123 | mean_action_log_prob, pre_tanh_value, 124 | ) 125 | 126 | 127 | class MakeDeterministic(Policy): 128 | def __init__(self, stochastic_policy): 129 | self.stochastic_policy = stochastic_policy 130 | 131 | def get_action(self, observation): 132 | return self.stochastic_policy.get_action(observation, 133 | deterministic=True) 134 | -------------------------------------------------------------------------------- /rlkit/envs/goal_generation/pickup_goal_dataset.py: -------------------------------------------------------------------------------- 1 | from multiworld.envs.mujoco.sawyer_xyz.sawyer_pick_and_place import ( 2 | get_image_presampled_goals 3 | ) 4 | import numpy as np 5 | import cv2 6 | import os.path as osp 7 | import random 8 | 9 | from rlkit.util.io import local_path_from_s3_or_local_path 10 | 11 | 12 | def setup_pickup_image_env(image_env, num_presampled_goals): 13 | """ 14 | Image env and pickup env will have presampled goals. VAE wrapper should 15 | encode whatever presampled goal is sampled. 16 | """ 17 | presampled_goals = get_image_presampled_goals(image_env, 18 | num_presampled_goals) 19 | image_env._presampled_goals = presampled_goals 20 | image_env.num_goals_presampled = \ 21 | presampled_goals[random.choice(list(presampled_goals))].shape[0] 22 | 23 | 24 | def get_image_presampled_goals_from_vae_env(env, num_presampled_goals, 25 | env_id=None): 26 | image_env = env.wrapped_env 27 | return get_image_presampled_goals(image_env, num_presampled_goals) 28 | 29 | 30 | def get_image_presampled_goals_from_image_env(env, num_presampled_goals, 31 | env_id=None): 32 | return get_image_presampled_goals(env, num_presampled_goals) 33 | 34 | 35 | def generate_vae_dataset(variant): 36 | return generate_vae_dataset_from_params(**variant) 37 | 38 | 39 | def generate_vae_dataset_from_params( 40 | env_class=None, 41 | env_kwargs=None, 42 | env_id=None, 43 | N=10000, 44 | test_p=0.9, 45 | use_cached=True, 46 | imsize=84, 47 | num_channels=1, 48 | show=False, 49 | init_camera=None, 50 | dataset_path=None, 51 | oracle_dataset=False, 52 | n_random_steps=100, 53 | vae_dataset_specific_env_kwargs=None, 54 | save_file_prefix=None, 55 | ): 56 | from multiworld.core.image_env import ImageEnv, unormalize_image 57 | import time 58 | 59 | assert oracle_dataset == True 60 | 61 | if env_kwargs is None: 62 | env_kwargs = {} 63 | if save_file_prefix is None: 64 | save_file_prefix = env_id 65 | if save_file_prefix is None: 66 | save_file_prefix = env_class.__name__ 67 | filename = "/tmp/{}_N{}_{}_imsize{}_oracle{}.npy".format( 68 | save_file_prefix, 69 | str(N), 70 | init_camera.__name__ if init_camera else '', 71 | imsize, 72 | oracle_dataset, 73 | ) 74 | info = {} 75 | if dataset_path is not None: 76 | filename = local_path_from_s3_or_local_path(dataset_path) 77 | dataset = np.load(filename) 78 | np.random.shuffle(dataset) 79 | N = dataset.shape[0] 80 | elif use_cached and osp.isfile(filename): 81 | dataset = np.load(filename) 82 | np.random.shuffle(dataset) 83 | print("loaded data from saved file", filename) 84 | else: 85 | now = time.time() 86 | 87 | if env_id is not None: 88 | import gym 89 | import multiworld 90 | multiworld.register_all_envs() 91 | env = gym.make(env_id) 92 | else: 93 | if vae_dataset_specific_env_kwargs is None: 94 | vae_dataset_specific_env_kwargs = {} 95 | for key, val in env_kwargs.items(): 96 | if key not in vae_dataset_specific_env_kwargs: 97 | vae_dataset_specific_env_kwargs[key] = val 98 | env = env_class(**vae_dataset_specific_env_kwargs) 99 | if not isinstance(env, ImageEnv): 100 | env = ImageEnv( 101 | env, 102 | imsize, 103 | init_camera=init_camera, 104 | transpose=True, 105 | normalize=True, 106 | ) 107 | setup_pickup_image_env(env, num_presampled_goals=N) 108 | env.reset() 109 | info['env'] = env 110 | 111 | dataset = np.zeros((N, imsize * imsize * num_channels), dtype=np.uint8) 112 | for i in range(N): 113 | img = env._presampled_goals['image_desired_goal'][i] 114 | dataset[i, :] = unormalize_image(img) 115 | if show: 116 | img = img.reshape(3, imsize, imsize).transpose() 117 | img = img[::-1, :, ::-1] 118 | cv2.imshow('img', img) 119 | cv2.waitKey(1) 120 | time.sleep(.2) 121 | # radius = input('waiting...') 122 | print("done making training data", filename, time.time() - now) 123 | np.random.shuffle(dataset) 124 | np.save(filename, dataset) 125 | 126 | n = int(N * test_p) 127 | train_dataset = dataset[:n, :] 128 | test_dataset = dataset[n:, :] 129 | return train_dataset, test_dataset, info 130 | -------------------------------------------------------------------------------- /rlkit/data_management/shared_obs_dict_replay_buffer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from rlkit.data_management.obs_dict_replay_buffer import ObsDictRelabelingBuffer 4 | 5 | import torch.multiprocessing as mp 6 | import ctypes 7 | 8 | 9 | class SharedObsDictRelabelingBuffer(ObsDictRelabelingBuffer): 10 | """ 11 | Same as an ObsDictRelabelingBuffer but the obs and next_obs are backed 12 | by multiprocessing arrays. The replay buffer size is also shared. The 13 | intended use case is for if one wants obs/next_obs to be shared between 14 | processes. Accesses are synchronized internally by locks (mp takes care 15 | of that). Technically, putting such large arrays in shared memory/requiring 16 | synchronized access can be extremely slow, but it seems ok empirically. 17 | 18 | This code also breaks a lot of functionality for the subprocess. For example, 19 | random_batch is incorrect as actions and _idx_to_future_obs_idx are not 20 | shared. If the subprocess needs all of the functionality, a mp.Array 21 | must be used for all numpy arrays in the replay buffer. 22 | 23 | """ 24 | 25 | def __init__( 26 | self, 27 | *args, 28 | **kwargs 29 | ): 30 | self._shared_size = mp.Value(ctypes.c_long, 0) 31 | ObsDictRelabelingBuffer.__init__(self, *args, **kwargs) 32 | 33 | self._mp_array_info = {} 34 | self._shared_obs_info = {} 35 | self._shared_next_obs_info = {} 36 | 37 | for obs_key, obs_arr in self._obs.items(): 38 | ctype = ctypes.c_double 39 | if obs_arr.dtype == np.uint8: 40 | ctype = ctypes.c_uint8 41 | 42 | self._shared_obs_info[obs_key] = ( 43 | mp.Array(ctype, obs_arr.size), 44 | obs_arr.dtype, 45 | obs_arr.shape, 46 | ) 47 | self._shared_next_obs_info[obs_key] = ( 48 | mp.Array(ctype, obs_arr.size), 49 | obs_arr.dtype, 50 | obs_arr.shape, 51 | ) 52 | 53 | self._obs[obs_key] = to_np(*self._shared_obs_info[obs_key]) 54 | self._next_obs[obs_key] = to_np( 55 | *self._shared_next_obs_info[obs_key]) 56 | self._register_mp_array("_actions") 57 | self._register_mp_array("_terminals") 58 | 59 | def _register_mp_array(self, arr_instance_var_name): 60 | """ 61 | Use this function to register an array to be shared. This will wipe arr. 62 | """ 63 | assert hasattr(self, arr_instance_var_name), arr_instance_var_name 64 | arr = getattr(self, arr_instance_var_name) 65 | 66 | ctype = ctypes.c_double 67 | if arr.dtype == np.uint8: 68 | ctype = ctypes.c_uint8 69 | 70 | self._mp_array_info[arr_instance_var_name] = ( 71 | mp.Array(ctype, arr.size), arr.dtype, arr.shape, 72 | ) 73 | setattr( 74 | self, 75 | arr_instance_var_name, 76 | to_np(*self._mp_array_info[arr_instance_var_name]) 77 | ) 78 | 79 | def init_from_mp_info( 80 | self, 81 | mp_info, 82 | ): 83 | """ 84 | The intended use is to have a subprocess serialize/copy a 85 | SharedObsDictRelabelingBuffer instance and call init_from on the 86 | instance's shared variables. This can't be done during serialization 87 | since multiprocessing shared objects can't be serialized and must be 88 | passed directly to the subprocess as an argument to the fork call. 89 | """ 90 | shared_obs_info, shared_next_obs_info, mp_array_info, shared_size = mp_info 91 | 92 | self._shared_obs_info = shared_obs_info 93 | self._shared_next_obs_info = shared_next_obs_info 94 | self._mp_array_info = mp_array_info 95 | for obs_key in self._shared_obs_info.keys(): 96 | self._obs[obs_key] = to_np(*self._shared_obs_info[obs_key]) 97 | self._next_obs[obs_key] = to_np( 98 | *self._shared_next_obs_info[obs_key]) 99 | 100 | for arr_instance_var_name in self._mp_array_info.keys(): 101 | setattr( 102 | self, 103 | arr_instance_var_name, 104 | to_np(*self._mp_array_info[arr_instance_var_name]) 105 | ) 106 | self._shared_size = shared_size 107 | 108 | def get_mp_info(self): 109 | return ( 110 | self._shared_obs_info, 111 | self._shared_next_obs_info, 112 | self._mp_array_info, 113 | self._shared_size, 114 | ) 115 | 116 | @property 117 | def _size(self): 118 | return self._shared_size.value 119 | 120 | @_size.setter 121 | def _size(self, size): 122 | self._shared_size.value = size 123 | 124 | 125 | def to_np(shared_arr, np_dtype, shape): 126 | return np.frombuffer(shared_arr.get_obj(), dtype=np_dtype).reshape(shape) 127 | -------------------------------------------------------------------------------- /rlkit/core/rl_algorithm.py: -------------------------------------------------------------------------------- 1 | import abc 2 | from collections import OrderedDict 3 | 4 | import gtimer as gt 5 | 6 | from rlkit.core import logger, eval_util 7 | from rlkit.data_management.replay_buffer import ReplayBuffer 8 | from rlkit.samplers.data_collector import DataCollector 9 | 10 | 11 | def _get_epoch_timings(): 12 | times_itrs = gt.get_times().stamps.itrs 13 | times = OrderedDict() 14 | epoch_time = 0 15 | for key in sorted(times_itrs): 16 | time = times_itrs[key][-1] 17 | epoch_time += time 18 | times['time/{} (s)'.format(key)] = time 19 | times['time/epoch (s)'] = epoch_time 20 | times['time/total (s)'] = gt.get_times().total 21 | return times 22 | 23 | 24 | class BaseRLAlgorithm(object, metaclass=abc.ABCMeta): 25 | def __init__( 26 | self, 27 | trainer, 28 | exploration_env, 29 | evaluation_env, 30 | exploration_data_collector: DataCollector, 31 | evaluation_data_collector: DataCollector, 32 | replay_buffer: ReplayBuffer, 33 | ): 34 | self.trainer = trainer 35 | self.expl_env = exploration_env 36 | self.eval_env = evaluation_env 37 | self.expl_data_collector = exploration_data_collector 38 | self.eval_data_collector = evaluation_data_collector 39 | self.replay_buffer = replay_buffer 40 | self._start_epoch = 0 41 | 42 | self.post_epoch_funcs = [] 43 | 44 | def train(self, start_epoch=0): 45 | self._start_epoch = start_epoch 46 | self._train() 47 | 48 | def _train(self): 49 | """ 50 | Train model. 51 | """ 52 | raise NotImplementedError('_train must implemented by inherited class') 53 | 54 | def _end_epoch(self, epoch): 55 | snapshot = self._get_snapshot() 56 | logger.save_itr_params(epoch, snapshot) 57 | gt.stamp('saving') 58 | self._log_stats(epoch) 59 | 60 | self.expl_data_collector.end_epoch(epoch) 61 | self.eval_data_collector.end_epoch(epoch) 62 | self.replay_buffer.end_epoch(epoch) 63 | self.trainer.end_epoch(epoch) 64 | 65 | for post_epoch_func in self.post_epoch_funcs: 66 | post_epoch_func(self, epoch) 67 | 68 | def _get_snapshot(self): 69 | snapshot = {} 70 | for k, v in self.trainer.get_snapshot().items(): 71 | snapshot['trainer/' + k] = v 72 | for k, v in self.expl_data_collector.get_snapshot().items(): 73 | snapshot['exploration/' + k] = v 74 | for k, v in self.eval_data_collector.get_snapshot().items(): 75 | snapshot['evaluation/' + k] = v 76 | for k, v in self.replay_buffer.get_snapshot().items(): 77 | snapshot['replay_buffer/' + k] = v 78 | return snapshot 79 | 80 | def _log_stats(self, epoch): 81 | logger.log("Epoch {} finished".format(epoch), with_timestamp=True) 82 | 83 | """ 84 | Replay Buffer 85 | """ 86 | logger.record_dict( 87 | self.replay_buffer.get_diagnostics(), 88 | prefix='replay_buffer/' 89 | ) 90 | 91 | """ 92 | Trainer 93 | """ 94 | logger.record_dict(self.trainer.get_diagnostics(), prefix='trainer/') 95 | 96 | """ 97 | Exploration 98 | """ 99 | logger.record_dict( 100 | self.expl_data_collector.get_diagnostics(), 101 | prefix='exploration/' 102 | ) 103 | expl_paths = self.expl_data_collector.get_epoch_paths() 104 | if hasattr(self.expl_env, 'get_diagnostics'): 105 | logger.record_dict( 106 | self.expl_env.get_diagnostics(expl_paths), 107 | prefix='exploration/', 108 | ) 109 | logger.record_dict( 110 | eval_util.get_generic_path_information(expl_paths), 111 | prefix="exploration/", 112 | ) 113 | """ 114 | Evaluation 115 | """ 116 | logger.record_dict( 117 | self.eval_data_collector.get_diagnostics(), 118 | prefix='evaluation/', 119 | ) 120 | eval_paths = self.eval_data_collector.get_epoch_paths() 121 | if hasattr(self.eval_env, 'get_diagnostics'): 122 | logger.record_dict( 123 | self.eval_env.get_diagnostics(eval_paths), 124 | prefix='evaluation/', 125 | ) 126 | logger.record_dict( 127 | eval_util.get_generic_path_information(eval_paths), 128 | prefix="evaluation/", 129 | ) 130 | 131 | """ 132 | Misc 133 | """ 134 | gt.stamp('logging') 135 | logger.record_dict(_get_epoch_timings()) 136 | logger.record_tabular('Epoch', epoch) 137 | logger.dump_tabular(with_prefix=False, with_timestamp=False) 138 | 139 | @abc.abstractmethod 140 | def training_mode(self, mode): 141 | """ 142 | Set training mode to `mode`. 143 | :param mode: If True, training will happen (e.g. set the dropout 144 | probabilities to not all ones). 145 | """ 146 | pass 147 | -------------------------------------------------------------------------------- /examples/skewfit/sawyer_door.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | import multiworld.envs.mujoco as mwmj 3 | import rlkit.util.hyperparameter as hyp 4 | from multiworld.envs.mujoco.cameras import sawyer_door_env_camera_v0 5 | from rlkit.launchers.launcher_util import run_experiment 6 | import rlkit.torch.vae.vae_schedules as vae_schedules 7 | from rlkit.launchers.skewfit_experiments import \ 8 | skewfit_full_experiment 9 | from rlkit.torch.vae.conv_vae import imsize48_default_architecture 10 | 11 | 12 | if __name__ == "__main__": 13 | variant = dict( 14 | algorithm='Skew-Fit-SAC', 15 | double_algo=False, 16 | online_vae_exploration=False, 17 | imsize=48, 18 | env_id='SawyerDoorHookResetFreeEnv-v1', 19 | init_camera=sawyer_door_env_camera_v0, 20 | skewfit_variant=dict( 21 | save_video=True, 22 | custom_goal_sampler='replay_buffer', 23 | online_vae_trainer_kwargs=dict( 24 | beta=20, 25 | lr=1e-3, 26 | ), 27 | save_video_period=50, 28 | qf_kwargs=dict( 29 | hidden_sizes=[400, 300], 30 | ), 31 | policy_kwargs=dict( 32 | hidden_sizes=[400, 300], 33 | ), 34 | twin_sac_trainer_kwargs=dict( 35 | reward_scale=1, 36 | discount=0.99, 37 | soft_target_tau=1e-3, 38 | target_update_period=1, 39 | use_automatic_entropy_tuning=True, 40 | ), 41 | max_path_length=100, 42 | algo_kwargs=dict( 43 | batch_size=1024, 44 | num_epochs=170, 45 | num_eval_steps_per_epoch=500, 46 | num_expl_steps_per_train_loop=500, 47 | num_trains_per_train_loop=1000, 48 | min_num_steps_before_training=10000, 49 | vae_training_schedule=vae_schedules.custom_schedule, 50 | oracle_data=False, 51 | vae_save_period=50, 52 | parallel_vae_train=False, 53 | ), 54 | replay_buffer_kwargs=dict( 55 | start_skew_epoch=10, 56 | max_size=int(100000), 57 | fraction_goals_rollout_goals=0.2, 58 | fraction_goals_env_goals=0.5, 59 | exploration_rewards_type='None', 60 | vae_priority_type='vae_prob', 61 | priority_function_kwargs=dict( 62 | sampling_method='importance_sampling', 63 | decoder_distribution='gaussian_identity_variance', 64 | num_latents_to_sample=10, 65 | ), 66 | power=-0.5, 67 | relabeling_goal_sampling_mode='custom_goal_sampler', 68 | ), 69 | exploration_goal_sampling_mode='custom_goal_sampler', 70 | evaluation_goal_sampling_mode='presampled', 71 | training_mode='train', 72 | testing_mode='test', 73 | reward_params=dict( 74 | type='latent_distance', 75 | ), 76 | observation_key='latent_observation', 77 | desired_goal_key='latent_desired_goal', 78 | presampled_goals_path=osp.join( 79 | osp.dirname(mwmj.__file__), 80 | "goals", 81 | "door_goals.npy", 82 | ), 83 | presample_goals=True, 84 | vae_wrapped_env_kwargs=dict( 85 | sample_from_true_prior=True, 86 | ), 87 | ), 88 | train_vae_variant=dict( 89 | representation_size=16, 90 | beta=20, 91 | num_epochs=0, 92 | dump_skew_debug_plots=False, 93 | decoder_activation='gaussian', 94 | generate_vae_dataset_kwargs=dict( 95 | N=2, 96 | test_p=.9, 97 | use_cached=True, 98 | show=False, 99 | oracle_dataset=False, 100 | n_random_steps=1, 101 | non_presampled_goal_img_is_garbage=True, 102 | ), 103 | vae_kwargs=dict( 104 | decoder_distribution='gaussian_identity_variance', 105 | input_channels=3, 106 | architecture=imsize48_default_architecture, 107 | ), 108 | algo_kwargs=dict( 109 | lr=1e-3, 110 | ), 111 | save_period=1, 112 | ), 113 | ) 114 | 115 | search_space = { 116 | } 117 | sweeper = hyp.DeterministicHyperparameterSweeper( 118 | search_space, default_parameters=variant, 119 | ) 120 | 121 | n_seeds = 1 122 | mode = 'local' 123 | exp_prefix = 'dev-{}'.format( 124 | __file__.replace('/', '-').replace('_', '-').split('.')[0] 125 | ) 126 | 127 | for exp_id, variant in enumerate(sweeper.iterate_hyperparameters()): 128 | for _ in range(n_seeds): 129 | run_experiment( 130 | skewfit_full_experiment, 131 | exp_prefix=exp_prefix, 132 | mode=mode, 133 | variant=variant, 134 | use_gpu=True, 135 | ) 136 | 137 | -------------------------------------------------------------------------------- /examples/her/her_td3_multiworld_sawyer_reach.py: -------------------------------------------------------------------------------- 1 | """ 2 | This should results in an average return of ~3000 by the end of training. 3 | 4 | Usually hits 3000 around epoch 80-100. Within a see, the performance will be 5 | a bit noisy from one epoch to the next (occasionally dips dow to ~2000). 6 | 7 | Note that one epoch = 5k steps, so 200 epochs = 1 million steps. 8 | """ 9 | import gym 10 | 11 | import rlkit.torch.pytorch_util as ptu 12 | from rlkit.data_management.obs_dict_replay_buffer import ObsDictRelabelingBuffer 13 | from rlkit.exploration_strategies.base import \ 14 | PolicyWrappedWithExplorationStrategy 15 | from rlkit.exploration_strategies.gaussian_and_epsilon_strategy import ( 16 | GaussianAndEpislonStrategy 17 | ) 18 | from rlkit.launchers.launcher_util import setup_logger 19 | from rlkit.samplers.data_collector import GoalConditionedPathCollector 20 | from rlkit.torch.her.her import HERTrainer 21 | from rlkit.torch.networks import FlattenMlp, TanhMlpPolicy 22 | from rlkit.torch.td3.td3 import TD3Trainer 23 | from rlkit.torch.torch_rl_algorithm import TorchBatchRLAlgorithm 24 | 25 | 26 | def experiment(variant): 27 | import multiworld 28 | multiworld.register_all_envs() 29 | eval_env = gym.make('SawyerReachXYZEnv-v0') 30 | expl_env = gym.make('SawyerReachXYZEnv-v0') 31 | observation_key = 'state_observation' 32 | desired_goal_key = 'state_desired_goal' 33 | achieved_goal_key = desired_goal_key.replace("desired", "achieved") 34 | es = GaussianAndEpislonStrategy( 35 | action_space=expl_env.action_space, 36 | max_sigma=.2, 37 | min_sigma=.2, # constant sigma 38 | epsilon=.3, 39 | ) 40 | obs_dim = expl_env.observation_space.spaces['observation'].low.size 41 | goal_dim = expl_env.observation_space.spaces['desired_goal'].low.size 42 | action_dim = expl_env.action_space.low.size 43 | qf1 = FlattenMlp( 44 | input_size=obs_dim + goal_dim + action_dim, 45 | output_size=1, 46 | **variant['qf_kwargs'] 47 | ) 48 | qf2 = FlattenMlp( 49 | input_size=obs_dim + goal_dim + action_dim, 50 | output_size=1, 51 | **variant['qf_kwargs'] 52 | ) 53 | target_qf1 = FlattenMlp( 54 | input_size=obs_dim + goal_dim + action_dim, 55 | output_size=1, 56 | **variant['qf_kwargs'] 57 | ) 58 | target_qf2 = FlattenMlp( 59 | input_size=obs_dim + goal_dim + action_dim, 60 | output_size=1, 61 | **variant['qf_kwargs'] 62 | ) 63 | policy = TanhMlpPolicy( 64 | input_size=obs_dim + goal_dim, 65 | output_size=action_dim, 66 | **variant['policy_kwargs'] 67 | ) 68 | target_policy = TanhMlpPolicy( 69 | input_size=obs_dim + goal_dim, 70 | output_size=action_dim, 71 | **variant['policy_kwargs'] 72 | ) 73 | expl_policy = PolicyWrappedWithExplorationStrategy( 74 | exploration_strategy=es, 75 | policy=policy, 76 | ) 77 | replay_buffer = ObsDictRelabelingBuffer( 78 | env=eval_env, 79 | observation_key=observation_key, 80 | desired_goal_key=desired_goal_key, 81 | achieved_goal_key=achieved_goal_key, 82 | **variant['replay_buffer_kwargs'] 83 | ) 84 | trainer = TD3Trainer( 85 | policy=policy, 86 | qf1=qf1, 87 | qf2=qf2, 88 | target_qf1=target_qf1, 89 | target_qf2=target_qf2, 90 | target_policy=target_policy, 91 | **variant['trainer_kwargs'] 92 | ) 93 | trainer = HERTrainer(trainer) 94 | eval_path_collector = GoalConditionedPathCollector( 95 | eval_env, 96 | policy, 97 | observation_key=observation_key, 98 | desired_goal_key=desired_goal_key, 99 | ) 100 | expl_path_collector = GoalConditionedPathCollector( 101 | expl_env, 102 | expl_policy, 103 | observation_key=observation_key, 104 | desired_goal_key=desired_goal_key, 105 | ) 106 | algorithm = TorchBatchRLAlgorithm( 107 | trainer=trainer, 108 | exploration_env=expl_env, 109 | evaluation_env=eval_env, 110 | exploration_data_collector=expl_path_collector, 111 | evaluation_data_collector=eval_path_collector, 112 | replay_buffer=replay_buffer, 113 | **variant['algo_kwargs'] 114 | ) 115 | algorithm.to(ptu.device) 116 | algorithm.train() 117 | 118 | 119 | if __name__ == "__main__": 120 | variant = dict( 121 | algo_kwargs=dict( 122 | num_epochs=100, 123 | max_path_length=50, 124 | batch_size=128, 125 | num_eval_steps_per_epoch=1000, 126 | num_expl_steps_per_train_loop=1000, 127 | num_trains_per_train_loop=1000, 128 | ), 129 | trainer_kwargs=dict( 130 | discount=0.99, 131 | ), 132 | replay_buffer_kwargs=dict( 133 | max_size=100000, 134 | fraction_goals_rollout_goals=0.2, 135 | fraction_goals_env_goals=0.0, 136 | ), 137 | qf_kwargs=dict( 138 | hidden_sizes=[400, 300], 139 | ), 140 | policy_kwargs=dict( 141 | hidden_sizes=[400, 300], 142 | ), 143 | ) 144 | setup_logger('her-td3-sawyer-experiment', variant=variant) 145 | experiment(variant) 146 | -------------------------------------------------------------------------------- /rlkit/envs/assets/reacher_7dof.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 84 | -------------------------------------------------------------------------------- /rlkit/envs/assets/low_gear_ratio_ant.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 85 | -------------------------------------------------------------------------------- /rlkit/envs/mujoco_image_env.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | import torch 4 | from PIL import Image 5 | from collections.__init__ import deque 6 | 7 | from gym import Env 8 | from gym.spaces import Box 9 | 10 | from rlkit.envs.wrappers import ProxyEnv 11 | 12 | 13 | class ImageMujocoEnv(ProxyEnv, Env): 14 | def __init__(self, 15 | wrapped_env, 16 | imsize=32, 17 | keep_prev=0, 18 | init_camera=None, 19 | camera_name=None, 20 | transpose=False, 21 | grayscale=False, 22 | normalize=False, 23 | ): 24 | import mujoco_py 25 | super().__init__(wrapped_env) 26 | 27 | self.imsize = imsize 28 | if grayscale: 29 | self.image_length = self.imsize * self.imsize 30 | else: 31 | self.image_length = 3 * self.imsize * self.imsize 32 | # This is torch format rather than PIL image 33 | self.image_shape = (self.imsize, self.imsize) 34 | # Flattened past image queue 35 | self.history_length = keep_prev + 1 36 | self.history = deque(maxlen=self.history_length) 37 | # init camera 38 | if init_camera is not None: 39 | sim = self._wrapped_env.sim 40 | viewer = mujoco_py.MjRenderContextOffscreen(sim, device_id=-1) 41 | init_camera(viewer.cam) 42 | sim.add_render_context(viewer) 43 | self.camera_name = camera_name # None means default camera 44 | self.transpose = transpose 45 | self.grayscale = grayscale 46 | self.normalize = normalize 47 | self._render_local = False 48 | 49 | self.observation_space = Box(low=0.0, 50 | high=1.0, 51 | shape=( 52 | self.image_length * self.history_length,)) 53 | 54 | def step(self, action): 55 | # image observation get returned as a flattened 1D array 56 | true_state, reward, done, info = super().step(action) 57 | 58 | observation = self._image_observation() 59 | self.history.append(observation) 60 | history = self._get_history().flatten() 61 | full_obs = self._get_obs(history, true_state) 62 | return full_obs, reward, done, info 63 | 64 | def reset(self, **kwargs): 65 | true_state = super().reset(**kwargs) 66 | self.history = deque(maxlen=self.history_length) 67 | 68 | observation = self._image_observation() 69 | self.history.append(observation) 70 | history = self._get_history().flatten() 71 | full_obs = self._get_obs(history, true_state) 72 | return full_obs 73 | 74 | def get_image(self): 75 | """TODO: this should probably consider history""" 76 | return self._image_observation() 77 | 78 | def _get_obs(self, history_flat, true_state): 79 | # adds extra information from true_state into to the image observation. 80 | # Used in ImageWithObsEnv. 81 | return history_flat 82 | 83 | def _image_observation(self): 84 | # returns the image as a torch format np array 85 | image_obs = self._wrapped_env.sim.render(width=self.imsize, 86 | height=self.imsize, 87 | camera_name=self.camera_name) 88 | if self._render_local: 89 | cv2.imshow('env', image_obs) 90 | cv2.waitKey(1) 91 | if self.grayscale: 92 | image_obs = Image.fromarray(image_obs).convert('L') 93 | image_obs = np.array(image_obs) 94 | if self.normalize: 95 | image_obs = image_obs / 255.0 96 | if self.transpose: 97 | image_obs = image_obs.transpose() 98 | return image_obs 99 | 100 | def _get_history(self): 101 | observations = list(self.history) 102 | 103 | obs_count = len(observations) 104 | for _ in range(self.history_length - obs_count): 105 | dummy = np.zeros(self.image_shape) 106 | observations.append(dummy) 107 | return np.c_[observations] 108 | 109 | def retrieve_images(self): 110 | # returns images in unflattened PIL format 111 | images = [] 112 | for image_obs in self.history: 113 | pil_image = self.torch_to_pil(torch.Tensor(image_obs)) 114 | images.append(pil_image) 115 | return images 116 | 117 | def split_obs(self, obs): 118 | # splits observation into image input and true observation input 119 | imlength = self.image_length * self.history_length 120 | obs_length = self.observation_space.low.size 121 | obs = obs.view(-1, obs_length) 122 | image_obs = obs.narrow(start=0, 123 | length=imlength, 124 | dimension=1) 125 | if obs_length == imlength: 126 | return image_obs, None 127 | 128 | fc_obs = obs.narrow(start=imlength, 129 | length=obs.shape[1] - imlength, 130 | dimension=1) 131 | return image_obs, fc_obs 132 | 133 | def enable_render(self): 134 | self._render_local = True 135 | 136 | 137 | class ImageMujocoWithObsEnv(ImageMujocoEnv): 138 | def __init__(self, env, **kwargs): 139 | super().__init__(env, **kwargs) 140 | self.observation_space = Box(low=0.0, 141 | high=1.0, 142 | shape=( 143 | self.image_length * self.history_length + 144 | self.wrapped_env.obs_dim,)) 145 | 146 | def _get_obs(self, history_flat, true_state): 147 | return np.concatenate([history_flat, 148 | true_state]) -------------------------------------------------------------------------------- /rlkit/envs/wrappers.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import itertools 3 | from gym import Env 4 | from gym.spaces import Box 5 | from gym.spaces import Discrete 6 | 7 | from collections import deque 8 | 9 | 10 | class ProxyEnv(Env): 11 | def __init__(self, wrapped_env): 12 | self._wrapped_env = wrapped_env 13 | self.action_space = self._wrapped_env.action_space 14 | self.observation_space = self._wrapped_env.observation_space 15 | 16 | @property 17 | def wrapped_env(self): 18 | return self._wrapped_env 19 | 20 | def reset(self, **kwargs): 21 | return self._wrapped_env.reset(**kwargs) 22 | 23 | def step(self, action): 24 | return self._wrapped_env.step(action) 25 | 26 | def render(self, *args, **kwargs): 27 | return self._wrapped_env.render(*args, **kwargs) 28 | 29 | @property 30 | def horizon(self): 31 | return self._wrapped_env.horizon 32 | 33 | def terminate(self): 34 | if hasattr(self.wrapped_env, "terminate"): 35 | self.wrapped_env.terminate() 36 | 37 | def __getattr__(self, attr): 38 | if attr == '_wrapped_env': 39 | raise AttributeError() 40 | return getattr(self._wrapped_env, attr) 41 | 42 | def __getstate__(self): 43 | """ 44 | This is useful to override in case the wrapped env has some funky 45 | __getstate__ that doesn't play well with overriding __getattr__. 46 | 47 | The main problematic case is/was gym's EzPickle serialization scheme. 48 | :return: 49 | """ 50 | return self.__dict__ 51 | 52 | def __setstate__(self, state): 53 | self.__dict__.update(state) 54 | 55 | def __str__(self): 56 | return '{}({})'.format(type(self).__name__, self.wrapped_env) 57 | 58 | 59 | class HistoryEnv(ProxyEnv, Env): 60 | def __init__(self, wrapped_env, history_len): 61 | super().__init__(wrapped_env) 62 | self.history_len = history_len 63 | 64 | high = np.inf * np.ones( 65 | self.history_len * self.observation_space.low.size) 66 | low = -high 67 | self.observation_space = Box(low=low, 68 | high=high, 69 | ) 70 | self.history = deque(maxlen=self.history_len) 71 | 72 | def step(self, action): 73 | state, reward, done, info = super().step(action) 74 | self.history.append(state) 75 | flattened_history = self._get_history().flatten() 76 | return flattened_history, reward, done, info 77 | 78 | def reset(self, **kwargs): 79 | state = super().reset() 80 | self.history = deque(maxlen=self.history_len) 81 | self.history.append(state) 82 | flattened_history = self._get_history().flatten() 83 | return flattened_history 84 | 85 | def _get_history(self): 86 | observations = list(self.history) 87 | 88 | obs_count = len(observations) 89 | for _ in range(self.history_len - obs_count): 90 | dummy = np.zeros(self._wrapped_env.observation_space.low.size) 91 | observations.append(dummy) 92 | return np.c_[observations] 93 | 94 | 95 | class DiscretizeEnv(ProxyEnv, Env): 96 | def __init__(self, wrapped_env, num_bins): 97 | super().__init__(wrapped_env) 98 | low = self.wrapped_env.action_space.low 99 | high = self.wrapped_env.action_space.high 100 | action_ranges = [ 101 | np.linspace(low[i], high[i], num_bins) 102 | for i in range(len(low)) 103 | ] 104 | self.idx_to_continuous_action = [ 105 | np.array(x) for x in itertools.product(*action_ranges) 106 | ] 107 | self.action_space = Discrete(len(self.idx_to_continuous_action)) 108 | 109 | def step(self, action): 110 | continuous_action = self.idx_to_continuous_action[action] 111 | return super().step(continuous_action) 112 | 113 | 114 | class NormalizedBoxEnv(ProxyEnv): 115 | """ 116 | Normalize action to in [-1, 1]. 117 | 118 | Optionally normalize observations and scale reward. 119 | """ 120 | 121 | def __init__( 122 | self, 123 | env, 124 | reward_scale=1., 125 | obs_mean=None, 126 | obs_std=None, 127 | ): 128 | ProxyEnv.__init__(self, env) 129 | self._should_normalize = not (obs_mean is None and obs_std is None) 130 | if self._should_normalize: 131 | if obs_mean is None: 132 | obs_mean = np.zeros_like(env.observation_space.low) 133 | else: 134 | obs_mean = np.array(obs_mean) 135 | if obs_std is None: 136 | obs_std = np.ones_like(env.observation_space.low) 137 | else: 138 | obs_std = np.array(obs_std) 139 | self._reward_scale = reward_scale 140 | self._obs_mean = obs_mean 141 | self._obs_std = obs_std 142 | ub = np.ones(self._wrapped_env.action_space.shape) 143 | self.action_space = Box(-1 * ub, ub) 144 | 145 | def estimate_obs_stats(self, obs_batch, override_values=False): 146 | if self._obs_mean is not None and not override_values: 147 | raise Exception("Observation mean and std already set. To " 148 | "override, set override_values to True.") 149 | self._obs_mean = np.mean(obs_batch, axis=0) 150 | self._obs_std = np.std(obs_batch, axis=0) 151 | 152 | def _apply_normalize_obs(self, obs): 153 | return (obs - self._obs_mean) / (self._obs_std + 1e-8) 154 | 155 | def step(self, action): 156 | lb = self._wrapped_env.action_space.low 157 | ub = self._wrapped_env.action_space.high 158 | scaled_action = lb + (action + 1.) * 0.5 * (ub - lb) 159 | scaled_action = np.clip(scaled_action, lb, ub) 160 | 161 | wrapped_step = self._wrapped_env.step(scaled_action) 162 | next_obs, reward, done, info = wrapped_step 163 | if self._should_normalize: 164 | next_obs = self._apply_normalize_obs(next_obs) 165 | return next_obs, reward * self._reward_scale, done, info 166 | 167 | def __str__(self): 168 | return "Normalized: %s" % self._wrapped_env 169 | 170 | -------------------------------------------------------------------------------- /examples/skewfit/sawyer_push.py: -------------------------------------------------------------------------------- 1 | import rlkit.util.hyperparameter as hyp 2 | from multiworld.envs.mujoco.cameras import sawyer_init_camera_zoomed_in 3 | from rlkit.launchers.launcher_util import run_experiment 4 | import rlkit.torch.vae.vae_schedules as vae_schedules 5 | from rlkit.launchers.skewfit_experiments import skewfit_full_experiment 6 | from rlkit.torch.vae.conv_vae import imsize48_default_architecture 7 | 8 | 9 | if __name__ == "__main__": 10 | variant = dict( 11 | algorithm='Skew-Fit', 12 | double_algo=False, 13 | online_vae_exploration=False, 14 | imsize=48, 15 | init_camera=sawyer_init_camera_zoomed_in, 16 | env_id='SawyerPushNIPSEasy-v0', 17 | skewfit_variant=dict( 18 | save_video=True, 19 | custom_goal_sampler='replay_buffer', 20 | online_vae_trainer_kwargs=dict( 21 | beta=20, 22 | lr=1e-3, 23 | ), 24 | save_video_period=100, 25 | qf_kwargs=dict( 26 | hidden_sizes=[400, 300], 27 | ), 28 | policy_kwargs=dict( 29 | hidden_sizes=[400, 300], 30 | ), 31 | vf_kwargs=dict( 32 | hidden_sizes=[400, 300], 33 | ), 34 | max_path_length=50, 35 | algo_kwargs=dict( 36 | batch_size=1024, 37 | num_epochs=1000, 38 | num_eval_steps_per_epoch=500, 39 | num_expl_steps_per_train_loop=500, 40 | num_trains_per_train_loop=1000, 41 | min_num_steps_before_training=10000, 42 | vae_training_schedule=vae_schedules.custom_schedule_2, 43 | oracle_data=False, 44 | vae_save_period=50, 45 | parallel_vae_train=False, 46 | ), 47 | twin_sac_trainer_kwargs=dict( 48 | discount=0.99, 49 | reward_scale=1, 50 | soft_target_tau=1e-3, 51 | target_update_period=1, # 1 52 | use_automatic_entropy_tuning=True, 53 | ), 54 | replay_buffer_kwargs=dict( 55 | start_skew_epoch=10, 56 | max_size=int(100000), 57 | fraction_goals_rollout_goals=0.2, 58 | fraction_goals_env_goals=0.5, 59 | exploration_rewards_type='None', 60 | vae_priority_type='vae_prob', 61 | priority_function_kwargs=dict( 62 | sampling_method='importance_sampling', 63 | decoder_distribution='gaussian_identity_variance', 64 | num_latents_to_sample=10, 65 | ), 66 | power=-1, 67 | relabeling_goal_sampling_mode='vae_prior', 68 | ), 69 | exploration_goal_sampling_mode='vae_prior', 70 | evaluation_goal_sampling_mode='reset_of_env', 71 | normalize=False, 72 | render=False, 73 | exploration_noise=0.0, 74 | exploration_type='ou', 75 | training_mode='train', 76 | testing_mode='test', 77 | reward_params=dict( 78 | type='latent_distance', 79 | ), 80 | observation_key='latent_observation', 81 | desired_goal_key='latent_desired_goal', 82 | vae_wrapped_env_kwargs=dict( 83 | sample_from_true_prior=True, 84 | ), 85 | ), 86 | train_vae_variant=dict( 87 | representation_size=4, 88 | beta=20, 89 | num_epochs=0, 90 | dump_skew_debug_plots=False, 91 | decoder_activation='gaussian', 92 | generate_vae_dataset_kwargs=dict( 93 | N=40, 94 | test_p=.9, 95 | use_cached=False, 96 | show=False, 97 | oracle_dataset=True, 98 | oracle_dataset_using_set_to_goal=True, 99 | n_random_steps=100, 100 | non_presampled_goal_img_is_garbage=True, 101 | ), 102 | vae_kwargs=dict( 103 | input_channels=3, 104 | architecture=imsize48_default_architecture, 105 | decoder_distribution='gaussian_identity_variance', 106 | ), 107 | # TODO: why the redundancy? 108 | algo_kwargs=dict( 109 | start_skew_epoch=5000, 110 | is_auto_encoder=False, 111 | batch_size=64, 112 | lr=1e-3, 113 | skew_config=dict( 114 | method='vae_prob', 115 | power=-1, 116 | ), 117 | skew_dataset=True, 118 | priority_function_kwargs=dict( 119 | decoder_distribution='gaussian_identity_variance', 120 | sampling_method='importance_sampling', 121 | num_latents_to_sample=10, 122 | ), 123 | use_parallel_dataloading=False, 124 | ), 125 | 126 | save_period=25, 127 | ), 128 | ) 129 | search_space = {} 130 | sweeper = hyp.DeterministicHyperparameterSweeper( 131 | search_space, default_parameters=variant, 132 | ) 133 | 134 | n_seeds = 1 135 | mode = 'local' 136 | exp_prefix = 'dev-{}'.format( 137 | __file__.replace('/', '-').replace('_', '-').split('.')[0] 138 | ) 139 | 140 | n_seeds = 3 141 | mode = 'ec2' 142 | exp_prefix = 'rlkit-skew-fit-pusher-reference-sample-from-true-prior-take2' 143 | 144 | for exp_id, variant in enumerate(sweeper.iterate_hyperparameters()): 145 | for _ in range(n_seeds): 146 | run_experiment( 147 | skewfit_full_experiment, 148 | exp_prefix=exp_prefix, 149 | mode=mode, 150 | variant=variant, 151 | use_gpu=True, 152 | num_exps_per_instance=3, 153 | gcp_kwargs=dict( 154 | terminate=True, 155 | zone='us-east1-c', 156 | gpu_kwargs=dict( 157 | gpu_model='nvidia-tesla-k80', 158 | num_gpu=1, 159 | ) 160 | ) 161 | ) 162 | -------------------------------------------------------------------------------- /rlkit/samplers/data_collector/path_collector.py: -------------------------------------------------------------------------------- 1 | from collections import deque, OrderedDict 2 | 3 | from rlkit.core.eval_util import create_stats_ordered_dict 4 | from rlkit.samplers.rollout_functions import rollout, multitask_rollout 5 | from rlkit.samplers.data_collector.base import PathCollector 6 | 7 | 8 | class MdpPathCollector(PathCollector): 9 | def __init__( 10 | self, 11 | env, 12 | policy, 13 | max_num_epoch_paths_saved=None, 14 | render=False, 15 | render_kwargs=None, 16 | ): 17 | if render_kwargs is None: 18 | render_kwargs = {} 19 | self._env = env 20 | self._policy = policy 21 | self._max_num_epoch_paths_saved = max_num_epoch_paths_saved 22 | self._epoch_paths = deque(maxlen=self._max_num_epoch_paths_saved) 23 | self._render = render 24 | self._render_kwargs = render_kwargs 25 | 26 | self._num_steps_total = 0 27 | self._num_paths_total = 0 28 | 29 | def collect_new_paths( 30 | self, 31 | max_path_length, 32 | num_steps, 33 | discard_incomplete_paths, 34 | ): 35 | paths = [] 36 | num_steps_collected = 0 37 | while num_steps_collected < num_steps: 38 | max_path_length_this_loop = min( # Do not go over num_steps 39 | max_path_length, 40 | num_steps - num_steps_collected, 41 | ) 42 | path = rollout( 43 | self._env, 44 | self._policy, 45 | max_path_length=max_path_length_this_loop, 46 | ) 47 | path_len = len(path['actions']) 48 | if ( 49 | path_len != max_path_length 50 | and not path['terminals'][-1] 51 | and discard_incomplete_paths 52 | ): 53 | break 54 | num_steps_collected += path_len 55 | paths.append(path) 56 | self._num_paths_total += len(paths) 57 | self._num_steps_total += num_steps_collected 58 | self._epoch_paths.extend(paths) 59 | return paths 60 | 61 | def get_epoch_paths(self): 62 | return self._epoch_paths 63 | 64 | def end_epoch(self, epoch): 65 | self._epoch_paths = deque(maxlen=self._max_num_epoch_paths_saved) 66 | 67 | def get_diagnostics(self): 68 | path_lens = [len(path['actions']) for path in self._epoch_paths] 69 | stats = OrderedDict([ 70 | ('num steps total', self._num_steps_total), 71 | ('num paths total', self._num_paths_total), 72 | ]) 73 | stats.update(create_stats_ordered_dict( 74 | "path length", 75 | path_lens, 76 | always_show_all_stats=True, 77 | )) 78 | return stats 79 | 80 | def get_snapshot(self): 81 | return dict( 82 | env=self._env, 83 | policy=self._policy, 84 | ) 85 | 86 | 87 | class GoalConditionedPathCollector(PathCollector): 88 | def __init__( 89 | self, 90 | env, 91 | policy, 92 | max_num_epoch_paths_saved=None, 93 | render=False, 94 | render_kwargs=None, 95 | observation_key='observation', 96 | desired_goal_key='desired_goal', 97 | ): 98 | if render_kwargs is None: 99 | render_kwargs = {} 100 | self._env = env 101 | self._policy = policy 102 | self._max_num_epoch_paths_saved = max_num_epoch_paths_saved 103 | self._render = render 104 | self._render_kwargs = render_kwargs 105 | self._epoch_paths = deque(maxlen=self._max_num_epoch_paths_saved) 106 | self._observation_key = observation_key 107 | self._desired_goal_key = desired_goal_key 108 | 109 | self._num_steps_total = 0 110 | self._num_paths_total = 0 111 | 112 | def collect_new_paths( 113 | self, 114 | max_path_length, 115 | num_steps, 116 | discard_incomplete_paths, 117 | ): 118 | paths = [] 119 | num_steps_collected = 0 120 | while num_steps_collected < num_steps: 121 | max_path_length_this_loop = min( # Do not go over num_steps 122 | max_path_length, 123 | num_steps - num_steps_collected, 124 | ) 125 | path = multitask_rollout( 126 | self._env, 127 | self._policy, 128 | max_path_length=max_path_length_this_loop, 129 | render=self._render, 130 | render_kwargs=self._render_kwargs, 131 | observation_key=self._observation_key, 132 | desired_goal_key=self._desired_goal_key, 133 | return_dict_obs=True, 134 | ) 135 | path_len = len(path['actions']) 136 | if ( 137 | path_len != max_path_length 138 | and not path['terminals'][-1] 139 | and discard_incomplete_paths 140 | ): 141 | break 142 | num_steps_collected += path_len 143 | paths.append(path) 144 | self._num_paths_total += len(paths) 145 | self._num_steps_total += num_steps_collected 146 | self._epoch_paths.extend(paths) 147 | return paths 148 | 149 | def get_epoch_paths(self): 150 | return self._epoch_paths 151 | 152 | def end_epoch(self, epoch): 153 | self._epoch_paths = deque(maxlen=self._max_num_epoch_paths_saved) 154 | 155 | def get_diagnostics(self): 156 | path_lens = [len(path['actions']) for path in self._epoch_paths] 157 | stats = OrderedDict([ 158 | ('num steps total', self._num_steps_total), 159 | ('num paths total', self._num_paths_total), 160 | ]) 161 | stats.update(create_stats_ordered_dict( 162 | "path length", 163 | path_lens, 164 | always_show_all_stats=True, 165 | )) 166 | return stats 167 | 168 | def get_snapshot(self): 169 | return dict( 170 | env=self._env, 171 | policy=self._policy, 172 | observation_key=self._observation_key, 173 | desired_goal_key=self._desired_goal_key, 174 | ) 175 | --------------------------------------------------------------------------------