├── rlkit
├── __init__.py
├── torch
│ ├── __init__.py
│ ├── ddpg
│ │ └── __init__.py
│ ├── dqn
│ │ ├── __init__.py
│ │ ├── double_dqn.py
│ │ └── dqn.py
│ ├── her
│ │ ├── __init__.py
│ │ └── her.py
│ ├── sac
│ │ ├── __init__.py
│ │ └── policies.py
│ ├── td3
│ │ └── __init__.py
│ ├── data_management
│ │ ├── __init__.py
│ │ └── normalizer.py
│ ├── vae
│ │ ├── vae_schedules.py
│ │ └── vae_base.py
│ ├── modules.py
│ ├── torch_rl_algorithm.py
│ ├── core.py
│ ├── data.py
│ ├── distributions.py
│ ├── pytorch_util.py
│ ├── networks.py
│ └── skewfit
│ │ └── video_gen.py
├── policies
│ ├── __init__.py
│ ├── simple.py
│ ├── base.py
│ └── argmax.py
├── samplers
│ ├── __init__.py
│ ├── data_collector
│ │ ├── __init__.py
│ │ ├── vae_env.py
│ │ ├── base.py
│ │ └── path_collector.py
│ ├── util.py
│ └── rollout_functions.py
├── data_management
│ ├── __init__.py
│ ├── path_builder.py
│ ├── env_replay_buffer.py
│ ├── replay_buffer.py
│ ├── simple_replay_buffer.py
│ ├── normalizer.py
│ └── shared_obs_dict_replay_buffer.py
├── exploration_strategies
│ ├── __init__.py
│ ├── epsilon_greedy.py
│ ├── gaussian_strategy.py
│ ├── base.py
│ ├── gaussian_and_epsilon_strategy.py
│ └── ou_strategy.py
├── core
│ ├── __init__.py
│ ├── trainer.py
│ ├── serializable.py
│ ├── batch_rl_algorithm.py
│ ├── online_rl_algorithm.py
│ ├── eval_util.py
│ └── rl_algorithm.py
├── launchers
│ ├── __init__.py
│ └── conf.py
├── envs
│ ├── env_utils.py
│ ├── mujoco_env.py
│ ├── ant.py
│ ├── goal_generation
│ │ └── pickup_goal_dataset.py
│ ├── assets
│ │ ├── reacher_7dof.xml
│ │ └── low_gear_ratio_ant.xml
│ ├── mujoco_image_env.py
│ └── wrappers.py
└── util
│ ├── ml_util.py
│ ├── io.py
│ └── video.py
├── docs
├── images
│ ├── her_dqn.png
│ ├── skewfit_door.png
│ ├── skewfit_pickup.png
│ ├── skewfit_pusher.png
│ ├── FetchReach-v1_HER-TD3.png
│ ├── her_td3_sawyer_reacher.png
│ └── SawyerReachXYZEnv-v0_HER-TD3.png
├── RIG.md
├── SkewFit.md
├── goal_based_envs.md
├── TDMs.md
└── HER.md
├── environment
├── docker
│ ├── vendor
│ │ ├── 10_nvidia.json
│ │ └── Xdummy-entrypoint
│ └── Dockerfile
├── mac-env.yml
├── linux-gpu-env.yml
└── linux-cpu-env.yml
├── .gitignore
├── setup.py
├── LICENSE
├── scripts
├── run_policy.py
├── run_goal_conditioned_policy.py
└── run_experiment_from_doodad.py
└── examples
├── doodad
├── ec2_example.py
└── gcp_example.py
├── dqn_and_double_dqn.py
├── ddpg.py
├── sac.py
├── her
├── her_dqn_gridworld.py
├── her_sac_gym_fetch_reach.py
└── her_td3_multiworld_sawyer_reach.py
├── td3.py
└── skewfit
├── sawyer_door.py
└── sawyer_push.py
/rlkit/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/rlkit/torch/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/rlkit/policies/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/rlkit/samplers/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/rlkit/torch/ddpg/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/rlkit/torch/dqn/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/rlkit/torch/her/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/rlkit/torch/sac/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/rlkit/torch/td3/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/rlkit/data_management/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/rlkit/exploration_strategies/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/rlkit/torch/data_management/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/docs/images/her_dqn.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mimoralea/rlkit/master/docs/images/her_dqn.png
--------------------------------------------------------------------------------
/docs/images/skewfit_door.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mimoralea/rlkit/master/docs/images/skewfit_door.png
--------------------------------------------------------------------------------
/docs/images/skewfit_pickup.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mimoralea/rlkit/master/docs/images/skewfit_pickup.png
--------------------------------------------------------------------------------
/docs/images/skewfit_pusher.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mimoralea/rlkit/master/docs/images/skewfit_pusher.png
--------------------------------------------------------------------------------
/docs/images/FetchReach-v1_HER-TD3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mimoralea/rlkit/master/docs/images/FetchReach-v1_HER-TD3.png
--------------------------------------------------------------------------------
/docs/images/her_td3_sawyer_reacher.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mimoralea/rlkit/master/docs/images/her_td3_sawyer_reacher.png
--------------------------------------------------------------------------------
/docs/images/SawyerReachXYZEnv-v0_HER-TD3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mimoralea/rlkit/master/docs/images/SawyerReachXYZEnv-v0_HER-TD3.png
--------------------------------------------------------------------------------
/environment/docker/vendor/10_nvidia.json:
--------------------------------------------------------------------------------
1 | {
2 | "file_format_version" : "1.0.0",
3 | "ICD" : {
4 | "library_path" : "libEGL_nvidia.so.0"
5 | }
6 | }
7 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | */*/mjkey.txt
2 | **/.DS_STORE
3 | **/*.pyc
4 | **/*.swp
5 | rlkit/launchers/config.py
6 | rlkit/launchers/conf_private.py
7 | MANIFEST
8 | *.egg-info
9 | \.idea/
10 |
--------------------------------------------------------------------------------
/rlkit/core/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | General classes, functions, utilities that are used throughout rlkit.
3 | """
4 | from rlkit.core.logging import logger
5 |
6 | __all__ = ['logger']
7 |
8 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from distutils.core import setup
2 | from setuptools import find_packages
3 |
4 | setup(
5 | name='rlkit',
6 | version='0.2.1dev',
7 | packages=find_packages(),
8 | license='MIT License',
9 | long_description=open('README.md').read(),
10 | )
--------------------------------------------------------------------------------
/rlkit/core/trainer.py:
--------------------------------------------------------------------------------
1 | import abc
2 |
3 |
4 | class Trainer(object, metaclass=abc.ABCMeta):
5 | @abc.abstractmethod
6 | def train(self, data):
7 | pass
8 |
9 | def end_epoch(self, epoch):
10 | pass
11 |
12 | def get_snapshot(self):
13 | return {}
14 |
15 | def get_diagnostics(self):
16 | return {}
17 |
--------------------------------------------------------------------------------
/rlkit/launchers/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | This module contains 'launchers', which are self-contained functions that take
3 | one dictionary and run a full experiment. The dictionary configures the
4 | experiment.
5 |
6 | It is important that the functions are completely self-contained (i.e. they
7 | import their own modules) so that they can be serialized.
8 | """
9 |
--------------------------------------------------------------------------------
/rlkit/policies/simple.py:
--------------------------------------------------------------------------------
1 | from rlkit.policies.base import Policy
2 |
3 |
4 | class RandomPolicy(Policy):
5 | """
6 | Policy that always outputs zero.
7 | """
8 |
9 | def __init__(self, action_space):
10 | self.action_space = action_space
11 |
12 | def get_action(self, obs):
13 | return self.action_space.sample(), {}
14 |
--------------------------------------------------------------------------------
/rlkit/samplers/data_collector/__init__.py:
--------------------------------------------------------------------------------
1 | from rlkit.samplers.data_collector.base import (
2 | DataCollector,
3 | PathCollector,
4 | StepCollector,
5 | )
6 | from rlkit.samplers.data_collector.path_collector import (
7 | MdpPathCollector,
8 | GoalConditionedPathCollector,
9 | )
10 | from rlkit.samplers.data_collector.step_collector import (
11 | GoalConditionedStepCollector
12 | )
13 |
--------------------------------------------------------------------------------
/rlkit/policies/base.py:
--------------------------------------------------------------------------------
1 | import abc
2 |
3 |
4 | class Policy(object, metaclass=abc.ABCMeta):
5 | """
6 | General policy interface.
7 | """
8 | @abc.abstractmethod
9 | def get_action(self, observation):
10 | """
11 |
12 | :param observation:
13 | :return: action, debug_dictionary
14 | """
15 | pass
16 |
17 | def reset(self):
18 | pass
19 |
20 |
21 | class ExplorationPolicy(Policy, metaclass=abc.ABCMeta):
22 | def set_num_steps_total(self, t):
23 | pass
24 |
--------------------------------------------------------------------------------
/rlkit/policies/argmax.py:
--------------------------------------------------------------------------------
1 | """
2 | Torch argmax policy
3 | """
4 | import numpy as np
5 | from torch import nn
6 |
7 | import rlkit.torch.pytorch_util as ptu
8 | from rlkit.policies.base import Policy
9 |
10 |
11 | class ArgmaxDiscretePolicy(nn.Module, Policy):
12 | def __init__(self, qf):
13 | super().__init__()
14 | self.qf = qf
15 |
16 | def get_action(self, obs):
17 | obs = np.expand_dims(obs, axis=0)
18 | obs = ptu.from_numpy(obs).float()
19 | q_values = self.qf(obs).squeeze(0)
20 | q_values_np = ptu.get_numpy(q_values)
21 | return q_values_np.argmax(), {}
22 |
--------------------------------------------------------------------------------
/rlkit/exploration_strategies/epsilon_greedy.py:
--------------------------------------------------------------------------------
1 | import random
2 |
3 | from rlkit.exploration_strategies.base import RawExplorationStrategy
4 |
5 |
6 | class EpsilonGreedy(RawExplorationStrategy):
7 | """
8 | Take a random discrete action with some probability.
9 | """
10 | def __init__(self, action_space, prob_random_action=0.1):
11 | self.prob_random_action = prob_random_action
12 | self.action_space = action_space
13 |
14 | def get_action_from_raw_action(self, action, **kwargs):
15 | if random.random() <= self.prob_random_action:
16 | return self.action_space.sample()
17 | return action
18 |
--------------------------------------------------------------------------------
/docs/RIG.md:
--------------------------------------------------------------------------------
1 | # Reinforcement Learning with Imagined Goals
2 | Implementation of reinforcement learning with imagined goals (RIG)
3 | To find out
4 | more, see any of the following links:
5 | * arXiv: https://arxiv.org/abs/1807.04742
6 | * Website: https://sites.google.com/site/visualrlwithimaginedgoals/
7 | * Blog Post: https://bair.berkeley.edu/blog/2018/09/06/rig/
8 |
9 | To see the original implementation, checkout version `v0.1.2` of this repo.
10 |
11 | In versions `0.2+`, RIG is a special case of [Skew-Fit](SkewFit.md) with the
12 | power set to `0`.
13 |
14 | ## Goal-based environments and `ObsDictRelabelingBuffer`
15 | [See here.](goal_based_envs.md)
16 |
--------------------------------------------------------------------------------
/environment/docker/vendor/Xdummy-entrypoint:
--------------------------------------------------------------------------------
1 | #!/usr/bin/python
2 | import argparse
3 | import os
4 | import sys
5 | import subprocess
6 |
7 | parser = argparse.ArgumentParser()
8 | args, extra_args = parser.parse_known_args()
9 | subprocess.Popen(["nohup", "Xdummy"], stdout=open('/dev/null', 'w'), stderr=open('/dev/null', 'w'))
10 | os.environ['DISPLAY'] = ':0'
11 | if not extra_args:
12 | sys.argv = ['/bin/bash']
13 | else:
14 | sys.argv = extra_args
15 | # Explicitly flush right before the exec since otherwise things might get
16 | # lost in Python's buffers around stdout/stderr (!).
17 | sys.stdout.flush()
18 | sys.stderr.flush()
19 | os.execvpe(sys.argv[0], sys.argv, os.environ)
20 |
21 |
--------------------------------------------------------------------------------
/environment/mac-env.yml:
--------------------------------------------------------------------------------
1 | name: rlkit
2 | channels:
3 | - kne # for pybox2d
4 | - pytorch
5 | - anaconda # for mkl
6 | dependencies:
7 | - cython
8 | - ipython # technically unnecessary
9 | - joblib=0.9.4
10 | - lockfile
11 | - mako=1.0.6=py35_0
12 | - matplotlib=2.0.2=np111py35_0
13 | - mkl=2018.0.2 # Need to add explicit dependence for pytorch
14 | - numba=0.35.0=np111py35_0
15 | - numpy=1.11.3
16 | - path.py=10.3.1=py35_0
17 | - pybox2d=2.3.1post2=py35_0
18 | - python=3.5.2
19 | - python-dateutil=2.6.1=py35_0
20 | - pytorch=0.4.1
21 | - scipy=1.0.1
22 | - pip:
23 | - cloudpickle==0.5.2
24 | - gym[all]==0.10.5
25 | - gitpython==2.1.7
26 | - gtimer==1.0.0b5
27 | - pygame==1.9.2
28 | - ipdb # technically unnecessary
29 |
--------------------------------------------------------------------------------
/environment/linux-gpu-env.yml:
--------------------------------------------------------------------------------
1 | name: rlkit
2 | channels:
3 | - kne # for pybox2d
4 | - pytorch
5 | - anaconda # for mkl
6 | dependencies:
7 | - cython
8 | - ipython # technically unnecessary
9 | - joblib=0.9.4
10 | - lockfile
11 | - mako=1.0.6=py35_0
12 | - matplotlib=2.0.2=np111py35_0
13 | - mkl=2018.0.2 # Need to add explicit dependence for pytorch
14 | - numba=0.35.0=np111py35_0
15 | - numpy=1.11.3
16 | - path.py=10.3.1=py35_0
17 | - pybox2d=2.3.1post2=py35_0
18 | - python=3.5.2
19 | - python-dateutil=2.6.1=py35_0
20 | - pytorch=0.4.1
21 | - scipy=1.0.1
22 | - patchelf
23 | - pip:
24 | - cloudpickle==0.5.2
25 | - gym[all]==0.10.5
26 | - gitpython==2.1.7
27 | - gtimer==1.0.0b5
28 | - pygame==1.9.2
29 | - ipdb # technically unnecessary
30 |
--------------------------------------------------------------------------------
/environment/linux-cpu-env.yml:
--------------------------------------------------------------------------------
1 | name: rlkit
2 | channels:
3 | - kne # for pybox2d
4 | - pytorch
5 | - anaconda # for mkl
6 | dependencies:
7 | - cython
8 | - ipython # technically unnecessary
9 | - joblib=0.9.4
10 | - lockfile
11 | - mako=1.0.6=py35_0
12 | - matplotlib=2.0.2=np111py35_0
13 | - mkl=2018.0.2 # Need to add explicit dependence for pytorch
14 | - numba=0.35.0=np111py35_0
15 | - numpy=1.11.3
16 | - path.py=10.3.1=py35_0
17 | - pybox2d=2.3.1post2=py35_0
18 | - python=3.5.2
19 | - python-dateutil=2.6.1=py35_0
20 | - pytorch-cpu=0.4.1
21 | - scipy=1.0.1
22 | - patchelf
23 | - pip:
24 | - cloudpickle==0.5.2
25 | - gym[all]==0.10.5
26 | - gitpython==2.1.7
27 | - gtimer==1.0.0b5
28 | - pygame==1.9.2
29 | - ipdb # technically unnecessary
30 |
--------------------------------------------------------------------------------
/rlkit/envs/env_utils.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | from gym.spaces import Box, Discrete, Tuple
4 |
5 | ENV_ASSET_DIR = os.path.join(os.path.dirname(__file__), 'assets')
6 |
7 |
8 | def get_asset_full_path(file_name):
9 | return os.path.join(ENV_ASSET_DIR, file_name)
10 |
11 |
12 | def get_dim(space):
13 | if isinstance(space, Box):
14 | return space.low.size
15 | elif isinstance(space, Discrete):
16 | return space.n
17 | elif isinstance(space, Tuple):
18 | return sum(get_dim(subspace) for subspace in space.spaces)
19 | elif hasattr(space, 'flat_dim'):
20 | return space.flat_dim
21 | else:
22 | raise TypeError("Unknown space: {}".format(space))
23 |
24 |
25 | def mode(env, mode_type):
26 | try:
27 | getattr(env, mode_type)()
28 | except AttributeError:
29 | pass
30 |
--------------------------------------------------------------------------------
/rlkit/samplers/data_collector/vae_env.py:
--------------------------------------------------------------------------------
1 | from rlkit.envs.vae_wrapper import VAEWrappedEnv
2 | from rlkit.samplers.data_collector import GoalConditionedPathCollector
3 |
4 |
5 | class VAEWrappedEnvPathCollector(GoalConditionedPathCollector):
6 | def __init__(
7 | self,
8 | goal_sampling_mode,
9 | env: VAEWrappedEnv,
10 | policy,
11 | decode_goals=False,
12 | **kwargs
13 | ):
14 | super().__init__(env, policy, **kwargs)
15 | self._goal_sampling_mode = goal_sampling_mode
16 | self._decode_goals = decode_goals
17 |
18 | def collect_new_paths(self, *args, **kwargs):
19 | self._env.goal_sampling_mode = self._goal_sampling_mode
20 | self._env.decode_goals = self._decode_goals
21 | return super().collect_new_paths(*args, **kwargs)
--------------------------------------------------------------------------------
/rlkit/samplers/data_collector/base.py:
--------------------------------------------------------------------------------
1 | import abc
2 |
3 |
4 | class DataCollector(object, metaclass=abc.ABCMeta):
5 | def end_epoch(self, epoch):
6 | pass
7 |
8 | def get_diagnostics(self):
9 | return {}
10 |
11 | def get_snapshot(self):
12 | return {}
13 |
14 | @abc.abstractmethod
15 | def get_epoch_paths(self):
16 | pass
17 |
18 |
19 | class PathCollector(DataCollector, metaclass=abc.ABCMeta):
20 | @abc.abstractmethod
21 | def collect_new_paths(
22 | self,
23 | max_path_length,
24 | num_steps,
25 | discard_incomplete_paths,
26 | ):
27 | pass
28 |
29 |
30 | class StepCollector(DataCollector, metaclass=abc.ABCMeta):
31 | @abc.abstractmethod
32 | def collect_new_steps(
33 | self,
34 | max_path_length,
35 | num_steps,
36 | discard_incomplete_paths,
37 | ):
38 | pass
39 |
--------------------------------------------------------------------------------
/rlkit/torch/her/her.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from rlkit.torch.torch_rl_algorithm import TorchTrainer
4 |
5 |
6 | class HERTrainer(TorchTrainer):
7 | def __init__(self, base_trainer: TorchTrainer):
8 | super().__init__()
9 | self._base_trainer = base_trainer
10 |
11 | def train_from_torch(self, data):
12 | obs = data['observations']
13 | next_obs = data['next_observations']
14 | goals = data['resampled_goals']
15 | data['observations'] = torch.cat((obs, goals), dim=1)
16 | data['next_observations'] = torch.cat((next_obs, goals), dim=1)
17 | self._base_trainer.train_from_torch(data)
18 |
19 | def get_diagnostics(self):
20 | return self._base_trainer.get_diagnostics()
21 |
22 | def end_epoch(self, epoch):
23 | self._base_trainer.end_epoch(epoch)
24 |
25 | @property
26 | def networks(self):
27 | return self._base_trainer.networks
28 |
29 | def get_snapshot(self):
30 | return self._base_trainer.get_snapshot()
31 |
--------------------------------------------------------------------------------
/docs/SkewFit.md:
--------------------------------------------------------------------------------
1 | # Skew-Fit
2 | Requires [multiworld](https://github.com/vitchyr/multiworld) to be installed:
3 | ```
4 | pip install git+https://github.com/vitchyr/multiworld.git@28ee206f60a45690d484737466b558abdef191ea
5 | ```
6 |
7 | Implementation of Skew-Fit. For more information:
8 | - [Videos](https://sites.google.com/view/skew-fit)
9 | - [arXiv](https://arxiv.org/abs/1903.03698)
10 |
11 | Here are the results you should expect from each script.
12 | These plots are generated with [viskit](https://github.com/vitchyr/viskit)
13 | with smoothing on.
14 |
15 | Note that [RIG](RIG.md) is a special-case of Skew-Fit with `power=0`.
16 |
17 |
18 | [examples/skewfit/sawyer_door.py](../examples/skewfit/sawyer_door.py). 1 Seed:
19 | 
20 |
21 | [examples/skewfit/sawyer_pickup.py](../examples/skewfit/sawyer_pickup.py). 3 Seeds:
22 | 
23 |
24 | [examples/skewfit/sawyer_pusher.py](../examples/skewfit/sawyer_pusher.py). 9 Seeds:
25 | 
26 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2018 Vitchyr Pong
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/rlkit/exploration_strategies/gaussian_strategy.py:
--------------------------------------------------------------------------------
1 | from rlkit.exploration_strategies.base import RawExplorationStrategy
2 | import numpy as np
3 |
4 |
5 | class GaussianStrategy(RawExplorationStrategy):
6 | """
7 | This strategy adds Gaussian noise to the action taken by the deterministic policy.
8 |
9 | Based on the rllab implementation.
10 | """
11 | def __init__(self, action_space, max_sigma=1.0, min_sigma=None,
12 | decay_period=1000000):
13 | assert len(action_space.shape) == 1
14 | self._max_sigma = max_sigma
15 | if min_sigma is None:
16 | min_sigma = max_sigma
17 | self._min_sigma = min_sigma
18 | self._decay_period = decay_period
19 | self._action_space = action_space
20 |
21 | def get_action_from_raw_action(self, action, t=None, **kwargs):
22 | sigma = (
23 | self._max_sigma - (self._max_sigma - self._min_sigma) *
24 | min(1.0, t * 1.0 / self._decay_period)
25 | )
26 | return np.clip(
27 | action + np.random.normal(size=len(action)) * sigma,
28 | self._action_space.low,
29 | self._action_space.high,
30 | )
31 |
--------------------------------------------------------------------------------
/rlkit/torch/vae/vae_schedules.py:
--------------------------------------------------------------------------------
1 | def always_train(epoch):
2 | return True, 300
3 |
4 |
5 | def custom_schedule(epoch):
6 | if epoch < 10:
7 | return True, 1000
8 | elif epoch < 300:
9 | return True, 200
10 | else:
11 | return epoch % 3 == 0, 200
12 |
13 |
14 | def custom_schedule_2(epoch):
15 | if epoch < 10:
16 | return True, 1000
17 | elif epoch < 100:
18 | return True, 200
19 | else:
20 | return epoch % 2 == 0, 200
21 |
22 |
23 | def every_other(epoch):
24 | return epoch % 2 == 0, 400
25 |
26 |
27 | def every_three(epoch):
28 | return epoch % 3 == 0, 600
29 |
30 |
31 | def every_three_a_lot(epoch):
32 | return epoch % 3 == 0, 1200
33 |
34 |
35 | def every_six(epoch):
36 | return epoch % 6 == 0, 1200
37 |
38 |
39 | def every_six_less(epoch):
40 | return epoch % 6 == 0, 600
41 |
42 |
43 | def every_six_much_less(epoch):
44 | return epoch % 6 == 0, 300
45 |
46 |
47 | def every_ten(epoch):
48 | return epoch % 10 == 0 or epoch == 5, 1000
49 |
50 |
51 | def every_twenty(epoch):
52 | return epoch % 10 == 0 or epoch == 5 or epoch == 10, 1000
53 |
54 |
55 | def never_train(epoch):
56 | return False, 0
57 |
--------------------------------------------------------------------------------
/scripts/run_policy.py:
--------------------------------------------------------------------------------
1 | from rlkit.samplers.rollout_functions import rollout
2 | from rlkit.torch.pytorch_util import set_gpu_mode
3 | import argparse
4 | import torch
5 | import uuid
6 | from rlkit.core import logger
7 |
8 | filename = str(uuid.uuid4())
9 |
10 |
11 | def simulate_policy(args):
12 | data = torch.load(args.file)
13 | policy = data['evaluation/policy']
14 | env = data['evaluation/env']
15 | print("Policy loaded")
16 | if args.gpu:
17 | set_gpu_mode(True)
18 | policy.cuda()
19 | while True:
20 | path = rollout(
21 | env,
22 | policy,
23 | max_path_length=args.H,
24 | render=True,
25 | )
26 | if hasattr(env, "log_diagnostics"):
27 | env.log_diagnostics([path])
28 | logger.dump_tabular()
29 |
30 |
31 | if __name__ == "__main__":
32 | parser = argparse.ArgumentParser()
33 | parser.add_argument('file', type=str,
34 | help='path to the snapshot file')
35 | parser.add_argument('--H', type=int, default=300,
36 | help='Max length of rollout')
37 | parser.add_argument('--gpu', action='store_true')
38 | args = parser.parse_args()
39 |
40 | simulate_policy(args)
41 |
--------------------------------------------------------------------------------
/rlkit/exploration_strategies/base.py:
--------------------------------------------------------------------------------
1 | import abc
2 |
3 | from rlkit.policies.base import ExplorationPolicy
4 |
5 |
6 | class ExplorationStrategy(object, metaclass=abc.ABCMeta):
7 | @abc.abstractmethod
8 | def get_action(self, t, observation, policy, **kwargs):
9 | pass
10 |
11 | def reset(self):
12 | pass
13 |
14 |
15 | class RawExplorationStrategy(ExplorationStrategy, metaclass=abc.ABCMeta):
16 | @abc.abstractmethod
17 | def get_action_from_raw_action(self, action, **kwargs):
18 | pass
19 |
20 | def get_action(self, t, policy, *args, **kwargs):
21 | action, agent_info = policy.get_action(*args, **kwargs)
22 | return self.get_action_from_raw_action(action, t=t), agent_info
23 |
24 | def reset(self):
25 | pass
26 |
27 |
28 | class PolicyWrappedWithExplorationStrategy(ExplorationPolicy):
29 | def __init__(
30 | self,
31 | exploration_strategy: ExplorationStrategy,
32 | policy,
33 | ):
34 | self.es = exploration_strategy
35 | self.policy = policy
36 | self.t = 0
37 |
38 | def set_num_steps_total(self, t):
39 | self.t = t
40 |
41 | def get_action(self, *args, **kwargs):
42 | return self.es.get_action(self.t, self.policy, *args, **kwargs)
43 |
44 | def reset(self):
45 | self.es.reset()
46 | self.policy.reset()
47 |
--------------------------------------------------------------------------------
/rlkit/exploration_strategies/gaussian_and_epsilon_strategy.py:
--------------------------------------------------------------------------------
1 | import random
2 | from rlkit.exploration_strategies.base import RawExplorationStrategy
3 | import numpy as np
4 |
5 |
6 | class GaussianAndEpislonStrategy(RawExplorationStrategy):
7 | """
8 | With probability epsilon, take a completely random action.
9 | with probability 1-epsilon, add Gaussian noise to the action taken by a
10 | deterministic policy.
11 | """
12 | def __init__(self, action_space, epsilon, max_sigma=1.0, min_sigma=None,
13 | decay_period=1000000):
14 | assert len(action_space.shape) == 1
15 | if min_sigma is None:
16 | min_sigma = max_sigma
17 | self._max_sigma = max_sigma
18 | self._epsilon = epsilon
19 | self._min_sigma = min_sigma
20 | self._decay_period = decay_period
21 | self._action_space = action_space
22 |
23 | def get_action_from_raw_action(self, action, t=None, **kwargs):
24 | if random.random() < self._epsilon:
25 | return self._action_space.sample()
26 | else:
27 | sigma = self._max_sigma - (self._max_sigma - self._min_sigma) * min(1.0, t * 1.0 / self._decay_period)
28 | return np.clip(
29 | action + np.random.normal(size=len(action)) * sigma,
30 | self._action_space.low,
31 | self._action_space.high,
32 | )
33 |
--------------------------------------------------------------------------------
/rlkit/torch/modules.py:
--------------------------------------------------------------------------------
1 | """
2 | Contain some self-contained modules.
3 | """
4 | import torch
5 | import torch.nn as nn
6 |
7 |
8 | class HuberLoss(nn.Module):
9 | def __init__(self, delta=1):
10 | super().__init__()
11 | self.huber_loss_delta1 = nn.SmoothL1Loss()
12 | self.delta = delta
13 |
14 | def forward(self, x, x_hat):
15 | loss = self.huber_loss_delta1(x / self.delta, x_hat / self.delta)
16 | return loss * self.delta * self.delta
17 |
18 |
19 | class LayerNorm(nn.Module):
20 | """
21 | Simple 1D LayerNorm.
22 | """
23 |
24 | def __init__(self, features, center=True, scale=False, eps=1e-6):
25 | super().__init__()
26 | self.center = center
27 | self.scale = scale
28 | self.eps = eps
29 | if self.scale:
30 | self.scale_param = nn.Parameter(torch.ones(features))
31 | else:
32 | self.scale_param = None
33 | if self.center:
34 | self.center_param = nn.Parameter(torch.zeros(features))
35 | else:
36 | self.center_param = None
37 |
38 | def forward(self, x):
39 | mean = x.mean(-1, keepdim=True)
40 | std = x.std(-1, keepdim=True)
41 | output = (x - mean) / (std + self.eps)
42 | if self.scale:
43 | output = output * self.scale_param
44 | if self.center:
45 | output = output + self.center_param
46 | return output
47 |
--------------------------------------------------------------------------------
/docs/goal_based_envs.md:
--------------------------------------------------------------------------------
1 | # Goal-based environments and `ObsDictRelabelingBuffer`
2 | Some algorithms, like HER, are for goal-conditioned environments, like
3 | the [OpenAI Gym GoalEnv](https://blog.openai.com/ingredients-for-robotics-research/)
4 | or the [multiworld MultitaskEnv](https://github.com/vitchyr/multiworld/)
5 | environments.
6 |
7 | These environments are different from normal gym environments in that they
8 | return dictionaries for observations, like so: the environments work like this:
9 |
10 | ```
11 | env = CarEnv()
12 | obs = env.reset()
13 | next_obs, reward, done, info = env.step(action)
14 | print(obs)
15 |
16 | # Output:
17 | # {
18 | # 'observation': ...,
19 | # 'desired_goal': ...,
20 | # 'achieved_goal': ...,
21 | # }
22 | ```
23 | The `GoalEnv` environments also have a function with signature
24 | ```
25 | def compute_rewards (achieved_goal, desired_goal):
26 | # achieved_goal and desired_goal are vectors
27 | ```
28 | while the `MultitaskEnv` has a signature like
29 | ```
30 | def compute_rewards (observation, action, next_observation):
31 | # observation and next_observations are dictionaries
32 | ```
33 | To learn more about these environments, check out the URLs above.
34 | This means that normal RL algorithms won't even "type check" with these
35 | environments.
36 |
37 | `ObsDictRelabelingBuffer` perform hindsight experience replay with
38 | either types of environments and works by saving specific values in the
39 | observation dictionary.
40 |
41 |
--------------------------------------------------------------------------------
/rlkit/data_management/path_builder.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 |
4 | class PathBuilder(dict):
5 | """
6 | Usage:
7 | ```
8 | path_builder = PathBuilder()
9 | path.add_sample(
10 | observations=1,
11 | actions=2,
12 | next_observations=3,
13 | ...
14 | )
15 | path.add_sample(
16 | observations=4,
17 | actions=5,
18 | next_observations=6,
19 | ...
20 | )
21 |
22 | path = path_builder.get_all_stacked()
23 |
24 | path['observations']
25 | # output: [1, 4]
26 | path['actions']
27 | # output: [2, 5]
28 | ```
29 |
30 | Note that the key should be "actions" and not "action" since the
31 | resulting dictionary will have those keys.
32 | """
33 |
34 | def __init__(self):
35 | super().__init__()
36 | self._path_length = 0
37 |
38 | def add_all(self, **key_to_value):
39 | for k, v in key_to_value.items():
40 | if k not in self:
41 | self[k] = [v]
42 | else:
43 | self[k].append(v)
44 | self._path_length += 1
45 |
46 | def get_all_stacked(self):
47 | output_dict = dict()
48 | for k, v in self.items():
49 | output_dict[k] = stack_list(v)
50 | return output_dict
51 |
52 | def __len__(self):
53 | return self._path_length
54 |
55 |
56 | def stack_list(lst):
57 | if isinstance(lst[0], dict):
58 | return lst
59 | else:
60 | return np.array(lst)
61 |
--------------------------------------------------------------------------------
/rlkit/torch/torch_rl_algorithm.py:
--------------------------------------------------------------------------------
1 | import abc
2 | from collections import OrderedDict
3 |
4 | from typing import Iterable
5 | from torch import nn as nn
6 |
7 | from rlkit.core.batch_rl_algorithm import BatchRLAlgorithm
8 | from rlkit.core.online_rl_algorithm import OnlineRLAlgorithm
9 | from rlkit.core.trainer import Trainer
10 | from rlkit.torch.core import np_to_pytorch_batch
11 |
12 |
13 | class TorchOnlineRLAlgorithm(OnlineRLAlgorithm):
14 | def to(self, device):
15 | for net in self.trainer.networks:
16 | net.to(device)
17 |
18 | def training_mode(self, mode):
19 | for net in self.trainer.networks:
20 | net.train(mode)
21 |
22 |
23 | class TorchBatchRLAlgorithm(BatchRLAlgorithm):
24 | def to(self, device):
25 | for net in self.trainer.networks:
26 | net.to(device)
27 |
28 | def training_mode(self, mode):
29 | for net in self.trainer.networks:
30 | net.train(mode)
31 |
32 |
33 | class TorchTrainer(Trainer, metaclass=abc.ABCMeta):
34 | def __init__(self):
35 | self._num_train_steps = 0
36 |
37 | def train(self, np_batch):
38 | self._num_train_steps += 1
39 | batch = np_to_pytorch_batch(np_batch)
40 | self.train_from_torch(batch)
41 |
42 | def get_diagnostics(self):
43 | return OrderedDict([
44 | ('num train calls', self._num_train_steps),
45 | ])
46 |
47 | @abc.abstractmethod
48 | def train_from_torch(self, batch):
49 | pass
50 |
51 | @property
52 | @abc.abstractmethod
53 | def networks(self) -> Iterable[nn.Module]:
54 | pass
55 |
--------------------------------------------------------------------------------
/rlkit/data_management/env_replay_buffer.py:
--------------------------------------------------------------------------------
1 | from gym.spaces import Discrete
2 |
3 | from rlkit.data_management.simple_replay_buffer import SimpleReplayBuffer
4 | from rlkit.envs.env_utils import get_dim
5 | import numpy as np
6 |
7 |
8 | class EnvReplayBuffer(SimpleReplayBuffer):
9 | def __init__(
10 | self,
11 | max_replay_buffer_size,
12 | env,
13 | env_info_sizes=None
14 | ):
15 | """
16 | :param max_replay_buffer_size:
17 | :param env:
18 | """
19 | self.env = env
20 | self._ob_space = env.observation_space
21 | self._action_space = env.action_space
22 |
23 | if env_info_sizes is None:
24 | if hasattr(env, 'info_sizes'):
25 | env_info_sizes = env.info_sizes
26 | else:
27 | env_info_sizes = dict()
28 |
29 | super().__init__(
30 | max_replay_buffer_size=max_replay_buffer_size,
31 | observation_dim=get_dim(self._ob_space),
32 | action_dim=get_dim(self._action_space),
33 | env_info_sizes=env_info_sizes
34 | )
35 |
36 | def add_sample(self, observation, action, reward, terminal,
37 | next_observation, **kwargs):
38 | if isinstance(self._action_space, Discrete):
39 | new_action = np.zeros(self._action_dim)
40 | new_action[action] = 1
41 | else:
42 | new_action = action
43 | return super().add_sample(
44 | observation=observation,
45 | action=new_action,
46 | reward=reward,
47 | next_observation=next_observation,
48 | terminal=terminal,
49 | **kwargs
50 | )
51 |
--------------------------------------------------------------------------------
/examples/doodad/ec2_example.py:
--------------------------------------------------------------------------------
1 | """
2 | Example of running stuff on EC2
3 | """
4 | import time
5 |
6 | from rlkit.core import logger
7 | from rlkit.launchers.launcher_util import run_experiment
8 | from datetime import datetime
9 | from pytz import timezone
10 | import pytz
11 |
12 |
13 | def example(variant):
14 | import torch
15 | logger.log(torch.__version__)
16 | date_format = '%m/%d/%Y %H:%M:%S %Z'
17 | date = datetime.now(tz=pytz.utc)
18 | logger.log("start")
19 | logger.log('Current date & time is: {}'.format(date.strftime(date_format)))
20 | if torch.cuda.is_available():
21 | x = torch.randn(3)
22 | logger.log(str(x.to(ptu.device)))
23 |
24 | date = date.astimezone(timezone('US/Pacific'))
25 | logger.log('Local date & time is: {}'.format(date.strftime(date_format)))
26 | for i in range(variant['num_seconds']):
27 | logger.log("Tick, {}".format(i))
28 | time.sleep(1)
29 | logger.log("end")
30 | logger.log('Local date & time is: {}'.format(date.strftime(date_format)))
31 |
32 | logger.log("start mujoco")
33 | from gym.envs.mujoco import HalfCheetahEnv
34 | e = HalfCheetahEnv()
35 | img = e.sim.render(32, 32)
36 | logger.log(str(sum(img)))
37 | logger.log("end mujocoy")
38 |
39 |
40 | if __name__ == "__main__":
41 | # noinspection PyTypeChecker
42 | date_format = '%m/%d/%Y %H:%M:%S %Z'
43 | date = datetime.now(tz=pytz.utc)
44 | logger.log("start")
45 | variant = dict(
46 | num_seconds=10,
47 | launch_time=str(date.strftime(date_format)),
48 | )
49 | run_experiment(
50 | example,
51 | exp_prefix="ec2-test",
52 | mode='ec2',
53 | variant=variant,
54 | # use_gpu=True, # GPUs are much more expensive!
55 | )
56 |
--------------------------------------------------------------------------------
/rlkit/torch/core.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 |
4 | from rlkit.torch import pytorch_util as ptu
5 |
6 |
7 | def eval_np(module, *args, **kwargs):
8 | """
9 | Eval this module with a numpy interface
10 |
11 | Same as a call to __call__ except all Variable input/outputs are
12 | replaced with numpy equivalents.
13 |
14 | Assumes the output is either a single object or a tuple of objects.
15 | """
16 | torch_args = tuple(torch_ify(x) for x in args)
17 | torch_kwargs = {k: torch_ify(v) for k, v in kwargs.items()}
18 | outputs = module(*torch_args, **torch_kwargs)
19 | if isinstance(outputs, tuple):
20 | return tuple(np_ify(x) for x in outputs)
21 | else:
22 | return np_ify(outputs)
23 |
24 |
25 | def torch_ify(np_array_or_other):
26 | if isinstance(np_array_or_other, np.ndarray):
27 | return ptu.from_numpy(np_array_or_other)
28 | else:
29 | return np_array_or_other
30 |
31 |
32 | def np_ify(tensor_or_other):
33 | if isinstance(tensor_or_other, torch.autograd.Variable):
34 | return ptu.get_numpy(tensor_or_other)
35 | else:
36 | return tensor_or_other
37 |
38 |
39 | def _elem_or_tuple_to_variable(elem_or_tuple):
40 | if isinstance(elem_or_tuple, tuple):
41 | return tuple(
42 | _elem_or_tuple_to_variable(e) for e in elem_or_tuple
43 | )
44 | return ptu.from_numpy(elem_or_tuple).float()
45 |
46 |
47 | def _filter_batch(np_batch):
48 | for k, v in np_batch.items():
49 | if v.dtype == np.bool:
50 | yield k, v.astype(int)
51 | else:
52 | yield k, v
53 |
54 |
55 | def np_to_pytorch_batch(np_batch):
56 | return {
57 | k: _elem_or_tuple_to_variable(x)
58 | for k, x in _filter_batch(np_batch)
59 | if x.dtype != np.dtype('O') # ignore object (e.g. dictionaries)
60 | }
61 |
62 |
--------------------------------------------------------------------------------
/docs/TDMs.md:
--------------------------------------------------------------------------------
1 | # Temporal Difference Models (TDMs)
2 | The TDM implementation is a bit different from the other algorithms. One reason for this is that the goals and rewards are retroactively relabelled. Some notable implementation details:
3 | - The networks (policy and QF) take in the goal and tau (the number of time steps left).
4 | - The algorithm relabels the terminal and rewards, so the terminal/reward from the environment are ignored completely.
5 | - TdmNormalizer is used to normalize the observations/states. If you want, you can totally ignore it and set `num_pretrain_path=0`.
6 | - The environments need to be [MultitaskEnv](../rlkit/torch/tdm/envs/multitask_env), meaning standard gym environments won't work out of the box. See below for details.
7 |
8 | The example scripts have tuned hyperparameters. Specifically, the following hyperparameters are tuned, as they seem to be the most important ones to tune:
9 | - `num_updates_per_env_step`
10 | - `reward_scale`
11 | - `max_tau`
12 |
13 |
14 | ## Creating your own environment for TDM
15 | A [MultitaskEnv](../envs/multitask_env.py) instances needs to implement 3 functions:
16 |
17 | ```python
18 | def goal_dim(self) -> int:
19 | """
20 | :return: int, dimension of goal vector
21 | """
22 | pass
23 |
24 | @abc.abstractmethod
25 | def sample_goals(self, batch_size):
26 | pass
27 |
28 | @abc.abstractmethod
29 | def convert_obs_to_goals(self, obs):
30 | pass
31 | ```
32 |
33 | If you want to see how to make an existing environment multitask, see [GoalXVelHalfCheetah](../envs/half_cheetah_env.py), which builds off of Gym's HalfCheetah environments.
34 |
35 | Another useful example might be [GoalXYPosAnt](../envs/ant_env.py), which builds off a custom environment.
36 |
37 | One important thing is that the environment should *not* include the goal as part of the state, since the goal will be separately given to the networks.
38 |
--------------------------------------------------------------------------------
/rlkit/torch/dqn/double_dqn.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 |
4 | import rlkit.torch.pytorch_util as ptu
5 | from rlkit.core.eval_util import create_stats_ordered_dict
6 | from rlkit.torch.dqn.dqn import DQNTrainer
7 |
8 |
9 | class DoubleDQNTrainer(DQNTrainer):
10 | def train_from_torch(self, batch):
11 | rewards = batch['rewards']
12 | terminals = batch['terminals']
13 | obs = batch['observations']
14 | actions = batch['actions']
15 | next_obs = batch['next_observations']
16 |
17 | """
18 | Compute loss
19 | """
20 |
21 | best_action_idxs = self.qf(next_obs).max(
22 | 1, keepdim=True
23 | )[1]
24 | target_q_values = self.target_qf(next_obs).gather(
25 | 1, best_action_idxs
26 | ).detach()
27 | y_target = rewards + (1. - terminals) * self.discount * target_q_values
28 | y_target = y_target.detach()
29 | # actions is a one-hot vector
30 | y_pred = torch.sum(self.qf(obs) * actions, dim=1, keepdim=True)
31 | qf_loss = self.qf_criterion(y_pred, y_target)
32 |
33 | """
34 | Update networks
35 | """
36 | self.qf_optimizer.zero_grad()
37 | qf_loss.backward()
38 | self.qf_optimizer.step()
39 |
40 | """
41 | Soft target network updates
42 | """
43 | if self._n_train_steps_total % self.target_update_period == 0:
44 | ptu.soft_update_from_to(
45 | self.qf, self.target_qf, self.soft_target_tau
46 | )
47 |
48 | """
49 | Save some statistics for eval using just one batch.
50 | """
51 | if self._need_to_update_eval_statistics:
52 | self._need_to_update_eval_statistics = False
53 | self.eval_statistics['QF Loss'] = np.mean(ptu.get_numpy(qf_loss))
54 | self.eval_statistics.update(create_stats_ordered_dict(
55 | 'Y Predictions',
56 | ptu.get_numpy(y_pred),
57 | ))
58 |
--------------------------------------------------------------------------------
/rlkit/exploration_strategies/ou_strategy.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import numpy.random as nr
3 |
4 | from rlkit.exploration_strategies.base import RawExplorationStrategy
5 |
6 |
7 | class OUStrategy(RawExplorationStrategy):
8 | """
9 | This strategy implements the Ornstein-Uhlenbeck process, which adds
10 | time-correlated noise to the actions taken by the deterministic policy.
11 | The OU process satisfies the following stochastic differential equation:
12 | dxt = theta*(mu - xt)*dt + sigma*dWt
13 | where Wt denotes the Wiener process
14 |
15 | Based on the rllab implementation.
16 | """
17 |
18 | def __init__(
19 | self,
20 | action_space,
21 | mu=0,
22 | theta=0.15,
23 | max_sigma=0.3,
24 | min_sigma=None,
25 | decay_period=100000,
26 | ):
27 | if min_sigma is None:
28 | min_sigma = max_sigma
29 | self.mu = mu
30 | self.theta = theta
31 | self.sigma = max_sigma
32 | self._max_sigma = max_sigma
33 | if min_sigma is None:
34 | min_sigma = max_sigma
35 | self._min_sigma = min_sigma
36 | self._decay_period = decay_period
37 | self.dim = np.prod(action_space.low.shape)
38 | self.low = action_space.low
39 | self.high = action_space.high
40 | self.state = np.ones(self.dim) * self.mu
41 | self.reset()
42 |
43 | def reset(self):
44 | self.state = np.ones(self.dim) * self.mu
45 |
46 | def evolve_state(self):
47 | x = self.state
48 | dx = self.theta * (self.mu - x) + self.sigma * nr.randn(len(x))
49 | self.state = x + dx
50 | return self.state
51 |
52 | def get_action_from_raw_action(self, action, t=0, **kwargs):
53 | ou_state = self.evolve_state()
54 | self.sigma = (
55 | self._max_sigma
56 | - (self._max_sigma - self._min_sigma)
57 | * min(1.0, t * 1.0 / self._decay_period)
58 | )
59 | return np.clip(action + ou_state, self.low, self.high)
60 |
--------------------------------------------------------------------------------
/rlkit/envs/mujoco_env.py:
--------------------------------------------------------------------------------
1 | import os
2 | from os import path
3 |
4 | import mujoco_py
5 | import numpy as np
6 | from gym.envs.mujoco import mujoco_env
7 |
8 | from rlkit.core.serializable import Serializable
9 |
10 | ENV_ASSET_DIR = os.path.join(os.path.dirname(__file__), 'assets')
11 |
12 |
13 | class MujocoEnv(mujoco_env.MujocoEnv, Serializable):
14 | """
15 | My own wrapper around MujocoEnv.
16 |
17 | The caller needs to declare
18 | """
19 | def __init__(
20 | self,
21 | model_path,
22 | frame_skip=1,
23 | model_path_is_local=True,
24 | automatically_set_obs_and_action_space=False,
25 | ):
26 | if model_path_is_local:
27 | model_path = get_asset_xml(model_path)
28 | if automatically_set_obs_and_action_space:
29 | mujoco_env.MujocoEnv.__init__(self, model_path, frame_skip)
30 | else:
31 | """
32 | Code below is copy/pasted from MujocoEnv's __init__ function.
33 | """
34 | if model_path.startswith("/"):
35 | fullpath = model_path
36 | else:
37 | fullpath = os.path.join(os.path.dirname(__file__), "assets", model_path)
38 | if not path.exists(fullpath):
39 | raise IOError("File %s does not exist" % fullpath)
40 | self.frame_skip = frame_skip
41 | self.model = mujoco_py.MjModel(fullpath)
42 | self.data = self.model.data
43 | self.viewer = None
44 |
45 | self.metadata = {
46 | 'render.modes': ['human', 'rgb_array'],
47 | 'video.frames_per_second': int(np.round(1.0 / self.dt))
48 | }
49 |
50 | self.init_qpos = self.model.data.qpos.ravel().copy()
51 | self.init_qvel = self.model.data.qvel.ravel().copy()
52 | self._seed()
53 |
54 | def init_serialization(self, locals):
55 | Serializable.quick_init(self, locals)
56 |
57 | def log_diagnostics(self, paths):
58 | pass
59 |
60 |
61 | def get_asset_xml(xml_name):
62 | return os.path.join(ENV_ASSET_DIR, xml_name)
63 |
--------------------------------------------------------------------------------
/examples/doodad/gcp_example.py:
--------------------------------------------------------------------------------
1 | """
2 | Example of running stuff on GCP
3 | """
4 | import time
5 |
6 | from rlkit.core import logger
7 | from rlkit.launchers.launcher_util import run_experiment
8 | from datetime import datetime
9 | from pytz import timezone
10 | import pytz
11 |
12 |
13 | def example(variant):
14 | import torch
15 | import rlkit.torch.pytorch_util as ptu
16 | print("Starting")
17 | logger.log(torch.__version__)
18 | date_format = '%m/%d/%Y %H:%M:%S %Z'
19 | date = datetime.now(tz=pytz.utc)
20 | logger.log("start")
21 | logger.log('Current date & time is: {}'.format(date.strftime(date_format)))
22 | logger.log("Cuda available: {}".format(torch.cuda.is_available()))
23 | if torch.cuda.is_available():
24 | x = torch.randn(3)
25 | logger.log(str(x.to(ptu.device)))
26 |
27 | date = date.astimezone(timezone('US/Pacific'))
28 | logger.log('Local date & time is: {}'.format(date.strftime(date_format)))
29 | for i in range(variant['num_seconds']):
30 | logger.log("Tick, {}".format(i))
31 | time.sleep(1)
32 | logger.log("end")
33 | logger.log('Local date & time is: {}'.format(date.strftime(date_format)))
34 |
35 | logger.log("start mujoco")
36 | from gym.envs.mujoco import HalfCheetahEnv
37 | e = HalfCheetahEnv()
38 | img = e.sim.render(32, 32)
39 | logger.log(str(sum(img)))
40 | logger.log("end mujoco")
41 |
42 | logger.record_tabular('Epoch', 1)
43 | logger.dump_tabular()
44 | logger.record_tabular('Epoch', 2)
45 | logger.dump_tabular()
46 | logger.record_tabular('Epoch', 3)
47 | logger.dump_tabular()
48 | print("Done")
49 |
50 |
51 | if __name__ == "__main__":
52 | # noinspection PyTypeChecker
53 | date_format = '%m/%d/%Y %H:%M:%S %Z'
54 | date = datetime.now(tz=pytz.utc)
55 | logger.log("start")
56 | variant = dict(
57 | num_seconds=10,
58 | launch_time=str(date.strftime(date_format)),
59 | )
60 | run_experiment(
61 | example,
62 | exp_prefix="gcp-test",
63 | mode='gcp',
64 | variant=variant,
65 | # use_gpu=True, # GPUs are much more expensive!
66 | )
67 |
--------------------------------------------------------------------------------
/rlkit/envs/ant.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | from rlkit.envs.mujoco_env import MujocoEnv
4 |
5 |
6 | class AntEnv(MujocoEnv):
7 | def __init__(self, use_low_gear_ratio=True):
8 | self.init_serialization(locals())
9 | if use_low_gear_ratio:
10 | xml_path = 'low_gear_ratio_ant.xml'
11 | else:
12 | xml_path = 'normal_gear_ratio_ant.xml'
13 | super().__init__(
14 | xml_path,
15 | frame_skip=5,
16 | automatically_set_obs_and_action_space=True,
17 | )
18 |
19 | def step(self, a):
20 | torso_xyz_before = self.get_body_com("torso")
21 | self.do_simulation(a, self.frame_skip)
22 | torso_xyz_after = self.get_body_com("torso")
23 | torso_velocity = torso_xyz_after - torso_xyz_before
24 | forward_reward = torso_velocity[0]/self.dt
25 | ctrl_cost = .5 * np.square(a).sum()
26 | contact_cost = 0.5 * 1e-3 * np.sum(
27 | np.square(np.clip(self.sim.data.cfrc_ext, -1, 1)))
28 | survive_reward = 1.0
29 | reward = forward_reward - ctrl_cost - contact_cost + survive_reward
30 | state = self.state_vector()
31 | notdone = np.isfinite(state).all() \
32 | and state[2] >= 0.2 and state[2] <= 1.0
33 | done = not notdone
34 | ob = self._get_obs()
35 | return ob, reward, done, dict(
36 | reward_forward=forward_reward,
37 | reward_ctrl=-ctrl_cost,
38 | reward_contact=-contact_cost,
39 | reward_survive=survive_reward,
40 | torso_velocity=torso_velocity,
41 | )
42 |
43 | def _get_obs(self):
44 | return np.concatenate([
45 | self.sim.data.qpos.flat[2:],
46 | self.sim.data.qvel.flat,
47 | ])
48 |
49 | def reset_model(self):
50 | qpos = self.init_qpos + self.np_random.uniform(size=self.model.nq, low=-.1, high=.1)
51 | qvel = self.init_qvel + self.np_random.randn(self.model.nv) * .1
52 | self.set_state(qpos, qvel)
53 | return self._get_obs()
54 |
55 | def viewer_setup(self):
56 | self.viewer.cam.distance = self.model.stat.extent * 0.5
57 |
--------------------------------------------------------------------------------
/rlkit/util/ml_util.py:
--------------------------------------------------------------------------------
1 | """
2 | General utility functions for machine learning.
3 | """
4 | import abc
5 | import math
6 | import numpy as np
7 |
8 |
9 | class ScalarSchedule(object, metaclass=abc.ABCMeta):
10 | @abc.abstractmethod
11 | def get_value(self, t):
12 | pass
13 |
14 |
15 | class ConstantSchedule(ScalarSchedule):
16 | def __init__(self, value):
17 | self._value = value
18 |
19 | def get_value(self, t):
20 | return self._value
21 |
22 |
23 | class LinearSchedule(ScalarSchedule):
24 | """
25 | Linearly interpolate and then stop at a final value.
26 | """
27 | def __init__(
28 | self,
29 | init_value,
30 | final_value,
31 | ramp_duration,
32 | ):
33 | self._init_value = init_value
34 | self._final_value = final_value
35 | self._ramp_duration = ramp_duration
36 |
37 | def get_value(self, t):
38 | return (
39 | self._init_value
40 | + (self._final_value - self._init_value)
41 | * min(1.0, t * 1.0 / self._ramp_duration)
42 | )
43 |
44 |
45 | class IntLinearSchedule(LinearSchedule):
46 | """
47 | Same as RampUpSchedule but round output to an int
48 | """
49 | def get_value(self, t):
50 | return int(super().get_value(t))
51 |
52 |
53 | class PiecewiseLinearSchedule(ScalarSchedule):
54 | """
55 | Given a list of (x, t) value-time pairs, return value x at time t,
56 | and linearly interpolate between the two
57 | """
58 | def __init__(
59 | self,
60 | x_values,
61 | y_values,
62 | ):
63 | self._x_values = x_values
64 | self._y_values = y_values
65 |
66 | def get_value(self, t):
67 | return np.interp(t, self._x_values, self._y_values)
68 |
69 |
70 | class IntPiecewiseLinearSchedule(PiecewiseLinearSchedule):
71 | def get_value(self, t):
72 | return int(super().get_value(t))
73 |
74 |
75 | def none_to_infty(bounds):
76 | if bounds is None:
77 | bounds = -math.inf, math.inf
78 | lb, ub = bounds
79 | if lb is None:
80 | lb = -math.inf
81 | if ub is None:
82 | ub = math.inf
83 | return lb, ub
84 |
--------------------------------------------------------------------------------
/scripts/run_goal_conditioned_policy.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import pickle
3 |
4 | from rlkit.core import logger
5 | from rlkit.samplers.rollout_functions import multitask_rollout
6 | from rlkit.torch import pytorch_util as ptu
7 | from rlkit.envs.vae_wrapper import VAEWrappedEnv
8 |
9 |
10 | def simulate_policy(args):
11 | data = pickle.load(open(args.file, "rb"))
12 | policy = data['evaluation/policy']
13 | env = data['evaluation/env']
14 | print("Policy and environment loaded")
15 | if args.gpu:
16 | ptu.set_gpu_mode(True)
17 | policy.to(ptu.device)
18 | if isinstance(env, VAEWrappedEnv) and hasattr(env, 'mode'):
19 | env.mode(args.mode)
20 | if args.enable_render or hasattr(env, 'enable_render'):
21 | # some environments need to be reconfigured for visualization
22 | env.enable_render()
23 | paths = []
24 | while True:
25 | paths.append(multitask_rollout(
26 | env,
27 | policy,
28 | max_path_length=args.H,
29 | render=not args.hide,
30 | observation_key='observation',
31 | desired_goal_key='desired_goal',
32 | ))
33 | if hasattr(env, "log_diagnostics"):
34 | env.log_diagnostics(paths)
35 | if hasattr(env, "get_diagnostics"):
36 | for k, v in env.get_diagnostics(paths).items():
37 | logger.record_tabular(k, v)
38 | logger.dump_tabular()
39 |
40 |
41 | if __name__ == "__main__":
42 |
43 | parser = argparse.ArgumentParser()
44 | parser.add_argument('file', type=str,
45 | help='path to the snapshot file')
46 | parser.add_argument('--H', type=int, default=300,
47 | help='Max length of rollout')
48 | parser.add_argument('--speedup', type=float, default=10,
49 | help='Speedup')
50 | parser.add_argument('--mode', default='video_env', type=str,
51 | help='env mode')
52 | parser.add_argument('--gpu', action='store_true')
53 | parser.add_argument('--enable_render', action='store_true')
54 | parser.add_argument('--hide', action='store_true')
55 | args = parser.parse_args()
56 |
57 | simulate_policy(args)
58 |
--------------------------------------------------------------------------------
/rlkit/core/serializable.py:
--------------------------------------------------------------------------------
1 | """
2 | Based on rllab's serializable.py file
3 |
4 | https://github.com/rll/rllab
5 | """
6 |
7 | import inspect
8 | import sys
9 |
10 |
11 | class Serializable(object):
12 |
13 | def __init__(self, *args, **kwargs):
14 | self.__args = args
15 | self.__kwargs = kwargs
16 |
17 | def quick_init(self, locals_):
18 | if getattr(self, "_serializable_initialized", False):
19 | return
20 | if sys.version_info >= (3, 0):
21 | spec = inspect.getfullargspec(self.__init__)
22 | # Exclude the first "self" parameter
23 | if spec.varkw:
24 | kwargs = locals_[spec.varkw].copy()
25 | else:
26 | kwargs = dict()
27 | if spec.kwonlyargs:
28 | for key in spec.kwonlyargs:
29 | kwargs[key] = locals_[key]
30 | else:
31 | spec = inspect.getargspec(self.__init__)
32 | if spec.keywords:
33 | kwargs = locals_[spec.keywords]
34 | else:
35 | kwargs = dict()
36 | if spec.varargs:
37 | varargs = locals_[spec.varargs]
38 | else:
39 | varargs = tuple()
40 | in_order_args = [locals_[arg] for arg in spec.args][1:]
41 | self.__args = tuple(in_order_args) + varargs
42 | self.__kwargs = kwargs
43 | setattr(self, "_serializable_initialized", True)
44 |
45 | def __getstate__(self):
46 | return {"__args": self.__args, "__kwargs": self.__kwargs}
47 |
48 | def __setstate__(self, d):
49 | # convert all __args to keyword-based arguments
50 | if sys.version_info >= (3, 0):
51 | spec = inspect.getfullargspec(self.__init__)
52 | else:
53 | spec = inspect.getargspec(self.__init__)
54 | in_order_args = spec.args[1:]
55 | out = type(self)(**dict(zip(in_order_args, d["__args"]), **d["__kwargs"]))
56 | self.__dict__.update(out.__dict__)
57 |
58 | @classmethod
59 | def clone(cls, obj, **kwargs):
60 | assert isinstance(obj, Serializable)
61 | d = obj.__getstate__()
62 | d["__kwargs"] = dict(d["__kwargs"], **kwargs)
63 | out = type(obj).__new__(type(obj))
64 | out.__setstate__(d)
65 | return out
66 |
--------------------------------------------------------------------------------
/docs/HER.md:
--------------------------------------------------------------------------------
1 | # Hindsight Experience Replay
2 | Some notes on the implementation of
3 | [Hindsight Experience Replay](https://arxiv.org/abs/1707.01495).
4 | ## Expected Results
5 | If you run the [Fetch example](../examples/her/her_td3_gym_fetch_reach.py), then
6 | you should get results like this:
7 | 
8 |
9 | If you run the [GridWorld example](../examples/her/her_dqn_gridworld.py)
10 | , then you should get results like this:
11 | 
12 |
13 | Note that these examples use HER combined with DQN and SAC, and not DDPG.
14 |
15 | These plots are generated using [viskit](https://github.com/vitchyr/viskit).
16 |
17 | ## Goal-based environments and `ObsDictRelabelingBuffer`
18 | [See here.](goal_based_envs.md)
19 |
20 | ## Implementation Difference
21 | This HER implemention is slightly different from the one presented in the paper.
22 | Rather than relabeling goals when saving data to the replay buffer, the goals
23 | are relabeled when sampling from the replay buffer.
24 |
25 |
26 | In other words, HER in the paper does this:
27 |
28 | Data collection
29 | 1. Sample $(s, a, r, s', g) ~ \text\{ENV}$.
30 | 2. Save $(s, a, r, s', g)$ into replay buffer $\mathcal B$.
31 | For i = 1, ..., K:
32 | Sample $g_i$ using the future strategy.
33 | Recompute rewards $r_i = f(s', g_i)$.
34 | Save $(s, a, r_i, s', g_)$ into replay buffer $\mathcal B$.
35 | Train time
36 | 1. Sample $(s, a, r, s', g)$ from replay buffer
37 | 2. Train Q function $(s, a, r, s', g)$
38 |
39 | The implementation here does:
40 |
41 | Data collection
42 | 1. Sample $(s, a, r, s', g) ~ \text\{ENV}$.
43 | 2. Save $(s, a, r, s', g)$ into replay buffer $\mathcal B$.
44 | Train time
45 | 1. Sample $(s, a, r, s', g)$ from replay buffer
46 | 2a. With probability 1/(K+1):
47 | Train Q function $(s, a, r, s', g)$
48 | 2b. With probability 1 - 1/(K+1):
49 | Sample $g'$ using the future strategy.
50 | Recompute rewards $r' = f(s', g')$.
51 | Train Q function on $(s, a, r', s', g')$
52 |
53 | Both implementations effective do the same thing: with probability 1/(K+1),
54 | you train the policy on the goal used during rollout. Otherwise, train the
55 | policy on a resampled goal.
56 |
57 |
--------------------------------------------------------------------------------
/rlkit/torch/data.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | from torch.utils.data import Dataset, Sampler
4 |
5 | # TODO: move this to more reasonable place
6 | from rlkit.data_management.obs_dict_replay_buffer import normalize_image
7 |
8 |
9 | class ImageDataset(Dataset):
10 |
11 | def __init__(self, images, should_normalize=True):
12 | super().__init__()
13 | self.dataset = images
14 | self.dataset_len = len(self.dataset)
15 | assert should_normalize == (images.dtype == np.uint8)
16 | self.should_normalize = should_normalize
17 |
18 | def __len__(self):
19 | return self.dataset_len
20 |
21 | def __getitem__(self, idxs):
22 | samples = self.dataset[idxs, :]
23 | if self.should_normalize:
24 | samples = normalize_image(samples)
25 | return np.float32(samples)
26 |
27 |
28 | class InfiniteRandomSampler(Sampler):
29 |
30 | def __init__(self, data_source):
31 | self.data_source = data_source
32 | self.iter = iter(torch.randperm(len(self.data_source)).tolist())
33 |
34 | def __iter__(self):
35 | return self
36 |
37 | def __next__(self):
38 | try:
39 | idx = next(self.iter)
40 | except StopIteration:
41 | self.iter = iter(torch.randperm(len(self.data_source)).tolist())
42 | idx = next(self.iter)
43 | return idx
44 |
45 | def __len__(self):
46 | return 2 ** 62
47 |
48 |
49 | class InfiniteWeightedRandomSampler(Sampler):
50 |
51 | def __init__(self, data_source, weights):
52 | assert len(data_source) == len(weights)
53 | assert len(weights.shape) == 1
54 | self.data_source = data_source
55 | # Always use CPU
56 | self._weights = torch.from_numpy(weights)
57 | self.iter = self._create_iterator()
58 |
59 | def update_weights(self, weights):
60 | self._weights = weights
61 | self.iter = self._create_iterator()
62 |
63 | def _create_iterator(self):
64 | return iter(
65 | torch.multinomial(
66 | self._weights, len(self._weights), replacement=True
67 | ).tolist()
68 | )
69 |
70 | def __iter__(self):
71 | return self
72 |
73 | def __next__(self):
74 | try:
75 | idx = next(self.iter)
76 | except StopIteration:
77 | self.iter = self._create_iterator()
78 | idx = next(self.iter)
79 | return idx
80 |
81 | def __len__(self):
82 | return 2 ** 62
83 |
--------------------------------------------------------------------------------
/rlkit/torch/distributions.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.distributions import Distribution, Normal
3 | import rlkit.torch.pytorch_util as ptu
4 |
5 |
6 | class TanhNormal(Distribution):
7 | """
8 | Represent distribution of X where
9 | X ~ tanh(Z)
10 | Z ~ N(mean, std)
11 |
12 | Note: this is not very numerically stable.
13 | """
14 | def __init__(self, normal_mean, normal_std, epsilon=1e-6):
15 | """
16 | :param normal_mean: Mean of the normal distribution
17 | :param normal_std: Std of the normal distribution
18 | :param epsilon: Numerical stability epsilon when computing log-prob.
19 | """
20 | self.normal_mean = normal_mean
21 | self.normal_std = normal_std
22 | self.normal = Normal(normal_mean, normal_std)
23 | self.epsilon = epsilon
24 |
25 | def sample_n(self, n, return_pre_tanh_value=False):
26 | z = self.normal.sample_n(n)
27 | if return_pre_tanh_value:
28 | return torch.tanh(z), z
29 | else:
30 | return torch.tanh(z)
31 |
32 | def log_prob(self, value, pre_tanh_value=None):
33 | """
34 |
35 | :param value: some value, x
36 | :param pre_tanh_value: arctanh(x)
37 | :return:
38 | """
39 | if pre_tanh_value is None:
40 | pre_tanh_value = torch.log(
41 | (1+value) / (1-value)
42 | ) / 2
43 | return self.normal.log_prob(pre_tanh_value) - torch.log(
44 | 1 - value * value + self.epsilon
45 | )
46 |
47 | def sample(self, return_pretanh_value=False):
48 | """
49 | Gradients will and should *not* pass through this operation.
50 |
51 | See https://github.com/pytorch/pytorch/issues/4620 for discussion.
52 | """
53 | z = self.normal.sample().detach()
54 |
55 | if return_pretanh_value:
56 | return torch.tanh(z), z
57 | else:
58 | return torch.tanh(z)
59 |
60 | def rsample(self, return_pretanh_value=False):
61 | """
62 | Sampling in the reparameterization case.
63 | """
64 | z = (
65 | self.normal_mean +
66 | self.normal_std *
67 | Normal(
68 | ptu.zeros(self.normal_mean.size()),
69 | ptu.ones(self.normal_std.size())
70 | ).sample()
71 | )
72 | z.requires_grad_()
73 |
74 | if return_pretanh_value:
75 | return torch.tanh(z), z
76 | else:
77 | return torch.tanh(z)
78 |
--------------------------------------------------------------------------------
/rlkit/torch/data_management/normalizer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import rlkit.torch.pytorch_util as ptu
3 | import numpy as np
4 |
5 | from rlkit.data_management.normalizer import Normalizer, FixedNormalizer
6 |
7 |
8 | class TorchNormalizer(Normalizer):
9 | """
10 | Update with np array, but de/normalize pytorch Tensors.
11 | """
12 | def normalize(self, v, clip_range=None):
13 | if not self.synchronized:
14 | self.synchronize()
15 | if clip_range is None:
16 | clip_range = self.default_clip_range
17 | mean = ptu.from_numpy(self.mean)
18 | std = ptu.from_numpy(self.std)
19 | if v.dim() == 2:
20 | # Unsqueeze along the batch use automatic broadcasting
21 | mean = mean.unsqueeze(0)
22 | std = std.unsqueeze(0)
23 | return torch.clamp((v - mean) / std, -clip_range, clip_range)
24 |
25 | def denormalize(self, v):
26 | if not self.synchronized:
27 | self.synchronize()
28 | mean = ptu.from_numpy(self.mean)
29 | std = ptu.from_numpy(self.std)
30 | if v.dim() == 2:
31 | mean = mean.unsqueeze(0)
32 | std = std.unsqueeze(0)
33 | return mean + v * std
34 |
35 |
36 | class TorchFixedNormalizer(FixedNormalizer):
37 | def normalize(self, v, clip_range=None):
38 | if clip_range is None:
39 | clip_range = self.default_clip_range
40 | mean = ptu.from_numpy(self.mean)
41 | std = ptu.from_numpy(self.std)
42 | if v.dim() == 2:
43 | # Unsqueeze along the batch use automatic broadcasting
44 | mean = mean.unsqueeze(0)
45 | std = std.unsqueeze(0)
46 | return torch.clamp((v - mean) / std, -clip_range, clip_range)
47 |
48 | def normalize_scale(self, v):
49 | """
50 | Only normalize the scale. Do not subtract the mean.
51 | """
52 | std = ptu.from_numpy(self.std)
53 | if v.dim() == 2:
54 | std = std.unsqueeze(0)
55 | return v / std
56 |
57 | def denormalize(self, v):
58 | mean = ptu.from_numpy(self.mean)
59 | std = ptu.from_numpy(self.std)
60 | if v.dim() == 2:
61 | mean = mean.unsqueeze(0)
62 | std = std.unsqueeze(0)
63 | return mean + v * std
64 |
65 | def denormalize_scale(self, v):
66 | """
67 | Only denormalize the scale. Do not add the mean.
68 | """
69 | std = ptu.from_numpy(self.std)
70 | if v.dim() == 2:
71 | std = std.unsqueeze(0)
72 | return v * std
73 |
--------------------------------------------------------------------------------
/scripts/run_experiment_from_doodad.py:
--------------------------------------------------------------------------------
1 | import doodad as dd
2 | import torch.multiprocessing as mp
3 |
4 | from rlkit.launchers.launcher_util import run_experiment_here
5 |
6 | if __name__ == "__main__":
7 | import matplotlib
8 | matplotlib.use('agg')
9 |
10 | mp.set_start_method('forkserver')
11 | args_dict = dd.get_args()
12 | method_call = args_dict['method_call']
13 | run_experiment_kwargs = args_dict['run_experiment_kwargs']
14 | output_dir = args_dict['output_dir']
15 | run_mode = args_dict.get('mode', None)
16 | if run_mode and run_mode in ['slurm_singularity', 'sss']:
17 | import os
18 | run_experiment_kwargs['variant']['slurm-job-id'] = os.environ.get(
19 | 'SLURM_JOB_ID', None
20 | )
21 | if run_mode and (run_mode == 'ec2' or run_mode == 'gcp'):
22 | if run_mode == 'ec2':
23 | try:
24 | import urllib.request
25 | instance_id = urllib.request.urlopen(
26 | 'http://169.254.169.254/latest/meta-data/instance-id'
27 | ).read().decode()
28 | run_experiment_kwargs['variant']['EC2_instance_id'] = instance_id
29 | except Exception as e:
30 | print("Could not get AWS instance ID. Error was...")
31 | print(e)
32 | if run_mode == 'gcp':
33 | try:
34 | import urllib.request
35 | request = urllib.request.Request(
36 | "http://metadata/computeMetadata/v1/instance/name",
37 | )
38 | # See this URL for why we need this header:
39 | # https://cloud.google.com/compute/docs/storing-retrieving-metadata
40 | request.add_header("Metadata-Flavor", "Google")
41 | instance_name = urllib.request.urlopen(request).read().decode()
42 | run_experiment_kwargs['variant']['GCP_instance_name'] = (
43 | instance_name
44 | )
45 | except Exception as e:
46 | print("Could not get GCP instance name. Error was...")
47 | print(e)
48 | # Do this in case base_log_dir was already set
49 | run_experiment_kwargs['base_log_dir'] = output_dir
50 | run_experiment_here(
51 | method_call,
52 | include_exp_prefix_sub_dir=False,
53 | **run_experiment_kwargs
54 | )
55 | else:
56 | run_experiment_here(
57 | method_call,
58 | log_dir=output_dir,
59 | **run_experiment_kwargs
60 | )
61 |
--------------------------------------------------------------------------------
/rlkit/data_management/replay_buffer.py:
--------------------------------------------------------------------------------
1 | import abc
2 |
3 |
4 | class ReplayBuffer(object, metaclass=abc.ABCMeta):
5 | """
6 | A class used to save and replay data.
7 | """
8 |
9 | @abc.abstractmethod
10 | def add_sample(self, observation, action, reward, next_observation,
11 | terminal, **kwargs):
12 | """
13 | Add a transition tuple.
14 | """
15 | pass
16 |
17 | @abc.abstractmethod
18 | def terminate_episode(self):
19 | """
20 | Let the replay buffer know that the episode has terminated in case some
21 | special book-keeping has to happen.
22 | :return:
23 | """
24 | pass
25 |
26 | @abc.abstractmethod
27 | def num_steps_can_sample(self, **kwargs):
28 | """
29 | :return: # of unique items that can be sampled.
30 | """
31 | pass
32 |
33 | def add_path(self, path):
34 | """
35 | Add a path to the replay buffer.
36 |
37 | This default implementation naively goes through every step, but you
38 | may want to optimize this.
39 |
40 | NOTE: You should NOT call "terminate_episode" after calling add_path.
41 | It's assumed that this function handles the episode termination.
42 |
43 | :param path: Dict like one outputted by rlkit.samplers.util.rollout
44 | """
45 | for i, (
46 | obs,
47 | action,
48 | reward,
49 | next_obs,
50 | terminal,
51 | agent_info,
52 | env_info
53 | ) in enumerate(zip(
54 | path["observations"],
55 | path["actions"],
56 | path["rewards"],
57 | path["next_observations"],
58 | path["terminals"],
59 | path["agent_infos"],
60 | path["env_infos"],
61 | )):
62 | self.add_sample(
63 | observation=obs,
64 | action=action,
65 | reward=reward,
66 | next_observation=next_obs,
67 | terminal=terminal,
68 | agent_info=agent_info,
69 | env_info=env_info,
70 | )
71 | self.terminate_episode()
72 |
73 | def add_paths(self, paths):
74 | for path in paths:
75 | self.add_path(path)
76 |
77 | @abc.abstractmethod
78 | def random_batch(self, batch_size):
79 | """
80 | Return a batch of size `batch_size`.
81 | :param batch_size:
82 | :return:
83 | """
84 | pass
85 |
86 | def get_diagnostics(self):
87 | return {}
88 |
89 | def get_snapshot(self):
90 | return {}
91 |
92 | def end_epoch(self, epoch):
93 | return
94 |
95 |
--------------------------------------------------------------------------------
/rlkit/util/io.py:
--------------------------------------------------------------------------------
1 | import joblib
2 | import numpy as np
3 | import pickle
4 |
5 | import boto3
6 |
7 | from rlkit.launchers.conf import LOCAL_LOG_DIR, AWS_S3_PATH
8 | import os
9 |
10 | PICKLE = 'pickle'
11 | NUMPY = 'numpy'
12 | JOBLIB = 'joblib'
13 |
14 |
15 | def local_path_from_s3_or_local_path(filename):
16 | relative_filename = os.path.join(LOCAL_LOG_DIR, filename)
17 | if os.path.isfile(filename):
18 | return filename
19 | elif os.path.isfile(relative_filename):
20 | return relative_filename
21 | else:
22 | return sync_down(filename)
23 |
24 |
25 | def sync_down(path, check_exists=True):
26 | is_docker = os.path.isfile("/.dockerenv")
27 | if is_docker:
28 | local_path = "/tmp/%s" % (path)
29 | else:
30 | local_path = "%s/%s" % (LOCAL_LOG_DIR, path)
31 |
32 | if check_exists and os.path.isfile(local_path):
33 | return local_path
34 |
35 | local_dir = os.path.dirname(local_path)
36 | os.makedirs(local_dir, exist_ok=True)
37 |
38 | if is_docker:
39 | from doodad.ec2.autoconfig import AUTOCONFIG
40 | os.environ["AWS_ACCESS_KEY_ID"] = AUTOCONFIG.aws_access_key()
41 | os.environ["AWS_SECRET_ACCESS_KEY"] = AUTOCONFIG.aws_access_secret()
42 |
43 | full_s3_path = os.path.join(AWS_S3_PATH, path)
44 | bucket_name, bucket_relative_path = split_s3_full_path(full_s3_path)
45 | try:
46 | bucket = boto3.resource('s3').Bucket(bucket_name)
47 | bucket.download_file(bucket_relative_path, local_path)
48 | except Exception as e:
49 | local_path = None
50 | print("Failed to sync! path: ", path)
51 | print("Exception: ", e)
52 | return local_path
53 |
54 |
55 | def split_s3_full_path(s3_path):
56 | """
57 | Split "s3://foo/bar/baz" into "foo" and "bar/baz"
58 | """
59 | bucket_name_and_directories = s3_path.split('//')[1]
60 | bucket_name, *directories = bucket_name_and_directories.split('/')
61 | directory_path = '/'.join(directories)
62 | return bucket_name, directory_path
63 |
64 |
65 | def load_local_or_remote_file(filepath, file_type=None):
66 | local_path = local_path_from_s3_or_local_path(filepath)
67 | if file_type is None:
68 | extension = local_path.split('.')[-1]
69 | if extension == 'npy':
70 | file_type = NUMPY
71 | else:
72 | file_type = PICKLE
73 | else:
74 | file_type = PICKLE
75 | if file_type == NUMPY:
76 | object = np.load(open(local_path, "rb"))
77 | elif file_type == JOBLIB:
78 | object = joblib.load(local_path)
79 | else:
80 | object = pickle.load(open(local_path, "rb"))
81 | print("loaded", local_path)
82 | return object
83 |
84 |
85 | if __name__ == "__main__":
86 | p = sync_down("ashvin/vae/new-point2d/run0/id1/params.pkl")
87 | print("got", p)
--------------------------------------------------------------------------------
/examples/dqn_and_double_dqn.py:
--------------------------------------------------------------------------------
1 | """
2 | Run DQN on grid world.
3 | """
4 |
5 | import gym
6 | from torch import nn as nn
7 |
8 | from rlkit.exploration_strategies.base import \
9 | PolicyWrappedWithExplorationStrategy
10 | from rlkit.exploration_strategies.epsilon_greedy import EpsilonGreedy
11 | from rlkit.policies.argmax import ArgmaxDiscretePolicy
12 | from rlkit.torch.dqn.dqn import DQNTrainer
13 | from rlkit.torch.networks import Mlp
14 | import rlkit.torch.pytorch_util as ptu
15 | from rlkit.data_management.env_replay_buffer import EnvReplayBuffer
16 | from rlkit.launchers.launcher_util import setup_logger
17 | from rlkit.samplers.data_collector import MdpPathCollector
18 | from rlkit.torch.torch_rl_algorithm import TorchBatchRLAlgorithm
19 |
20 |
21 | def experiment(variant):
22 | expl_env = gym.make('CartPole-v0')
23 | eval_env = gym.make('CartPole-v0')
24 | obs_dim = expl_env.observation_space.low.size
25 | action_dim = eval_env.action_space.n
26 |
27 | qf = Mlp(
28 | hidden_sizes=[32, 32],
29 | input_size=obs_dim,
30 | output_size=action_dim,
31 | )
32 | target_qf = Mlp(
33 | hidden_sizes=[32, 32],
34 | input_size=obs_dim,
35 | output_size=action_dim,
36 | )
37 | qf_criterion = nn.MSELoss()
38 | eval_policy = ArgmaxDiscretePolicy(qf)
39 | expl_policy = PolicyWrappedWithExplorationStrategy(
40 | EpsilonGreedy(expl_env.action_space),
41 | eval_policy,
42 | )
43 | eval_path_collector = MdpPathCollector(
44 | eval_env,
45 | eval_policy,
46 | )
47 | expl_path_collector = MdpPathCollector(
48 | expl_env,
49 | expl_policy,
50 | )
51 | trainer = DQNTrainer(
52 | qf=qf,
53 | target_qf=target_qf,
54 | qf_criterion=qf_criterion,
55 | **variant['trainer_kwargs']
56 | )
57 | replay_buffer = EnvReplayBuffer(
58 | variant['replay_buffer_size'],
59 | expl_env,
60 | )
61 | algorithm = TorchBatchRLAlgorithm(
62 | trainer=trainer,
63 | exploration_env=expl_env,
64 | evaluation_env=eval_env,
65 | exploration_data_collector=expl_path_collector,
66 | evaluation_data_collector=eval_path_collector,
67 | replay_buffer=replay_buffer,
68 | **variant['algorithm_kwargs']
69 | )
70 | algorithm.to(ptu.device)
71 | algorithm.train()
72 |
73 |
74 |
75 |
76 | if __name__ == "__main__":
77 | # noinspection PyTypeChecker
78 | variant = dict(
79 | algorithm="SAC",
80 | version="normal",
81 | layer_size=256,
82 | replay_buffer_size=int(1E6),
83 | algorithm_kwargs=dict(
84 | num_epochs=3000,
85 | num_eval_steps_per_epoch=5000,
86 | num_trains_per_train_loop=1000,
87 | num_expl_steps_per_train_loop=1000,
88 | min_num_steps_before_training=1000,
89 | max_path_length=1000,
90 | batch_size=256,
91 | ),
92 | trainer_kwargs=dict(
93 | discount=0.99,
94 | learning_rate=3E-4,
95 | ),
96 | )
97 | setup_logger('name-of-experiment', variant=variant)
98 | # ptu.set_gpu_mode(True) # optionally set the GPU (default=False)
99 | experiment(variant)
100 |
--------------------------------------------------------------------------------
/rlkit/core/batch_rl_algorithm.py:
--------------------------------------------------------------------------------
1 | import abc
2 |
3 | import gtimer as gt
4 | from rlkit.core.rl_algorithm import BaseRLAlgorithm
5 | from rlkit.data_management.replay_buffer import ReplayBuffer
6 | from rlkit.samplers.data_collector import PathCollector
7 |
8 |
9 | class BatchRLAlgorithm(BaseRLAlgorithm, metaclass=abc.ABCMeta):
10 | def __init__(
11 | self,
12 | trainer,
13 | exploration_env,
14 | evaluation_env,
15 | exploration_data_collector: PathCollector,
16 | evaluation_data_collector: PathCollector,
17 | replay_buffer: ReplayBuffer,
18 | batch_size,
19 | max_path_length,
20 | num_epochs,
21 | num_eval_steps_per_epoch,
22 | num_expl_steps_per_train_loop,
23 | num_trains_per_train_loop,
24 | num_train_loops_per_epoch=1,
25 | min_num_steps_before_training=0,
26 | ):
27 | super().__init__(
28 | trainer,
29 | exploration_env,
30 | evaluation_env,
31 | exploration_data_collector,
32 | evaluation_data_collector,
33 | replay_buffer,
34 | )
35 | self.batch_size = batch_size
36 | self.max_path_length = max_path_length
37 | self.num_epochs = num_epochs
38 | self.num_eval_steps_per_epoch = num_eval_steps_per_epoch
39 | self.num_trains_per_train_loop = num_trains_per_train_loop
40 | self.num_train_loops_per_epoch = num_train_loops_per_epoch
41 | self.num_expl_steps_per_train_loop = num_expl_steps_per_train_loop
42 | self.min_num_steps_before_training = min_num_steps_before_training
43 |
44 | def _train(self):
45 | if self.min_num_steps_before_training > 0:
46 | init_expl_paths = self.expl_data_collector.collect_new_paths(
47 | self.max_path_length,
48 | self.min_num_steps_before_training,
49 | discard_incomplete_paths=False,
50 | )
51 | self.replay_buffer.add_paths(init_expl_paths)
52 | self.expl_data_collector.end_epoch(-1)
53 |
54 | for epoch in gt.timed_for(
55 | range(self._start_epoch, self.num_epochs),
56 | save_itrs=True,
57 | ):
58 | self.eval_data_collector.collect_new_paths(
59 | self.max_path_length,
60 | self.num_eval_steps_per_epoch,
61 | discard_incomplete_paths=True,
62 | )
63 | gt.stamp('evaluation sampling')
64 |
65 | for _ in range(self.num_train_loops_per_epoch):
66 | new_expl_paths = self.expl_data_collector.collect_new_paths(
67 | self.max_path_length,
68 | self.num_expl_steps_per_train_loop,
69 | discard_incomplete_paths=False,
70 | )
71 | gt.stamp('exploration sampling', unique=False)
72 |
73 | self.replay_buffer.add_paths(new_expl_paths)
74 | gt.stamp('data storing', unique=False)
75 |
76 | self.training_mode(True)
77 | for _ in range(self.num_trains_per_train_loop):
78 | train_data = self.replay_buffer.random_batch(
79 | self.batch_size)
80 | self.trainer.train(train_data)
81 | gt.stamp('training', unique=False)
82 | self.training_mode(False)
83 |
84 | self._end_epoch(epoch)
85 |
--------------------------------------------------------------------------------
/rlkit/torch/dqn/dqn.py:
--------------------------------------------------------------------------------
1 | from collections import OrderedDict
2 |
3 | import numpy as np
4 | import torch
5 | import torch.optim as optim
6 | from torch import nn as nn
7 |
8 | import rlkit.torch.pytorch_util as ptu
9 | from rlkit.core.eval_util import create_stats_ordered_dict
10 | from rlkit.torch.torch_rl_algorithm import TorchTrainer
11 |
12 |
13 | class DQNTrainer(TorchTrainer):
14 | def __init__(
15 | self,
16 | qf,
17 | target_qf,
18 | learning_rate=1e-3,
19 | soft_target_tau=1e-3,
20 | target_update_period=1,
21 | qf_criterion=None,
22 |
23 | discount=0.99,
24 | reward_scale=1.0,
25 | ):
26 | super().__init__()
27 | self.qf = qf
28 | self.target_qf = target_qf
29 | self.learning_rate = learning_rate
30 | self.soft_target_tau = soft_target_tau
31 | self.target_update_period = target_update_period
32 | self.qf_optimizer = optim.Adam(
33 | self.qf.parameters(),
34 | lr=self.learning_rate,
35 | )
36 | self.discount = discount
37 | self.reward_scale = reward_scale
38 | self.qf_criterion = qf_criterion or nn.MSELoss()
39 | self.eval_statistics = OrderedDict()
40 | self._n_train_steps_total = 0
41 | self._need_to_update_eval_statistics = True
42 |
43 | def train_from_torch(self, batch):
44 | rewards = batch['rewards'] * self.reward_scale
45 | terminals = batch['terminals']
46 | obs = batch['observations']
47 | actions = batch['actions']
48 | next_obs = batch['next_observations']
49 |
50 | """
51 | Compute loss
52 | """
53 |
54 | target_q_values = self.target_qf(next_obs).detach().max(
55 | 1, keepdim=True
56 | )[0]
57 | y_target = rewards + (1. - terminals) * self.discount * target_q_values
58 | y_target = y_target.detach()
59 | # actions is a one-hot vector
60 | y_pred = torch.sum(self.qf(obs) * actions, dim=1, keepdim=True)
61 | qf_loss = self.qf_criterion(y_pred, y_target)
62 |
63 | """
64 | Soft target network updates
65 | """
66 | self.qf_optimizer.zero_grad()
67 | qf_loss.backward()
68 | self.qf_optimizer.step()
69 |
70 | """
71 | Soft Updates
72 | """
73 | if self._n_train_steps_total % self.target_update_period == 0:
74 | ptu.soft_update_from_to(
75 | self.qf, self.target_qf, self.soft_target_tau
76 | )
77 |
78 | """
79 | Save some statistics for eval using just one batch.
80 | """
81 | if self._need_to_update_eval_statistics:
82 | self._need_to_update_eval_statistics = False
83 | self.eval_statistics['QF Loss'] = np.mean(ptu.get_numpy(qf_loss))
84 | self.eval_statistics.update(create_stats_ordered_dict(
85 | 'Y Predictions',
86 | ptu.get_numpy(y_pred),
87 | ))
88 |
89 | def get_diagnostics(self):
90 | return self.eval_statistics
91 |
92 | def end_epoch(self, epoch):
93 | self._need_to_update_eval_statistics = True
94 |
95 | @property
96 | def networks(self):
97 | return [
98 | self.qf,
99 | self.target_qf,
100 | ]
101 |
102 | def get_snapshot(self):
103 | return dict(
104 | qf=self.qf,
105 | target_qf=self.target_qf,
106 | )
107 |
--------------------------------------------------------------------------------
/examples/ddpg.py:
--------------------------------------------------------------------------------
1 | """
2 | Example of running PyTorch implementation of DDPG on HalfCheetah.
3 | """
4 | import copy
5 |
6 | from gym.envs.mujoco import HalfCheetahEnv
7 |
8 | from rlkit.data_management.env_replay_buffer import EnvReplayBuffer
9 | from rlkit.envs.wrappers import NormalizedBoxEnv
10 | from rlkit.exploration_strategies.base import (
11 | PolicyWrappedWithExplorationStrategy
12 | )
13 | from rlkit.exploration_strategies.ou_strategy import OUStrategy
14 | from rlkit.launchers.launcher_util import setup_logger
15 | from rlkit.samplers.data_collector import MdpPathCollector
16 | from rlkit.torch.networks import FlattenMlp, TanhMlpPolicy
17 | from rlkit.torch.ddpg.ddpg import DDPGTrainer
18 | import rlkit.torch.pytorch_util as ptu
19 | from rlkit.torch.torch_rl_algorithm import TorchBatchRLAlgorithm
20 |
21 |
22 | def experiment(variant):
23 | eval_env = NormalizedBoxEnv(HalfCheetahEnv())
24 | expl_env = NormalizedBoxEnv(HalfCheetahEnv())
25 | # Or for a specific version:
26 | # import gym
27 | # env = NormalizedBoxEnv(gym.make('HalfCheetah-v1'))
28 | obs_dim = eval_env.observation_space.low.size
29 | action_dim = eval_env.action_space.low.size
30 | qf = FlattenMlp(
31 | input_size=obs_dim + action_dim,
32 | output_size=1,
33 | **variant['qf_kwargs']
34 | )
35 | policy = TanhMlpPolicy(
36 | input_size=obs_dim,
37 | output_size=action_dim,
38 | **variant['policy_kwargs']
39 | )
40 | target_qf = copy.deepcopy(qf)
41 | target_policy = copy.deepcopy(policy)
42 | eval_path_collector = MdpPathCollector(eval_env, policy)
43 | exploration_policy = PolicyWrappedWithExplorationStrategy(
44 | exploration_strategy=OUStrategy(action_space=expl_env.action_space),
45 | policy=policy,
46 | )
47 | expl_path_collector = MdpPathCollector(expl_env, exploration_policy)
48 | replay_buffer = EnvReplayBuffer(variant['replay_buffer_size'], expl_env)
49 | trainer = DDPGTrainer(
50 | qf=qf,
51 | target_qf=target_qf,
52 | policy=policy,
53 | target_policy=target_policy,
54 | **variant['trainer_kwargs']
55 | )
56 | algorithm = TorchBatchRLAlgorithm(
57 | trainer=trainer,
58 | exploration_env=expl_env,
59 | evaluation_env=eval_env,
60 | exploration_data_collector=expl_path_collector,
61 | evaluation_data_collector=eval_path_collector,
62 | replay_buffer=replay_buffer,
63 | **variant['algorithm_kwargs']
64 | )
65 | algorithm.to(ptu.device)
66 | algorithm.train()
67 |
68 |
69 | if __name__ == "__main__":
70 | # noinspection PyTypeChecker
71 | variant = dict(
72 | algorithm_kwargs=dict(
73 | num_epochs=1000,
74 | num_eval_steps_per_epoch=1000,
75 | num_trains_per_train_loop=1000,
76 | num_expl_steps_per_train_loop=1000,
77 | min_num_steps_before_training=10000,
78 | max_path_length=1000,
79 | batch_size=128,
80 | ),
81 | trainer_kwargs=dict(
82 | use_soft_update=True,
83 | tau=1e-2,
84 | discount=0.99,
85 | qf_learning_rate=1e-3,
86 | policy_learning_rate=1e-4,
87 | ),
88 | qf_kwargs=dict(
89 | hidden_sizes=[400, 300],
90 | ),
91 | policy_kwargs=dict(
92 | hidden_sizes=[400, 300],
93 | ),
94 | replay_buffer_size=int(1E6),
95 | )
96 | # ptu.set_gpu_mode(True) # optionally set the GPU (default=False)
97 | setup_logger('name-of-experiment', variant=variant)
98 | experiment(variant)
99 |
--------------------------------------------------------------------------------
/rlkit/torch/pytorch_util.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 |
4 |
5 | def soft_update_from_to(source, target, tau):
6 | for target_param, param in zip(target.parameters(), source.parameters()):
7 | target_param.data.copy_(
8 | target_param.data * (1.0 - tau) + param.data * tau
9 | )
10 |
11 |
12 | def copy_model_params_from_to(source, target):
13 | for target_param, param in zip(target.parameters(), source.parameters()):
14 | target_param.data.copy_(param.data)
15 |
16 |
17 | def fanin_init(tensor):
18 | size = tensor.size()
19 | if len(size) == 2:
20 | fan_in = size[0]
21 | elif len(size) > 2:
22 | fan_in = np.prod(size[1:])
23 | else:
24 | raise Exception("Shape must be have dimension at least 2.")
25 | bound = 1. / np.sqrt(fan_in)
26 | return tensor.data.uniform_(-bound, bound)
27 |
28 |
29 | def fanin_init_weights_like(tensor):
30 | size = tensor.size()
31 | if len(size) == 2:
32 | fan_in = size[0]
33 | elif len(size) > 2:
34 | fan_in = np.prod(size[1:])
35 | else:
36 | raise Exception("Shape must be have dimension at least 2.")
37 | bound = 1. / np.sqrt(fan_in)
38 | new_tensor = FloatTensor(tensor.size())
39 | new_tensor.uniform_(-bound, bound)
40 | return new_tensor
41 |
42 |
43 | """
44 | GPU wrappers
45 | """
46 |
47 | _use_gpu = False
48 | device = None
49 | _gpu_id = 0
50 |
51 |
52 | def set_gpu_mode(mode, gpu_id=0):
53 | global _use_gpu
54 | global device
55 | global _gpu_id
56 | _gpu_id = gpu_id
57 | _use_gpu = mode
58 | device = torch.device("cuda:" + str(gpu_id) if _use_gpu else "cpu")
59 |
60 |
61 | def gpu_enabled():
62 | return _use_gpu
63 |
64 |
65 | def set_device(gpu_id):
66 | torch.cuda.set_device(gpu_id)
67 |
68 |
69 | # noinspection PyPep8Naming
70 | def FloatTensor(*args, torch_device=None, **kwargs):
71 | if torch_device is None:
72 | torch_device = device
73 | return torch.FloatTensor(*args, **kwargs, device=torch_device)
74 |
75 |
76 | def from_numpy(*args, **kwargs):
77 | return torch.from_numpy(*args, **kwargs).float().to(device)
78 |
79 |
80 | def get_numpy(tensor):
81 | return tensor.to('cpu').detach().numpy()
82 |
83 |
84 | def zeros(*sizes, torch_device=None, **kwargs):
85 | if torch_device is None:
86 | torch_device = device
87 | return torch.zeros(*sizes, **kwargs, device=torch_device)
88 |
89 |
90 | def ones(*sizes, torch_device=None, **kwargs):
91 | if torch_device is None:
92 | torch_device = device
93 | return torch.ones(*sizes, **kwargs, device=torch_device)
94 |
95 |
96 | def ones_like(*args, torch_device=None, **kwargs):
97 | if torch_device is None:
98 | torch_device = device
99 | return torch.ones_like(*args, **kwargs, device=torch_device)
100 |
101 |
102 | def randn(*args, torch_device=None, **kwargs):
103 | if torch_device is None:
104 | torch_device = device
105 | return torch.randn(*args, **kwargs, device=torch_device)
106 |
107 |
108 | def zeros_like(*args, torch_device=None, **kwargs):
109 | if torch_device is None:
110 | torch_device = device
111 | return torch.zeros_like(*args, **kwargs, device=torch_device)
112 |
113 |
114 | def tensor(*args, torch_device=None, **kwargs):
115 | if torch_device is None:
116 | torch_device = device
117 | return torch.tensor(*args, **kwargs, device=torch_device)
118 |
119 |
120 | def normal(*args, **kwargs):
121 | return torch.normal(*args, **kwargs).to(device)
122 |
--------------------------------------------------------------------------------
/rlkit/data_management/simple_replay_buffer.py:
--------------------------------------------------------------------------------
1 | from collections import OrderedDict
2 |
3 | import numpy as np
4 |
5 | from rlkit.data_management.replay_buffer import ReplayBuffer
6 |
7 |
8 | class SimpleReplayBuffer(ReplayBuffer):
9 |
10 | def __init__(
11 | self,
12 | max_replay_buffer_size,
13 | observation_dim,
14 | action_dim,
15 | env_info_sizes,
16 | ):
17 | self._observation_dim = observation_dim
18 | self._action_dim = action_dim
19 | self._max_replay_buffer_size = max_replay_buffer_size
20 | self._observations = np.zeros((max_replay_buffer_size, observation_dim))
21 | # It's a bit memory inefficient to save the observations twice,
22 | # but it makes the code *much* easier since you no longer have to
23 | # worry about termination conditions.
24 | self._next_obs = np.zeros((max_replay_buffer_size, observation_dim))
25 | self._actions = np.zeros((max_replay_buffer_size, action_dim))
26 | # Make everything a 2D np array to make it easier for other code to
27 | # reason about the shape of the data
28 | self._rewards = np.zeros((max_replay_buffer_size, 1))
29 | # self._terminals[i] = a terminal was received at time i
30 | self._terminals = np.zeros((max_replay_buffer_size, 1), dtype='uint8')
31 | # Define self._env_infos[key][i] to be the return value of env_info[key]
32 | # at time i
33 | self._env_infos = {}
34 | for key, size in env_info_sizes.items():
35 | self._env_infos[key] = np.zeros((max_replay_buffer_size, size))
36 | self._env_info_keys = env_info_sizes.keys()
37 |
38 | self._top = 0
39 | self._size = 0
40 |
41 | def add_sample(self, observation, action, reward, next_observation,
42 | terminal, env_info, **kwargs):
43 | self._observations[self._top] = observation
44 | self._actions[self._top] = action
45 | self._rewards[self._top] = reward
46 | self._terminals[self._top] = terminal
47 | self._next_obs[self._top] = next_observation
48 |
49 | for key in self._env_info_keys:
50 | self._env_infos[key][self._top] = env_info[key]
51 | self._advance()
52 |
53 | def terminate_episode(self):
54 | pass
55 |
56 | def _advance(self):
57 | self._top = (self._top + 1) % self._max_replay_buffer_size
58 | if self._size < self._max_replay_buffer_size:
59 | self._size += 1
60 |
61 | def random_batch(self, batch_size):
62 | indices = np.random.randint(0, self._size, batch_size)
63 | batch = dict(
64 | observations=self._observations[indices],
65 | actions=self._actions[indices],
66 | rewards=self._rewards[indices],
67 | terminals=self._terminals[indices],
68 | next_observations=self._next_obs[indices],
69 | )
70 | for key in self._env_info_keys:
71 | assert key not in batch.keys()
72 | batch[key] = self._env_infos[key][indices]
73 | return batch
74 |
75 | def rebuild_env_info_dict(self, idx):
76 | return {
77 | key: self._env_infos[key][idx]
78 | for key in self._env_info_keys
79 | }
80 |
81 | def batch_env_info_dict(self, indices):
82 | return {
83 | key: self._env_infos[key][indices]
84 | for key in self._env_info_keys
85 | }
86 |
87 | def num_steps_can_sample(self):
88 | return self._size
89 |
90 | def get_diagnostics(self):
91 | return OrderedDict([
92 | ('size', self._size)
93 | ])
94 |
--------------------------------------------------------------------------------
/examples/sac.py:
--------------------------------------------------------------------------------
1 | from gym.envs.mujoco import HalfCheetahEnv
2 |
3 | import rlkit.torch.pytorch_util as ptu
4 | from rlkit.data_management.env_replay_buffer import EnvReplayBuffer
5 | from rlkit.envs.wrappers import NormalizedBoxEnv
6 | from rlkit.launchers.launcher_util import setup_logger
7 | from rlkit.samplers.data_collector import MdpPathCollector
8 | from rlkit.torch.sac.policies import TanhGaussianPolicy, MakeDeterministic
9 | from rlkit.torch.sac.sac import SACTrainer
10 | from rlkit.torch.networks import FlattenMlp
11 | from rlkit.torch.torch_rl_algorithm import TorchBatchRLAlgorithm
12 |
13 |
14 | def experiment(variant):
15 | expl_env = NormalizedBoxEnv(HalfCheetahEnv())
16 | eval_env = NormalizedBoxEnv(HalfCheetahEnv())
17 | obs_dim = expl_env.observation_space.low.size
18 | action_dim = eval_env.action_space.low.size
19 |
20 | M = variant['layer_size']
21 | qf1 = FlattenMlp(
22 | input_size=obs_dim + action_dim,
23 | output_size=1,
24 | hidden_sizes=[M, M],
25 | )
26 | qf2 = FlattenMlp(
27 | input_size=obs_dim + action_dim,
28 | output_size=1,
29 | hidden_sizes=[M, M],
30 | )
31 | target_qf1 = FlattenMlp(
32 | input_size=obs_dim + action_dim,
33 | output_size=1,
34 | hidden_sizes=[M, M],
35 | )
36 | target_qf2 = FlattenMlp(
37 | input_size=obs_dim + action_dim,
38 | output_size=1,
39 | hidden_sizes=[M, M],
40 | )
41 | policy = TanhGaussianPolicy(
42 | obs_dim=obs_dim,
43 | action_dim=action_dim,
44 | hidden_sizes=[M, M],
45 | )
46 | eval_policy = MakeDeterministic(policy)
47 | eval_path_collector = MdpPathCollector(
48 | eval_env,
49 | eval_policy,
50 | )
51 | expl_path_collector = MdpPathCollector(
52 | expl_env,
53 | policy,
54 | )
55 | replay_buffer = EnvReplayBuffer(
56 | variant['replay_buffer_size'],
57 | expl_env,
58 | )
59 | trainer = SACTrainer(
60 | env=eval_env,
61 | policy=policy,
62 | qf1=qf1,
63 | qf2=qf2,
64 | target_qf1=target_qf1,
65 | target_qf2=target_qf2,
66 | **variant['trainer_kwargs']
67 | )
68 | algorithm = TorchBatchRLAlgorithm(
69 | trainer=trainer,
70 | exploration_env=expl_env,
71 | evaluation_env=eval_env,
72 | exploration_data_collector=expl_path_collector,
73 | evaluation_data_collector=eval_path_collector,
74 | replay_buffer=replay_buffer,
75 | **variant['algorithm_kwargs']
76 | )
77 | algorithm.to(ptu.device)
78 | algorithm.train()
79 |
80 |
81 |
82 |
83 | if __name__ == "__main__":
84 | # noinspection PyTypeChecker
85 | variant = dict(
86 | algorithm="SAC",
87 | version="normal",
88 | layer_size=256,
89 | replay_buffer_size=int(1E6),
90 | algorithm_kwargs=dict(
91 | num_epochs=3000,
92 | num_eval_steps_per_epoch=5000,
93 | num_trains_per_train_loop=1000,
94 | num_expl_steps_per_train_loop=1000,
95 | min_num_steps_before_training=1000,
96 | max_path_length=1000,
97 | batch_size=256,
98 | ),
99 | trainer_kwargs=dict(
100 | discount=0.99,
101 | soft_target_tau=5e-3,
102 | target_update_period=1,
103 | policy_lr=3E-4,
104 | qf_lr=3E-4,
105 | reward_scale=1,
106 | use_automatic_entropy_tuning=True,
107 | ),
108 | )
109 | setup_logger('name-of-experiment', variant=variant)
110 | # ptu.set_gpu_mode(True) # optionally set the GPU (default=False)
111 | experiment(variant)
112 |
--------------------------------------------------------------------------------
/environment/docker/Dockerfile:
--------------------------------------------------------------------------------
1 | # We need the CUDA base dockerfile to enable GPU rendering
2 | # on hosts with GPUs.
3 | # The image below is a pinned version of nvidia/cuda:9.1-cudnn7-devel-ubuntu16.04 (from Jan 2018)
4 | # If updating the base image, be sure to test on GPU since it has broken in the past.
5 | FROM nvidia/cuda:9.0-cudnn7-devel-ubuntu16.04
6 |
7 |
8 | RUN apt-get update -q \
9 | && DEBIAN_FRONTEND=noninteractive apt-get install -y \
10 | curl \
11 | git \
12 | libgl1-mesa-dev \
13 | libgl1-mesa-glx \
14 | libglew-dev \
15 | libosmesa6-dev \
16 | software-properties-common \
17 | net-tools \
18 | unzip \
19 | vim \
20 | virtualenv \
21 | wget \
22 | xpra \
23 | xserver-xorg-dev \
24 | && apt-get clean \
25 | && rm -rf /var/lib/apt/lists/*
26 |
27 | RUN DEBIAN_FRONTEND=noninteractive add-apt-repository --yes ppa:deadsnakes/ppa && apt-get update
28 | RUN DEBIAN_FRONTEND=noninteractive apt-get install --yes python3.5-dev python3.5 python3-pip
29 | RUN virtualenv --python=python3.5 env
30 |
31 | RUN rm /usr/bin/python
32 | RUN ln -s /env/bin/python3.5 /usr/bin/python
33 | RUN ln -s /env/bin/pip3.5 /usr/bin/pip
34 | RUN ln -s /env/bin/pytest /usr/bin/pytest
35 |
36 | RUN curl -o /usr/local/bin/patchelf https://s3-us-west-2.amazonaws.com/openai-sci-artifacts/manual-builds/patchelf_0.9_amd64.elf \
37 | && chmod +x /usr/local/bin/patchelf
38 |
39 | ENV LANG C.UTF-8
40 |
41 | RUN mkdir -p /root/.mujoco \
42 | && wget https://www.roboti.us/download/mjpro150_linux.zip -O mujoco.zip \
43 | && unzip mujoco.zip -d /root/.mujoco \
44 | && rm mujoco.zip
45 | COPY ./mjkey.txt /root/.mujoco/
46 | ENV LD_LIBRARY_PATH /root/.mujoco/mjpro150/bin:${LD_LIBRARY_PATH}
47 | ENV LD_LIBRARY_PATH /usr/local/nvidia/lib64:${LD_LIBRARY_PATH}
48 |
49 | COPY vendor/Xdummy /usr/local/bin/Xdummy
50 | RUN chmod +x /usr/local/bin/Xdummy
51 |
52 | # Workaround for https://bugs.launchpad.net/ubuntu/+source/nvidia-graphics-drivers-375/+bug/1674677
53 | COPY ./vendor/10_nvidia.json /usr/share/glvnd/egl_vendor.d/10_nvidia.json
54 |
55 | RUN apt-get update && apt-get install -y libav-tools
56 |
57 | # For some reason this works despite an error showing up...
58 | RUN DEBIAN_FRONTEND=noninteractive apt-get -qy install nvidia-384; exit 0
59 | ENV LD_LIBRARY_PATH ${LD_LIBRARY_PATH}:/usr/lib/nvidia-384
60 |
61 | RUN mkdir /root/code
62 | WORKDIR /root/code
63 |
64 | WORKDIR /mujoco_py
65 |
66 | # For atari-py
67 | RUN apt-get install -y zlib1g-dev swig cmake
68 |
69 | # Previous versions installed from a requirements.txt, but direct pip
70 | # install seems cleaner
71 | RUN pip install glfw>=1.4.0
72 | RUN pip install numpy>=1.11
73 | RUN pip install Cython>=0.27.2
74 | RUN pip install imageio>=2.1.2
75 | RUN pip install cffi>=1.10
76 | RUN pip install imagehash>=3.4
77 | RUN pip install ipdb
78 | RUN pip install Pillow>=4.0.0
79 | RUN pip install pycparser>=2.17.0
80 | RUN pip install pytest>=3.0.5
81 | RUN pip install pytest-instafail==0.3.0
82 | RUN pip install scipy>=0.18.0
83 | RUN pip install sphinx
84 | RUN pip install sphinx_rtd_theme
85 | RUN pip install numpydoc
86 | RUN pip install cloudpickle==0.5.2
87 | RUN pip install cached-property==1.3.1
88 | RUN pip install gym[all]==0.10.5
89 | RUN pip install gitpython==2.1.7
90 | RUN pip install gtimer==1.0.0b5
91 | RUN pip install awscli==1.11.179
92 | RUN pip install boto3==1.4.8
93 | RUN pip install ray==0.2.2
94 | RUN pip install path.py==10.3.1
95 | RUN pip install http://download.pytorch.org/whl/cu90/torch-0.4.1-cp35-cp35m-linux_x86_64.whl
96 | RUN pip install joblib==0.9.4
97 | RUN pip install opencv-python==3.4.0.12
98 | RUN pip install torchvision==0.2.0
99 | RUN pip install sk-video==1.1.10
100 | RUN pip install git+https://github.com/vitchyr/multiworld.git
101 |
--------------------------------------------------------------------------------
/rlkit/data_management/normalizer.py:
--------------------------------------------------------------------------------
1 | """
2 | Based on code from Marcin Andrychowicz
3 | """
4 | import numpy as np
5 |
6 |
7 | class Normalizer(object):
8 | def __init__(
9 | self,
10 | size,
11 | eps=1e-8,
12 | default_clip_range=np.inf,
13 | mean=0,
14 | std=1,
15 | ):
16 | self.size = size
17 | self.eps = eps
18 | self.default_clip_range = default_clip_range
19 | self.sum = np.zeros(self.size, np.float32)
20 | self.sumsq = np.zeros(self.size, np.float32)
21 | self.count = np.ones(1, np.float32)
22 | self.mean = mean + np.zeros(self.size, np.float32)
23 | self.std = std * np.ones(self.size, np.float32)
24 | self.synchronized = True
25 |
26 | def update(self, v):
27 | if v.ndim == 1:
28 | v = np.expand_dims(v, 0)
29 | assert v.ndim == 2
30 | assert v.shape[1] == self.size
31 | self.sum += v.sum(axis=0)
32 | self.sumsq += (np.square(v)).sum(axis=0)
33 | self.count[0] += v.shape[0]
34 | self.synchronized = False
35 |
36 | def normalize(self, v, clip_range=None):
37 | if not self.synchronized:
38 | self.synchronize()
39 | if clip_range is None:
40 | clip_range = self.default_clip_range
41 | mean, std = self.mean, self.std
42 | if v.ndim == 2:
43 | mean = mean.reshape(1, -1)
44 | std = std.reshape(1, -1)
45 | return np.clip((v - mean) / std, -clip_range, clip_range)
46 |
47 | def denormalize(self, v):
48 | if not self.synchronized:
49 | self.synchronize()
50 | mean, std = self.mean, self.std
51 | if v.ndim == 2:
52 | mean = mean.reshape(1, -1)
53 | std = std.reshape(1, -1)
54 | return mean + v * std
55 |
56 | def synchronize(self):
57 | self.mean[...] = self.sum / self.count[0]
58 | self.std[...] = np.sqrt(
59 | np.maximum(
60 | np.square(self.eps),
61 | self.sumsq / self.count[0] - np.square(self.mean)
62 | )
63 | )
64 | self.synchronized = True
65 |
66 |
67 | class IdentityNormalizer(object):
68 | def __init__(self, *args, **kwargs):
69 | pass
70 |
71 | def update(self, v):
72 | pass
73 |
74 | def normalize(self, v, clip_range=None):
75 | return v
76 |
77 | def denormalize(self, v):
78 | return v
79 |
80 |
81 | class FixedNormalizer(object):
82 | def __init__(
83 | self,
84 | size,
85 | default_clip_range=np.inf,
86 | mean=0,
87 | std=1,
88 | eps=1e-8,
89 | ):
90 | assert std > 0
91 | std = std + eps
92 | self.size = size
93 | self.default_clip_range = default_clip_range
94 | self.mean = mean + np.zeros(self.size, np.float32)
95 | self.std = std + np.zeros(self.size, np.float32)
96 | self.eps = eps
97 |
98 | def set_mean(self, mean):
99 | self.mean = mean + np.zeros(self.size, np.float32)
100 |
101 | def set_std(self, std):
102 | std = std + self.eps
103 | self.std = std + np.zeros(self.size, np.float32)
104 |
105 | def normalize(self, v, clip_range=None):
106 | if clip_range is None:
107 | clip_range = self.default_clip_range
108 | mean, std = self.mean, self.std
109 | if v.ndim == 2:
110 | mean = mean.reshape(1, -1)
111 | std = std.reshape(1, -1)
112 | return np.clip((v - mean) / std, -clip_range, clip_range)
113 |
114 | def denormalize(self, v):
115 | mean, std = self.mean, self.std
116 | if v.ndim == 2:
117 | mean = mean.reshape(1, -1)
118 | std = std.reshape(1, -1)
119 | return mean + v * std
120 |
121 | def copy_stats(self, other):
122 | self.set_mean(other.mean)
123 | self.set_std(other.std)
124 |
--------------------------------------------------------------------------------
/rlkit/torch/networks.py:
--------------------------------------------------------------------------------
1 | """
2 | General networks for pytorch.
3 |
4 | Algorithm-specific networks should go else-where.
5 | """
6 | import torch
7 | from torch import nn as nn
8 | from torch.nn import functional as F
9 |
10 | from rlkit.policies.base import Policy
11 | from rlkit.torch import pytorch_util as ptu
12 | from rlkit.torch.core import eval_np
13 | from rlkit.torch.data_management.normalizer import TorchFixedNormalizer
14 | from rlkit.torch.modules import LayerNorm
15 |
16 |
17 | def identity(x):
18 | return x
19 |
20 |
21 | class Mlp(nn.Module):
22 | def __init__(
23 | self,
24 | hidden_sizes,
25 | output_size,
26 | input_size,
27 | init_w=3e-3,
28 | hidden_activation=F.relu,
29 | output_activation=identity,
30 | hidden_init=ptu.fanin_init,
31 | b_init_value=0.1,
32 | layer_norm=False,
33 | layer_norm_kwargs=None,
34 | ):
35 | super().__init__()
36 |
37 | if layer_norm_kwargs is None:
38 | layer_norm_kwargs = dict()
39 |
40 | self.input_size = input_size
41 | self.output_size = output_size
42 | self.hidden_activation = hidden_activation
43 | self.output_activation = output_activation
44 | self.layer_norm = layer_norm
45 | self.fcs = []
46 | self.layer_norms = []
47 | in_size = input_size
48 |
49 | for i, next_size in enumerate(hidden_sizes):
50 | fc = nn.Linear(in_size, next_size)
51 | in_size = next_size
52 | hidden_init(fc.weight)
53 | fc.bias.data.fill_(b_init_value)
54 | self.__setattr__("fc{}".format(i), fc)
55 | self.fcs.append(fc)
56 |
57 | if self.layer_norm:
58 | ln = LayerNorm(next_size)
59 | self.__setattr__("layer_norm{}".format(i), ln)
60 | self.layer_norms.append(ln)
61 |
62 | self.last_fc = nn.Linear(in_size, output_size)
63 | self.last_fc.weight.data.uniform_(-init_w, init_w)
64 | self.last_fc.bias.data.uniform_(-init_w, init_w)
65 |
66 | def forward(self, input, return_preactivations=False):
67 | h = input
68 | for i, fc in enumerate(self.fcs):
69 | h = fc(h)
70 | if self.layer_norm and i < len(self.fcs) - 1:
71 | h = self.layer_norms[i](h)
72 | h = self.hidden_activation(h)
73 | preactivation = self.last_fc(h)
74 | output = self.output_activation(preactivation)
75 | if return_preactivations:
76 | return output, preactivation
77 | else:
78 | return output
79 |
80 |
81 | class FlattenMlp(Mlp):
82 | """
83 | Flatten inputs along dimension 1 and then pass through MLP.
84 | """
85 |
86 | def forward(self, *inputs, **kwargs):
87 | flat_inputs = torch.cat(inputs, dim=1)
88 | return super().forward(flat_inputs, **kwargs)
89 |
90 |
91 | class MlpPolicy(Mlp, Policy):
92 | """
93 | A simpler interface for creating policies.
94 | """
95 |
96 | def __init__(
97 | self,
98 | *args,
99 | obs_normalizer: TorchFixedNormalizer = None,
100 | **kwargs
101 | ):
102 | super().__init__(*args, **kwargs)
103 | self.obs_normalizer = obs_normalizer
104 |
105 | def forward(self, obs, **kwargs):
106 | if self.obs_normalizer:
107 | obs = self.obs_normalizer.normalize(obs)
108 | return super().forward(obs, **kwargs)
109 |
110 | def get_action(self, obs_np):
111 | actions = self.get_actions(obs_np[None])
112 | return actions[0, :], {}
113 |
114 | def get_actions(self, obs):
115 | return eval_np(self, obs)
116 |
117 |
118 | class TanhMlpPolicy(MlpPolicy):
119 | """
120 | A helper class since most policies have a tanh output activation.
121 | """
122 | def __init__(self, *args, **kwargs):
123 | super().__init__(*args, output_activation=torch.tanh, **kwargs)
124 |
--------------------------------------------------------------------------------
/rlkit/core/online_rl_algorithm.py:
--------------------------------------------------------------------------------
1 | import abc
2 |
3 | import gtimer as gt
4 | from rlkit.core.rl_algorithm import BaseRLAlgorithm
5 | from rlkit.data_management.replay_buffer import ReplayBuffer
6 | from rlkit.samplers.data_collector import (
7 | PathCollector,
8 | StepCollector,
9 | )
10 |
11 |
12 | class OnlineRLAlgorithm(BaseRLAlgorithm, metaclass=abc.ABCMeta):
13 | def __init__(
14 | self,
15 | trainer,
16 | exploration_env,
17 | evaluation_env,
18 | exploration_data_collector: StepCollector,
19 | evaluation_data_collector: PathCollector,
20 | replay_buffer: ReplayBuffer,
21 | batch_size,
22 | max_path_length,
23 | num_epochs,
24 | num_eval_steps_per_epoch,
25 | num_expl_steps_per_train_loop,
26 | num_trains_per_train_loop,
27 | num_train_loops_per_epoch=1,
28 | min_num_steps_before_training=0,
29 | ):
30 | super().__init__(
31 | trainer,
32 | exploration_env,
33 | evaluation_env,
34 | exploration_data_collector,
35 | evaluation_data_collector,
36 | replay_buffer,
37 | )
38 | self.batch_size = batch_size
39 | self.max_path_length = max_path_length
40 | self.num_epochs = num_epochs
41 | self.num_eval_steps_per_epoch = num_eval_steps_per_epoch
42 | self.num_trains_per_train_loop = num_trains_per_train_loop
43 | self.num_train_loops_per_epoch = num_train_loops_per_epoch
44 | self.num_expl_steps_per_train_loop = num_expl_steps_per_train_loop
45 | self.min_num_steps_before_training = min_num_steps_before_training
46 |
47 | assert self.num_trains_per_train_loop >= self.num_expl_steps_per_train_loop, \
48 | 'Online training presumes num_trains_per_train_loop >= num_expl_steps_per_train_loop'
49 |
50 | def _train(self):
51 | self.training_mode(False)
52 | if self.min_num_steps_before_training > 0:
53 | self.expl_data_collector.collect_new_steps(
54 | self.max_path_length,
55 | self.min_num_steps_before_training,
56 | discard_incomplete_paths=False,
57 | )
58 | init_expl_paths = self.expl_data_collector.get_epoch_paths()
59 | self.replay_buffer.add_paths(init_expl_paths)
60 | self.expl_data_collector.end_epoch(-1)
61 |
62 | gt.stamp('initial exploration', unique=True)
63 |
64 | num_trains_per_expl_step = self.num_trains_per_train_loop // self.num_expl_steps_per_train_loop
65 | for epoch in gt.timed_for(
66 | range(self._start_epoch, self.num_epochs),
67 | save_itrs=True,
68 | ):
69 | self.eval_data_collector.collect_new_paths(
70 | self.max_path_length,
71 | self.num_eval_steps_per_epoch,
72 | discard_incomplete_paths=True,
73 | )
74 | gt.stamp('evaluation sampling')
75 |
76 | for _ in range(self.num_train_loops_per_epoch):
77 | for _ in range(self.num_expl_steps_per_train_loop):
78 | self.expl_data_collector.collect_new_steps(
79 | self.max_path_length,
80 | 1, # num steps
81 | discard_incomplete_paths=False,
82 | )
83 | gt.stamp('exploration sampling', unique=False)
84 |
85 | self.training_mode(True)
86 | for _ in range(num_trains_per_expl_step):
87 | train_data = self.replay_buffer.random_batch(
88 | self.batch_size)
89 | self.trainer.train(train_data)
90 | gt.stamp('training', unique=False)
91 | self.training_mode(False)
92 |
93 | new_expl_paths = self.expl_data_collector.get_epoch_paths()
94 | self.replay_buffer.add_paths(new_expl_paths)
95 | gt.stamp('data storing', unique=False)
96 |
97 | self._end_epoch(epoch)
98 |
--------------------------------------------------------------------------------
/rlkit/samplers/util.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 |
4 | def rollout(env, agent, max_path_length=np.inf, render=False):
5 | """
6 | The following value for the following keys will be a 2D array, with the
7 | first dimension corresponding to the time dimension.
8 | - observations
9 | - actions
10 | - rewards
11 | - next_observations
12 | - terminals
13 |
14 | The next two elements will be lists of dictionaries, with the index into
15 | the list being the index into the time
16 | - agent_infos
17 | - env_infos
18 |
19 | :param env:
20 | :param agent:
21 | :param max_path_length:
22 | :param render:
23 | :return:
24 | """
25 | observations = []
26 | actions = []
27 | rewards = []
28 | terminals = []
29 | agent_infos = []
30 | env_infos = []
31 | o = env.reset()
32 | next_o = None
33 | path_length = 0
34 | if render:
35 | env.render()
36 | while path_length < max_path_length:
37 | a, agent_info = agent.get_action(o)
38 | next_o, r, d, env_info = env.step(a)
39 | observations.append(o)
40 | rewards.append(r)
41 | terminals.append(d)
42 | actions.append(a)
43 | agent_infos.append(agent_info)
44 | env_infos.append(env_info)
45 | path_length += 1
46 | if d:
47 | break
48 | o = next_o
49 | if render:
50 | env.render()
51 |
52 | actions = np.array(actions)
53 | if len(actions.shape) == 1:
54 | actions = np.expand_dims(actions, 1)
55 | observations = np.array(observations)
56 | if len(observations.shape) == 1:
57 | observations = np.expand_dims(observations, 1)
58 | next_o = np.array([next_o])
59 | next_observations = np.vstack(
60 | (
61 | observations[1:, :],
62 | np.expand_dims(next_o, 0)
63 | )
64 | )
65 | return dict(
66 | observations=observations,
67 | actions=actions,
68 | rewards=np.array(rewards).reshape(-1, 1),
69 | next_observations=next_observations,
70 | terminals=np.array(terminals).reshape(-1, 1),
71 | agent_infos=agent_infos,
72 | env_infos=env_infos,
73 | )
74 |
75 |
76 | def split_paths(paths):
77 | """
78 | Stack multiples obs/actions/etc. from different paths
79 | :param paths: List of paths, where one path is something returned from
80 | the rollout functino above.
81 | :return: Tuple. Every element will have shape batch_size X DIM, including
82 | the rewards and terminal flags.
83 | """
84 | rewards = [path["rewards"].reshape(-1, 1) for path in paths]
85 | terminals = [path["terminals"].reshape(-1, 1) for path in paths]
86 | actions = [path["actions"] for path in paths]
87 | obs = [path["observations"] for path in paths]
88 | next_obs = [path["next_observations"] for path in paths]
89 | rewards = np.vstack(rewards)
90 | terminals = np.vstack(terminals)
91 | obs = np.vstack(obs)
92 | actions = np.vstack(actions)
93 | next_obs = np.vstack(next_obs)
94 | assert len(rewards.shape) == 2
95 | assert len(terminals.shape) == 2
96 | assert len(obs.shape) == 2
97 | assert len(actions.shape) == 2
98 | assert len(next_obs.shape) == 2
99 | return rewards, terminals, obs, actions, next_obs
100 |
101 |
102 | def split_paths_to_dict(paths):
103 | rewards, terminals, obs, actions, next_obs = split_paths(paths)
104 | return dict(
105 | rewards=rewards,
106 | terminals=terminals,
107 | observations=obs,
108 | actions=actions,
109 | next_observations=next_obs,
110 | )
111 |
112 |
113 | def get_stat_in_paths(paths, dict_name, scalar_name):
114 | if len(paths) == 0:
115 | return np.array([[]])
116 |
117 | if type(paths[0][dict_name]) == dict:
118 | # Support rllab interface
119 | return [path[dict_name][scalar_name] for path in paths]
120 |
121 | return [
122 | [info[scalar_name] for info in path[dict_name]]
123 | for path in paths
124 | ]
--------------------------------------------------------------------------------
/rlkit/core/eval_util.py:
--------------------------------------------------------------------------------
1 | """
2 | Common evaluation utilities.
3 | """
4 |
5 | from collections import OrderedDict
6 | from numbers import Number
7 |
8 | import numpy as np
9 |
10 | import rlkit.pythonplusplus as ppp
11 |
12 |
13 | def get_generic_path_information(paths, stat_prefix=''):
14 | """
15 | Get an OrderedDict with a bunch of statistic names and values.
16 | """
17 | statistics = OrderedDict()
18 | returns = [sum(path["rewards"]) for path in paths]
19 |
20 | rewards = np.vstack([path["rewards"] for path in paths])
21 | statistics.update(create_stats_ordered_dict('Rewards', rewards,
22 | stat_prefix=stat_prefix))
23 | statistics.update(create_stats_ordered_dict('Returns', returns,
24 | stat_prefix=stat_prefix))
25 | actions = [path["actions"] for path in paths]
26 | if len(actions[0].shape) == 1:
27 | actions = np.hstack([path["actions"] for path in paths])
28 | else:
29 | actions = np.vstack([path["actions"] for path in paths])
30 | statistics.update(create_stats_ordered_dict(
31 | 'Actions', actions, stat_prefix=stat_prefix
32 | ))
33 | statistics['Num Paths'] = len(paths)
34 | statistics[stat_prefix + 'Average Returns'] = get_average_returns(paths)
35 |
36 | for info_key in ['env_infos', 'agent_infos']:
37 | if info_key in paths[0]:
38 | all_env_infos = [
39 | ppp.list_of_dicts__to__dict_of_lists(p[info_key])
40 | for p in paths
41 | ]
42 | for k in all_env_infos[0].keys():
43 | final_ks = np.array([info[k][-1] for info in all_env_infos])
44 | first_ks = np.array([info[k][0] for info in all_env_infos])
45 | all_ks = np.concatenate([info[k] for info in all_env_infos])
46 | statistics.update(create_stats_ordered_dict(
47 | stat_prefix + k,
48 | final_ks,
49 | stat_prefix='{}/final/'.format(info_key),
50 | ))
51 | statistics.update(create_stats_ordered_dict(
52 | stat_prefix + k,
53 | first_ks,
54 | stat_prefix='{}/initial/'.format(info_key),
55 | ))
56 | statistics.update(create_stats_ordered_dict(
57 | stat_prefix + k,
58 | all_ks,
59 | stat_prefix='{}/'.format(info_key),
60 | ))
61 |
62 | return statistics
63 |
64 |
65 | def get_average_returns(paths):
66 | returns = [sum(path["rewards"]) for path in paths]
67 | return np.mean(returns)
68 |
69 |
70 | def create_stats_ordered_dict(
71 | name,
72 | data,
73 | stat_prefix=None,
74 | always_show_all_stats=True,
75 | exclude_max_min=False,
76 | ):
77 | if stat_prefix is not None:
78 | name = "{}{}".format(stat_prefix, name)
79 | if isinstance(data, Number):
80 | return OrderedDict({name: data})
81 |
82 | if len(data) == 0:
83 | return OrderedDict()
84 |
85 | if isinstance(data, tuple):
86 | ordered_dict = OrderedDict()
87 | for number, d in enumerate(data):
88 | sub_dict = create_stats_ordered_dict(
89 | "{0}_{1}".format(name, number),
90 | d,
91 | )
92 | ordered_dict.update(sub_dict)
93 | return ordered_dict
94 |
95 | if isinstance(data, list):
96 | try:
97 | iter(data[0])
98 | except TypeError:
99 | pass
100 | else:
101 | data = np.concatenate(data)
102 |
103 | if (isinstance(data, np.ndarray) and data.size == 1
104 | and not always_show_all_stats):
105 | return OrderedDict({name: float(data)})
106 |
107 | stats = OrderedDict([
108 | (name + ' Mean', np.mean(data)),
109 | (name + ' Std', np.std(data)),
110 | ])
111 | if not exclude_max_min:
112 | stats[name + ' Max'] = np.max(data)
113 | stats[name + ' Min'] = np.min(data)
114 | return stats
115 |
--------------------------------------------------------------------------------
/examples/her/her_dqn_gridworld.py:
--------------------------------------------------------------------------------
1 | """
2 | This should results in an average return of -20 by the end of training.
3 |
4 | Usually hits -30 around epoch 50.
5 | Note that one epoch = 5k steps, so 200 epochs = 1 million steps.
6 | """
7 | import gym
8 |
9 | import rlkit.torch.pytorch_util as ptu
10 | from rlkit.data_management.obs_dict_replay_buffer import ObsDictRelabelingBuffer
11 | from rlkit.exploration_strategies.base import \
12 | PolicyWrappedWithExplorationStrategy
13 | from rlkit.exploration_strategies.epsilon_greedy import EpsilonGreedy
14 | from rlkit.launchers.launcher_util import setup_logger
15 | from rlkit.samplers.data_collector import GoalConditionedPathCollector
16 |
17 | from rlkit.policies.argmax import ArgmaxDiscretePolicy
18 | from rlkit.torch.dqn.dqn import DQNTrainer
19 | from rlkit.torch.her.her import HERTrainer
20 | from rlkit.torch.networks import FlattenMlp
21 | from rlkit.torch.torch_rl_algorithm import TorchBatchRLAlgorithm
22 |
23 | try:
24 | import multiworld.envs.gridworlds
25 | except ImportError as e:
26 | print("To run this example, you need to install `multiworld`. See "
27 | "https://github.com/vitchyr/multiworld.")
28 | raise e
29 |
30 |
31 | def experiment(variant):
32 | expl_env = gym.make('GoalGridworld-v0')
33 | eval_env = gym.make('GoalGridworld-v0')
34 |
35 | obs_dim = expl_env.observation_space.spaces['observation'].low.size
36 | goal_dim = expl_env.observation_space.spaces['desired_goal'].low.size
37 | action_dim = expl_env.action_space.n
38 | qf = FlattenMlp(
39 | input_size=obs_dim + goal_dim,
40 | output_size=action_dim,
41 | hidden_sizes=[400, 300],
42 | )
43 | target_qf = FlattenMlp(
44 | input_size=obs_dim + goal_dim,
45 | output_size=action_dim,
46 | hidden_sizes=[400, 300],
47 | )
48 | eval_policy = ArgmaxDiscretePolicy(qf)
49 | exploration_strategy = EpsilonGreedy(
50 | action_space=expl_env.action_space,
51 | )
52 | expl_policy = PolicyWrappedWithExplorationStrategy(
53 | exploration_strategy=exploration_strategy,
54 | policy=eval_policy,
55 | )
56 |
57 | replay_buffer = ObsDictRelabelingBuffer(
58 | env=eval_env,
59 | **variant['replay_buffer_kwargs']
60 | )
61 | observation_key = 'observation'
62 | desired_goal_key = 'desired_goal'
63 | eval_path_collector = GoalConditionedPathCollector(
64 | eval_env,
65 | eval_policy,
66 | observation_key=observation_key,
67 | desired_goal_key=desired_goal_key,
68 | )
69 | expl_path_collector = GoalConditionedPathCollector(
70 | expl_env,
71 | expl_policy,
72 | observation_key=observation_key,
73 | desired_goal_key=desired_goal_key,
74 | )
75 | trainer = DQNTrainer(
76 | qf=qf,
77 | target_qf=target_qf,
78 | **variant['trainer_kwargs']
79 | )
80 | trainer = HERTrainer(trainer)
81 | algorithm = TorchBatchRLAlgorithm(
82 | trainer=trainer,
83 | exploration_env=expl_env,
84 | evaluation_env=eval_env,
85 | exploration_data_collector=expl_path_collector,
86 | evaluation_data_collector=eval_path_collector,
87 | replay_buffer=replay_buffer,
88 | **variant['algo_kwargs']
89 | )
90 | algorithm.to(ptu.device)
91 | algorithm.train()
92 |
93 |
94 | if __name__ == "__main__":
95 | variant = dict(
96 | algo_kwargs=dict(
97 | num_epochs=100,
98 | max_path_length=50,
99 | num_eval_steps_per_epoch=1000,
100 | num_expl_steps_per_train_loop=1000,
101 | num_trains_per_train_loop=1000,
102 | min_num_steps_before_training=1000,
103 | batch_size=128,
104 | ),
105 | trainer_kwargs=dict(
106 | discount=0.99,
107 | ),
108 | replay_buffer_kwargs=dict(
109 | max_size=100000,
110 | fraction_goals_rollout_goals=0.2, # equal to k = 4 in HER paper
111 | fraction_goals_env_goals=0.0,
112 | ),
113 | )
114 | setup_logger('her-dqn-gridworld-experiment', variant=variant)
115 | experiment(variant)
116 |
--------------------------------------------------------------------------------
/rlkit/util/video.py:
--------------------------------------------------------------------------------
1 | import os
2 | import os.path as osp
3 | import time
4 |
5 | import numpy as np
6 | import scipy.misc
7 | import skvideo.io
8 |
9 | from rlkit.envs.vae_wrapper import VAEWrappedEnv
10 |
11 |
12 | def dump_video(
13 | env,
14 | policy,
15 | filename,
16 | rollout_function,
17 | rows=3,
18 | columns=6,
19 | pad_length=0,
20 | pad_color=255,
21 | do_timer=True,
22 | horizon=100,
23 | dirname_to_save_images=None,
24 | subdirname="rollouts",
25 | imsize=84,
26 | num_channels=3,
27 | ):
28 | frames = []
29 | H = 3 * imsize
30 | W = imsize
31 | N = rows * columns
32 | for i in range(N):
33 | start = time.time()
34 | path = rollout_function(
35 | env,
36 | policy,
37 | max_path_length=horizon,
38 | render=False,
39 | )
40 | is_vae_env = isinstance(env, VAEWrappedEnv)
41 | l = []
42 | for d in path['full_observations']:
43 | if is_vae_env:
44 | recon = np.clip(env._reconstruct_img(d['image_observation']), 0,
45 | 1)
46 | else:
47 | recon = d['image_observation']
48 | l.append(
49 | get_image(
50 | d['image_desired_goal'],
51 | d['image_observation'],
52 | recon,
53 | pad_length=pad_length,
54 | pad_color=pad_color,
55 | imsize=imsize,
56 | )
57 | )
58 | frames += l
59 |
60 | if dirname_to_save_images:
61 | rollout_dir = osp.join(dirname_to_save_images, subdirname, str(i))
62 | os.makedirs(rollout_dir, exist_ok=True)
63 | rollout_frames = frames[-101:]
64 | goal_img = np.flip(rollout_frames[0][:imsize, :imsize, :], 0)
65 | scipy.misc.imsave(rollout_dir + "/goal.png", goal_img)
66 | goal_img = np.flip(rollout_frames[1][:imsize, :imsize, :], 0)
67 | scipy.misc.imsave(rollout_dir + "/z_goal.png", goal_img)
68 | for j in range(0, 101, 1):
69 | img = np.flip(rollout_frames[j][imsize:, :imsize, :], 0)
70 | scipy.misc.imsave(rollout_dir + "/" + str(j) + ".png", img)
71 | if do_timer:
72 | print(i, time.time() - start)
73 |
74 | frames = np.array(frames, dtype=np.uint8)
75 | path_length = frames.size // (
76 | N * (H + 2 * pad_length) * (W + 2 * pad_length) * num_channels
77 | )
78 | frames = np.array(frames, dtype=np.uint8).reshape(
79 | (N, path_length, H + 2 * pad_length, W + 2 * pad_length, num_channels)
80 | )
81 | f1 = []
82 | for k1 in range(columns):
83 | f2 = []
84 | for k2 in range(rows):
85 | k = k1 * rows + k2
86 | f2.append(frames[k:k + 1, :, :, :, :].reshape(
87 | (path_length, H + 2 * pad_length, W + 2 * pad_length,
88 | num_channels)
89 | ))
90 | f1.append(np.concatenate(f2, axis=1))
91 | outputdata = np.concatenate(f1, axis=2)
92 | skvideo.io.vwrite(filename, outputdata)
93 | print("Saved video to ", filename)
94 |
95 |
96 | def get_image(goal, obs, recon_obs, imsize=84, pad_length=1, pad_color=255):
97 | if len(goal.shape) == 1:
98 | goal = goal.reshape(-1, imsize, imsize).transpose()
99 | obs = obs.reshape(-1, imsize, imsize).transpose()
100 | recon_obs = recon_obs.reshape(-1, imsize, imsize).transpose()
101 | img = np.concatenate((goal, obs, recon_obs))
102 | img = np.uint8(255 * img)
103 | if pad_length > 0:
104 | img = add_border(img, pad_length, pad_color)
105 | return img
106 |
107 |
108 | def add_border(img, pad_length, pad_color, imsize=84):
109 | H = 3 * imsize
110 | W = imsize
111 | img = img.reshape((3 * imsize, imsize, -1))
112 | img2 = np.ones((H + 2 * pad_length, W + 2 * pad_length, img.shape[2]),
113 | dtype=np.uint8) * pad_color
114 | img2[pad_length:-pad_length, pad_length:-pad_length, :] = img
115 | return img2
116 |
--------------------------------------------------------------------------------
/rlkit/torch/skewfit/video_gen.py:
--------------------------------------------------------------------------------
1 | import os
2 | import os.path as osp
3 |
4 | import uuid
5 | from rlkit.envs.vae_wrapper import VAEWrappedEnv
6 |
7 | filename = str(uuid.uuid4())
8 |
9 | import skvideo.io
10 | import numpy as np
11 | import time
12 |
13 | import scipy.misc
14 |
15 | def add_border(img, pad_length, pad_color, imsize=84):
16 | H = 3*imsize
17 | W = imsize
18 | img = img.reshape((3*imsize, imsize, -1))
19 | img2 = np.ones((H + 2 * pad_length, W + 2 * pad_length, img.shape[2]), dtype=np.uint8) * pad_color
20 | img2[pad_length:-pad_length, pad_length:-pad_length, :] = img
21 | return img2
22 |
23 |
24 | def get_image(goal, obs, recon_obs, imsize=84, pad_length=1, pad_color=255):
25 | if len(goal.shape) == 1:
26 | goal = goal.reshape(-1, imsize, imsize).transpose(2, 1, 0)
27 | obs = obs.reshape(-1, imsize, imsize).transpose(2,1,0)
28 | recon_obs = recon_obs.reshape(-1, imsize, imsize).transpose(2,1,0)
29 | img = np.concatenate((goal, obs, recon_obs))
30 | img = np.uint8(255 * img)
31 | if pad_length > 0:
32 | img = add_border(img, pad_length, pad_color)
33 | return img
34 |
35 |
36 | def dump_video(
37 | env,
38 | policy,
39 | filename,
40 | rollout_function,
41 | rows=3,
42 | columns=6,
43 | pad_length=0,
44 | pad_color=255,
45 | do_timer=True,
46 | horizon=100,
47 | dirname_to_save_images=None,
48 | subdirname="rollouts",
49 | imsize=84,
50 | ):
51 | # num_channels = env.vae.input_channels
52 | num_channels = 1 if env.grayscale else 3
53 | frames = []
54 | H = 3*imsize
55 | W=imsize
56 | N = rows * columns
57 | for i in range(N):
58 | start = time.time()
59 | path = rollout_function(
60 | env,
61 | policy,
62 | max_path_length=horizon,
63 | render=False,
64 | )
65 | is_vae_env = isinstance(env, VAEWrappedEnv)
66 | l = []
67 | for d in path['full_observations']:
68 | if is_vae_env:
69 | recon = np.clip(env._reconstruct_img(d['image_observation']), 0, 1)
70 | else:
71 | recon = d['image_observation']
72 | l.append(
73 | get_image(
74 | d['image_desired_goal'],
75 | d['image_observation'],
76 | recon,
77 | pad_length=pad_length,
78 | pad_color=pad_color,
79 | imsize=imsize,
80 | )
81 | )
82 | frames += l
83 |
84 | if dirname_to_save_images:
85 | rollout_dir = osp.join(dirname_to_save_images, subdirname, str(i))
86 | os.makedirs(rollout_dir, exist_ok=True)
87 | rollout_frames = frames[-101:]
88 | goal_img = np.flip(rollout_frames[0][:imsize, :imsize, :], 0)
89 | scipy.misc.imsave(rollout_dir+"/goal.png", goal_img)
90 | goal_img = np.flip(rollout_frames[1][:imsize, :imsize, :], 0)
91 | scipy.misc.imsave(rollout_dir+"/z_goal.png", goal_img)
92 | for j in range(0, 101, 1):
93 | img = np.flip(rollout_frames[j][imsize:, :imsize, :], 0)
94 | scipy.misc.imsave(rollout_dir+"/"+str(j)+".png", img)
95 | if do_timer:
96 | print(i, time.time() - start)
97 |
98 | frames = np.array(frames, dtype=np.uint8)
99 | path_length = frames.size // (
100 | N * (H + 2*pad_length) * (W + 2*pad_length) * num_channels
101 | )
102 | frames = np.array(frames, dtype=np.uint8).reshape(
103 | (N, path_length, H + 2 * pad_length, W + 2 * pad_length, num_channels)
104 | )
105 | f1 = []
106 | for k1 in range(columns):
107 | f2 = []
108 | for k2 in range(rows):
109 | k = k1 * rows + k2
110 | f2.append(frames[k:k+1, :, :, :, :].reshape(
111 | (path_length, H + 2 * pad_length, W + 2 * pad_length, num_channels)
112 | ))
113 | f1.append(np.concatenate(f2, axis=1))
114 | outputdata = np.concatenate(f1, axis=2)
115 | skvideo.io.vwrite(filename, outputdata)
116 | print("Saved video to ", filename)
117 |
--------------------------------------------------------------------------------
/examples/td3.py:
--------------------------------------------------------------------------------
1 | """
2 | This should results in an average return of ~3000 by the end of training.
3 |
4 | Usually hits 3000 around epoch 80-100. Within a see, the performance will be
5 | a bit noisy from one epoch to the next (occasionally dips dow to ~2000).
6 |
7 | Note that one epoch = 5k steps, so 200 epochs = 1 million steps.
8 | """
9 | from gym.envs.mujoco import HalfCheetahEnv
10 |
11 | import rlkit.torch.pytorch_util as ptu
12 | from rlkit.data_management.env_replay_buffer import EnvReplayBuffer
13 | from rlkit.envs.wrappers import NormalizedBoxEnv
14 | from rlkit.exploration_strategies.base import \
15 | PolicyWrappedWithExplorationStrategy
16 | from rlkit.exploration_strategies.gaussian_strategy import GaussianStrategy
17 | from rlkit.launchers.launcher_util import setup_logger
18 | from rlkit.samplers.data_collector import MdpPathCollector
19 | from rlkit.torch.networks import FlattenMlp, TanhMlpPolicy
20 | from rlkit.torch.td3.td3 import TD3Trainer
21 | from rlkit.torch.torch_rl_algorithm import TorchBatchRLAlgorithm
22 |
23 |
24 | def experiment(variant):
25 | expl_env = NormalizedBoxEnv(HalfCheetahEnv())
26 | eval_env = NormalizedBoxEnv(HalfCheetahEnv())
27 | obs_dim = expl_env.observation_space.low.size
28 | action_dim = expl_env.action_space.low.size
29 | qf1 = FlattenMlp(
30 | input_size=obs_dim + action_dim,
31 | output_size=1,
32 | **variant['qf_kwargs']
33 | )
34 | qf2 = FlattenMlp(
35 | input_size=obs_dim + action_dim,
36 | output_size=1,
37 | **variant['qf_kwargs']
38 | )
39 | target_qf1 = FlattenMlp(
40 | input_size=obs_dim + action_dim,
41 | output_size=1,
42 | **variant['qf_kwargs']
43 | )
44 | target_qf2 = FlattenMlp(
45 | input_size=obs_dim + action_dim,
46 | output_size=1,
47 | **variant['qf_kwargs']
48 | )
49 | policy = TanhMlpPolicy(
50 | input_size=obs_dim,
51 | output_size=action_dim,
52 | **variant['policy_kwargs']
53 | )
54 | target_policy = TanhMlpPolicy(
55 | input_size=obs_dim,
56 | output_size=action_dim,
57 | **variant['policy_kwargs']
58 | )
59 | es = GaussianStrategy(
60 | action_space=expl_env.action_space,
61 | max_sigma=0.1,
62 | min_sigma=0.1, # Constant sigma
63 | )
64 | exploration_policy = PolicyWrappedWithExplorationStrategy(
65 | exploration_strategy=es,
66 | policy=policy,
67 | )
68 | eval_path_collector = MdpPathCollector(
69 | eval_env,
70 | policy,
71 | )
72 | expl_path_collector = MdpPathCollector(
73 | expl_env,
74 | exploration_policy,
75 | )
76 | replay_buffer = EnvReplayBuffer(
77 | variant['replay_buffer_size'],
78 | expl_env,
79 | )
80 | trainer = TD3Trainer(
81 | policy=policy,
82 | qf1=qf1,
83 | qf2=qf2,
84 | target_qf1=target_qf1,
85 | target_qf2=target_qf2,
86 | target_policy=target_policy,
87 | **variant['trainer_kwargs']
88 | )
89 | algorithm = TorchBatchRLAlgorithm(
90 | trainer=trainer,
91 | exploration_env=expl_env,
92 | evaluation_env=eval_env,
93 | exploration_data_collector=expl_path_collector,
94 | evaluation_data_collector=eval_path_collector,
95 | replay_buffer=replay_buffer,
96 | **variant['algorithm_kwargs']
97 | )
98 | algorithm.to(ptu.device)
99 | algorithm.train()
100 |
101 |
102 | if __name__ == "__main__":
103 | variant = dict(
104 | algorithm_kwargs=dict(
105 | num_epochs=3000,
106 | num_eval_steps_per_epoch=5000,
107 | num_trains_per_train_loop=1000,
108 | num_expl_steps_per_train_loop=1000,
109 | min_num_steps_before_training=1000,
110 | max_path_length=1000,
111 | batch_size=256,
112 | ),
113 | trainer_kwargs=dict(
114 | discount=0.99,
115 | ),
116 | qf_kwargs=dict(
117 | hidden_sizes=[400, 300],
118 | ),
119 | policy_kwargs=dict(
120 | hidden_sizes=[400, 300],
121 | ),
122 | replay_buffer_size=int(1E6),
123 | )
124 | # ptu.set_gpu_mode(True) # optionally set the GPU (default=False)
125 | setup_logger('rlkit-post-refactor-td3-half-cheetah', variant=variant)
126 | experiment(variant)
127 |
--------------------------------------------------------------------------------
/rlkit/torch/vae/vae_base.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | import abc
4 | from torch.distributions import Normal
5 | from torch.nn import functional as F
6 | from rlkit.torch import pytorch_util as ptu
7 |
8 |
9 | class VAEBase(torch.nn.Module, metaclass=abc.ABCMeta):
10 | def __init__(
11 | self,
12 | representation_size,
13 | ):
14 | super().__init__()
15 | self.representation_size = representation_size
16 |
17 | @abc.abstractmethod
18 | def encode(self, input):
19 | """
20 | :param input:
21 | :return: latent_distribution_params
22 | """
23 | raise NotImplementedError()
24 |
25 | @abc.abstractmethod
26 | def rsample(self, latent_distribution_params):
27 | """
28 |
29 | :param latent_distribution_params:
30 | :return: latents
31 | """
32 | raise NotImplementedError()
33 |
34 | @abc.abstractmethod
35 | def reparameterize(self, latent_distribution_params):
36 | """
37 |
38 | :param latent_distribution_params:
39 | :return: latents
40 | """
41 | raise NotImplementedError()
42 |
43 | @abc.abstractmethod
44 | def decode(self, latents):
45 | """
46 | :param latents:
47 | :return: reconstruction, obs_distribution_params
48 | """
49 | raise NotImplementedError()
50 |
51 | @abc.abstractmethod
52 | def logprob(self, inputs, obs_distribution_params):
53 | """
54 | :param inputs:
55 | :param obs_distribution_params:
56 | :return: log probability of input under decoder
57 | """
58 | raise NotImplementedError()
59 |
60 | @abc.abstractmethod
61 | def kl_divergence(self, latent_distribution_params):
62 | """
63 | :param latent_distribution_params:
64 | :return: kl div between latent_distribution_params and prior on latent space
65 | """
66 | raise NotImplementedError()
67 |
68 | @abc.abstractmethod
69 | def get_encoding_from_latent_distribution_params(self, latent_distribution_params):
70 | """
71 |
72 | :param latent_distribution_params:
73 | :return: get latents from latent distribution params
74 | """
75 | raise NotImplementedError()
76 |
77 | def forward(self, input):
78 | """
79 | :param input:
80 | :return: reconstructed input, obs_distribution_params, latent_distribution_params
81 | """
82 | latent_distribution_params = self.encode(input)
83 | latents = self.reparameterize(latent_distribution_params)
84 | reconstructions, obs_distribution_params = self.decode(latents)
85 | return reconstructions, obs_distribution_params, latent_distribution_params
86 |
87 |
88 | class GaussianLatentVAE(VAEBase):
89 | def __init__(
90 | self,
91 | representation_size,
92 | ):
93 | super().__init__(representation_size)
94 | self.dist_mu = np.zeros(self.representation_size)
95 | self.dist_std = np.ones(self.representation_size)
96 |
97 | def rsample(self, latent_distribution_params):
98 | mu, logvar = latent_distribution_params
99 | stds = (0.5 * logvar).exp()
100 | epsilon = ptu.randn(*mu.size())
101 | latents = epsilon * stds + mu
102 | return latents
103 |
104 | def reparameterize(self, latent_distribution_params):
105 | if self.training:
106 | return self.rsample(latent_distribution_params)
107 | else:
108 | return latent_distribution_params[0]
109 |
110 | def kl_divergence(self, latent_distribution_params):
111 | mu, logvar = latent_distribution_params
112 | return - 0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp(), dim=1).mean()
113 |
114 | def get_encoding_from_latent_distribution_params(self, latent_distribution_params):
115 | return latent_distribution_params[0].cpu()
116 |
117 |
118 | def compute_bernoulli_log_prob(x, reconstruction_of_x):
119 | return -1 * F.binary_cross_entropy(
120 | reconstruction_of_x,
121 | x,
122 | reduction='elementwise_mean'
123 | )
124 |
125 |
126 | def compute_gaussian_log_prob(input, dec_mu, dec_var):
127 | decoder_dist = Normal(dec_mu, dec_var.pow(0.5))
128 | log_probs = decoder_dist.log_prob(input)
129 | vals = log_probs.sum(dim=1, keepdim=True)
130 | return vals.mean()
131 |
132 |
--------------------------------------------------------------------------------
/examples/her/her_sac_gym_fetch_reach.py:
--------------------------------------------------------------------------------
1 | import gym
2 |
3 | import rlkit.torch.pytorch_util as ptu
4 | from rlkit.data_management.obs_dict_replay_buffer import ObsDictRelabelingBuffer
5 | from rlkit.launchers.launcher_util import setup_logger
6 | from rlkit.samplers.data_collector import GoalConditionedPathCollector
7 | from rlkit.torch.her.her import HERTrainer
8 | from rlkit.torch.networks import FlattenMlp
9 | from rlkit.torch.sac.policies import MakeDeterministic, TanhGaussianPolicy
10 | from rlkit.torch.sac.sac import SACTrainer
11 | from rlkit.torch.torch_rl_algorithm import TorchBatchRLAlgorithm
12 |
13 |
14 | def experiment(variant):
15 | eval_env = gym.make('FetchReach-v1')
16 | expl_env = gym.make('FetchReach-v1')
17 |
18 | observation_key = 'observation'
19 | desired_goal_key = 'desired_goal'
20 |
21 | achieved_goal_key = desired_goal_key.replace("desired", "achieved")
22 | replay_buffer = ObsDictRelabelingBuffer(
23 | env=eval_env,
24 | observation_key=observation_key,
25 | desired_goal_key=desired_goal_key,
26 | achieved_goal_key=achieved_goal_key,
27 | **variant['replay_buffer_kwargs']
28 | )
29 | obs_dim = eval_env.observation_space.spaces['observation'].low.size
30 | action_dim = eval_env.action_space.low.size
31 | goal_dim = eval_env.observation_space.spaces['desired_goal'].low.size
32 | qf1 = FlattenMlp(
33 | input_size=obs_dim + action_dim + goal_dim,
34 | output_size=1,
35 | **variant['qf_kwargs']
36 | )
37 | qf2 = FlattenMlp(
38 | input_size=obs_dim + action_dim + goal_dim,
39 | output_size=1,
40 | **variant['qf_kwargs']
41 | )
42 | target_qf1 = FlattenMlp(
43 | input_size=obs_dim + action_dim + goal_dim,
44 | output_size=1,
45 | **variant['qf_kwargs']
46 | )
47 | target_qf2 = FlattenMlp(
48 | input_size=obs_dim + action_dim + goal_dim,
49 | output_size=1,
50 | **variant['qf_kwargs']
51 | )
52 | policy = TanhGaussianPolicy(
53 | obs_dim=obs_dim + goal_dim,
54 | action_dim=action_dim,
55 | **variant['policy_kwargs']
56 | )
57 | eval_policy = MakeDeterministic(policy)
58 | trainer = SACTrainer(
59 | env=eval_env,
60 | policy=policy,
61 | qf1=qf1,
62 | qf2=qf2,
63 | target_qf1=target_qf1,
64 | target_qf2=target_qf2,
65 | **variant['sac_trainer_kwargs']
66 | )
67 | trainer = HERTrainer(trainer)
68 | eval_path_collector = GoalConditionedPathCollector(
69 | eval_env,
70 | eval_policy,
71 | observation_key=observation_key,
72 | desired_goal_key=desired_goal_key,
73 | )
74 | expl_path_collector = GoalConditionedPathCollector(
75 | expl_env,
76 | policy,
77 | observation_key=observation_key,
78 | desired_goal_key=desired_goal_key,
79 | )
80 | algorithm = TorchBatchRLAlgorithm(
81 | trainer=trainer,
82 | exploration_env=expl_env,
83 | evaluation_env=eval_env,
84 | exploration_data_collector=expl_path_collector,
85 | evaluation_data_collector=eval_path_collector,
86 | replay_buffer=replay_buffer,
87 | **variant['algo_kwargs']
88 | )
89 | algorithm.to(ptu.device)
90 | algorithm.train()
91 |
92 |
93 |
94 | if __name__ == "__main__":
95 | variant = dict(
96 | algorithm='HER-SAC',
97 | version='normal',
98 | algo_kwargs=dict(
99 | batch_size=128,
100 | num_epochs=100,
101 | num_eval_steps_per_epoch=5000,
102 | num_expl_steps_per_train_loop=1000,
103 | num_trains_per_train_loop=1000,
104 | min_num_steps_before_training=1000,
105 | max_path_length=50,
106 | ),
107 | sac_trainer_kwargs=dict(
108 | discount=0.99,
109 | soft_target_tau=5e-3,
110 | target_update_period=1,
111 | policy_lr=3E-4,
112 | qf_lr=3E-4,
113 | reward_scale=1,
114 | use_automatic_entropy_tuning=True,
115 | ),
116 | replay_buffer_kwargs=dict(
117 | max_size=int(1E6),
118 | fraction_goals_rollout_goals=0.2, # equal to k = 4 in HER paper
119 | fraction_goals_env_goals=0,
120 | ),
121 | qf_kwargs=dict(
122 | hidden_sizes=[400, 300],
123 | ),
124 | policy_kwargs=dict(
125 | hidden_sizes=[400, 300],
126 | ),
127 | )
128 | setup_logger('her-sac-fetch-experiment', variant=variant)
129 | experiment(variant)
130 |
--------------------------------------------------------------------------------
/rlkit/launchers/conf.py:
--------------------------------------------------------------------------------
1 | """
2 | Copy this file to config.py and modify as needed.
3 | """
4 | import os
5 | from os.path import join
6 | import rlkit
7 |
8 | """
9 | `doodad.mount.MountLocal` by default ignores directories called "data"
10 | If you're going to rename this directory and use EC2, then change
11 | `doodad.mount.MountLocal.filter_dir`
12 | """
13 | # The directory of the project, not source
14 | rlkit_project_dir = join(os.path.dirname(rlkit.__file__), os.pardir)
15 | LOCAL_LOG_DIR = join(rlkit_project_dir, 'data')
16 |
17 | """
18 | ********************************************************************************
19 | ********************************************************************************
20 | ********************************************************************************
21 |
22 | You probably don't need to set all of the configurations below this line,
23 | unless you use AWS, GCP, Slurm, and/or Slurm on a remote server. I recommend
24 | ignoring most of these things and only using them on an as-needed basis.
25 |
26 | ********************************************************************************
27 | ********************************************************************************
28 | ********************************************************************************
29 | """
30 |
31 | """
32 | General doodad settings.
33 | """
34 | CODE_DIRS_TO_MOUNT = [
35 | rlkit_project_dir,
36 | # '/home/user/python/module/one', Add more paths as needed
37 | ]
38 |
39 | HOME = os.getenv('HOME') if os.getenv('HOME') is not None else os.getenv("USERPROFILE")
40 |
41 | DIR_AND_MOUNT_POINT_MAPPINGS = [
42 | dict(
43 | local_dir=join(HOME, '.mujoco/'),
44 | mount_point='/root/.mujoco',
45 | ),
46 | ]
47 | RUN_DOODAD_EXPERIMENT_SCRIPT_PATH = (
48 | join(rlkit_project_dir, 'scripts', 'run_experiment_from_doodad.py')
49 | # '/home/user/path/to/rlkit/scripts/run_experiment_from_doodad.py'
50 | )
51 | """
52 | AWS Settings
53 | """
54 | # If not set, default will be chosen by doodad
55 | # AWS_S3_PATH = 's3://bucket/directory
56 |
57 | # The docker image is looked up on dockerhub.com.
58 | DOODAD_DOCKER_IMAGE = "TODO"
59 | INSTANCE_TYPE = 'c4.large'
60 | SPOT_PRICE = 0.03
61 |
62 | GPU_DOODAD_DOCKER_IMAGE = 'TODO'
63 | GPU_INSTANCE_TYPE = 'g2.2xlarge'
64 | GPU_SPOT_PRICE = 0.5
65 |
66 | # You can use AMI images with the docker images already installed.
67 | REGION_TO_GPU_AWS_IMAGE_ID = {
68 | 'us-west-1': "TODO",
69 | 'us-east-1': "TODO",
70 | }
71 |
72 | REGION_TO_GPU_AWS_AVAIL_ZONE = {
73 | 'us-east-1': "us-east-1b",
74 | }
75 |
76 | # This really shouldn't matter and in theory could be whatever
77 | OUTPUT_DIR_FOR_DOODAD_TARGET = '/tmp/doodad-output/'
78 |
79 |
80 | """
81 | Slurm Settings
82 | """
83 | SINGULARITY_IMAGE = '/home/PATH/TO/IMAGE.img'
84 | # This assumes you saved mujoco to $HOME/.mujoco
85 | SINGULARITY_PRE_CMDS = [
86 | 'export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$HOME/.mujoco/mjpro150/bin'
87 | ]
88 | SLURM_CPU_CONFIG = dict(
89 | account_name='TODO',
90 | partition='savio',
91 | nodes=1,
92 | n_tasks=1,
93 | n_gpus=1,
94 | )
95 | SLURM_GPU_CONFIG = dict(
96 | account_name='TODO',
97 | partition='savio2_1080ti',
98 | nodes=1,
99 | n_tasks=1,
100 | n_gpus=1,
101 | )
102 |
103 |
104 | """
105 | Slurm Script Settings
106 |
107 | These are basically the same settings as above, but for the remote machine
108 | where you will be running the generated script.
109 | """
110 | SSS_CODE_DIRS_TO_MOUNT = [
111 | ]
112 | SSS_DIR_AND_MOUNT_POINT_MAPPINGS = [
113 | dict(
114 | local_dir='/global/home/users/USERNAME/.mujoco',
115 | mount_point='/root/.mujoco',
116 | ),
117 | ]
118 | SSS_LOG_DIR = '/global/scratch/USERNAME/doodad-log'
119 |
120 | SSS_IMAGE = '/global/scratch/USERNAME/TODO.img'
121 | SSS_RUN_DOODAD_EXPERIMENT_SCRIPT_PATH = (
122 | '/global/home/users/USERNAME/path/to/rlkit/scripts'
123 | '/run_experiment_from_doodad.py'
124 | )
125 | SSS_PRE_CMDS = [
126 | 'export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/global/home/users/USERNAME'
127 | '/.mujoco/mjpro150/bin'
128 | ]
129 |
130 | """
131 | GCP Settings
132 | """
133 | GCP_IMAGE_NAME = 'TODO'
134 | GCP_GPU_IMAGE_NAME = 'TODO'
135 | GCP_BUCKET_NAME = 'TODO'
136 |
137 | GCP_DEFAULT_KWARGS = dict(
138 | zone='us-west2-c',
139 | instance_type='n1-standard-4',
140 | image_project='TODO',
141 | terminate=True,
142 | preemptible=True,
143 | gpu_kwargs=dict(
144 | gpu_model='nvidia-tesla-p4',
145 | num_gpu=1,
146 | )
147 | )
148 |
149 | try:
150 | from rlkit.launchers.conf_private import *
151 | except ImportError:
152 | print("No personal conf_private.py found.")
153 |
--------------------------------------------------------------------------------
/rlkit/samplers/rollout_functions.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 |
4 | def multitask_rollout(
5 | env,
6 | agent,
7 | max_path_length=np.inf,
8 | render=False,
9 | render_kwargs=None,
10 | observation_key=None,
11 | desired_goal_key=None,
12 | get_action_kwargs=None,
13 | return_dict_obs=False,
14 | ):
15 | if render_kwargs is None:
16 | render_kwargs = {}
17 | if get_action_kwargs is None:
18 | get_action_kwargs = {}
19 | dict_obs = []
20 | dict_next_obs = []
21 | observations = []
22 | actions = []
23 | rewards = []
24 | terminals = []
25 | agent_infos = []
26 | env_infos = []
27 | next_observations = []
28 | path_length = 0
29 | agent.reset()
30 | o = env.reset()
31 | if render:
32 | env.render(**render_kwargs)
33 | goal = o[desired_goal_key]
34 | while path_length < max_path_length:
35 | dict_obs.append(o)
36 | if observation_key:
37 | o = o[observation_key]
38 | new_obs = np.hstack((o, goal))
39 | a, agent_info = agent.get_action(new_obs, **get_action_kwargs)
40 | next_o, r, d, env_info = env.step(a)
41 | if render:
42 | env.render(**render_kwargs)
43 | observations.append(o)
44 | rewards.append(r)
45 | terminals.append(d)
46 | actions.append(a)
47 | next_observations.append(next_o)
48 | dict_next_obs.append(next_o)
49 | agent_infos.append(agent_info)
50 | env_infos.append(env_info)
51 | path_length += 1
52 | if d:
53 | break
54 | o = next_o
55 | actions = np.array(actions)
56 | if len(actions.shape) == 1:
57 | actions = np.expand_dims(actions, 1)
58 | observations = np.array(observations)
59 | next_observations = np.array(next_observations)
60 | if return_dict_obs:
61 | observations = dict_obs
62 | next_observations = dict_next_obs
63 | return dict(
64 | observations=observations,
65 | actions=actions,
66 | rewards=np.array(rewards).reshape(-1, 1),
67 | next_observations=next_observations,
68 | terminals=np.array(terminals).reshape(-1, 1),
69 | agent_infos=agent_infos,
70 | env_infos=env_infos,
71 | goals=np.repeat(goal[None], path_length, 0),
72 | full_observations=dict_obs,
73 | )
74 |
75 |
76 | def rollout(
77 | env,
78 | agent,
79 | max_path_length=np.inf,
80 | render=False,
81 | render_kwargs=None,
82 | ):
83 | """
84 | The following value for the following keys will be a 2D array, with the
85 | first dimension corresponding to the time dimension.
86 | - observations
87 | - actions
88 | - rewards
89 | - next_observations
90 | - terminals
91 |
92 | The next two elements will be lists of dictionaries, with the index into
93 | the list being the index into the time
94 | - agent_infos
95 | - env_infos
96 | """
97 | if render_kwargs is None:
98 | render_kwargs = {}
99 | observations = []
100 | actions = []
101 | rewards = []
102 | terminals = []
103 | agent_infos = []
104 | env_infos = []
105 | o = env.reset()
106 | agent.reset()
107 | next_o = None
108 | path_length = 0
109 | if render:
110 | env.render(**render_kwargs)
111 | while path_length < max_path_length:
112 | a, agent_info = agent.get_action(o)
113 | next_o, r, d, env_info = env.step(a)
114 | observations.append(o)
115 | rewards.append(r)
116 | terminals.append(d)
117 | actions.append(a)
118 | agent_infos.append(agent_info)
119 | env_infos.append(env_info)
120 | path_length += 1
121 | if d:
122 | break
123 | o = next_o
124 | if render:
125 | env.render(**render_kwargs)
126 |
127 | actions = np.array(actions)
128 | if len(actions.shape) == 1:
129 | actions = np.expand_dims(actions, 1)
130 | observations = np.array(observations)
131 | if len(observations.shape) == 1:
132 | observations = np.expand_dims(observations, 1)
133 | next_o = np.array([next_o])
134 | next_observations = np.vstack(
135 | (
136 | observations[1:, :],
137 | np.expand_dims(next_o, 0)
138 | )
139 | )
140 | return dict(
141 | observations=observations,
142 | actions=actions,
143 | rewards=np.array(rewards).reshape(-1, 1),
144 | next_observations=next_observations,
145 | terminals=np.array(terminals).reshape(-1, 1),
146 | agent_infos=agent_infos,
147 | env_infos=env_infos,
148 | )
149 |
--------------------------------------------------------------------------------
/rlkit/torch/sac/policies.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | from torch import nn as nn
4 |
5 | from rlkit.policies.base import ExplorationPolicy, Policy
6 | from rlkit.torch.core import eval_np
7 | from rlkit.torch.distributions import TanhNormal
8 | from rlkit.torch.networks import Mlp
9 |
10 |
11 | LOG_SIG_MAX = 2
12 | LOG_SIG_MIN = -20
13 |
14 |
15 | class TanhGaussianPolicy(Mlp, ExplorationPolicy):
16 | """
17 | Usage:
18 |
19 | ```
20 | policy = TanhGaussianPolicy(...)
21 | action, mean, log_std, _ = policy(obs)
22 | action, mean, log_std, _ = policy(obs, deterministic=True)
23 | action, mean, log_std, log_prob = policy(obs, return_log_prob=True)
24 | ```
25 |
26 | Here, mean and log_std are the mean and log_std of the Gaussian that is
27 | sampled from.
28 |
29 | If deterministic is True, action = tanh(mean).
30 | If return_log_prob is False (default), log_prob = None
31 | This is done because computing the log_prob can be a bit expensive.
32 | """
33 | def __init__(
34 | self,
35 | hidden_sizes,
36 | obs_dim,
37 | action_dim,
38 | std=None,
39 | init_w=1e-3,
40 | **kwargs
41 | ):
42 | super().__init__(
43 | hidden_sizes,
44 | input_size=obs_dim,
45 | output_size=action_dim,
46 | init_w=init_w,
47 | **kwargs
48 | )
49 | self.log_std = None
50 | self.std = std
51 | if std is None:
52 | last_hidden_size = obs_dim
53 | if len(hidden_sizes) > 0:
54 | last_hidden_size = hidden_sizes[-1]
55 | self.last_fc_log_std = nn.Linear(last_hidden_size, action_dim)
56 | self.last_fc_log_std.weight.data.uniform_(-init_w, init_w)
57 | self.last_fc_log_std.bias.data.uniform_(-init_w, init_w)
58 | else:
59 | self.log_std = np.log(std)
60 | assert LOG_SIG_MIN <= self.log_std <= LOG_SIG_MAX
61 |
62 | def get_action(self, obs_np, deterministic=False):
63 | actions = self.get_actions(obs_np[None], deterministic=deterministic)
64 | return actions[0, :], {}
65 |
66 | def get_actions(self, obs_np, deterministic=False):
67 | return eval_np(self, obs_np, deterministic=deterministic)[0]
68 |
69 | def forward(
70 | self,
71 | obs,
72 | reparameterize=True,
73 | deterministic=False,
74 | return_log_prob=False,
75 | ):
76 | """
77 | :param obs: Observation
78 | :param deterministic: If True, do not sample
79 | :param return_log_prob: If True, return a sample and its log probability
80 | """
81 | h = obs
82 | for i, fc in enumerate(self.fcs):
83 | h = self.hidden_activation(fc(h))
84 | mean = self.last_fc(h)
85 | if self.std is None:
86 | log_std = self.last_fc_log_std(h)
87 | log_std = torch.clamp(log_std, LOG_SIG_MIN, LOG_SIG_MAX)
88 | std = torch.exp(log_std)
89 | else:
90 | std = self.std
91 | log_std = self.log_std
92 |
93 | log_prob = None
94 | entropy = None
95 | mean_action_log_prob = None
96 | pre_tanh_value = None
97 | if deterministic:
98 | action = torch.tanh(mean)
99 | else:
100 | tanh_normal = TanhNormal(mean, std)
101 | if return_log_prob:
102 | if reparameterize is True:
103 | action, pre_tanh_value = tanh_normal.rsample(
104 | return_pretanh_value=True
105 | )
106 | else:
107 | action, pre_tanh_value = tanh_normal.sample(
108 | return_pretanh_value=True
109 | )
110 | log_prob = tanh_normal.log_prob(
111 | action,
112 | pre_tanh_value=pre_tanh_value
113 | )
114 | log_prob = log_prob.sum(dim=1, keepdim=True)
115 | else:
116 | if reparameterize is True:
117 | action = tanh_normal.rsample()
118 | else:
119 | action = tanh_normal.sample()
120 |
121 | return (
122 | action, mean, log_std, log_prob, entropy, std,
123 | mean_action_log_prob, pre_tanh_value,
124 | )
125 |
126 |
127 | class MakeDeterministic(Policy):
128 | def __init__(self, stochastic_policy):
129 | self.stochastic_policy = stochastic_policy
130 |
131 | def get_action(self, observation):
132 | return self.stochastic_policy.get_action(observation,
133 | deterministic=True)
134 |
--------------------------------------------------------------------------------
/rlkit/envs/goal_generation/pickup_goal_dataset.py:
--------------------------------------------------------------------------------
1 | from multiworld.envs.mujoco.sawyer_xyz.sawyer_pick_and_place import (
2 | get_image_presampled_goals
3 | )
4 | import numpy as np
5 | import cv2
6 | import os.path as osp
7 | import random
8 |
9 | from rlkit.util.io import local_path_from_s3_or_local_path
10 |
11 |
12 | def setup_pickup_image_env(image_env, num_presampled_goals):
13 | """
14 | Image env and pickup env will have presampled goals. VAE wrapper should
15 | encode whatever presampled goal is sampled.
16 | """
17 | presampled_goals = get_image_presampled_goals(image_env,
18 | num_presampled_goals)
19 | image_env._presampled_goals = presampled_goals
20 | image_env.num_goals_presampled = \
21 | presampled_goals[random.choice(list(presampled_goals))].shape[0]
22 |
23 |
24 | def get_image_presampled_goals_from_vae_env(env, num_presampled_goals,
25 | env_id=None):
26 | image_env = env.wrapped_env
27 | return get_image_presampled_goals(image_env, num_presampled_goals)
28 |
29 |
30 | def get_image_presampled_goals_from_image_env(env, num_presampled_goals,
31 | env_id=None):
32 | return get_image_presampled_goals(env, num_presampled_goals)
33 |
34 |
35 | def generate_vae_dataset(variant):
36 | return generate_vae_dataset_from_params(**variant)
37 |
38 |
39 | def generate_vae_dataset_from_params(
40 | env_class=None,
41 | env_kwargs=None,
42 | env_id=None,
43 | N=10000,
44 | test_p=0.9,
45 | use_cached=True,
46 | imsize=84,
47 | num_channels=1,
48 | show=False,
49 | init_camera=None,
50 | dataset_path=None,
51 | oracle_dataset=False,
52 | n_random_steps=100,
53 | vae_dataset_specific_env_kwargs=None,
54 | save_file_prefix=None,
55 | ):
56 | from multiworld.core.image_env import ImageEnv, unormalize_image
57 | import time
58 |
59 | assert oracle_dataset == True
60 |
61 | if env_kwargs is None:
62 | env_kwargs = {}
63 | if save_file_prefix is None:
64 | save_file_prefix = env_id
65 | if save_file_prefix is None:
66 | save_file_prefix = env_class.__name__
67 | filename = "/tmp/{}_N{}_{}_imsize{}_oracle{}.npy".format(
68 | save_file_prefix,
69 | str(N),
70 | init_camera.__name__ if init_camera else '',
71 | imsize,
72 | oracle_dataset,
73 | )
74 | info = {}
75 | if dataset_path is not None:
76 | filename = local_path_from_s3_or_local_path(dataset_path)
77 | dataset = np.load(filename)
78 | np.random.shuffle(dataset)
79 | N = dataset.shape[0]
80 | elif use_cached and osp.isfile(filename):
81 | dataset = np.load(filename)
82 | np.random.shuffle(dataset)
83 | print("loaded data from saved file", filename)
84 | else:
85 | now = time.time()
86 |
87 | if env_id is not None:
88 | import gym
89 | import multiworld
90 | multiworld.register_all_envs()
91 | env = gym.make(env_id)
92 | else:
93 | if vae_dataset_specific_env_kwargs is None:
94 | vae_dataset_specific_env_kwargs = {}
95 | for key, val in env_kwargs.items():
96 | if key not in vae_dataset_specific_env_kwargs:
97 | vae_dataset_specific_env_kwargs[key] = val
98 | env = env_class(**vae_dataset_specific_env_kwargs)
99 | if not isinstance(env, ImageEnv):
100 | env = ImageEnv(
101 | env,
102 | imsize,
103 | init_camera=init_camera,
104 | transpose=True,
105 | normalize=True,
106 | )
107 | setup_pickup_image_env(env, num_presampled_goals=N)
108 | env.reset()
109 | info['env'] = env
110 |
111 | dataset = np.zeros((N, imsize * imsize * num_channels), dtype=np.uint8)
112 | for i in range(N):
113 | img = env._presampled_goals['image_desired_goal'][i]
114 | dataset[i, :] = unormalize_image(img)
115 | if show:
116 | img = img.reshape(3, imsize, imsize).transpose()
117 | img = img[::-1, :, ::-1]
118 | cv2.imshow('img', img)
119 | cv2.waitKey(1)
120 | time.sleep(.2)
121 | # radius = input('waiting...')
122 | print("done making training data", filename, time.time() - now)
123 | np.random.shuffle(dataset)
124 | np.save(filename, dataset)
125 |
126 | n = int(N * test_p)
127 | train_dataset = dataset[:n, :]
128 | test_dataset = dataset[n:, :]
129 | return train_dataset, test_dataset, info
130 |
--------------------------------------------------------------------------------
/rlkit/data_management/shared_obs_dict_replay_buffer.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | from rlkit.data_management.obs_dict_replay_buffer import ObsDictRelabelingBuffer
4 |
5 | import torch.multiprocessing as mp
6 | import ctypes
7 |
8 |
9 | class SharedObsDictRelabelingBuffer(ObsDictRelabelingBuffer):
10 | """
11 | Same as an ObsDictRelabelingBuffer but the obs and next_obs are backed
12 | by multiprocessing arrays. The replay buffer size is also shared. The
13 | intended use case is for if one wants obs/next_obs to be shared between
14 | processes. Accesses are synchronized internally by locks (mp takes care
15 | of that). Technically, putting such large arrays in shared memory/requiring
16 | synchronized access can be extremely slow, but it seems ok empirically.
17 |
18 | This code also breaks a lot of functionality for the subprocess. For example,
19 | random_batch is incorrect as actions and _idx_to_future_obs_idx are not
20 | shared. If the subprocess needs all of the functionality, a mp.Array
21 | must be used for all numpy arrays in the replay buffer.
22 |
23 | """
24 |
25 | def __init__(
26 | self,
27 | *args,
28 | **kwargs
29 | ):
30 | self._shared_size = mp.Value(ctypes.c_long, 0)
31 | ObsDictRelabelingBuffer.__init__(self, *args, **kwargs)
32 |
33 | self._mp_array_info = {}
34 | self._shared_obs_info = {}
35 | self._shared_next_obs_info = {}
36 |
37 | for obs_key, obs_arr in self._obs.items():
38 | ctype = ctypes.c_double
39 | if obs_arr.dtype == np.uint8:
40 | ctype = ctypes.c_uint8
41 |
42 | self._shared_obs_info[obs_key] = (
43 | mp.Array(ctype, obs_arr.size),
44 | obs_arr.dtype,
45 | obs_arr.shape,
46 | )
47 | self._shared_next_obs_info[obs_key] = (
48 | mp.Array(ctype, obs_arr.size),
49 | obs_arr.dtype,
50 | obs_arr.shape,
51 | )
52 |
53 | self._obs[obs_key] = to_np(*self._shared_obs_info[obs_key])
54 | self._next_obs[obs_key] = to_np(
55 | *self._shared_next_obs_info[obs_key])
56 | self._register_mp_array("_actions")
57 | self._register_mp_array("_terminals")
58 |
59 | def _register_mp_array(self, arr_instance_var_name):
60 | """
61 | Use this function to register an array to be shared. This will wipe arr.
62 | """
63 | assert hasattr(self, arr_instance_var_name), arr_instance_var_name
64 | arr = getattr(self, arr_instance_var_name)
65 |
66 | ctype = ctypes.c_double
67 | if arr.dtype == np.uint8:
68 | ctype = ctypes.c_uint8
69 |
70 | self._mp_array_info[arr_instance_var_name] = (
71 | mp.Array(ctype, arr.size), arr.dtype, arr.shape,
72 | )
73 | setattr(
74 | self,
75 | arr_instance_var_name,
76 | to_np(*self._mp_array_info[arr_instance_var_name])
77 | )
78 |
79 | def init_from_mp_info(
80 | self,
81 | mp_info,
82 | ):
83 | """
84 | The intended use is to have a subprocess serialize/copy a
85 | SharedObsDictRelabelingBuffer instance and call init_from on the
86 | instance's shared variables. This can't be done during serialization
87 | since multiprocessing shared objects can't be serialized and must be
88 | passed directly to the subprocess as an argument to the fork call.
89 | """
90 | shared_obs_info, shared_next_obs_info, mp_array_info, shared_size = mp_info
91 |
92 | self._shared_obs_info = shared_obs_info
93 | self._shared_next_obs_info = shared_next_obs_info
94 | self._mp_array_info = mp_array_info
95 | for obs_key in self._shared_obs_info.keys():
96 | self._obs[obs_key] = to_np(*self._shared_obs_info[obs_key])
97 | self._next_obs[obs_key] = to_np(
98 | *self._shared_next_obs_info[obs_key])
99 |
100 | for arr_instance_var_name in self._mp_array_info.keys():
101 | setattr(
102 | self,
103 | arr_instance_var_name,
104 | to_np(*self._mp_array_info[arr_instance_var_name])
105 | )
106 | self._shared_size = shared_size
107 |
108 | def get_mp_info(self):
109 | return (
110 | self._shared_obs_info,
111 | self._shared_next_obs_info,
112 | self._mp_array_info,
113 | self._shared_size,
114 | )
115 |
116 | @property
117 | def _size(self):
118 | return self._shared_size.value
119 |
120 | @_size.setter
121 | def _size(self, size):
122 | self._shared_size.value = size
123 |
124 |
125 | def to_np(shared_arr, np_dtype, shape):
126 | return np.frombuffer(shared_arr.get_obj(), dtype=np_dtype).reshape(shape)
127 |
--------------------------------------------------------------------------------
/rlkit/core/rl_algorithm.py:
--------------------------------------------------------------------------------
1 | import abc
2 | from collections import OrderedDict
3 |
4 | import gtimer as gt
5 |
6 | from rlkit.core import logger, eval_util
7 | from rlkit.data_management.replay_buffer import ReplayBuffer
8 | from rlkit.samplers.data_collector import DataCollector
9 |
10 |
11 | def _get_epoch_timings():
12 | times_itrs = gt.get_times().stamps.itrs
13 | times = OrderedDict()
14 | epoch_time = 0
15 | for key in sorted(times_itrs):
16 | time = times_itrs[key][-1]
17 | epoch_time += time
18 | times['time/{} (s)'.format(key)] = time
19 | times['time/epoch (s)'] = epoch_time
20 | times['time/total (s)'] = gt.get_times().total
21 | return times
22 |
23 |
24 | class BaseRLAlgorithm(object, metaclass=abc.ABCMeta):
25 | def __init__(
26 | self,
27 | trainer,
28 | exploration_env,
29 | evaluation_env,
30 | exploration_data_collector: DataCollector,
31 | evaluation_data_collector: DataCollector,
32 | replay_buffer: ReplayBuffer,
33 | ):
34 | self.trainer = trainer
35 | self.expl_env = exploration_env
36 | self.eval_env = evaluation_env
37 | self.expl_data_collector = exploration_data_collector
38 | self.eval_data_collector = evaluation_data_collector
39 | self.replay_buffer = replay_buffer
40 | self._start_epoch = 0
41 |
42 | self.post_epoch_funcs = []
43 |
44 | def train(self, start_epoch=0):
45 | self._start_epoch = start_epoch
46 | self._train()
47 |
48 | def _train(self):
49 | """
50 | Train model.
51 | """
52 | raise NotImplementedError('_train must implemented by inherited class')
53 |
54 | def _end_epoch(self, epoch):
55 | snapshot = self._get_snapshot()
56 | logger.save_itr_params(epoch, snapshot)
57 | gt.stamp('saving')
58 | self._log_stats(epoch)
59 |
60 | self.expl_data_collector.end_epoch(epoch)
61 | self.eval_data_collector.end_epoch(epoch)
62 | self.replay_buffer.end_epoch(epoch)
63 | self.trainer.end_epoch(epoch)
64 |
65 | for post_epoch_func in self.post_epoch_funcs:
66 | post_epoch_func(self, epoch)
67 |
68 | def _get_snapshot(self):
69 | snapshot = {}
70 | for k, v in self.trainer.get_snapshot().items():
71 | snapshot['trainer/' + k] = v
72 | for k, v in self.expl_data_collector.get_snapshot().items():
73 | snapshot['exploration/' + k] = v
74 | for k, v in self.eval_data_collector.get_snapshot().items():
75 | snapshot['evaluation/' + k] = v
76 | for k, v in self.replay_buffer.get_snapshot().items():
77 | snapshot['replay_buffer/' + k] = v
78 | return snapshot
79 |
80 | def _log_stats(self, epoch):
81 | logger.log("Epoch {} finished".format(epoch), with_timestamp=True)
82 |
83 | """
84 | Replay Buffer
85 | """
86 | logger.record_dict(
87 | self.replay_buffer.get_diagnostics(),
88 | prefix='replay_buffer/'
89 | )
90 |
91 | """
92 | Trainer
93 | """
94 | logger.record_dict(self.trainer.get_diagnostics(), prefix='trainer/')
95 |
96 | """
97 | Exploration
98 | """
99 | logger.record_dict(
100 | self.expl_data_collector.get_diagnostics(),
101 | prefix='exploration/'
102 | )
103 | expl_paths = self.expl_data_collector.get_epoch_paths()
104 | if hasattr(self.expl_env, 'get_diagnostics'):
105 | logger.record_dict(
106 | self.expl_env.get_diagnostics(expl_paths),
107 | prefix='exploration/',
108 | )
109 | logger.record_dict(
110 | eval_util.get_generic_path_information(expl_paths),
111 | prefix="exploration/",
112 | )
113 | """
114 | Evaluation
115 | """
116 | logger.record_dict(
117 | self.eval_data_collector.get_diagnostics(),
118 | prefix='evaluation/',
119 | )
120 | eval_paths = self.eval_data_collector.get_epoch_paths()
121 | if hasattr(self.eval_env, 'get_diagnostics'):
122 | logger.record_dict(
123 | self.eval_env.get_diagnostics(eval_paths),
124 | prefix='evaluation/',
125 | )
126 | logger.record_dict(
127 | eval_util.get_generic_path_information(eval_paths),
128 | prefix="evaluation/",
129 | )
130 |
131 | """
132 | Misc
133 | """
134 | gt.stamp('logging')
135 | logger.record_dict(_get_epoch_timings())
136 | logger.record_tabular('Epoch', epoch)
137 | logger.dump_tabular(with_prefix=False, with_timestamp=False)
138 |
139 | @abc.abstractmethod
140 | def training_mode(self, mode):
141 | """
142 | Set training mode to `mode`.
143 | :param mode: If True, training will happen (e.g. set the dropout
144 | probabilities to not all ones).
145 | """
146 | pass
147 |
--------------------------------------------------------------------------------
/examples/skewfit/sawyer_door.py:
--------------------------------------------------------------------------------
1 | import os.path as osp
2 | import multiworld.envs.mujoco as mwmj
3 | import rlkit.util.hyperparameter as hyp
4 | from multiworld.envs.mujoco.cameras import sawyer_door_env_camera_v0
5 | from rlkit.launchers.launcher_util import run_experiment
6 | import rlkit.torch.vae.vae_schedules as vae_schedules
7 | from rlkit.launchers.skewfit_experiments import \
8 | skewfit_full_experiment
9 | from rlkit.torch.vae.conv_vae import imsize48_default_architecture
10 |
11 |
12 | if __name__ == "__main__":
13 | variant = dict(
14 | algorithm='Skew-Fit-SAC',
15 | double_algo=False,
16 | online_vae_exploration=False,
17 | imsize=48,
18 | env_id='SawyerDoorHookResetFreeEnv-v1',
19 | init_camera=sawyer_door_env_camera_v0,
20 | skewfit_variant=dict(
21 | save_video=True,
22 | custom_goal_sampler='replay_buffer',
23 | online_vae_trainer_kwargs=dict(
24 | beta=20,
25 | lr=1e-3,
26 | ),
27 | save_video_period=50,
28 | qf_kwargs=dict(
29 | hidden_sizes=[400, 300],
30 | ),
31 | policy_kwargs=dict(
32 | hidden_sizes=[400, 300],
33 | ),
34 | twin_sac_trainer_kwargs=dict(
35 | reward_scale=1,
36 | discount=0.99,
37 | soft_target_tau=1e-3,
38 | target_update_period=1,
39 | use_automatic_entropy_tuning=True,
40 | ),
41 | max_path_length=100,
42 | algo_kwargs=dict(
43 | batch_size=1024,
44 | num_epochs=170,
45 | num_eval_steps_per_epoch=500,
46 | num_expl_steps_per_train_loop=500,
47 | num_trains_per_train_loop=1000,
48 | min_num_steps_before_training=10000,
49 | vae_training_schedule=vae_schedules.custom_schedule,
50 | oracle_data=False,
51 | vae_save_period=50,
52 | parallel_vae_train=False,
53 | ),
54 | replay_buffer_kwargs=dict(
55 | start_skew_epoch=10,
56 | max_size=int(100000),
57 | fraction_goals_rollout_goals=0.2,
58 | fraction_goals_env_goals=0.5,
59 | exploration_rewards_type='None',
60 | vae_priority_type='vae_prob',
61 | priority_function_kwargs=dict(
62 | sampling_method='importance_sampling',
63 | decoder_distribution='gaussian_identity_variance',
64 | num_latents_to_sample=10,
65 | ),
66 | power=-0.5,
67 | relabeling_goal_sampling_mode='custom_goal_sampler',
68 | ),
69 | exploration_goal_sampling_mode='custom_goal_sampler',
70 | evaluation_goal_sampling_mode='presampled',
71 | training_mode='train',
72 | testing_mode='test',
73 | reward_params=dict(
74 | type='latent_distance',
75 | ),
76 | observation_key='latent_observation',
77 | desired_goal_key='latent_desired_goal',
78 | presampled_goals_path=osp.join(
79 | osp.dirname(mwmj.__file__),
80 | "goals",
81 | "door_goals.npy",
82 | ),
83 | presample_goals=True,
84 | vae_wrapped_env_kwargs=dict(
85 | sample_from_true_prior=True,
86 | ),
87 | ),
88 | train_vae_variant=dict(
89 | representation_size=16,
90 | beta=20,
91 | num_epochs=0,
92 | dump_skew_debug_plots=False,
93 | decoder_activation='gaussian',
94 | generate_vae_dataset_kwargs=dict(
95 | N=2,
96 | test_p=.9,
97 | use_cached=True,
98 | show=False,
99 | oracle_dataset=False,
100 | n_random_steps=1,
101 | non_presampled_goal_img_is_garbage=True,
102 | ),
103 | vae_kwargs=dict(
104 | decoder_distribution='gaussian_identity_variance',
105 | input_channels=3,
106 | architecture=imsize48_default_architecture,
107 | ),
108 | algo_kwargs=dict(
109 | lr=1e-3,
110 | ),
111 | save_period=1,
112 | ),
113 | )
114 |
115 | search_space = {
116 | }
117 | sweeper = hyp.DeterministicHyperparameterSweeper(
118 | search_space, default_parameters=variant,
119 | )
120 |
121 | n_seeds = 1
122 | mode = 'local'
123 | exp_prefix = 'dev-{}'.format(
124 | __file__.replace('/', '-').replace('_', '-').split('.')[0]
125 | )
126 |
127 | for exp_id, variant in enumerate(sweeper.iterate_hyperparameters()):
128 | for _ in range(n_seeds):
129 | run_experiment(
130 | skewfit_full_experiment,
131 | exp_prefix=exp_prefix,
132 | mode=mode,
133 | variant=variant,
134 | use_gpu=True,
135 | )
136 |
137 |
--------------------------------------------------------------------------------
/examples/her/her_td3_multiworld_sawyer_reach.py:
--------------------------------------------------------------------------------
1 | """
2 | This should results in an average return of ~3000 by the end of training.
3 |
4 | Usually hits 3000 around epoch 80-100. Within a see, the performance will be
5 | a bit noisy from one epoch to the next (occasionally dips dow to ~2000).
6 |
7 | Note that one epoch = 5k steps, so 200 epochs = 1 million steps.
8 | """
9 | import gym
10 |
11 | import rlkit.torch.pytorch_util as ptu
12 | from rlkit.data_management.obs_dict_replay_buffer import ObsDictRelabelingBuffer
13 | from rlkit.exploration_strategies.base import \
14 | PolicyWrappedWithExplorationStrategy
15 | from rlkit.exploration_strategies.gaussian_and_epsilon_strategy import (
16 | GaussianAndEpislonStrategy
17 | )
18 | from rlkit.launchers.launcher_util import setup_logger
19 | from rlkit.samplers.data_collector import GoalConditionedPathCollector
20 | from rlkit.torch.her.her import HERTrainer
21 | from rlkit.torch.networks import FlattenMlp, TanhMlpPolicy
22 | from rlkit.torch.td3.td3 import TD3Trainer
23 | from rlkit.torch.torch_rl_algorithm import TorchBatchRLAlgorithm
24 |
25 |
26 | def experiment(variant):
27 | import multiworld
28 | multiworld.register_all_envs()
29 | eval_env = gym.make('SawyerReachXYZEnv-v0')
30 | expl_env = gym.make('SawyerReachXYZEnv-v0')
31 | observation_key = 'state_observation'
32 | desired_goal_key = 'state_desired_goal'
33 | achieved_goal_key = desired_goal_key.replace("desired", "achieved")
34 | es = GaussianAndEpislonStrategy(
35 | action_space=expl_env.action_space,
36 | max_sigma=.2,
37 | min_sigma=.2, # constant sigma
38 | epsilon=.3,
39 | )
40 | obs_dim = expl_env.observation_space.spaces['observation'].low.size
41 | goal_dim = expl_env.observation_space.spaces['desired_goal'].low.size
42 | action_dim = expl_env.action_space.low.size
43 | qf1 = FlattenMlp(
44 | input_size=obs_dim + goal_dim + action_dim,
45 | output_size=1,
46 | **variant['qf_kwargs']
47 | )
48 | qf2 = FlattenMlp(
49 | input_size=obs_dim + goal_dim + action_dim,
50 | output_size=1,
51 | **variant['qf_kwargs']
52 | )
53 | target_qf1 = FlattenMlp(
54 | input_size=obs_dim + goal_dim + action_dim,
55 | output_size=1,
56 | **variant['qf_kwargs']
57 | )
58 | target_qf2 = FlattenMlp(
59 | input_size=obs_dim + goal_dim + action_dim,
60 | output_size=1,
61 | **variant['qf_kwargs']
62 | )
63 | policy = TanhMlpPolicy(
64 | input_size=obs_dim + goal_dim,
65 | output_size=action_dim,
66 | **variant['policy_kwargs']
67 | )
68 | target_policy = TanhMlpPolicy(
69 | input_size=obs_dim + goal_dim,
70 | output_size=action_dim,
71 | **variant['policy_kwargs']
72 | )
73 | expl_policy = PolicyWrappedWithExplorationStrategy(
74 | exploration_strategy=es,
75 | policy=policy,
76 | )
77 | replay_buffer = ObsDictRelabelingBuffer(
78 | env=eval_env,
79 | observation_key=observation_key,
80 | desired_goal_key=desired_goal_key,
81 | achieved_goal_key=achieved_goal_key,
82 | **variant['replay_buffer_kwargs']
83 | )
84 | trainer = TD3Trainer(
85 | policy=policy,
86 | qf1=qf1,
87 | qf2=qf2,
88 | target_qf1=target_qf1,
89 | target_qf2=target_qf2,
90 | target_policy=target_policy,
91 | **variant['trainer_kwargs']
92 | )
93 | trainer = HERTrainer(trainer)
94 | eval_path_collector = GoalConditionedPathCollector(
95 | eval_env,
96 | policy,
97 | observation_key=observation_key,
98 | desired_goal_key=desired_goal_key,
99 | )
100 | expl_path_collector = GoalConditionedPathCollector(
101 | expl_env,
102 | expl_policy,
103 | observation_key=observation_key,
104 | desired_goal_key=desired_goal_key,
105 | )
106 | algorithm = TorchBatchRLAlgorithm(
107 | trainer=trainer,
108 | exploration_env=expl_env,
109 | evaluation_env=eval_env,
110 | exploration_data_collector=expl_path_collector,
111 | evaluation_data_collector=eval_path_collector,
112 | replay_buffer=replay_buffer,
113 | **variant['algo_kwargs']
114 | )
115 | algorithm.to(ptu.device)
116 | algorithm.train()
117 |
118 |
119 | if __name__ == "__main__":
120 | variant = dict(
121 | algo_kwargs=dict(
122 | num_epochs=100,
123 | max_path_length=50,
124 | batch_size=128,
125 | num_eval_steps_per_epoch=1000,
126 | num_expl_steps_per_train_loop=1000,
127 | num_trains_per_train_loop=1000,
128 | ),
129 | trainer_kwargs=dict(
130 | discount=0.99,
131 | ),
132 | replay_buffer_kwargs=dict(
133 | max_size=100000,
134 | fraction_goals_rollout_goals=0.2,
135 | fraction_goals_env_goals=0.0,
136 | ),
137 | qf_kwargs=dict(
138 | hidden_sizes=[400, 300],
139 | ),
140 | policy_kwargs=dict(
141 | hidden_sizes=[400, 300],
142 | ),
143 | )
144 | setup_logger('her-td3-sawyer-experiment', variant=variant)
145 | experiment(variant)
146 |
--------------------------------------------------------------------------------
/rlkit/envs/assets/reacher_7dof.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 |
56 |
57 |
58 |
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 |
67 |
68 |
70 |
71 |
72 |
73 |
74 |
75 |
76 |
77 |
78 |
79 |
80 |
81 |
82 |
83 |
84 |
--------------------------------------------------------------------------------
/rlkit/envs/assets/low_gear_ratio_ant.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 |
56 |
57 |
58 |
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 |
67 |
68 |
69 |
70 |
71 |
72 |
73 |
74 |
75 |
76 |
77 |
78 |
79 |
80 |
81 |
82 |
83 |
84 |
85 |
--------------------------------------------------------------------------------
/rlkit/envs/mujoco_image_env.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import numpy as np
3 | import torch
4 | from PIL import Image
5 | from collections.__init__ import deque
6 |
7 | from gym import Env
8 | from gym.spaces import Box
9 |
10 | from rlkit.envs.wrappers import ProxyEnv
11 |
12 |
13 | class ImageMujocoEnv(ProxyEnv, Env):
14 | def __init__(self,
15 | wrapped_env,
16 | imsize=32,
17 | keep_prev=0,
18 | init_camera=None,
19 | camera_name=None,
20 | transpose=False,
21 | grayscale=False,
22 | normalize=False,
23 | ):
24 | import mujoco_py
25 | super().__init__(wrapped_env)
26 |
27 | self.imsize = imsize
28 | if grayscale:
29 | self.image_length = self.imsize * self.imsize
30 | else:
31 | self.image_length = 3 * self.imsize * self.imsize
32 | # This is torch format rather than PIL image
33 | self.image_shape = (self.imsize, self.imsize)
34 | # Flattened past image queue
35 | self.history_length = keep_prev + 1
36 | self.history = deque(maxlen=self.history_length)
37 | # init camera
38 | if init_camera is not None:
39 | sim = self._wrapped_env.sim
40 | viewer = mujoco_py.MjRenderContextOffscreen(sim, device_id=-1)
41 | init_camera(viewer.cam)
42 | sim.add_render_context(viewer)
43 | self.camera_name = camera_name # None means default camera
44 | self.transpose = transpose
45 | self.grayscale = grayscale
46 | self.normalize = normalize
47 | self._render_local = False
48 |
49 | self.observation_space = Box(low=0.0,
50 | high=1.0,
51 | shape=(
52 | self.image_length * self.history_length,))
53 |
54 | def step(self, action):
55 | # image observation get returned as a flattened 1D array
56 | true_state, reward, done, info = super().step(action)
57 |
58 | observation = self._image_observation()
59 | self.history.append(observation)
60 | history = self._get_history().flatten()
61 | full_obs = self._get_obs(history, true_state)
62 | return full_obs, reward, done, info
63 |
64 | def reset(self, **kwargs):
65 | true_state = super().reset(**kwargs)
66 | self.history = deque(maxlen=self.history_length)
67 |
68 | observation = self._image_observation()
69 | self.history.append(observation)
70 | history = self._get_history().flatten()
71 | full_obs = self._get_obs(history, true_state)
72 | return full_obs
73 |
74 | def get_image(self):
75 | """TODO: this should probably consider history"""
76 | return self._image_observation()
77 |
78 | def _get_obs(self, history_flat, true_state):
79 | # adds extra information from true_state into to the image observation.
80 | # Used in ImageWithObsEnv.
81 | return history_flat
82 |
83 | def _image_observation(self):
84 | # returns the image as a torch format np array
85 | image_obs = self._wrapped_env.sim.render(width=self.imsize,
86 | height=self.imsize,
87 | camera_name=self.camera_name)
88 | if self._render_local:
89 | cv2.imshow('env', image_obs)
90 | cv2.waitKey(1)
91 | if self.grayscale:
92 | image_obs = Image.fromarray(image_obs).convert('L')
93 | image_obs = np.array(image_obs)
94 | if self.normalize:
95 | image_obs = image_obs / 255.0
96 | if self.transpose:
97 | image_obs = image_obs.transpose()
98 | return image_obs
99 |
100 | def _get_history(self):
101 | observations = list(self.history)
102 |
103 | obs_count = len(observations)
104 | for _ in range(self.history_length - obs_count):
105 | dummy = np.zeros(self.image_shape)
106 | observations.append(dummy)
107 | return np.c_[observations]
108 |
109 | def retrieve_images(self):
110 | # returns images in unflattened PIL format
111 | images = []
112 | for image_obs in self.history:
113 | pil_image = self.torch_to_pil(torch.Tensor(image_obs))
114 | images.append(pil_image)
115 | return images
116 |
117 | def split_obs(self, obs):
118 | # splits observation into image input and true observation input
119 | imlength = self.image_length * self.history_length
120 | obs_length = self.observation_space.low.size
121 | obs = obs.view(-1, obs_length)
122 | image_obs = obs.narrow(start=0,
123 | length=imlength,
124 | dimension=1)
125 | if obs_length == imlength:
126 | return image_obs, None
127 |
128 | fc_obs = obs.narrow(start=imlength,
129 | length=obs.shape[1] - imlength,
130 | dimension=1)
131 | return image_obs, fc_obs
132 |
133 | def enable_render(self):
134 | self._render_local = True
135 |
136 |
137 | class ImageMujocoWithObsEnv(ImageMujocoEnv):
138 | def __init__(self, env, **kwargs):
139 | super().__init__(env, **kwargs)
140 | self.observation_space = Box(low=0.0,
141 | high=1.0,
142 | shape=(
143 | self.image_length * self.history_length +
144 | self.wrapped_env.obs_dim,))
145 |
146 | def _get_obs(self, history_flat, true_state):
147 | return np.concatenate([history_flat,
148 | true_state])
--------------------------------------------------------------------------------
/rlkit/envs/wrappers.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import itertools
3 | from gym import Env
4 | from gym.spaces import Box
5 | from gym.spaces import Discrete
6 |
7 | from collections import deque
8 |
9 |
10 | class ProxyEnv(Env):
11 | def __init__(self, wrapped_env):
12 | self._wrapped_env = wrapped_env
13 | self.action_space = self._wrapped_env.action_space
14 | self.observation_space = self._wrapped_env.observation_space
15 |
16 | @property
17 | def wrapped_env(self):
18 | return self._wrapped_env
19 |
20 | def reset(self, **kwargs):
21 | return self._wrapped_env.reset(**kwargs)
22 |
23 | def step(self, action):
24 | return self._wrapped_env.step(action)
25 |
26 | def render(self, *args, **kwargs):
27 | return self._wrapped_env.render(*args, **kwargs)
28 |
29 | @property
30 | def horizon(self):
31 | return self._wrapped_env.horizon
32 |
33 | def terminate(self):
34 | if hasattr(self.wrapped_env, "terminate"):
35 | self.wrapped_env.terminate()
36 |
37 | def __getattr__(self, attr):
38 | if attr == '_wrapped_env':
39 | raise AttributeError()
40 | return getattr(self._wrapped_env, attr)
41 |
42 | def __getstate__(self):
43 | """
44 | This is useful to override in case the wrapped env has some funky
45 | __getstate__ that doesn't play well with overriding __getattr__.
46 |
47 | The main problematic case is/was gym's EzPickle serialization scheme.
48 | :return:
49 | """
50 | return self.__dict__
51 |
52 | def __setstate__(self, state):
53 | self.__dict__.update(state)
54 |
55 | def __str__(self):
56 | return '{}({})'.format(type(self).__name__, self.wrapped_env)
57 |
58 |
59 | class HistoryEnv(ProxyEnv, Env):
60 | def __init__(self, wrapped_env, history_len):
61 | super().__init__(wrapped_env)
62 | self.history_len = history_len
63 |
64 | high = np.inf * np.ones(
65 | self.history_len * self.observation_space.low.size)
66 | low = -high
67 | self.observation_space = Box(low=low,
68 | high=high,
69 | )
70 | self.history = deque(maxlen=self.history_len)
71 |
72 | def step(self, action):
73 | state, reward, done, info = super().step(action)
74 | self.history.append(state)
75 | flattened_history = self._get_history().flatten()
76 | return flattened_history, reward, done, info
77 |
78 | def reset(self, **kwargs):
79 | state = super().reset()
80 | self.history = deque(maxlen=self.history_len)
81 | self.history.append(state)
82 | flattened_history = self._get_history().flatten()
83 | return flattened_history
84 |
85 | def _get_history(self):
86 | observations = list(self.history)
87 |
88 | obs_count = len(observations)
89 | for _ in range(self.history_len - obs_count):
90 | dummy = np.zeros(self._wrapped_env.observation_space.low.size)
91 | observations.append(dummy)
92 | return np.c_[observations]
93 |
94 |
95 | class DiscretizeEnv(ProxyEnv, Env):
96 | def __init__(self, wrapped_env, num_bins):
97 | super().__init__(wrapped_env)
98 | low = self.wrapped_env.action_space.low
99 | high = self.wrapped_env.action_space.high
100 | action_ranges = [
101 | np.linspace(low[i], high[i], num_bins)
102 | for i in range(len(low))
103 | ]
104 | self.idx_to_continuous_action = [
105 | np.array(x) for x in itertools.product(*action_ranges)
106 | ]
107 | self.action_space = Discrete(len(self.idx_to_continuous_action))
108 |
109 | def step(self, action):
110 | continuous_action = self.idx_to_continuous_action[action]
111 | return super().step(continuous_action)
112 |
113 |
114 | class NormalizedBoxEnv(ProxyEnv):
115 | """
116 | Normalize action to in [-1, 1].
117 |
118 | Optionally normalize observations and scale reward.
119 | """
120 |
121 | def __init__(
122 | self,
123 | env,
124 | reward_scale=1.,
125 | obs_mean=None,
126 | obs_std=None,
127 | ):
128 | ProxyEnv.__init__(self, env)
129 | self._should_normalize = not (obs_mean is None and obs_std is None)
130 | if self._should_normalize:
131 | if obs_mean is None:
132 | obs_mean = np.zeros_like(env.observation_space.low)
133 | else:
134 | obs_mean = np.array(obs_mean)
135 | if obs_std is None:
136 | obs_std = np.ones_like(env.observation_space.low)
137 | else:
138 | obs_std = np.array(obs_std)
139 | self._reward_scale = reward_scale
140 | self._obs_mean = obs_mean
141 | self._obs_std = obs_std
142 | ub = np.ones(self._wrapped_env.action_space.shape)
143 | self.action_space = Box(-1 * ub, ub)
144 |
145 | def estimate_obs_stats(self, obs_batch, override_values=False):
146 | if self._obs_mean is not None and not override_values:
147 | raise Exception("Observation mean and std already set. To "
148 | "override, set override_values to True.")
149 | self._obs_mean = np.mean(obs_batch, axis=0)
150 | self._obs_std = np.std(obs_batch, axis=0)
151 |
152 | def _apply_normalize_obs(self, obs):
153 | return (obs - self._obs_mean) / (self._obs_std + 1e-8)
154 |
155 | def step(self, action):
156 | lb = self._wrapped_env.action_space.low
157 | ub = self._wrapped_env.action_space.high
158 | scaled_action = lb + (action + 1.) * 0.5 * (ub - lb)
159 | scaled_action = np.clip(scaled_action, lb, ub)
160 |
161 | wrapped_step = self._wrapped_env.step(scaled_action)
162 | next_obs, reward, done, info = wrapped_step
163 | if self._should_normalize:
164 | next_obs = self._apply_normalize_obs(next_obs)
165 | return next_obs, reward * self._reward_scale, done, info
166 |
167 | def __str__(self):
168 | return "Normalized: %s" % self._wrapped_env
169 |
170 |
--------------------------------------------------------------------------------
/examples/skewfit/sawyer_push.py:
--------------------------------------------------------------------------------
1 | import rlkit.util.hyperparameter as hyp
2 | from multiworld.envs.mujoco.cameras import sawyer_init_camera_zoomed_in
3 | from rlkit.launchers.launcher_util import run_experiment
4 | import rlkit.torch.vae.vae_schedules as vae_schedules
5 | from rlkit.launchers.skewfit_experiments import skewfit_full_experiment
6 | from rlkit.torch.vae.conv_vae import imsize48_default_architecture
7 |
8 |
9 | if __name__ == "__main__":
10 | variant = dict(
11 | algorithm='Skew-Fit',
12 | double_algo=False,
13 | online_vae_exploration=False,
14 | imsize=48,
15 | init_camera=sawyer_init_camera_zoomed_in,
16 | env_id='SawyerPushNIPSEasy-v0',
17 | skewfit_variant=dict(
18 | save_video=True,
19 | custom_goal_sampler='replay_buffer',
20 | online_vae_trainer_kwargs=dict(
21 | beta=20,
22 | lr=1e-3,
23 | ),
24 | save_video_period=100,
25 | qf_kwargs=dict(
26 | hidden_sizes=[400, 300],
27 | ),
28 | policy_kwargs=dict(
29 | hidden_sizes=[400, 300],
30 | ),
31 | vf_kwargs=dict(
32 | hidden_sizes=[400, 300],
33 | ),
34 | max_path_length=50,
35 | algo_kwargs=dict(
36 | batch_size=1024,
37 | num_epochs=1000,
38 | num_eval_steps_per_epoch=500,
39 | num_expl_steps_per_train_loop=500,
40 | num_trains_per_train_loop=1000,
41 | min_num_steps_before_training=10000,
42 | vae_training_schedule=vae_schedules.custom_schedule_2,
43 | oracle_data=False,
44 | vae_save_period=50,
45 | parallel_vae_train=False,
46 | ),
47 | twin_sac_trainer_kwargs=dict(
48 | discount=0.99,
49 | reward_scale=1,
50 | soft_target_tau=1e-3,
51 | target_update_period=1, # 1
52 | use_automatic_entropy_tuning=True,
53 | ),
54 | replay_buffer_kwargs=dict(
55 | start_skew_epoch=10,
56 | max_size=int(100000),
57 | fraction_goals_rollout_goals=0.2,
58 | fraction_goals_env_goals=0.5,
59 | exploration_rewards_type='None',
60 | vae_priority_type='vae_prob',
61 | priority_function_kwargs=dict(
62 | sampling_method='importance_sampling',
63 | decoder_distribution='gaussian_identity_variance',
64 | num_latents_to_sample=10,
65 | ),
66 | power=-1,
67 | relabeling_goal_sampling_mode='vae_prior',
68 | ),
69 | exploration_goal_sampling_mode='vae_prior',
70 | evaluation_goal_sampling_mode='reset_of_env',
71 | normalize=False,
72 | render=False,
73 | exploration_noise=0.0,
74 | exploration_type='ou',
75 | training_mode='train',
76 | testing_mode='test',
77 | reward_params=dict(
78 | type='latent_distance',
79 | ),
80 | observation_key='latent_observation',
81 | desired_goal_key='latent_desired_goal',
82 | vae_wrapped_env_kwargs=dict(
83 | sample_from_true_prior=True,
84 | ),
85 | ),
86 | train_vae_variant=dict(
87 | representation_size=4,
88 | beta=20,
89 | num_epochs=0,
90 | dump_skew_debug_plots=False,
91 | decoder_activation='gaussian',
92 | generate_vae_dataset_kwargs=dict(
93 | N=40,
94 | test_p=.9,
95 | use_cached=False,
96 | show=False,
97 | oracle_dataset=True,
98 | oracle_dataset_using_set_to_goal=True,
99 | n_random_steps=100,
100 | non_presampled_goal_img_is_garbage=True,
101 | ),
102 | vae_kwargs=dict(
103 | input_channels=3,
104 | architecture=imsize48_default_architecture,
105 | decoder_distribution='gaussian_identity_variance',
106 | ),
107 | # TODO: why the redundancy?
108 | algo_kwargs=dict(
109 | start_skew_epoch=5000,
110 | is_auto_encoder=False,
111 | batch_size=64,
112 | lr=1e-3,
113 | skew_config=dict(
114 | method='vae_prob',
115 | power=-1,
116 | ),
117 | skew_dataset=True,
118 | priority_function_kwargs=dict(
119 | decoder_distribution='gaussian_identity_variance',
120 | sampling_method='importance_sampling',
121 | num_latents_to_sample=10,
122 | ),
123 | use_parallel_dataloading=False,
124 | ),
125 |
126 | save_period=25,
127 | ),
128 | )
129 | search_space = {}
130 | sweeper = hyp.DeterministicHyperparameterSweeper(
131 | search_space, default_parameters=variant,
132 | )
133 |
134 | n_seeds = 1
135 | mode = 'local'
136 | exp_prefix = 'dev-{}'.format(
137 | __file__.replace('/', '-').replace('_', '-').split('.')[0]
138 | )
139 |
140 | n_seeds = 3
141 | mode = 'ec2'
142 | exp_prefix = 'rlkit-skew-fit-pusher-reference-sample-from-true-prior-take2'
143 |
144 | for exp_id, variant in enumerate(sweeper.iterate_hyperparameters()):
145 | for _ in range(n_seeds):
146 | run_experiment(
147 | skewfit_full_experiment,
148 | exp_prefix=exp_prefix,
149 | mode=mode,
150 | variant=variant,
151 | use_gpu=True,
152 | num_exps_per_instance=3,
153 | gcp_kwargs=dict(
154 | terminate=True,
155 | zone='us-east1-c',
156 | gpu_kwargs=dict(
157 | gpu_model='nvidia-tesla-k80',
158 | num_gpu=1,
159 | )
160 | )
161 | )
162 |
--------------------------------------------------------------------------------
/rlkit/samplers/data_collector/path_collector.py:
--------------------------------------------------------------------------------
1 | from collections import deque, OrderedDict
2 |
3 | from rlkit.core.eval_util import create_stats_ordered_dict
4 | from rlkit.samplers.rollout_functions import rollout, multitask_rollout
5 | from rlkit.samplers.data_collector.base import PathCollector
6 |
7 |
8 | class MdpPathCollector(PathCollector):
9 | def __init__(
10 | self,
11 | env,
12 | policy,
13 | max_num_epoch_paths_saved=None,
14 | render=False,
15 | render_kwargs=None,
16 | ):
17 | if render_kwargs is None:
18 | render_kwargs = {}
19 | self._env = env
20 | self._policy = policy
21 | self._max_num_epoch_paths_saved = max_num_epoch_paths_saved
22 | self._epoch_paths = deque(maxlen=self._max_num_epoch_paths_saved)
23 | self._render = render
24 | self._render_kwargs = render_kwargs
25 |
26 | self._num_steps_total = 0
27 | self._num_paths_total = 0
28 |
29 | def collect_new_paths(
30 | self,
31 | max_path_length,
32 | num_steps,
33 | discard_incomplete_paths,
34 | ):
35 | paths = []
36 | num_steps_collected = 0
37 | while num_steps_collected < num_steps:
38 | max_path_length_this_loop = min( # Do not go over num_steps
39 | max_path_length,
40 | num_steps - num_steps_collected,
41 | )
42 | path = rollout(
43 | self._env,
44 | self._policy,
45 | max_path_length=max_path_length_this_loop,
46 | )
47 | path_len = len(path['actions'])
48 | if (
49 | path_len != max_path_length
50 | and not path['terminals'][-1]
51 | and discard_incomplete_paths
52 | ):
53 | break
54 | num_steps_collected += path_len
55 | paths.append(path)
56 | self._num_paths_total += len(paths)
57 | self._num_steps_total += num_steps_collected
58 | self._epoch_paths.extend(paths)
59 | return paths
60 |
61 | def get_epoch_paths(self):
62 | return self._epoch_paths
63 |
64 | def end_epoch(self, epoch):
65 | self._epoch_paths = deque(maxlen=self._max_num_epoch_paths_saved)
66 |
67 | def get_diagnostics(self):
68 | path_lens = [len(path['actions']) for path in self._epoch_paths]
69 | stats = OrderedDict([
70 | ('num steps total', self._num_steps_total),
71 | ('num paths total', self._num_paths_total),
72 | ])
73 | stats.update(create_stats_ordered_dict(
74 | "path length",
75 | path_lens,
76 | always_show_all_stats=True,
77 | ))
78 | return stats
79 |
80 | def get_snapshot(self):
81 | return dict(
82 | env=self._env,
83 | policy=self._policy,
84 | )
85 |
86 |
87 | class GoalConditionedPathCollector(PathCollector):
88 | def __init__(
89 | self,
90 | env,
91 | policy,
92 | max_num_epoch_paths_saved=None,
93 | render=False,
94 | render_kwargs=None,
95 | observation_key='observation',
96 | desired_goal_key='desired_goal',
97 | ):
98 | if render_kwargs is None:
99 | render_kwargs = {}
100 | self._env = env
101 | self._policy = policy
102 | self._max_num_epoch_paths_saved = max_num_epoch_paths_saved
103 | self._render = render
104 | self._render_kwargs = render_kwargs
105 | self._epoch_paths = deque(maxlen=self._max_num_epoch_paths_saved)
106 | self._observation_key = observation_key
107 | self._desired_goal_key = desired_goal_key
108 |
109 | self._num_steps_total = 0
110 | self._num_paths_total = 0
111 |
112 | def collect_new_paths(
113 | self,
114 | max_path_length,
115 | num_steps,
116 | discard_incomplete_paths,
117 | ):
118 | paths = []
119 | num_steps_collected = 0
120 | while num_steps_collected < num_steps:
121 | max_path_length_this_loop = min( # Do not go over num_steps
122 | max_path_length,
123 | num_steps - num_steps_collected,
124 | )
125 | path = multitask_rollout(
126 | self._env,
127 | self._policy,
128 | max_path_length=max_path_length_this_loop,
129 | render=self._render,
130 | render_kwargs=self._render_kwargs,
131 | observation_key=self._observation_key,
132 | desired_goal_key=self._desired_goal_key,
133 | return_dict_obs=True,
134 | )
135 | path_len = len(path['actions'])
136 | if (
137 | path_len != max_path_length
138 | and not path['terminals'][-1]
139 | and discard_incomplete_paths
140 | ):
141 | break
142 | num_steps_collected += path_len
143 | paths.append(path)
144 | self._num_paths_total += len(paths)
145 | self._num_steps_total += num_steps_collected
146 | self._epoch_paths.extend(paths)
147 | return paths
148 |
149 | def get_epoch_paths(self):
150 | return self._epoch_paths
151 |
152 | def end_epoch(self, epoch):
153 | self._epoch_paths = deque(maxlen=self._max_num_epoch_paths_saved)
154 |
155 | def get_diagnostics(self):
156 | path_lens = [len(path['actions']) for path in self._epoch_paths]
157 | stats = OrderedDict([
158 | ('num steps total', self._num_steps_total),
159 | ('num paths total', self._num_paths_total),
160 | ])
161 | stats.update(create_stats_ordered_dict(
162 | "path length",
163 | path_lens,
164 | always_show_all_stats=True,
165 | ))
166 | return stats
167 |
168 | def get_snapshot(self):
169 | return dict(
170 | env=self._env,
171 | policy=self._policy,
172 | observation_key=self._observation_key,
173 | desired_goal_key=self._desired_goal_key,
174 | )
175 |
--------------------------------------------------------------------------------