├── rlkit
├── __init__.py
├── demos
│ ├── __init__.py
│ ├── source
│ │ ├── __init__.py
│ │ ├── demo_source.py
│ │ ├── path_loader.py
│ │ └── hand_demo_source.py
│ ├── spacemouse
│ │ ├── __init__.py
│ │ ├── input_client.py
│ │ └── README.md
│ ├── her_td3bc.py
│ ├── play_demo.py
│ └── her_bc.py
├── torch
│ ├── __init__.py
│ ├── sac
│ │ ├── __init__.py
│ │ └── policies
│ │ │ └── base.py
│ ├── data_management
│ │ ├── __init__.py
│ │ └── normalizer.py
│ ├── networks
│ │ ├── stochastic
│ │ │ └── __init__.py
│ │ ├── custom.py
│ │ ├── linear_transform.py
│ │ ├── ae_tanh_policy.py
│ │ ├── __init__.py
│ │ ├── image_state.py
│ │ ├── two_headed_mlp.py
│ │ └── experimental.py
│ ├── transforms
│ │ ├── __init__.py
│ │ └── _pil_constants.py
│ ├── her
│ │ └── her.py
│ └── core.py
├── utils
│ ├── __init__.py
│ ├── val_util.py
│ ├── experiment_util.py
│ ├── device_util.py
│ ├── logging.config
│ ├── logging.py
│ ├── timer.py
│ ├── path_builder.py
│ └── image_util.py
├── learning
│ ├── __init__.py
│ └── online_offline_split_replay_buffer.py
├── networks
│ ├── __init__.py
│ ├── utils.py
│ └── gaussian_policy.py
├── samplers
│ ├── __init__.py
│ ├── data_collector
│ │ ├── __init__.py
│ │ ├── base.py
│ │ ├── contextual_path_collector.py
│ │ └── joint_path_collector.py
│ ├── in_place.py
│ └── util.py
├── data_management
│ ├── __init__.py
│ ├── external
│ │ ├── __init__.py
│ │ └── bair_dataset
│ │ │ └── __init__.py
│ ├── wrappers
│ │ ├── concat_to_obs_wrapper.py
│ │ ├── proxy_buffer.py
│ │ └── replay_buffer_wrapper.py
│ ├── path_builder.py
│ ├── dataset_logger_fn.py
│ ├── images.py
│ ├── tau_replay_buffer.py
│ ├── ocm_subtraj_replay_buffer.py
│ ├── split_buffer.py
│ └── replay_buffer.py
├── envs
│ ├── memory
│ │ ├── __init__.py
│ │ └── hidden_cartpole.py
│ ├── mujoco
│ │ ├── __init__.py
│ │ ├── point_env.py
│ │ ├── twod_point_random_init.py
│ │ ├── reacher_env.py
│ │ ├── twod_point.py
│ │ ├── twod_maze.py
│ │ ├── oned_point.py
│ │ ├── hopper_env.py
│ │ ├── ant.py
│ │ ├── mujoco_env_murtaza.py
│ │ ├── mujoco_env.py
│ │ └── pusher.py
│ ├── pygame
│ │ └── __init__.py
│ ├── simple
│ │ ├── __init__.py
│ │ └── point.py
│ ├── pearl_envs
│ │ ├── rand_param_envs
│ │ │ ├── __init__.py
│ │ │ ├── hopper_rand_params.py
│ │ │ ├── walker2d_rand_params.py
│ │ │ └── pr2_env_reach.py
│ │ ├── hopper_rand_params_wrapper.py
│ │ ├── walker_rand_params_wrapper.py
│ │ ├── ant_normal.py
│ │ ├── half_cheetah.py
│ │ ├── ant_multitask_base.py
│ │ ├── ant_goal.py
│ │ ├── ant_dir.py
│ │ ├── mujoco_env.py
│ │ ├── __init__.py
│ │ ├── ant.py
│ │ ├── humanoid_dir.py
│ │ ├── half_cheetah_vel.py
│ │ └── half_cheetah_dir.py
│ ├── base.py
│ ├── contextual
│ │ ├── __init__.py
│ │ └── task_conditioned.py
│ ├── sawyer_parallel_hack.py
│ ├── images
│ │ ├── __init__.py
│ │ ├── text_renderer.py
│ │ ├── env_renderer.py
│ │ └── insert_image_env.py
│ ├── __init__.py
│ ├── wrappers
│ │ ├── image_mujoco_env_with_obs.py
│ │ ├── reward_wrapper_env.py
│ │ ├── discretize_env.py
│ │ ├── __init__.py
│ │ ├── modular_env.py
│ │ ├── flat_to_dict.py
│ │ ├── stack_observation_env.py
│ │ ├── history_env.py
│ │ └── normalized_box_env.py
│ ├── assets
│ │ ├── oned_point.xml
│ │ ├── twod_point.xml
│ │ ├── twod_point_random_init.xml
│ │ ├── twod_maze.xml
│ │ ├── water_maze.xml
│ │ └── small_water_maze.xml
│ ├── gridcraft
│ │ ├── utils.py
│ │ └── custom_test.py
│ ├── env_utils.py
│ ├── time_limited_env.py
│ ├── supervised_learning_env.py
│ ├── robosuite_wrapper.py
│ ├── gripper_state_wrapper.py
│ ├── proxy_env.py
│ ├── make_env.py
│ └── dual_encoder_wrapper.py
├── visualization
│ ├── __init__.py
│ └── plotter.py
├── core
│ ├── __init__.py
│ ├── trainer.py
│ ├── loss.py
│ ├── distribution.py
│ ├── util.py
│ ├── ray_csv_logger.py
│ ├── timer.py
│ └── serializeable.py
├── util
│ ├── gym_util.py
│ ├── random_util.py
│ ├── __init__.py
│ ├── ray_util.py
│ ├── inspect_q_util.py
│ └── tensorboard_logger.py
├── launchers
│ ├── __init__.py
│ ├── contextual
│ │ └── util.py
│ └── doodad_wrapper.py
├── policies
│ ├── base.py
│ └── action_repeat.py
└── exploration_strategies
│ ├── noop.py
│ ├── epsilon_greedy.py
│ ├── gaussian_strategy.py
│ ├── gaussian_and_epsilon.py
│ ├── ou_strategy.py
│ └── base.py
├── stable_contrastive_rl.png
└── LICENSE
/rlkit/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/rlkit/demos/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/rlkit/torch/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/rlkit/utils/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/rlkit/learning/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/rlkit/networks/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/rlkit/samplers/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/rlkit/torch/sac/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/rlkit/data_management/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/rlkit/demos/source/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/rlkit/envs/memory/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/rlkit/envs/mujoco/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/rlkit/envs/pygame/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/rlkit/envs/simple/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/rlkit/visualization/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/rlkit/demos/spacemouse/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/rlkit/data_management/external/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/rlkit/torch/data_management/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/rlkit/torch/networks/stochastic/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/rlkit/envs/pearl_envs/rand_param_envs/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/rlkit/data_management/external/bair_dataset/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/rlkit/torch/transforms/__init__.py:
--------------------------------------------------------------------------------
1 | from .transforms import *
--------------------------------------------------------------------------------
/rlkit/torch/networks/custom.py:
--------------------------------------------------------------------------------
1 | """
2 | Random networks
3 | """
4 |
--------------------------------------------------------------------------------
/stable_contrastive_rl.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chongyi-zheng/stable_contrastive_rl/HEAD/stable_contrastive_rl.png
--------------------------------------------------------------------------------
/rlkit/demos/source/demo_source.py:
--------------------------------------------------------------------------------
1 | class DemoSource:
2 | def load_paths(self):
3 | """Should return a list of paths in PathBuilder format"""
4 | return [{}, ]
5 |
--------------------------------------------------------------------------------
/rlkit/core/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | General classes, functions, utilities that are used throughout railrl.
3 | """
4 | from rlkit.core.logging import logger, setup_logger
5 |
6 | __all__ = ["logger", "setup_logger"]
7 |
--------------------------------------------------------------------------------
/rlkit/envs/base.py:
--------------------------------------------------------------------------------
1 | import abc
2 |
3 |
4 | class RolloutEnv(object):
5 | """ Environment that supports full rollouts."""
6 |
7 | @abc.abstractmethod
8 | def rollout(self, *args, **kwargs):
9 | pass
10 |
--------------------------------------------------------------------------------
/rlkit/util/gym_util.py:
--------------------------------------------------------------------------------
1 | from gym.envs import registration
2 |
3 |
4 | def get_class_and_kwargs(spec_or_id):
5 | if isinstance(spec_or_id, registration.EnvSpec):
6 | spec = spec_or_id
7 | else:
8 | spec = registration.spec(spec_or_id)
9 | return registration.load(spec._entry_point), spec._kwargs
--------------------------------------------------------------------------------
/rlkit/envs/contextual/__init__.py:
--------------------------------------------------------------------------------
1 | from rlkit.envs.contextual.contextual_env import (
2 | ContextualEnv,
3 | ContextualRewardFn,
4 | delete_info,
5 | insert_reward,
6 | )
7 |
8 | __all__ = [
9 | 'ContextualEnv',
10 | 'ContextualRewardFn',
11 | 'delete_info',
12 | 'insert_reward',
13 | ]
--------------------------------------------------------------------------------
/rlkit/torch/networks/linear_transform.py:
--------------------------------------------------------------------------------
1 | from rlkit.torch.core import PyTorchModule
2 |
3 |
4 | class LinearTransform(PyTorchModule):
5 | def __init__(self, m, b):
6 | super().__init__()
7 | self.m = m
8 | self.b = b
9 |
10 | def __call__(self, t):
11 | return self.m * t + self.b
12 |
--------------------------------------------------------------------------------
/rlkit/utils/val_util.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 |
4 | def preprocess_image(data):
5 | _shape = list(data.shape[:-3])
6 | data = np.reshape(data, [-1, 3, 48, 48])
7 | data = np.transpose(data, [0, 1, 3, 2])
8 | data = np.reshape(data, _shape + [3, 48, 48])
9 | data = data - 0.5
10 | return data
11 |
--------------------------------------------------------------------------------
/rlkit/envs/sawyer_parallel_hack.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import warnings
3 |
4 | idx = 0
5 | while idx < len(sys.path):
6 | if 'sawyer_control' in sys.path[idx]:
7 | warnings.warn("Undoing ros generated __init__ for parallel until a better fix is found")
8 | del sys.path[idx]
9 | else:
10 | idx += 1
11 |
12 |
13 |
--------------------------------------------------------------------------------
/rlkit/core/trainer.py:
--------------------------------------------------------------------------------
1 | import abc
2 |
3 |
4 | class Trainer(object, metaclass=abc.ABCMeta):
5 | @abc.abstractmethod
6 | def train(self, data):
7 | pass
8 |
9 | def end_epoch(self, epoch):
10 | pass
11 |
12 | def get_snapshot(self):
13 | return {}
14 |
15 | def get_diagnostics(self):
16 | return {}
17 |
--------------------------------------------------------------------------------
/rlkit/launchers/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | This module contains 'launchers', which are self-contained functions that take
3 | one dictionary and run a full experiment. The dictionary configures the
4 | experiment.
5 |
6 | It is important that the functions are completely self-contained (i.e. they
7 | import their own modules) so that they can be serialized.
8 | """
9 |
--------------------------------------------------------------------------------
/rlkit/util/random_util.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 |
4 | def random_point_in_circle(angle_range=(0, 2*np.pi), radius=(0, 25)):
5 | angle = np.random.uniform(*angle_range)
6 | radius = radius if np.isscalar(radius) else np.random.uniform(*radius)
7 | x, y = np.cos(angle) * radius, np.sin(angle) * radius
8 | point = np.array([x, y])
9 | return point
10 |
--------------------------------------------------------------------------------
/rlkit/utils/experiment_util.py:
--------------------------------------------------------------------------------
1 | # import sys
2 | # import os
3 | # import stat
4 |
5 |
6 | from rlkit.launchers import launcher_util as lu
7 |
8 |
9 | def run_variant(experiment, variant):
10 | launcher_config = variant.get("launcher_config")
11 | lu.run_experiment(
12 | experiment,
13 | variant=variant,
14 | **launcher_config,
15 | )
16 |
--------------------------------------------------------------------------------
/rlkit/envs/images/__init__.py:
--------------------------------------------------------------------------------
1 | from rlkit.envs.images.renderer import Renderer
2 | from rlkit.envs.images.env_renderer import EnvRenderer, GymEnvRenderer
3 | from rlkit.envs.images.insert_image_env import InsertImageEnv, InsertImagesEnv
4 |
5 | __all__ = [
6 | 'Renderer',
7 | 'EnvRenderer',
8 | 'GymEnvRenderer',
9 | 'InsertImageEnv',
10 | 'InsertImagesEnv',
11 | ]
12 |
13 |
14 |
--------------------------------------------------------------------------------
/rlkit/util/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | Modules in here are self-contained modules, that depend only core libraries
3 | like numpy and pythonplusplus. However, they should NOT depend on things that
4 | are specific to railrl or rllab.
5 |
6 | The only exception is when an external dependency is explicit in the name of
7 | the module, e.g. rllab_util can depend on rllab. hyperopt can depend on
8 | hyperopt, etc.
9 | """
--------------------------------------------------------------------------------
/rlkit/utils/device_util.py:
--------------------------------------------------------------------------------
1 | from absl import logging # NOQA
2 |
3 | import torch
4 |
5 | import rlkit.torch.pytorch_util as ptu
6 |
7 |
8 | def set_device(use_gpu, gpu_id=None):
9 | if use_gpu:
10 | assert torch.cuda.is_available()
11 |
12 | if gpu_id is None:
13 | gpu_id = 0
14 |
15 | else:
16 | gpu_id = -1
17 |
18 | ptu.set_gpu_mode(mode=use_gpu, gpu_id=gpu_id)
19 | logging.info('Device: %r', ptu.device)
20 |
--------------------------------------------------------------------------------
/rlkit/samplers/data_collector/__init__.py:
--------------------------------------------------------------------------------
1 | from rlkit.samplers.data_collector.base import (
2 | DataCollector,
3 | PathCollector,
4 | StepCollector,
5 | )
6 | from rlkit.samplers.data_collector.path_collector import (
7 | MdpPathCollector,
8 | ObsDictPathCollector,
9 | GoalConditionedPathCollector,
10 | VAEWrappedEnvPathCollector,
11 | )
12 | from rlkit.samplers.data_collector.step_collector import (
13 | GoalConditionedStepCollector
14 | )
15 |
--------------------------------------------------------------------------------
/rlkit/utils/logging.config:
--------------------------------------------------------------------------------
1 | [loggers]
2 | keys=root
3 |
4 | [handlers]
5 | keys=consoleHandler
6 |
7 | [formatters]
8 | keys=simpleFormatter
9 |
10 | [logger_root]
11 | level=DEBUG
12 | handlers=consoleHandler
13 |
14 | [handler_consoleHandler]
15 | class=StreamHandler
16 | level=DEBUG
17 | formatter=simpleFormatter
18 | args=(sys.stdout,)
19 |
20 | [formatter_simpleFormatter]
21 | format=[%(asctime)s.%(msecs)03d %(filename)s:%(lineno)d %(levelname)s] %(message)s
22 | datefmt=%m-%d %H:%M:%S
23 |
--------------------------------------------------------------------------------
/rlkit/policies/base.py:
--------------------------------------------------------------------------------
1 | import abc
2 |
3 | class Policy(object, metaclass=abc.ABCMeta):
4 | """
5 | General policy interface.
6 | """
7 | @abc.abstractmethod
8 | def get_action(self, observation):
9 | """
10 |
11 | :param observation:
12 | :return: action, debug_dictionary
13 | """
14 | pass
15 |
16 | def reset(self):
17 | pass
18 |
19 |
20 | class ExplorationPolicy(Policy, metaclass=abc.ABCMeta):
21 | def set_num_steps_total(self, t):
22 | pass
23 |
--------------------------------------------------------------------------------
/rlkit/utils/logging.py:
--------------------------------------------------------------------------------
1 | """Logging utilites.
2 | """
3 |
4 | from __future__ import absolute_import
5 | from __future__ import division
6 | from __future__ import print_function
7 |
8 | import sys # NOQA
9 | import os.path
10 | import logging
11 | import logging.config
12 |
13 |
14 | try:
15 | # if 'absl' not in sys.modules:
16 | config_path = os.path.join(os.path.dirname(__file__), 'logging.config')
17 | logging.config.fileConfig(config_path)
18 | except Exception:
19 | print('Unable to set the formatters for logging.')
20 |
21 |
22 | logger = logging.getLogger('root')
23 |
--------------------------------------------------------------------------------
/rlkit/core/loss.py:
--------------------------------------------------------------------------------
1 | import abc
2 | from collections import OrderedDict
3 |
4 |
5 | LossStatistics = OrderedDict
6 |
7 |
8 | class LossFunction(object, metaclass=abc.ABCMeta):
9 | @abc.abstractmethod
10 | def compute_loss(self, batch, skip_statistics=False, **kwargs):
11 | """Returns loss and statistics given a batch of data.
12 | batch : Data to compute loss of
13 | skip_statistics: Whether statistics should be calculated. If True, then
14 | an empty dict is returned for the statistics.
15 |
16 | Returns: (loss, stats) tuple.
17 | """
18 | pass
19 |
--------------------------------------------------------------------------------
/rlkit/envs/__init__.py:
--------------------------------------------------------------------------------
1 | # from envs.env_utils import register_environments
2 | from gym.envs.registration import register
3 |
4 | register(id='OneDPoint-v0', entry_point='railrl.envs.oned_point:OneDPoint')
5 | register(id='TwoDPoint-v0', entry_point='railrl.envs.twod_point:TwoDPoint')
6 | register(id='TwoDPointRandomInit-v0',
7 | entry_point='railrl.envs.twod_point_random_init:TwoDPointRandomInit')
8 | register(id='TwoDMaze-v0', entry_point='railrl.envs.twod_maze:TwoDMaze')
9 | register(id='WaterMaze-v0', entry_point='railrl.envs.water_maze:WaterMaze')
10 | register(id='Sawyer-v0', entry_point='railrl.envs.sawyer_env:SawyerEnv')
11 |
--------------------------------------------------------------------------------
/rlkit/envs/wrappers/image_mujoco_env_with_obs.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from gym.spaces import Box
3 |
4 | from rlkit.envs.wrappers.image_mujoco_env import ImageMujocoEnv
5 |
6 |
7 | class ImageMujocoWithObsEnv(ImageMujocoEnv):
8 | def __init__(self, env, **kwargs):
9 | super().__init__(env, **kwargs)
10 | self.observation_space = Box(
11 | low=0.0,
12 | high=1.0,
13 | shape=(self.image_length * self.history_length
14 | + self.wrapped_env.obs_dim,))
15 |
16 | def _get_obs(self, history_flat, true_state):
17 | return np.concatenate([history_flat, true_state])
18 |
--------------------------------------------------------------------------------
/rlkit/util/ray_util.py:
--------------------------------------------------------------------------------
1 | from ray import serialization, utils
2 |
3 |
4 | def set_serialization_mode_to_pickle(cls):
5 | """
6 | Whenever this class is serialized by ray, it will default to using pickle
7 | serialization (__setstate__ and __getstate__)
8 |
9 | WARNING: This will only work if the driver is serializing and workers
10 | are de-serializing.
11 |
12 | :param cls: class instance or the Class itself
13 | """
14 | if cls not in serialization.type_to_class_id:
15 | serialization.add_class_to_whitelist(
16 | cls,
17 | utils.random_string(),
18 | pickle=True,
19 | )
20 |
--------------------------------------------------------------------------------
/rlkit/exploration_strategies/noop.py:
--------------------------------------------------------------------------------
1 | from rlkit.exploration_strategies.base import RawExplorationStrategy
2 | import numpy as np
3 |
4 |
5 | class NoopStrategy(RawExplorationStrategy):
6 | """
7 | Exploration strategy that does nothing other than clip the action.
8 | """
9 |
10 | def __init__(self, action_space, **kwargs):
11 | self.action_space = action_space
12 |
13 | def get_action(self, t, observation, policy, **kwargs):
14 | return policy.get_action(observation)
15 |
16 | def get_action_from_raw_action(self, action, **kwargs):
17 | return np.clip(action, self.action_space.low, self.action_space.high)
18 |
--------------------------------------------------------------------------------
/rlkit/envs/pearl_envs/hopper_rand_params_wrapper.py:
--------------------------------------------------------------------------------
1 | from rlkit.envs.pearl_envs.rand_param_envs.hopper_rand_params import HopperRandParamsEnv
2 |
3 |
4 | class HopperRandParamsWrappedEnv(HopperRandParamsEnv):
5 | def __init__(self, n_tasks=2, randomize_tasks=True):
6 | super(HopperRandParamsWrappedEnv, self).__init__()
7 | self.tasks = self.sample_tasks(n_tasks)
8 | self.reset_task(0)
9 |
10 | def get_all_task_idx(self):
11 | return range(len(self.tasks))
12 |
13 | def reset_task(self, idx):
14 | self._task = self.tasks[idx]
15 | self._goal = idx
16 | self.set_task(self._task)
17 | self.reset()
18 |
--------------------------------------------------------------------------------
/rlkit/envs/pearl_envs/walker_rand_params_wrapper.py:
--------------------------------------------------------------------------------
1 | from rlkit.envs.pearl_envs.rand_param_envs.walker2d_rand_params import Walker2DRandParamsEnv
2 |
3 |
4 | class WalkerRandParamsWrappedEnv(Walker2DRandParamsEnv):
5 | def __init__(self, n_tasks=2, randomize_tasks=True):
6 | super(WalkerRandParamsWrappedEnv, self).__init__()
7 | self.tasks = self.sample_tasks(n_tasks)
8 | self.reset_task(0)
9 |
10 | def get_all_task_idx(self):
11 | return range(len(self.tasks))
12 |
13 | def reset_task(self, idx):
14 | self._task = self.tasks[idx]
15 | self._goal = idx
16 | self.set_task(self._task)
17 | self.reset()
18 |
--------------------------------------------------------------------------------
/rlkit/demos/her_td3bc.py:
--------------------------------------------------------------------------------
1 | from rlkit.data_management.obs_dict_replay_buffer import \
2 | ObsDictRelabelingBuffer
3 | from rlkit.torch.her.her import HER
4 | from rlkit.demos.td3_bc import TD3BC
5 |
6 |
7 | class HerTD3BC(HER, TD3BC):
8 | def __init__(
9 | self,
10 | *args,
11 | td3_kwargs,
12 | her_kwargs,
13 | base_kwargs,
14 | **kwargs
15 | ):
16 | HER.__init__(
17 | self,
18 | **her_kwargs,
19 | )
20 | TD3BC.__init__(self, *args, **kwargs, **td3_kwargs, **base_kwargs)
21 | assert isinstance(
22 | self.replay_buffer, ObsDictRelabelingBuffer
23 | )
24 |
--------------------------------------------------------------------------------
/rlkit/networks/utils.py:
--------------------------------------------------------------------------------
1 | import math
2 |
3 | import torch
4 | import torch.nn as nn
5 |
6 |
7 | @torch.no_grad()
8 | def variance_scaling_init_(tensor, scale=1, mode="fan_avg", distribution="uniform"):
9 | fan_in, fan_out = nn.init._calculate_fan_in_and_fan_out(tensor)
10 |
11 | if mode == "fan_in":
12 | scale /= fan_in
13 |
14 | elif mode == "fan_out":
15 | scale /= fan_out
16 |
17 | else:
18 | scale /= (fan_in + fan_out) / 2
19 |
20 | if distribution == "normal":
21 | std = math.sqrt(scale)
22 |
23 | return tensor.normal_(0, std)
24 |
25 | else:
26 | bound = math.sqrt(3 * scale)
27 |
28 | return tensor.uniform_(-bound, bound)
29 |
--------------------------------------------------------------------------------
/rlkit/envs/wrappers/reward_wrapper_env.py:
--------------------------------------------------------------------------------
1 | from rlkit.envs.proxy_env import ProxyEnv
2 |
3 |
4 | class RewardWrapperEnv(ProxyEnv):
5 | """Substitute a different reward function"""
6 |
7 | def __init__(
8 | self,
9 | env,
10 | compute_reward_fn,
11 | ):
12 | ProxyEnv.__init__(self, env)
13 | self.spec = env.spec # hack for hand envs
14 | self.compute_reward_fn = compute_reward_fn
15 |
16 | def step(self, action):
17 | next_obs, reward, done, info = self._wrapped_env.step(action)
18 | info["env_reward"] = reward
19 | reward = self.compute_reward_fn(next_obs, reward, done, info)
20 | return next_obs, reward, done, info
21 |
--------------------------------------------------------------------------------
/rlkit/demos/play_demo.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | import sys
4 | # print(sys.path)
5 | sys.path.remove("/opt/ros/kinetic/lib/python2.7/dist-packages")
6 |
7 | import cv2
8 | import sys
9 | import pickle
10 |
11 | def play_demos(path):
12 | data = pickle.load(open(path, "rb"))
13 | # data = np.load(path, allow_pickle=True)
14 |
15 | for traj in data:
16 | obs = traj["observations"]
17 |
18 | for o in obs:
19 | img = o["image_observation"].reshape(3, 500, 300)[:, 60:, :240].transpose()
20 | img = img[:, :, ::-1]
21 | cv2.imshow('window', img)
22 | cv2.waitKey(100)
23 |
24 | if __name__ == '__main__':
25 | demo_path = sys.argv[1]
26 | play_demos(demo_path)
27 |
--------------------------------------------------------------------------------
/rlkit/demos/spacemouse/input_client.py:
--------------------------------------------------------------------------------
1 | """
2 | Should be run on a machine connected to a spacemouse
3 | """
4 |
5 | from robosuite.devices import SpaceMouse
6 | import time
7 | import Pyro4
8 | from rlkit.launchers import config
9 | # HOSTNAME = config.SPACEMOUSE_HOSTNAME
10 | HOSTNAME = "192.168.1.4"
11 |
12 | Pyro4.config.SERIALIZERS_ACCEPTED = set(
13 | ['pickle', 'json', 'marshal', 'serpent'])
14 | Pyro4.config.SERIALIZER = 'pickle'
15 |
16 | nameserver = Pyro4.locateNS(host=HOSTNAME)
17 | uri = nameserver.lookup("example.greeting")
18 | device_state = Pyro4.Proxy(uri)
19 | device = SpaceMouse()
20 | while True:
21 | state = device.get_controller_state()
22 | print(state)
23 | time.sleep(0.1)
24 | device_state.set_state(state)
25 |
--------------------------------------------------------------------------------
/rlkit/envs/pearl_envs/ant_normal.py:
--------------------------------------------------------------------------------
1 | from gym.envs.mujoco import AntEnv
2 |
3 |
4 | class AntNormal(AntEnv):
5 | def __init__(
6 | self,
7 | *args,
8 | n_tasks=2, # number of distinct tasks in this domain, shoudl equal sum of train and eval tasks
9 | randomize_tasks=True, # shuffle the tasks after creating them
10 | **kwargs
11 | ):
12 | self.tasks = [0 for _ in range(n_tasks)]
13 | self._goal = 0
14 | super().__init__(*args, **kwargs)
15 |
16 | def get_all_task_idx(self):
17 | return self.tasks
18 |
19 | def reset_task(self, idx):
20 | # not tasks. just give the same reward every time step.
21 | pass
22 |
23 | def sample_tasks(self, num_tasks):
24 | return [0 for _ in range(num_tasks)]
25 |
--------------------------------------------------------------------------------
/rlkit/data_management/wrappers/concat_to_obs_wrapper.py:
--------------------------------------------------------------------------------
1 | from rlkit.data_management.wrappers.proxy_buffer import ProxyBuffer
2 | import numpy as np
3 |
4 | class ConcatToObsWrapper(ProxyBuffer):
5 | def __init__(self, replay_buffer, keys_to_concat):
6 | """keys_to_concat: list of strings"""
7 | super().__init__(replay_buffer)
8 | self.keys_to_concat = keys_to_concat
9 |
10 | def random_batch(self, batch_size):
11 | batch = self._wrapped_buffer.random_batch(batch_size)
12 | obs = batch['observations']
13 | next_obs = batch['next_observations']
14 | to_concat = [batch[key] for key in self.keys_to_concat]
15 | batch['observations'] = np.concatenate([obs] + to_concat, axis=1)
16 | batch['next_observations'] = np.concatenate([next_obs] + to_concat, axis=1)
17 | return batch
18 |
--------------------------------------------------------------------------------
/rlkit/exploration_strategies/epsilon_greedy.py:
--------------------------------------------------------------------------------
1 | import random
2 |
3 | from rlkit.exploration_strategies.base import RawExplorationStrategy
4 |
5 |
6 | class EpsilonGreedy(RawExplorationStrategy):
7 | """
8 | Take a random discrete action with some probability.
9 | """
10 | def __init__(self, action_space, prob_random_action=0.1):
11 | self.prob_random_action = prob_random_action
12 | self.action_space = action_space
13 |
14 | def get_action(self, t, policy, *args, **kwargs):
15 | action, agent_info = policy.get_action(*args, **kwargs)
16 | return self.get_action_from_raw_action(action), agent_info
17 |
18 | def get_action_from_raw_action(self, action, **kwargs):
19 | if random.random() <= self.prob_random_action:
20 | return self.action_space.sample()
21 | return action
22 |
--------------------------------------------------------------------------------
/rlkit/demos/source/path_loader.py:
--------------------------------------------------------------------------------
1 | from collections import OrderedDict
2 |
3 | import numpy as np
4 | import torch
5 | import torch.optim as optim
6 | from torch import nn as nn
7 | import torch.nn.functional as F
8 |
9 | import rlkit.torch.pytorch_util as ptu
10 | from rlkit.core.eval_util import create_stats_ordered_dict
11 | from rlkit.torch.torch_rl_algorithm import TorchTrainer
12 |
13 | from rlkit.util.io import load_local_or_remote_file
14 |
15 | import random
16 | from rlkit.torch.core import np_to_pytorch_batch
17 | from rlkit.data_management.path_builder import PathBuilder
18 |
19 | # import matplotlib
20 | # matplotlib.use('TkAgg')
21 | # import matplotlib.pyplot as plt
22 |
23 | from rlkit.core import logger
24 |
25 | import glob
26 |
27 | class PathLoader:
28 | """
29 | Loads demonstrations and/or off-policy data into a Trainer
30 | """
31 |
32 | def load_demos(self, ):
33 | pass
34 |
--------------------------------------------------------------------------------
/rlkit/policies/action_repeat.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | from rlkit.policies.base import Policy
4 |
5 |
6 | class ActionRepeatPolicy(Policy):
7 | """
8 | General policy interface.
9 | """
10 | def __init__(self, policy, repeat_prob=.5):
11 | self._policy = policy
12 | self._repeat_prob = repeat_prob
13 | self._last_action = None
14 |
15 | def get_action(self, observation):
16 | """
17 |
18 | :param observation:
19 | :return: action, debug_dictionary
20 | """
21 | action = self._policy.get_action(observation)
22 | if (
23 | self._last_action is not None
24 | and np.random.uniform() <= self._repeat_prob
25 | ):
26 | action = self._last_action
27 | self._last_action = action
28 | return self._last_action
29 |
30 | def reset(self):
31 | self._last_action = None
--------------------------------------------------------------------------------
/rlkit/envs/assets/oned_point.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
--------------------------------------------------------------------------------
/rlkit/envs/mujoco/point_env.py:
--------------------------------------------------------------------------------
1 | from gym import Env
2 | from gym.spaces import Box
3 |
4 | import numpy as np
5 |
6 |
7 | class PointEnv(Env):
8 | @property
9 | def observation_space(self):
10 | return Box(low=-np.inf, high=np.inf, shape=(2,))
11 |
12 | @property
13 | def action_space(self):
14 | return Box(low=-0.1, high=0.1, shape=(2,))
15 |
16 | def reset(self):
17 | self._state = np.random.uniform(-1, 1, size=(2,))
18 | observation = np.copy(self._state)
19 | return observation
20 |
21 | def step(self, action):
22 | self._state = self._state + action
23 | x, y = self._state
24 | reward = - (x ** 2 + y ** 2) ** 0.5
25 | done = abs(x) < 0.01 and abs(y) < 0.01
26 | next_observation = np.copy(self._state)
27 | return next_observation, reward, done
28 |
29 | def render(self, **kwargs):
30 | print('current state:', self._state)
31 |
--------------------------------------------------------------------------------
/rlkit/samplers/data_collector/base.py:
--------------------------------------------------------------------------------
1 | import abc
2 |
3 |
4 | class DataCollector(object, metaclass=abc.ABCMeta):
5 | def end_epoch(self, epoch):
6 | pass
7 |
8 | def get_diagnostics(self):
9 | return {}
10 |
11 | def get_snapshot(self):
12 | return {}
13 |
14 | @abc.abstractmethod
15 | def get_epoch_paths(self):
16 | pass
17 |
18 |
19 | class PathCollector(DataCollector, metaclass=abc.ABCMeta):
20 | @abc.abstractmethod
21 | def collect_new_paths(
22 | self,
23 | max_path_length,
24 | num_steps,
25 | discard_incomplete_paths,
26 | ):
27 | pass
28 |
29 |
30 | class StepCollector(DataCollector, metaclass=abc.ABCMeta):
31 | @abc.abstractmethod
32 | def collect_new_steps(
33 | self,
34 | max_path_length,
35 | num_steps,
36 | discard_incomplete_paths,
37 | ):
38 | pass
39 |
--------------------------------------------------------------------------------
/rlkit/envs/wrappers/discretize_env.py:
--------------------------------------------------------------------------------
1 | import itertools
2 |
3 | import numpy as np
4 | from gym import Env
5 | from gym.spaces import Discrete
6 |
7 | from rlkit.envs.proxy_env import ProxyEnv
8 |
9 |
10 | class DiscretizeEnv(ProxyEnv, Env):
11 | def __init__(self, wrapped_env, num_bins):
12 | super().__init__(wrapped_env)
13 | low = self.wrapped_env.action_space.low
14 | high = self.wrapped_env.action_space.high
15 | action_ranges = [
16 | np.linspace(low[i], high[i], num_bins)
17 | for i in range(len(low))
18 | ]
19 | self.idx_to_continuous_action = [
20 | np.array(x) for x in itertools.product(*action_ranges)
21 | ]
22 | self.action_space = Discrete(len(self.idx_to_continuous_action))
23 |
24 | def step(self, action):
25 | continuous_action = self.idx_to_continuous_action[action]
26 | return super().step(continuous_action)
27 |
28 |
29 |
--------------------------------------------------------------------------------
/rlkit/torch/transforms/_pil_constants.py:
--------------------------------------------------------------------------------
1 | import PIL
2 | from PIL import Image
3 |
4 | # See https://pillow.readthedocs.io/en/stable/releasenotes/9.1.0.html#deprecations
5 | # TODO: Remove this file once PIL minimal version is >= 9.1
6 |
7 | if tuple(int(part) for part in PIL.__version__.split(".")) >= (9, 1):
8 | BICUBIC = Image.Resampling.BICUBIC
9 | BILINEAR = Image.Resampling.BILINEAR
10 | LINEAR = Image.Resampling.BILINEAR
11 | NEAREST = Image.Resampling.NEAREST
12 |
13 | AFFINE = Image.Transform.AFFINE
14 | FLIP_LEFT_RIGHT = Image.Transpose.FLIP_LEFT_RIGHT
15 | FLIP_TOP_BOTTOM = Image.Transpose.FLIP_TOP_BOTTOM
16 | PERSPECTIVE = Image.Transform.PERSPECTIVE
17 | else:
18 | BICUBIC = Image.BICUBIC
19 | BILINEAR = Image.BILINEAR
20 | NEAREST = Image.NEAREST
21 | LINEAR = Image.LINEAR
22 |
23 | AFFINE = Image.AFFINE
24 | FLIP_LEFT_RIGHT = Image.FLIP_LEFT_RIGHT
25 | FLIP_TOP_BOTTOM = Image.FLIP_TOP_BOTTOM
26 | PERSPECTIVE = Image.PERSPECTIVE
--------------------------------------------------------------------------------
/rlkit/envs/wrappers/__init__.py:
--------------------------------------------------------------------------------
1 | """Imports wrappers
2 |
3 | TODO: DEPRECATE. This pattern imports unecessary modules, creating unnecessary
4 | software dependencies. Just import the specific module instead.
5 | """
6 |
7 | from rlkit.envs.wrappers.discretize_env import DiscretizeEnv
8 | from rlkit.envs.wrappers.history_env import HistoryEnv
9 | # from rlkit.envs.wrappers.image_mujoco_env import ImageMujocoEnv
10 | # from rlkit.envs.wrappers.image_mujoco_env_with_obs import ImageMujocoWithObsEnv
11 | from rlkit.envs.wrappers.normalized_box_env import NormalizedBoxEnv
12 | from rlkit.envs.proxy_env import ProxyEnv
13 | from rlkit.envs.wrappers.reward_wrapper_env import RewardWrapperEnv
14 | from rlkit.envs.wrappers.stack_observation_env import StackObservationEnv
15 |
16 |
17 | __all__ = [
18 | 'DiscretizeEnv',
19 | 'HistoryEnv',
20 | # 'ImageMujocoEnv',
21 | # 'ImageMujocoWithObsEnv',
22 | 'NormalizedBoxEnv',
23 | 'ProxyEnv',
24 | 'RewardWrapperEnv',
25 | 'StackObservationEnv',
26 | ]
27 |
--------------------------------------------------------------------------------
/rlkit/torch/her/her.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from rlkit.torch.torch_rl_algorithm import TorchTrainer
4 |
5 |
6 | class HERTrainer(TorchTrainer):
7 | def __init__(self, base_trainer: TorchTrainer):
8 | super().__init__()
9 | self._base_trainer = base_trainer
10 |
11 | def train_from_torch(self, batch):
12 | obs = batch['observations']
13 | next_obs = batch['next_observations']
14 | goals = batch['resampled_goals']
15 | batch['observations'] = torch.cat((obs, goals), dim=1)
16 | batch['next_observations'] = torch.cat((next_obs, goals), dim=1)
17 | self._base_trainer.train_from_torch(batch)
18 |
19 | def get_diagnostics(self):
20 | return self._base_trainer.get_diagnostics()
21 |
22 | def end_epoch(self, epoch):
23 | self._base_trainer.end_epoch(epoch)
24 |
25 | @property
26 | def networks(self):
27 | return self._base_trainer.networks
28 |
29 | def get_snapshot(self):
30 | return self._base_trainer.get_snapshot()
31 |
--------------------------------------------------------------------------------
/rlkit/envs/gridcraft/utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | def flat_to_one_hot(val, ndim):
4 | """
5 |
6 | >>> flat_to_one_hot(2, ndim=4)
7 | array([ 0., 0., 1., 0.])
8 | >>> flat_to_one_hot(4, ndim=5)
9 | array([ 0., 0., 0., 0., 1.])
10 | >>> flat_to_one_hot(np.array([2, 4, 3]), ndim=5)
11 | array([[ 0., 0., 1., 0., 0.],
12 | [ 0., 0., 0., 0., 1.],
13 | [ 0., 0., 0., 1., 0.]])
14 | """
15 | shape =np.array(val).shape
16 | v = np.zeros(shape + (ndim,))
17 | if len(shape) == 1:
18 | v[np.arange(shape[0]), val] = 1.0
19 | else:
20 | v[val] = 1.0
21 | return v
22 |
23 | def one_hot_to_flat(val):
24 | """
25 | >>> one_hot_to_flat(np.array([0,0,0,0,1]))
26 | 4
27 | >>> one_hot_to_flat(np.array([0,0,1,0]))
28 | 2
29 | >>> one_hot_to_flat(np.array([[0,0,1,0], [1,0,0,0], [0,1,0,0]]))
30 | array([2, 0, 1])
31 | """
32 | idxs = np.array(np.where(val == 1.0))[-1]
33 | if len(val.shape) == 1:
34 | return int(idxs)
35 | return idxs
--------------------------------------------------------------------------------
/rlkit/envs/mujoco/twod_point_random_init.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | from rlkit.envs.mujoco.mujoco_env import MujocoEnv
4 |
5 | TARGET = np.array([0.2, 0])
6 |
7 |
8 | class TwoDPointRandomInit(MujocoEnv):
9 | def __init__(self):
10 | self.init_serialization(locals())
11 | super().__init__('twod_point.xml')
12 |
13 | def _step(self, a):
14 | self.do_simulation(a, self.frame_skip)
15 | ob = self._get_obs()
16 | pos = ob[0:2]
17 | dist = np.linalg.norm(pos - TARGET)
18 | reward = - (dist + 1e-2*np.linalg.norm(a))
19 | done = False
20 | return ob, reward, done, {}
21 |
22 | def reset_model(self):
23 | qpos = self.np_random.uniform(size=self.model.nq, low=-0.25, high=0.25)
24 | qvel = self.np_random.uniform(size=self.model.nv, low=-0.25, high=0.25)
25 | self.set_state(qpos, qvel)
26 | return self._get_obs()
27 |
28 | def _get_obs(self):
29 | return np.concatenate([self.model.data.qpos]).ravel()
30 |
31 | def viewer_setup(self):
32 | pass
33 |
--------------------------------------------------------------------------------
/rlkit/envs/env_utils.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | from gym.spaces import Box, Discrete, Tuple
4 |
5 | ENV_ASSET_DIR = os.path.join(os.path.dirname(__file__), 'assets')
6 |
7 |
8 | def get_asset_full_path(file_name):
9 | return os.path.join(ENV_ASSET_DIR, file_name)
10 |
11 |
12 | def get_dim(space):
13 | if isinstance(space, Box):
14 | return space.low.size
15 | elif isinstance(space, Discrete):
16 | return space.n
17 | elif isinstance(space, Tuple):
18 | return sum(get_dim(subspace) for subspace in space.spaces)
19 | elif hasattr(space, 'flat_dim'):
20 | return space.flat_dim
21 | else:
22 | raise TypeError("Unknown space: {}".format(space))
23 |
24 |
25 | def gym_env(name):
26 | from rllab.envs.gym_env import GymEnv
27 | return GymEnv(name,
28 | record_video=False,
29 | log_dir='/tmp/gym-test', # Ignore gym log.
30 | record_log=False)
31 |
32 | def mode(env, mode_type):
33 | try:
34 | getattr(env, mode_type)()
35 | except AttributeError:
36 | pass
37 |
--------------------------------------------------------------------------------
/rlkit/envs/wrappers/modular_env.py:
--------------------------------------------------------------------------------
1 | import abc
2 | from typing import Callable, Any, Dict, List
3 |
4 | import gym.spaces
5 |
6 | Path = Dict
7 | Diagnostics = Dict
8 | Context = Any
9 | ContextualDiagnosticsFn = Callable[
10 | [List[Path], List[Context]],
11 | Diagnostics,
12 | ]
13 |
14 |
15 | class RewardFn(object, metaclass=abc.ABCMeta):
16 | """Some reward function"""
17 | @abc.abstractmethod
18 | def __call__(
19 | self,
20 | action,
21 | next_state: dict,
22 | ):
23 | pass
24 |
25 |
26 | class ModularEnv(gym.Wrapper):
27 | """An env where you can separately specify the reward and transition."""
28 | def __init__(
29 | self,
30 | env: gym.Env,
31 | reward_fn: RewardFn,
32 | ):
33 | super().__init__(env)
34 | self.reward_fn = reward_fn
35 |
36 | def step(self, action):
37 | obs, reward, done, info = super().step(action)
38 | if self.reward_fn:
39 | reward = self.reward_fn(action, obs)
40 | return obs, reward, done, info
41 |
--------------------------------------------------------------------------------
/rlkit/envs/images/text_renderer.py:
--------------------------------------------------------------------------------
1 | from PIL import Image
2 | from PIL import ImageDraw
3 | from PIL import ImageFont
4 | import numpy as np
5 |
6 |
7 | from rlkit.envs.images import Renderer
8 |
9 |
10 | class TextRenderer(Renderer):
11 | """I gave up! See plot_renderer.TextRenderer"""
12 |
13 | def __init__(self, text, *args,
14 | text_color='white',
15 | background_color='black',
16 | **kwargs):
17 | super().__init__(*args,
18 | create_image_format='HWC',
19 | **kwargs)
20 |
21 | font = ImageFont.truetype(
22 | '/usr/share/fonts/truetype/ubuntu-font-family/Ubuntu-R.ttf', 100)
23 | _, h, w = self.image_chw
24 | self._img = Image.new('RGB', (w, h), background_color)
25 | self._draw_interface = ImageDraw.Draw(self._img)
26 | self._draw_interface.text((0, 0), text, fill=text_color, font=font)
27 | self._np_img = np.array(self._img).copy()
28 |
29 | def _create_image(self, *args, **kwargs):
30 | return self._np_img
31 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023 Chongyi Zheng
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/rlkit/torch/networks/ae_tanh_policy.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from rlkit.torch import pytorch_util as ptu
4 | from rlkit.torch.networks import MlpPolicy
5 |
6 |
7 | class AETanhPolicy(MlpPolicy):
8 | """
9 | A helper class since most policies have a tanh output activation.
10 | """
11 |
12 | def __init__(
13 | self,
14 | ae,
15 | env,
16 | history_length,
17 | *args,
18 | **kwargs
19 | ):
20 | super().__init__(*args, **kwargs, output_activation=torch.tanh)
21 | self.ae = ae
22 | self.history_length = history_length
23 | self.env = env
24 |
25 | def get_action(self, obs_np):
26 | obs = obs_np
27 | obs = ptu.from_numpy(obs)
28 | image_obs, fc_obs = self.env.split_obs(obs)
29 | latent_obs = self.ae.history_encoder(image_obs, self.history_length)
30 | if fc_obs is not None:
31 | latent_obs = torch.cat((latent_obs, fc_obs), dim=1)
32 | obs_np = ptu.get_numpy(latent_obs)[0]
33 | actions = self.get_actions(obs_np[None])
34 | return actions[0, :], {}
35 |
36 |
--------------------------------------------------------------------------------
/rlkit/envs/pearl_envs/half_cheetah.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from gym.envs.mujoco import HalfCheetahEnv as HalfCheetahEnv_
3 |
4 | class HalfCheetahEnv(HalfCheetahEnv_):
5 | def _get_obs(self):
6 | return np.concatenate([
7 | self.sim.data.qpos.flat[1:],
8 | self.sim.data.qvel.flat,
9 | self.get_body_com("torso").flat,
10 | ]).astype(np.float32).flatten()
11 |
12 | def viewer_setup(self):
13 | camera_id = self.model.camera_name2id('track')
14 | self.viewer.cam.type = 2
15 | self.viewer.cam.fixedcamid = camera_id
16 | self.viewer.cam.distance = self.model.stat.extent * 0.35
17 | # Hide the overlay
18 | self.viewer._hide_overlay = True
19 |
20 | def render(self, mode='human', width=500, height=500, **kwargs):
21 | if mode == 'rgb_array':
22 | self._get_viewer(mode).render(width=width, height=height)
23 | data = self._get_viewer(mode).read_pixels(width, height, depth=False)[::-1, :, :]
24 | return data
25 | elif mode == 'human':
26 | self._get_viewer(mode).render()
27 |
--------------------------------------------------------------------------------
/rlkit/envs/mujoco/reacher_env.py:
--------------------------------------------------------------------------------
1 | from collections import OrderedDict
2 |
3 | from gym.envs.mujoco.reacher import ReacherEnv as GymReacherEnv
4 |
5 | from rlkit.core import logger as default_logger
6 | from rlkit.core.eval_util import get_stat_in_paths, create_stats_ordered_dict
7 |
8 |
9 | class ReacherEnv(GymReacherEnv):
10 | def log_diagnostics(self, paths, logger=default_logger):
11 | statistics = OrderedDict()
12 | for name_in_env_infos, name_to_log in [
13 | ('reward_dist', 'Distance Reward'),
14 | ('reward_ctrl', 'Action Reward'),
15 | ]:
16 | stat = get_stat_in_paths(paths, 'env_infos', name_in_env_infos)
17 | statistics.update(create_stats_ordered_dict(
18 | name_to_log,
19 | stat,
20 | ))
21 | distances = get_stat_in_paths(paths, 'env_infos', 'reward_dist')
22 | statistics.update(create_stats_ordered_dict(
23 | "Final Distance Reward",
24 | [ds[-1] for ds in distances],
25 | ))
26 | for key, value in statistics.items():
27 | logger.record_tabular(key, value)
28 |
29 |
30 |
--------------------------------------------------------------------------------
/rlkit/core/distribution.py:
--------------------------------------------------------------------------------
1 | import abc
2 | from typing import Dict
3 | from gym import Space
4 |
5 |
6 | class DictDistribution(object, metaclass=abc.ABCMeta):
7 |
8 | @abc.abstractmethod
9 | def sample(self, batch_size: int):
10 | pass
11 |
12 | @property
13 | @abc.abstractmethod
14 | def spaces(self) -> Dict[str, Space]:
15 | pass
16 |
17 | def __call__(self, *args, **kwargs):
18 | """For backward compatibility with DictDistributionGenerator"""
19 | return self
20 |
21 |
22 | class DictDistributionGenerator(DictDistribution, metaclass=abc.ABCMeta):
23 |
24 | def __call__(self, *args, **kwargs) -> DictDistribution:
25 | raise NotImplementedError
26 |
27 |
28 | class DictDistributionClosure(DictDistributionGenerator):
29 | """Fills in args to a DictDistribution"""
30 |
31 | def __init__(self, clz, *args, **kwargs):
32 | self.clz = clz
33 | self.args = args
34 | self.kwargs = kwargs
35 |
36 | def __call__(self, **extra_kwargs) -> DictDistribution:
37 | self.kwargs.update(**extra_kwargs)
38 | return self.clz(
39 | *args,
40 | **kwargs,
41 | )
42 |
--------------------------------------------------------------------------------
/rlkit/envs/mujoco/twod_point.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | from rlkit.envs.mujoco.mujoco_env import MujocoEnv
4 |
5 | TARGET = np.array([0.2, 0])
6 |
7 |
8 | class TwoDPoint(MujocoEnv):
9 | def __init__(self):
10 | self.init_serialization(locals())
11 | super().__init__('twod_point.xml')
12 |
13 | def _step(self, a):
14 | self.do_simulation(a, self.frame_skip)
15 | ob = self._get_obs()
16 | pos = ob[0:2]
17 | dist = np.linalg.norm(pos - TARGET)
18 | reward = - (dist + 1e-2 * np.linalg.norm(a))
19 | done = False
20 | return ob, reward, done, {}
21 |
22 | def reset_model(self):
23 | qpos = self.init_qpos + np.random.uniform(size=self.model.nq, low=-0.01,
24 | high=0.01)
25 | qvel = self.init_qvel + np.random.uniform(size=self.model.nv, low=-0.01,
26 | high=0.01)
27 | self.set_state(qpos, qvel)
28 | return self._get_obs()
29 |
30 | def _get_obs(self):
31 | return np.concatenate([self.model.data.qpos]).ravel()
32 |
33 | def viewer_setup(self):
34 | pass
35 |
36 |
37 |
--------------------------------------------------------------------------------
/rlkit/exploration_strategies/gaussian_strategy.py:
--------------------------------------------------------------------------------
1 | from rlkit.exploration_strategies.base import RawExplorationStrategy
2 | import numpy as np
3 |
4 |
5 | class GaussianStrategy(RawExplorationStrategy):
6 | """
7 | This strategy adds Gaussian noise to the action taken by the deterministic policy.
8 |
9 | Based on the rllab implementation.
10 | """
11 | def __init__(self, action_space, max_sigma=1.0, min_sigma=None,
12 | decay_period=1000000):
13 | assert len(action_space.shape) == 1
14 | self._max_sigma = max_sigma
15 | if min_sigma is None:
16 | min_sigma = max_sigma
17 | self._min_sigma = min_sigma
18 | self._decay_period = decay_period
19 | self._action_space = action_space
20 |
21 | def get_action_from_raw_action(self, action, t=None, **kwargs):
22 | sigma = (
23 | self._max_sigma - (self._max_sigma - self._min_sigma) *
24 | min(1.0, t * 1.0 / self._decay_period)
25 | )
26 | return np.clip(
27 | action + np.random.normal(size=len(action)) * sigma,
28 | self._action_space.low,
29 | self._action_space.high,
30 | )
31 |
--------------------------------------------------------------------------------
/rlkit/envs/time_limited_env.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from gym.spaces import Box
3 |
4 | from rlkit.envs.proxy_env import ProxyEnv
5 |
6 |
7 | class TimeLimitedEnv(ProxyEnv):
8 | def __init__(self, wrapped_env, horizon):
9 | self._wrapped_env = wrapped_env
10 | self._horizon = horizon
11 |
12 | self._observation_space = Box(
13 | np.hstack((self._wrapped_env.observation_space.low, [0])),
14 | np.hstack((self._wrapped_env.observation_space.high, [1])),
15 | )
16 | self._t = 0
17 |
18 | @property
19 | def observation_space(self):
20 | return self._observation_space
21 |
22 | @property
23 | def horizon(self):
24 | return self._horizon
25 |
26 | def step(self, action):
27 | obs, reward, done, info = self._wrapped_env.step(action)
28 | self._t += 1
29 | done = done or self._t == self.horizon
30 | new_obs = np.hstack((obs, float(self._t) / self.horizon))
31 | return new_obs, reward, done, info
32 |
33 | def reset(self, **kwargs):
34 | obs = self._wrapped_env.reset(**kwargs)
35 | self._t = 0
36 | new_obs = np.hstack((obs, self._t))
37 | return new_obs
38 |
--------------------------------------------------------------------------------
/rlkit/envs/gridcraft/custom_test.py:
--------------------------------------------------------------------------------
1 | from rlkit.envs.gridcraft import REW_ARENA_64
2 | from rlkit.envs.gridcraft.grid_env import GridEnv
3 | from rlkit.envs.gridcraft.grid_spec import *
4 | from rlkit.envs.gridcraft.mazes import MAZE_ANY_START1
5 | import gym.spaces.prng as prng
6 | import numpy as np
7 |
8 | if __name__ == "__main__":
9 | prng.seed(2)
10 |
11 | maze_spec = \
12 | spec_from_string("SOOOO#R#OO\\"+
13 | "OSOOO#2##O\\" +
14 | "###OO#3O#O\\" +
15 | "OOOOO#OO#O\\" +
16 | "OOOOOOOOOO\\"
17 | )
18 |
19 | #maze_spec = spec_from_sparse_locations(50, 50, {START: [(25,25)], REWARD: [(45,45)]})
20 | # maze_spec = REW_ARENA_64
21 | maze_spec = MAZE_ANY_START1
22 |
23 | env = GridEnv(maze_spec, one_hot=True, add_eyes=True, coordinate_wise=True)
24 |
25 | s = env.reset()
26 | #env.render()
27 |
28 | obses = []
29 | for t in range(10):
30 | a = env.action_space.sample()
31 | obs, r, done, infos = env.step(a, verbose=True)
32 | obses.append(obs)
33 | obses = np.array(obses)
34 |
35 | paths = [{'observations': obses}]
36 | env.plot_trajs(paths)
--------------------------------------------------------------------------------
/rlkit/envs/mujoco/twod_maze.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | from rlkit.envs.mujoco.mujoco_env import MujocoEnv
4 |
5 | INIT_POS = np.array([0.2,0.15])
6 | TARGET = np.array([0.2, -0.15]) + INIT_POS
7 | DIST_THRESH = 0.05
8 |
9 |
10 | class TwoDMaze(MujocoEnv):
11 | def __init__(self):
12 | self.init_serialization(locals())
13 | super().__init__('twod_maze.xml')
14 |
15 | def _step(self, a):
16 | self.do_simulation(a, self.frame_skip)
17 | ob = self._get_obs()
18 | pos = ob[0:2]
19 | dist = np.linalg.norm(pos - TARGET)
20 | dist_cost = 1 if dist>DIST_THRESH else dist/DIST_THRESH
21 | reward = - (dist_cost)# + 1e-2*np.linalg.norm(a))
22 | #print(reward, dist)
23 | done = False
24 | return ob, reward, done, {}
25 |
26 | def reset_model(self):
27 | qpos = self.init_qpos + self.np_random.uniform(size=self.model.nq, low=-0.01, high=0.01)
28 | qvel = self.init_qvel + self.np_random.uniform(size=self.model.nv, low=-0.01, high=0.01)
29 | self.set_state(qpos, qvel)
30 | return self._get_obs()
31 |
32 | def _get_obs(self):
33 | return np.concatenate([self.model.data.qpos]).ravel()
34 |
35 | def viewer_setup(self):
36 | pass
37 |
--------------------------------------------------------------------------------
/rlkit/envs/supervised_learning_env.py:
--------------------------------------------------------------------------------
1 | import abc
2 |
3 |
4 |
5 | class RecurrentSupervisedLearningEnv(metaclass=abc.ABCMeta):
6 | """
7 | An environment that's really just a supervised learning task.
8 | """
9 |
10 | @abc.abstractmethod
11 | def get_batch(self, batch_size):
12 | """
13 |
14 | :param batch_size: Size of the batch size
15 | :return: tuple (X, Y) where
16 | X is a numpy array of size (
17 | batch_size, self.sequence_length, self.feature_dim
18 | )
19 | Y is a numpy array of size (
20 | batch_size, self.sequence_length, self.target_dim
21 | )
22 | """
23 | pass
24 |
25 | @property
26 | @abc.abstractmethod
27 | def feature_dim(self):
28 | """
29 | :return: Integer. Dimension of the features.
30 | """
31 | pass
32 |
33 | @property
34 | @abc.abstractmethod
35 | def target_dim(self):
36 | """
37 | :return: Integer. Dimension of the target.
38 | """
39 | pass
40 |
41 | @property
42 | @abc.abstractmethod
43 | def sequence_length(self):
44 | """
45 | :return: Integer. Dimension of the target.
46 | """
47 | pass
48 |
--------------------------------------------------------------------------------
/rlkit/core/util.py:
--------------------------------------------------------------------------------
1 | import os
2 | LOG_DIR = os.getcwd()
3 |
4 |
5 | class Wrapper(object):
6 | """
7 | Mixin for deferring attributes to a wrapped, inner object.
8 | """
9 |
10 | def __init__(self, inner):
11 | self.inner = inner
12 |
13 | def __getattr__(self, attr):
14 | """
15 | Dispatch attributes by their status as magic, members, or missing.
16 | - magic is handled by the standard getattr
17 | - existing attributes are returned
18 | - missing attributes are deferred to the inner object.
19 | """
20 | # don't make magic any more magical
21 | is_magic = attr.startswith('__') and attr.endswith('__')
22 | if is_magic:
23 | return super().__getattr__(attr)
24 | try:
25 | # try to return the attribute...
26 | return self.__dict__[attr]
27 | except:
28 | # ...and defer to the inner dataset if it's not here
29 | return getattr(self.inner, attr)
30 |
31 |
32 | class SimpleWrapper(object):
33 | """
34 | Mixin for deferring attributes to a wrapped, inner object.
35 | """
36 |
37 | def __init__(self, inner):
38 | self._inner = inner
39 |
40 | def __getattr__(self, attr):
41 | if attr == '_inner':
42 | raise AttributeError()
43 | return getattr(self._inner, attr)
44 |
--------------------------------------------------------------------------------
/rlkit/envs/mujoco/oned_point.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from gym import utils
3 | from gym.envs.mujoco import mujoco_env
4 |
5 | from rlkit.envs.env_utils import get_asset_full_path
6 |
7 |
8 | class OneDPoint(mujoco_env.MujocoEnv, utils.EzPickle):
9 | def __init__(self):
10 | utils.EzPickle.__init__(self)
11 | mujoco_env.MujocoEnv.__init__(self, get_asset_full_path('oned_point.xml'), 2)
12 |
13 | def _step(self, a):
14 | self.do_simulation(a, self.frame_skip)
15 | ob = self._get_obs()
16 | pos = ob[0]
17 | reward = 1 if pos > 0.5 else 0 #pos #(pos)**2 / 10.
18 | #reward = a[0]
19 | #notdone = np.isfinite(ob).all() and (np.abs(ob[1]) <= .2)
20 | #done = not notdone
21 | done = False
22 | return ob, reward, done, {}
23 |
24 | def reset_model(self):
25 | qpos = self.init_qpos + self.np_random.uniform(size=self.model.nq, low=-0.01, high=0.01)
26 | qvel = self.init_qvel + self.np_random.uniform(size=self.model.nv, low=-0.01, high=0.01)
27 | self.set_state(qpos, qvel)
28 | return self._get_obs()
29 |
30 | def _get_obs(self):
31 | return np.concatenate([self.model.data.qpos, self.model.data.qvel]).ravel()
32 |
33 | def viewer_setup(self):
34 | v = self.viewer
35 | v.cam.trackbodyid=0
36 | v.cam.distance = v.model.stat.extent
37 |
--------------------------------------------------------------------------------
/rlkit/utils/timer.py:
--------------------------------------------------------------------------------
1 | """TensorFlow Task Generators API."""
2 |
3 | from __future__ import absolute_import
4 | from __future__ import division
5 | from __future__ import print_function
6 |
7 | import time
8 |
9 |
10 | class Timer(object):
11 |
12 | def __init__(self, keys):
13 | assert isinstance(keys, list)
14 | self._keys = keys
15 | self._keys = keys
16 |
17 | self.reset()
18 |
19 | def reset(self):
20 | self._start_time = {
21 | key: None for key in self._keys
22 | }
23 | self._time_acc = {
24 | key: 0.0 for key in self._keys
25 | }
26 | self._count = {
27 | key: 0 for key in self._keys
28 | }
29 |
30 | @property
31 | def time_acc(self):
32 | return self._time_acc
33 |
34 | def tic(self, key):
35 | assert key in self._keys
36 | self._start_time[key] = time.time()
37 |
38 | def toc(self, key):
39 | assert self._start_time[key] is not None
40 | self._time_acc[key] += time.time() - self._start_time[key]
41 | self._count[key] += 1
42 | self._start_time[key] = None
43 |
44 | def accumulated_time(self, key):
45 | return self._time_acc[key]
46 |
47 | def average_time(self, key):
48 | assert self._count[key] > 0
49 | return self._time_acc[key] / self._count[key]
50 |
--------------------------------------------------------------------------------
/rlkit/envs/wrappers/flat_to_dict.py:
--------------------------------------------------------------------------------
1 | import gym
2 | import numpy as np
3 | from gym.spaces import Dict
4 |
5 | from rlkit.core.util import SimpleWrapper
6 | from rlkit.policies.base import Policy
7 |
8 |
9 | class FlatToDictEnv(gym.Wrapper):
10 | """Wrap an environment that returns a flat obs to return a dict."""
11 | def __init__(self, env, observation_key):
12 | super().__init__(env)
13 | self.observation_key = observation_key
14 | new_ob_space = {
15 | self.observation_key: self.observation_space
16 | }
17 | self.observation_space = Dict(new_ob_space)
18 |
19 | def step(self, action):
20 | obs, reward, done, info = super().step(action)
21 | return {self.observation_key: obs}, reward, done, info
22 |
23 | def reset(self):
24 | obs = super().reset()
25 | return {self.observation_key: obs}
26 |
27 |
28 | class FlatToDictPolicy(SimpleWrapper, Policy):
29 | """Wrap a policy that expect a flat obs so expects a dict obs."""
30 |
31 | def __init__(self, policy, observation_key):
32 | super().__init__(policy)
33 | self.policy = policy
34 | self.observation_key = observation_key
35 |
36 | def get_action(self, observation, *args, **kwargs):
37 | flat_ob = observation[self.observation_key]
38 | return self.policy.get_action(flat_ob, *args, **kwargs)
39 |
--------------------------------------------------------------------------------
/rlkit/demos/her_bc.py:
--------------------------------------------------------------------------------
1 | from rlkit.data_management.obs_dict_replay_buffer import \
2 | ObsDictRelabelingBuffer
3 | from rlkit.torch.her.her import HER
4 | from rlkit.demos.behavior_clone import BehaviorClone
5 |
6 |
7 | class HerBC(HER, BehaviorClone):
8 | def __init__(
9 | self,
10 | *args,
11 | td3_kwargs,
12 | her_kwargs,
13 | base_kwargs,
14 | **kwargs
15 | ):
16 | HER.__init__(
17 | self,
18 | **her_kwargs,
19 | )
20 | BehaviorClone.__init__(self, *args, **kwargs, **td3_kwargs, **base_kwargs)
21 | assert isinstance(
22 | self.replay_buffer, ObsDictRelabelingBuffer
23 | )
24 |
25 | def _handle_rollout_ending(self):
26 | """Don't add anything to rollout buffer"""
27 | self._n_rollouts_total += 1
28 | # if len(self._current_path_builder) > 0:
29 | # path = self._current_path_builder.get_all_stacked()
30 | # self.replay_buffer.add_path(path)
31 | # self._exploration_paths.append(path)
32 | # self._current_path_builder = PathBuilder()
33 |
34 | def _handle_path(self, path):
35 | """Don't add anything to rollout buffer"""
36 | self._n_rollouts_total += 1
37 | # self.replay_buffer.add_path(path)
38 | self._exploration_paths.append(path)
39 |
--------------------------------------------------------------------------------
/rlkit/envs/mujoco/hopper_env.py:
--------------------------------------------------------------------------------
1 | from collections import OrderedDict
2 |
3 | from gym.envs.mujoco import HopperEnv as GymHopperEnv
4 |
5 | from rlkit.core import logger as default_logger
6 | from rlkit.core.eval_util import get_stat_in_paths, create_stats_ordered_dict
7 |
8 |
9 | class HopperEnv(GymHopperEnv):
10 | def _step(self, a):
11 | ob, reward, done, _ = super()._step(a)
12 | posafter, height, ang = self.model.data.qpos[0:3, 0]
13 | return ob, reward, done, {
14 | 'posafter': posafter,
15 | 'height': height,
16 | 'angle': ang,
17 | }
18 |
19 | def log_diagnostics(self, paths, logger=default_logger):
20 | statistics = OrderedDict()
21 | for name_in_env_infos, name_to_log in [
22 | ('posafter', 'Position'),
23 | ('height', 'Height'),
24 | ('angle', 'Angle'),
25 | ]:
26 | stats = get_stat_in_paths(paths, 'env_infos', name_in_env_infos)
27 | statistics.update(create_stats_ordered_dict(
28 | name_to_log,
29 | stats,
30 | ))
31 | statistics.update(create_stats_ordered_dict(
32 | "Final " + name_to_log,
33 | [s[-1] for s in stats],
34 | ))
35 | for key, value in statistics.items():
36 | logger.record_tabular(key, value)
37 |
38 |
--------------------------------------------------------------------------------
/rlkit/envs/wrappers/stack_observation_env.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from gym.spaces import Box
3 |
4 | from rlkit.envs.proxy_env import ProxyEnv
5 |
6 |
7 | class StackObservationEnv(ProxyEnv):
8 | """
9 | Env wrapper for passing history of observations as the new observation
10 | """
11 |
12 | def __init__(
13 | self,
14 | env,
15 | stack_obs=1,
16 | ):
17 | ProxyEnv.__init__(self, env)
18 | self.stack_obs = stack_obs
19 | low = env.observation_space.low
20 | high = env.observation_space.high
21 | self.obs_dim = low.size
22 | self._last_obs = np.zeros((self.stack_obs, self.obs_dim))
23 | self.observation_space = Box(
24 | low=np.repeat(low, stack_obs),
25 | high=np.repeat(high, stack_obs),
26 | )
27 |
28 | def reset(self):
29 | self._last_obs = np.zeros((self.stack_obs, self.obs_dim))
30 | next_obs = self._wrapped_env.reset()
31 | self._last_obs[-1, :] = next_obs
32 | return self._last_obs.copy().flatten()
33 |
34 | def step(self, action):
35 | next_obs, reward, done, info = self._wrapped_env.step(action)
36 | self._last_obs = np.vstack((
37 | self._last_obs[1:, :],
38 | next_obs
39 | ))
40 | return self._last_obs.copy().flatten(), reward, done, info
41 |
42 |
43 |
--------------------------------------------------------------------------------
/rlkit/samplers/in_place.py:
--------------------------------------------------------------------------------
1 | from rlkit.samplers.util import rollout
2 |
3 |
4 | class InPlacePathSampler(object):
5 | """
6 | A sampler that does not serialization for sampling. Instead, it just uses
7 | the current policy and environment as-is.
8 |
9 | WARNING: This will affect the environment! So
10 | ```
11 | sampler = InPlacePathSampler(env, ...)
12 | sampler.obtain_samples # this has side-effects: env will change!
13 | ```
14 | """
15 | def __init__(self, env, policy, max_samples, max_path_length, render=False):
16 | self.env = env
17 | self.policy = policy
18 | self.max_path_length = max_path_length
19 | self.max_samples = max_samples
20 | self.render = render
21 | assert max_samples >= max_path_length, "Need max_samples >= max_path_length"
22 |
23 | def start_worker(self):
24 | pass
25 |
26 | def shutdown_worker(self):
27 | pass
28 |
29 | def obtain_samples(self):
30 | paths = []
31 | n_steps_total = 0
32 | while n_steps_total + self.max_path_length <= self.max_samples:
33 | path = rollout(
34 | self.env, self.policy, max_path_length=self.max_path_length,
35 | animated=self.render
36 | )
37 | paths.append(path)
38 | n_steps_total += len(path['observations'])
39 | return paths
40 |
--------------------------------------------------------------------------------
/rlkit/exploration_strategies/gaussian_and_epsilon.py:
--------------------------------------------------------------------------------
1 | import random
2 | from rlkit.exploration_strategies.base import RawExplorationStrategy
3 | import numpy as np
4 |
5 |
6 | class GaussianAndEpsilonStrategy(RawExplorationStrategy):
7 | """
8 | With probability epsilon, take a completely random action.
9 | with probability 1-epsilon, add Gaussian noise to the action taken by a
10 | deterministic policy.
11 | """
12 | def __init__(self, action_space, epsilon, max_sigma=1.0, min_sigma=None,
13 | decay_period=1000000):
14 | assert len(action_space.shape) == 1
15 | if min_sigma is None:
16 | min_sigma = max_sigma
17 | self._max_sigma = max_sigma
18 | self._epsilon = epsilon
19 | self._min_sigma = min_sigma
20 | self._decay_period = decay_period
21 | self._action_space = action_space
22 |
23 | def get_action_from_raw_action(self, action, t=None, **kwargs):
24 | if random.random() < self._epsilon:
25 | return self._action_space.sample()
26 | else:
27 | sigma = self._max_sigma - (self._max_sigma - self._min_sigma) * min(1.0, t * 1.0 / self._decay_period)
28 | return np.clip(
29 | action + np.random.normal(size=len(action)) * sigma,
30 | self._action_space.low,
31 | self._action_space.high,
32 | )
33 |
--------------------------------------------------------------------------------
/rlkit/core/ray_csv_logger.py:
--------------------------------------------------------------------------------
1 | from collections import OrderedDict
2 | import os
3 |
4 | from ray.tune.logger import CSVLogger
5 |
6 |
7 | class SequentialCSVLogger(CSVLogger):
8 | """CSVLogger to be used with SequentialRayExperiment
9 |
10 | on receiving a log_dict with next_algo=True, a new csv progress file we be
11 | used.
12 | """
13 | def _init(self):
14 | self.created_logger = False
15 | # We need to create an initial self._file to make Ray happy...
16 | self.setup_new_logger('temp_progress.csv')
17 |
18 | def setup_new_logger(self, csv_fname):
19 | self.csv_fname = csv_fname
20 | if self.created_logger:
21 | self._file.close()
22 | self.created_logger = True
23 | progress_file = os.path.join(self.logdir, csv_fname)
24 | self._continuing = os.path.exists(progress_file)
25 | self._file = open(progress_file, "a")
26 | self._csv_out = None
27 |
28 | def on_result(self, result):
29 | if 'log_fname' in result and result['log_fname'] != self.csv_fname:
30 | self.setup_new_logger(result['log_fname'])
31 | """
32 | Sort the keys to enforce deterministic ordering during resuming.
33 | """
34 | sorted_result = OrderedDict()
35 | for key in sorted(result.keys()):
36 | sorted_result[key] = result[key]
37 | super().on_result(sorted_result)
38 |
--------------------------------------------------------------------------------
/rlkit/envs/pearl_envs/ant_multitask_base.py:
--------------------------------------------------------------------------------
1 | from rlkit.envs.pearl_envs.ant import AntEnv
2 |
3 |
4 | class MultitaskAntEnv(AntEnv):
5 | def __init__(self, task={}, n_tasks=2, **kwargs):
6 | self._task = task
7 | self.tasks = self.sample_tasks(n_tasks)
8 | self._goal = self.tasks[0]['goal']
9 | super(MultitaskAntEnv, self).__init__(**kwargs)
10 |
11 | """
12 | def step(self, action):
13 | xposbefore = self.sim.data.qpos[0]
14 | self.do_simulation(action, self.frame_skip)
15 | xposafter = self.sim.data.qpos[0]
16 |
17 | forward_vel = (xposafter - xposbefore) / self.dt
18 | forward_reward = -1.0 * abs(forward_vel - self._goal_vel)
19 | ctrl_cost = 0.5 * 1e-1 * np.sum(np.square(action))
20 |
21 | observation = self._get_obs()
22 | reward = forward_reward - ctrl_cost
23 | done = False
24 | infos = dict(reward_forward=forward_reward,
25 | reward_ctrl=-ctrl_cost, task=self._task)
26 | return (observation, reward, done, infos)
27 | """
28 |
29 |
30 | def get_all_task_idx(self):
31 | return range(len(self.tasks))
32 |
33 | def reset_task(self, idx):
34 | try:
35 | self._task = self.tasks[idx]
36 | except IndexError as e:
37 | import ipdb; ipdb.set_trace()
38 | self._goal = self._task['goal'] # assume parameterization of task by single vector
39 | self.reset()
40 |
--------------------------------------------------------------------------------
/rlkit/data_management/path_builder.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 |
4 | class PathBuilder(dict):
5 | """
6 | Usage:
7 | ```
8 | path_builder = PathBuilder()
9 | path.add_sample(
10 | observations=1,
11 | actions=2,
12 | next_observations=3,
13 | ...
14 | )
15 | path.add_sample(
16 | observations=4,
17 | actions=5,
18 | next_observations=6,
19 | ...
20 | )
21 |
22 | path = path_builder.get_all_stacked()
23 |
24 | path['observations']
25 | # output: [1, 4]
26 | path['actions']
27 | # output: [2, 5]
28 | ```
29 |
30 | Note that the key should be "actions" and not "action" since the
31 | resulting dictionary will have those keys.
32 | """
33 |
34 | def __init__(self):
35 | super().__init__()
36 | self._path_length = 0
37 |
38 | def add_all(self, **key_to_value):
39 | for k, v in key_to_value.items():
40 | if k not in self:
41 | self[k] = [v]
42 | else:
43 | self[k].append(v)
44 | self._path_length += 1
45 |
46 | def get_all_stacked(self):
47 | output_dict = dict()
48 | for k, v in self.items():
49 | output_dict[k] = stack_list(v)
50 | return output_dict
51 |
52 | def __len__(self):
53 | return self._path_length
54 |
55 |
56 | def stack_list(lst):
57 | if isinstance(lst[0], dict):
58 | return lst
59 | else:
60 | return np.array(lst)
61 |
--------------------------------------------------------------------------------
/rlkit/utils/path_builder.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | global paths
4 |
5 |
6 | class PathBuilder(dict):
7 | """
8 | Usage:
9 | ```
10 | path_builder = PathBuilder()
11 | path.add_sample(
12 | observations=1,
13 | actions=2,
14 | next_observations=3,
15 | ...
16 | )
17 | path.add_sample(
18 | observations=4,
19 | actions=5,
20 | next_observations=6,
21 | ...
22 | )
23 |
24 | path = path_builder.get_all_stacked()
25 |
26 | path['observations']
27 | # output: [1, 4]
28 | path['actions']
29 | # output: [2, 5]
30 | ```
31 |
32 | Note that the key should be "actions" and not "action" since the
33 | resulting dictionary will have those keys.
34 | """
35 |
36 | def __init__(self):
37 | super().__init__()
38 | self._path_length = 0
39 |
40 | def add_all(self, **key_to_value):
41 | for k, v in key_to_value.items():
42 | if k not in self:
43 | self[k] = [v]
44 | else:
45 | self[k].append(v)
46 | self._path_length += 1
47 |
48 | def get_all_stacked(self):
49 | output_dict = dict()
50 | for k, v in self.items():
51 | output_dict[k] = stack_list(v)
52 | return output_dict
53 |
54 | def __len__(self):
55 | return self._path_length
56 |
57 |
58 | def stack_list(lst):
59 | if isinstance(lst[0], dict):
60 | return lst
61 | else:
62 | return np.array(lst)
63 |
--------------------------------------------------------------------------------
/rlkit/demos/source/hand_demo_source.py:
--------------------------------------------------------------------------------
1 | from rlkit.demos.source.demo_source import DemoSource
2 | import pickle
3 |
4 | from rlkit.data_management.path_builder import PathBuilder
5 |
6 | from rlkit.util.io import load_local_or_remote_file
7 |
8 | class HandDemoSource(DemoSource):
9 | def __init__(self, filename):
10 | self.data = load_local_or_remote_file(filename)
11 |
12 | def load_paths(self):
13 | paths = []
14 | for i in range(len(self.data)):
15 | p = self.data[i]
16 | H = len(p["observations"]) - 1
17 |
18 | path_builder = PathBuilder()
19 |
20 | for t in range(H):
21 | p["observations"][t]
22 |
23 | ob = path["observations"][t, :]
24 | action = path["actions"][t, :]
25 | reward = path["rewards"][t]
26 | next_ob = path["observations"][t+1, :]
27 | terminal = 0
28 | agent_info = {} # todo (need to unwrap each key)
29 | env_info = {} # todo (need to unwrap each key)
30 |
31 | path_builder.add_all(
32 | observations=ob,
33 | actions=action,
34 | rewards=reward,
35 | next_observations=next_ob,
36 | terminals=terminal,
37 | agent_infos=agent_info,
38 | env_infos=env_info,
39 | )
40 |
41 | path = path_builder.get_all_stacked()
42 | paths.append(path)
43 | return paths
44 |
--------------------------------------------------------------------------------
/rlkit/envs/simple/point.py:
--------------------------------------------------------------------------------
1 | import os
2 | from gym import spaces
3 | import numpy as np
4 | import gym
5 |
6 |
7 | class Point(gym.Env):
8 | """Superclass for all MuJoCo environments.
9 | """
10 |
11 | def __init__(self, n=2, action_scale=0.2, fixed_goal=None):
12 | self.fixed_goal = fixed_goal
13 | self.n = n
14 | self.action_scale = action_scale
15 | self.goal = np.zeros((n,))
16 | self.state = np.zeros((n,))
17 |
18 | @property
19 | def action_space(self):
20 | return spaces.Box(
21 | low=-1*np.ones((self.n,)),
22 | high=1*np.ones((self.n,))
23 | )
24 |
25 | @property
26 | def observation_space(self):
27 | return spaces.Box(
28 | low=-5*np.ones((2 * self.n,)),
29 | high=5*np.ones((2 * self.n,))
30 | )
31 |
32 | def reset(self):
33 | self.state = np.zeros((self.n,))
34 | if self.fixed_goal is None:
35 | self.goal = np.random.uniform(-5, 5, size=(self.n,))
36 | else:
37 | self.goal = np.array(self.fixed_goal)
38 | return self._get_obs()
39 |
40 | def step(self, action):
41 | action = np.clip(action, -1, 1) * self.action_scale
42 | new_state = self.state + action
43 | new_state = np.clip(new_state, -5, 5)
44 | self.state = new_state
45 | reward = -np.linalg.norm(new_state - self.goal)
46 |
47 | return self._get_obs(), reward, False, {}
48 |
49 | def _get_obs(self):
50 | return np.concatenate([self.state, self.goal])
51 |
--------------------------------------------------------------------------------
/rlkit/data_management/wrappers/proxy_buffer.py:
--------------------------------------------------------------------------------
1 | from rlkit.data_management.replay_buffer import ReplayBuffer
2 |
3 | class ProxyBuffer(ReplayBuffer):
4 | def __init__(self, replay_buffer):
5 | self._wrapped_buffer = replay_buffer
6 |
7 | def add_sample(self, *args, **kwargs):
8 | self._wrapped_buffer.add_sample(*args, **kwargs)
9 |
10 | def terminate_episode(self, *args, **kwargs):
11 | self._wrapped_buffer.terminate_episode(*args, **kwargs)
12 |
13 | def num_steps_can_sample(self, *args, **kwargs):
14 | self._wrapped_buffer.num_steps_can_sample(*args, **kwargs)
15 |
16 | def add_path(self, *args, **kwargs):
17 | self._wrapped_buffer.add_path(*args, **kwargs)
18 |
19 | def add_paths(self, *args, **kwargs):
20 | self._wrapped_buffer.add_paths(*args, **kwargs)
21 |
22 | def random_batch(self, *args, **kwargs):
23 | self._wrapped_buffer.random_batch(*args, **kwargs)
24 |
25 | def get_diagnostics(self, *args, **kwargs):
26 | return self._wrapped_buffer.get_diagnostics(*args, **kwargs)
27 |
28 | def get_snapshot(self, *args, **kwargs):
29 | return self._wrapped_buffer.get_snapshot(*args, **kwargs)
30 |
31 | def end_epoch(self, *args, **kwargs):
32 | return self._wrapped_buffer.end_epoch(*args, **kwargs)
33 |
34 | @property
35 | def wrapped_buffer(self):
36 | return self._wrapped_buffer
37 |
38 | def __getattr__(self, attr):
39 | if attr == '_wrapped_buffer':
40 | raise AttributeError()
41 | return getattr(self._wrapped_buffer, attr)
42 |
--------------------------------------------------------------------------------
/rlkit/envs/wrappers/history_env.py:
--------------------------------------------------------------------------------
1 | from collections import deque
2 |
3 | import numpy as np
4 | from gym import Env
5 | from gym.spaces import Box
6 |
7 | from rlkit.envs.proxy_env import ProxyEnv
8 |
9 |
10 | class HistoryEnv(ProxyEnv, Env):
11 | def __init__(self, wrapped_env, history_len):
12 | super().__init__(wrapped_env)
13 | self.history_len = history_len
14 |
15 | high = np.inf * np.ones(
16 | self.history_len * self.observation_space.low.size)
17 | low = -high
18 | self.observation_space = Box(low=low,
19 | high=high,
20 | )
21 | self.history = deque(maxlen=self.history_len)
22 |
23 | def step(self, action):
24 | state, reward, done, info = super().step(action)
25 | self.history.append(state)
26 | flattened_history = self._get_history().flatten()
27 | return flattened_history, reward, done, info
28 |
29 | def reset(self, **kwargs):
30 | state = super().reset()
31 | self.history = deque(maxlen=self.history_len)
32 | self.history.append(state)
33 | flattened_history = self._get_history().flatten()
34 | return flattened_history
35 |
36 | def _get_history(self):
37 | observations = list(self.history)
38 |
39 | obs_count = len(observations)
40 | for _ in range(self.history_len - obs_count):
41 | dummy = np.zeros(self._wrapped_env.observation_space.low.size)
42 | observations.append(dummy)
43 | return np.c_[observations]
44 |
45 |
46 |
--------------------------------------------------------------------------------
/rlkit/core/timer.py:
--------------------------------------------------------------------------------
1 | import time
2 |
3 | from collections import defaultdict
4 |
5 |
6 | class Timer:
7 | def __init__(self, return_global_times=False):
8 | self.stamps = None
9 | self.epoch_start_time = None
10 | self.global_start_time = time.time()
11 | self._return_global_times = return_global_times
12 |
13 | self.reset()
14 |
15 | def reset(self):
16 | self.stamps = defaultdict(lambda: 0)
17 | self.start_times = {}
18 | self.epoch_start_time = time.time()
19 |
20 | def start_timer(self, name, unique=True):
21 | if unique:
22 | assert name not in self.start_times.keys()
23 | self.start_times[name] = time.time()
24 |
25 | def stop_timer(self, name):
26 | assert name in self.start_times.keys()
27 | start_time = self.start_times[name]
28 | end_time = time.time()
29 | self.stamps[name] += (end_time - start_time)
30 |
31 | def get_times(self):
32 | global_times = {}
33 | cur_time = time.time()
34 | global_times['epoch_time'] = (cur_time - self.epoch_start_time)
35 | if self._return_global_times:
36 | global_times['global_time'] = (cur_time - self.global_start_time)
37 | return {
38 | **self.stamps.copy(),
39 | **global_times,
40 | }
41 |
42 | @property
43 | def return_global_times(self):
44 | return self._return_global_times
45 |
46 | @return_global_times.setter
47 | def return_global_times(self, value):
48 | self._return_global_times = value
49 |
50 |
51 | timer = Timer()
52 |
--------------------------------------------------------------------------------
/rlkit/data_management/dataset_logger_fn.py:
--------------------------------------------------------------------------------
1 | """Creates a callable that can be passed as a post_epoch_fn to an algorithm
2 | for evaluating on specific data."""
3 |
4 | from rlkit.core.logging import add_prefix
5 |
6 | from rlkit.torch.core import np_to_pytorch_batch
7 | import rlkit.torch.pytorch_util as ptu
8 |
9 | import numpy as np
10 |
11 | class DatasetLoggerFn:
12 | def __init__(self, dataset, fn, prefix="", batch_size=64, *args, **kwargs):
13 | self.dataset = dataset
14 | self.fn = fn
15 | self.prefix = prefix
16 | self.batch_size = batch_size
17 | self.args = args
18 | self.kwargs = kwargs
19 |
20 | def __call__(self, algo):
21 | batch = self.dataset.random_batch(self.batch_size)
22 | batch = np_to_pytorch_batch(batch)
23 | log_dict = self.fn(batch, *self.args, **self.kwargs)
24 | return add_prefix(log_dict, self.prefix)
25 |
26 | def run_bc_batch(batch, policy):
27 | o = batch["observations"]
28 | u = batch["actions"]
29 | # g = batch["resampled_goals"]
30 | # og = torch.cat((o, g), dim=1)
31 | og = o
32 | # pred_u, *_ = self.policy(og)
33 | dist = policy(og)
34 | pred_u, log_pi = dist.rsample_and_logprob()
35 | stats = dist.get_diagnostics()
36 |
37 | mse = (pred_u - u) ** 2
38 | mse_loss = np.mean(ptu.get_numpy(mse.mean()))
39 |
40 | policy_logpp = dist.log_prob(u, )
41 | logp_loss = -policy_logpp.mean()
42 | policy_loss = np.mean(ptu.get_numpy(logp_loss))
43 |
44 | return dict(
45 | bc_loss=policy_loss,
46 | mse_loss=mse_loss,
47 | **stats
48 | )
--------------------------------------------------------------------------------
/rlkit/data_management/wrappers/replay_buffer_wrapper.py:
--------------------------------------------------------------------------------
1 | from rlkit.data_management.replay_buffer import ReplayBuffer
2 |
3 | class ProxyBuffer(ReplayBuffer):
4 | def __init__(self, replay_buffer):
5 | self._wrapped_buffer = wrapped_buffer
6 |
7 | def add_sample(self, *args, **kwargs):
8 | self._wrapped_buffer.add_sample(*args, **kwargs)
9 |
10 | def terminate_episode(self, *args, **kwargs):
11 | self._wrapped_buffer.terminate_episode(*args, **kwargs)
12 |
13 | def num_steps_can_sample(self, *args, **kwargs):
14 | self._wrapped_buffer.num_steps_can_sample(*args, **kwargs)
15 |
16 | def add_path(self, *args, **kwargs):
17 | self._wrapped_buffer.add_path(*args, **kwargs)
18 |
19 | def add_paths(self, *args, **kwargs):
20 | self._wrapped_buffer.add_paths(*args, **kwargs)
21 |
22 | @abc.abstractmethod
23 | def random_batch(self, *args, **kwargs):
24 | self._wrapped_buffer.random_batch(*args, **kwargs)
25 |
26 | def get_diagnostics(self, *args, **kwargs):
27 | return self._wrapped_buffer.get_diagnostics(*args, **kwargs)
28 |
29 | def get_snapshot(self, *args, **kwargs):
30 | return self._wrapped_buffer.get_snapshot(*args, **kwargs)
31 |
32 | def end_epoch(self, *args, **kwargs):
33 | return self._wrapped_buffer.end_epoch(*args, **kwargs)
34 |
35 | @property
36 | def wrapped_buffer(self):
37 | return self._wrapped_buffer
38 |
39 | def __getattr__(self, attr):
40 | if attr == '_wrapped_buffer':
41 | raise AttributeError()
42 | return getattr(self._wrapped_buffer, attr)
43 |
--------------------------------------------------------------------------------
/rlkit/samplers/util.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from rlkit.samplers.rollout_functions import deprecated_rollout as normal_rollout
3 |
4 | def rollout(*args, **kwargs):
5 | # TODO Steven: remove pointer
6 | return normal_rollout(*args, **kwargs)
7 |
8 | def split_paths(paths):
9 | """
10 | Stack multiples obs/actions/etc. from different paths
11 | :param paths: List of paths, where one path is something returned from
12 | the rollout functino above.
13 | :return: Tuple. Every element will have shape batch_size X DIM, including
14 | the rewards and terminal flags.
15 | """
16 | rewards = [path["rewards"].reshape(-1, 1) for path in paths]
17 | terminals = [path["terminals"].reshape(-1, 1) for path in paths]
18 | actions = [path["actions"] for path in paths]
19 | obs = [path["observations"] for path in paths]
20 | next_obs = [path["next_observations"] for path in paths]
21 | rewards = np.vstack(rewards)
22 | terminals = np.vstack(terminals)
23 | obs = np.vstack(obs)
24 | actions = np.vstack(actions)
25 | next_obs = np.vstack(next_obs)
26 | assert len(rewards.shape) == 2
27 | assert len(terminals.shape) == 2
28 | assert len(obs.shape) == 2
29 | assert len(actions.shape) == 2
30 | assert len(next_obs.shape) == 2
31 | return rewards, terminals, obs, actions, next_obs
32 |
33 |
34 | def split_paths_to_dict(paths):
35 | rewards, terminals, obs, actions, next_obs = split_paths(paths)
36 | return dict(
37 | rewards=rewards,
38 | terminals=terminals,
39 | observations=obs,
40 | actions=actions,
41 | next_observations=next_obs,
42 | )
43 |
--------------------------------------------------------------------------------
/rlkit/demos/spacemouse/README.md:
--------------------------------------------------------------------------------
1 | # Spacemouse Demonstrations
2 |
3 | This code allows an agent (including physical robots) to be controlled by a 7 degree-of-freedom 3Dconnexion Spacemouse device.
4 |
5 | ## Setup Instructions
6 |
7 | ### Spacemouse (on a Mac)
8 |
9 | 1. Clone robosuite ([https://github.com/anair13/robosuite](https://github.com/anair13/robosuite)) and add it to the python path
10 | 2. Run `pip install hidapi` and install Spacemouse drivers
11 | 2. Ensure you can run the following file (and
12 | see input values from the spacemouse): `robosuite/devices/spacemouse.py`
13 | 4. Follow the example in `railrl/demos/collect_demo.py` to collect demonstrations
14 |
15 | ### Server
16 | We haven't been able to install Spacemouse drivers for Linux but instead we use a Spacemouse on a Mac ("client") and send messages over a network to a Linux machine ("server").
17 |
18 | #### Setup
19 | On the client, run the setup above. On the server, run:
20 | 1. Run `pip install Pyro4`
21 | 2. Make sure the hostname in `railrl/demos/spacemouse/config.py` is correct (in the example I use gauss1.banatao.berkeley.edu). This hostname needs to be visible (eg. you can ping it) from both the client and server
22 |
23 | #### Run
24 |
25 | 1. On the server, start the nameserver:
26 | ```export PYRO_SERIALIZERS_ACCEPTED=serpent,json,marshal,pickle
27 | python -m Pyro4.naming -n euler1.dyn.berkeley.edu
28 | ```
29 | 2. On the server, run a script that uses the `SpaceMouseExpert` imported from `railrl/demos/spacemouse/input_server.py` such as ```python experiments/ashvin/iros2019/collect_demos_spacemouse.py```
30 | 2. On the client, run ```python railrl/demos/spacemouse/input_client.py```
31 |
--------------------------------------------------------------------------------
/rlkit/launchers/contextual/util.py:
--------------------------------------------------------------------------------
1 | from os import path as osp
2 |
3 | from gym.wrappers import TimeLimit
4 |
5 | from rlkit.core import logger
6 | from rlkit.visualization.video import dump_video
7 |
8 |
9 | def get_save_video_function(
10 | rollout_function,
11 | env,
12 | policy,
13 | save_video_period=10,
14 | imsize=48,
15 | tag="",
16 | **dump_video_kwargs
17 | ):
18 | logdir = logger.get_snapshot_dir()
19 |
20 | def save_video(algo, epoch):
21 | if epoch % save_video_period == 0 or epoch >= algo.num_epochs - 1:
22 | if tag is not None and len(tag) > 0:
23 | filename = 'video_{}_{epoch}_env.mp4'.format(tag, epoch=epoch)
24 | else:
25 | filename = 'video_{epoch}_env.mp4'.format(epoch=epoch)
26 | filename = osp.join(
27 | logdir,
28 | filename,
29 | )
30 | dump_video(env, policy, filename, rollout_function,
31 | imsize=imsize, **dump_video_kwargs)
32 | return save_video
33 |
34 |
35 | def get_gym_env(
36 | env_id,
37 | env_class=None,
38 | env_kwargs=None,
39 | unwrap_timed_envs=False,
40 | ):
41 | if env_kwargs is None:
42 | env_kwargs = {}
43 |
44 | assert env_id or env_class
45 |
46 | if env_id:
47 | import gym
48 | import multiworld
49 | multiworld.register_all_envs()
50 | env = gym.make(env_id)
51 | else:
52 | env = env_class(**env_kwargs)
53 |
54 | if isinstance(env, TimeLimit) and unwrap_timed_envs:
55 | env = env.env
56 |
57 | return env
58 |
--------------------------------------------------------------------------------
/rlkit/data_management/images.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 |
4 | class _ImageNumpyArr:
5 | """
6 | Wrapper for a numpy array. This code automatically normalizes/unormalizes
7 | the image internally. This process should be completely hidden from the
8 | user of this class.
9 |
10 | im_arr = image_numpy_wrapper.zeros(10)
11 | # img is normalized. ImageNumpyArr automatically stores as np.uint8
12 | im_arr[2] = img
13 | # ImageNumpyArr automatically normalizes the np.uint8 and returns as np.float
14 | img_2 = im_arr[2]
15 | """
16 |
17 | def __init__(self, np_array):
18 | assert np_array.dtype == np.uint8
19 | self.np_array = np_array
20 | self.shape = self.np_array.shape
21 | self.size = self.np_array.size
22 | self.dtype = np.uint8
23 |
24 | def __getitem__(self, idxs):
25 | return normalize_image(self.np_array[idxs], dtype=np.float32)
26 |
27 | def __setitem__(self, idxs, value):
28 | if value.dtype != np.uint8:
29 | self.np_array[idxs] = unnormalize_image(value)
30 | else:
31 | self.np_array[idxs] = value
32 |
33 |
34 | def zeros(shape, *args, **kwargs):
35 | arr = np.zeros(shape, dtype=np.uint8)
36 | return _ImageNumpyArr(arr)
37 |
38 |
39 | def ones(shape, *args, **kwargs):
40 | arr = np.ones(shape, dtype=np.uint8)
41 | return _ImageNumpyArr(arr)
42 |
43 |
44 | def from_np(np_arr):
45 | return _ImageNumpyArr(np_arr)
46 |
47 |
48 | def normalize_image(image, dtype=np.float64):
49 | assert image.dtype == np.uint8
50 | return dtype(image) / 255.0
51 |
52 |
53 | def unnormalize_image(image):
54 | assert image.dtype != np.uint8
55 | return np.uint8(image * 255.0)
56 |
--------------------------------------------------------------------------------
/rlkit/torch/networks/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | General networks for pytorch.
3 |
4 | Algorithm-specific networks should go else-where.
5 | """
6 | from rlkit.torch.networks.basic import (
7 | Sigmoid, Clamp, SigmoidClamp, ConcatTuple, Detach, Flatten, FlattenEach, Split, Reshape,
8 | )
9 | from rlkit.torch.networks.cnn import BasicCNN, CNN, MergedCNN, CNNPolicy, TwoChannelCNN, ConcatTwoChannelCNN
10 | from rlkit.torch.networks.dcnn import DCNN, TwoHeadDCNN
11 | from rlkit.torch.networks.deprecated_feedforward import (
12 | FeedForwardPolicy, FeedForwardQFunction
13 | )
14 | from rlkit.torch.networks.feat_point_mlp import FeatPointMlp
15 | from rlkit.torch.networks.image_state import ImageStatePolicy, ImageStateQ
16 | from rlkit.torch.networks.linear_transform import LinearTransform
17 | from rlkit.torch.networks.mlp import (
18 | Mlp, ConcatMlp, MlpPolicy, TanhMlpPolicy,
19 | MlpQf,
20 | MlpQfWithObsProcessor,
21 | ConcatMultiHeadedMlp,
22 | )
23 | from rlkit.torch.networks.pretrained_cnn import PretrainedCNN
24 | from rlkit.torch.networks.two_headed_mlp import TwoHeadMlp
25 |
26 | __all__ = [
27 | 'Sigmoid',
28 | 'Clamp',
29 | 'SigmoidClamp',
30 | 'ConcatMlp',
31 | 'ConcatMultiHeadedMlp',
32 | 'ConcatTuple',
33 | 'BasicCNN',
34 | 'CNN',
35 | 'TwoChannelCNN',
36 | "ConcatTwoChannelCNN",
37 | 'CNNPolicy',
38 | 'DCNN',
39 | 'Detach',
40 | 'FeedForwardPolicy',
41 | 'FeedForwardQFunction',
42 | 'FeatPointMlp',
43 | 'Flatten',
44 | 'FlattenEach',
45 | 'LinearTransform',
46 | 'ImageStatePolicy',
47 | 'ImageStateQ',
48 | 'MergedCNN',
49 | 'Mlp',
50 | 'PretrainedCNN',
51 | 'Reshape',
52 | 'Split',
53 | 'TwoHeadDCNN',
54 | 'TwoHeadMlp',
55 | ]
56 |
57 |
--------------------------------------------------------------------------------
/rlkit/envs/assets/twod_point.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
--------------------------------------------------------------------------------
/rlkit/envs/assets/twod_point_random_init.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
--------------------------------------------------------------------------------
/rlkit/samplers/data_collector/contextual_path_collector.py:
--------------------------------------------------------------------------------
1 | from collections import deque
2 | from functools import partial
3 |
4 | from rlkit.envs.contextual import ContextualEnv
5 | from rlkit.policies.base import Policy
6 | from rlkit.samplers.data_collector import MdpPathCollector
7 | from rlkit.samplers.rollout_functions import contextual_rollout
8 |
9 |
10 | class ContextualPathCollector(MdpPathCollector):
11 | def __init__(
12 | self,
13 | env: ContextualEnv,
14 | policy: Policy,
15 | max_num_epoch_paths_saved=None,
16 | observation_keys=('observation',),
17 | context_keys_for_policy='context',
18 | render=False,
19 | render_kwargs=None,
20 | obs_processor=None,
21 | rollout=contextual_rollout,
22 | **kwargs
23 | ):
24 | rollout_fn = partial(
25 | rollout,
26 | context_keys_for_policy=context_keys_for_policy,
27 | observation_keys=observation_keys,
28 | obs_processor=obs_processor,
29 | )
30 | super().__init__(
31 | env, policy,
32 | max_num_epoch_paths_saved=max_num_epoch_paths_saved,
33 | render=render,
34 | render_kwargs=render_kwargs,
35 | rollout_fn=rollout_fn,
36 | **kwargs
37 | )
38 | self._observation_keys = observation_keys
39 | self._context_keys_for_policy = context_keys_for_policy
40 |
41 | def end_epoch(self, epoch):
42 | self._epoch_paths = deque(maxlen=self._max_num_epoch_paths_saved)
43 |
44 | def get_snapshot(self):
45 | snapshot = super().get_snapshot()
46 | snapshot.update(
47 | observation_keys=self._observation_keys,
48 | context_keys_for_policy=self._context_keys_for_policy,
49 | )
50 | return snapshot
51 |
--------------------------------------------------------------------------------
/rlkit/envs/pearl_envs/ant_goal.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | from rlkit.envs.pearl_envs.ant_multitask_base import MultitaskAntEnv
4 |
5 |
6 | # Copy task structure from https://github.com/jonasrothfuss/ProMP/blob/master/meta_policy_search/envs/mujoco_envs/ant_rand_goal.py
7 | class AntGoalEnv(MultitaskAntEnv):
8 | def __init__(self, task={}, n_tasks=2, randomize_tasks=True, **kwargs):
9 | self.quick_init(locals())
10 | super(AntGoalEnv, self).__init__(task, n_tasks, **kwargs)
11 |
12 | def step(self, action):
13 | self.do_simulation(action, self.frame_skip)
14 | xposafter = np.array(self.get_body_com("torso"))
15 |
16 | goal_reward = -np.sum(np.abs(xposafter[:2] - self._goal)) # make it happy, not suicidal
17 |
18 | ctrl_cost = .1 * np.square(action).sum()
19 | contact_cost = 0.5 * 1e-3 * np.sum(
20 | np.square(np.clip(self.sim.data.cfrc_ext, -1, 1)))
21 | survive_reward = 0.0
22 | reward = goal_reward - ctrl_cost - contact_cost + survive_reward
23 | state = self.state_vector()
24 | done = False
25 | ob = self._get_obs()
26 | return ob, reward, done, dict(
27 | goal_forward=goal_reward,
28 | reward_ctrl=-ctrl_cost,
29 | reward_contact=-contact_cost,
30 | reward_survive=survive_reward,
31 | )
32 |
33 | def sample_tasks(self, num_tasks):
34 | a = np.random.random(num_tasks) * 2 * np.pi
35 | r = 3 * np.random.random(num_tasks) ** 0.5
36 | goals = np.stack((r * np.cos(a), r * np.sin(a)), axis=-1)
37 | tasks = [{'goal': goal} for goal in goals]
38 | return tasks
39 |
40 | def _get_obs(self):
41 | return np.concatenate([
42 | self.sim.data.qpos.flat,
43 | self.sim.data.qvel.flat,
44 | np.clip(self.sim.data.cfrc_ext, -1, 1).flat,
45 | ])
46 |
--------------------------------------------------------------------------------
/rlkit/envs/robosuite_wrapper.py:
--------------------------------------------------------------------------------
1 | from gym import Env
2 | from gym.spaces import Box
3 | import numpy as np
4 | import robosuite as suite
5 | from rlkit.core.serializeable import Serializable
6 |
7 |
8 | class RobosuiteStateWrapperEnv(Serializable, Env):
9 | def __init__(self, wrapped_env_id, observation_keys=('robot-state', 'object-state'), **wrapped_env_kwargs):
10 | Serializable.quick_init(self, locals())
11 | self._wrapped_env = suite.make(
12 | wrapped_env_id,
13 | **wrapped_env_kwargs
14 | )
15 | self.action_space = Box(self._wrapped_env.action_spec[0], self._wrapped_env.action_spec[1], dtype=np.float32)
16 | observation_dim = self._wrapped_env.observation_spec()['robot-state'].shape[0] \
17 | + self._wrapped_env.observation_spec()['object-state'].shape[0]
18 | self.observation_space = Box(
19 | -np.inf * np.ones(observation_dim),
20 | np.inf * np.ones(observation_dim),
21 | dtype=np.float32,
22 | )
23 | self._observation_keys = observation_keys
24 |
25 | def step(self, action):
26 | obs, reward, done, info = self._wrapped_env.step(action)
27 | obs = self.flatten_dict_obs(obs)
28 | return obs, reward, done, info
29 |
30 | def flatten_dict_obs(self, obs):
31 | obs = np.concatenate(tuple(obs[k] for k in self._observation_keys))
32 | return obs
33 |
34 | def reset(self):
35 | obs = self._wrapped_env.reset()
36 | obs = self.flatten_dict_obs(obs)
37 | return obs
38 |
39 | def render(self):
40 | self._wrapped_env.render()
41 |
42 | def __getattr__(self, attr):
43 | if attr == '_wrapped_env':
44 | raise AttributeError()
45 | return getattr(self._wrapped_env, attr)
46 |
47 | def __str__(self):
48 | return '{}({})'.format(type(self).__name__, self.wrapped_env)
49 |
50 |
--------------------------------------------------------------------------------
/rlkit/envs/assets/twod_maze.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
--------------------------------------------------------------------------------
/rlkit/envs/gripper_state_wrapper.py:
--------------------------------------------------------------------------------
1 | # import Dict
2 |
3 | import numpy as np
4 | from gym.spaces import Box, Dict
5 |
6 | from rlkit.envs.wrappers import ProxyEnv
7 | from roboverse.bullet.misc import quat_to_deg
8 |
9 |
10 | def process_gripper_state(gripper_state):
11 | gripper_pos = gripper_state[:3]
12 | gripper_ori = quat_to_deg(gripper_state[3:7]) / 360.0
13 | gripper_tips_distance = gripper_state[7:8]
14 | return np.concatenate([gripper_pos, gripper_ori, gripper_tips_distance], axis=0)
15 |
16 |
17 | class GripperStateWrappedEnv(ProxyEnv):
18 | def __init__(self,
19 | wrapped_env,
20 | state_key,
21 | step_keys_map=None,
22 | ):
23 | super().__init__(wrapped_env)
24 | self.state_key = state_key
25 | self.gripper_state_size = 7
26 | gripper_state_space = Box(
27 | -1 * np.ones(self.gripper_state_size),
28 | 1 * np.ones(self.gripper_state_size),
29 | dtype=np.float32,
30 | )
31 |
32 | if step_keys_map is None:
33 | step_keys_map = {}
34 | self.step_keys_map = step_keys_map
35 | spaces = self.wrapped_env.observation_space.spaces
36 | for value in self.step_keys_map.values():
37 | spaces[value] = gripper_state_space
38 | self.observation_space = Dict(spaces)
39 |
40 | def step(self, action):
41 | obs, reward, done, info = self.wrapped_env.step(action)
42 | new_obs = self._update_obs(obs)
43 | return new_obs, reward, done, info
44 |
45 | def _update_obs(self, obs):
46 | for key in self.step_keys_map:
47 | value = self.step_keys_map[key]
48 | gripper_state = obs[self.state_key]
49 | process_gripper_state(gripper_state)
50 | obs[value] = process_gripper_state(gripper_state)
51 | obs = {**obs, **self.reset_obs}
52 | return obs
53 |
--------------------------------------------------------------------------------
/rlkit/envs/proxy_env.py:
--------------------------------------------------------------------------------
1 | from gym import Env
2 |
3 |
4 | class ProxyEnv(Env):
5 | def __init__(self, wrapped_env):
6 | self._wrapped_env = wrapped_env
7 | self.action_space = self._wrapped_env.action_space
8 | self.observation_space = self._wrapped_env.observation_space
9 |
10 | @property
11 | def wrapped_env(self):
12 | return self._wrapped_env
13 |
14 | def reset(self, **kwargs):
15 | return self._wrapped_env.reset(**kwargs)
16 |
17 | def step(self, action):
18 | return self._wrapped_env.step(action)
19 |
20 | def render(self, *args, **kwargs):
21 | return self._wrapped_env.render(*args, **kwargs)
22 |
23 | @property
24 | def horizon(self):
25 | return self._wrapped_env.horizon
26 |
27 | def terminate(self):
28 | if hasattr(self.wrapped_env, "terminate"):
29 | self.wrapped_env.terminate()
30 |
31 | def seed(self, _seed):
32 | return self.wrapped_env.seed(_seed)
33 |
34 | def __getattr__(self, attr):
35 | if attr == '_wrapped_env':
36 | raise AttributeError()
37 | if attr == 'planner':
38 | return self._planner
39 | if attr == 'set_vf':
40 | return self.set_vf
41 | return getattr(self._wrapped_env, attr)
42 | # try:
43 | # getattr(self, attr)
44 | # except Exception:
45 | # return getattr(self._wrapped_env, attr)
46 |
47 | def __getstate__(self):
48 | """
49 | This is useful to override in case the wrapped env has some funky
50 | __getstate__ that doesn't play well with overriding __getattr__.
51 |
52 | The main problematic case is/was gym's EzPickle serialization scheme.
53 | :return:
54 | """
55 | return self.__dict__
56 |
57 | def __setstate__(self, state):
58 | self.__dict__.update(state)
59 |
60 | def __str__(self):
61 | return '{}({})'.format(type(self).__name__, self.wrapped_env)
62 |
--------------------------------------------------------------------------------
/rlkit/envs/pearl_envs/ant_dir.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | from rlkit.envs.pearl_envs.ant_multitask_base import MultitaskAntEnv
4 |
5 |
6 | class AntDirEnv(MultitaskAntEnv):
7 |
8 | def __init__(self, task={}, n_tasks=2, forward_backward=False, randomize_tasks=True, **kwargs):
9 | self.quick_init(locals())
10 | self.forward_backward = forward_backward
11 | super(AntDirEnv, self).__init__(task, n_tasks, **kwargs)
12 |
13 | def step(self, action):
14 | torso_xyz_before = np.array(self.get_body_com("torso"))
15 |
16 | direct = (np.cos(self._goal), np.sin(self._goal))
17 |
18 | self.do_simulation(action, self.frame_skip)
19 | torso_xyz_after = np.array(self.get_body_com("torso"))
20 | torso_velocity = torso_xyz_after - torso_xyz_before
21 | forward_reward = np.dot((torso_velocity[:2]/self.dt), direct)
22 |
23 | ctrl_cost = .5 * np.square(action).sum()
24 | contact_cost = 0.5 * 1e-3 * np.sum(
25 | np.square(np.clip(self.sim.data.cfrc_ext, -1, 1)))
26 | survive_reward = 1.0
27 | reward = forward_reward - ctrl_cost - contact_cost + survive_reward
28 | state = self.state_vector()
29 | notdone = np.isfinite(state).all() \
30 | and state[2] >= 0.2 and state[2] <= 1.0
31 | done = not notdone
32 | ob = self._get_obs()
33 | return ob, reward, done, dict(
34 | reward_forward=forward_reward,
35 | reward_ctrl=-ctrl_cost,
36 | reward_contact=-contact_cost,
37 | reward_survive=survive_reward,
38 | torso_velocity=torso_velocity,
39 | )
40 |
41 | def sample_tasks(self, num_tasks):
42 | if self.forward_backward:
43 | assert num_tasks == 2
44 | velocities = np.array([0., np.pi])
45 | else:
46 | velocities = np.random.uniform(0., 2.0 * np.pi, size=(num_tasks,))
47 | tasks = [{'goal': velocity} for velocity in velocities]
48 | return tasks
49 |
--------------------------------------------------------------------------------
/rlkit/envs/images/env_renderer.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from PIL import Image
3 |
4 | from rlkit.envs.images import Renderer
5 |
6 |
7 | class EnvRenderer(Renderer):
8 | # TODO: switch to env.render interface
9 | def __init__(
10 | self,
11 | init_camera=None,
12 | normalize_image=True, # most gym envs output uint8
13 | create_image_format='HWC',
14 | **kwargs
15 | ):
16 | """Render an image."""
17 | super().__init__(
18 | normalize_image=normalize_image,
19 | create_image_format=create_image_format,
20 | **kwargs)
21 | self._init_camera = init_camera
22 | self._camera_is_initialized = False
23 |
24 | def _create_image(self, env):
25 | if not self._camera_is_initialized and self._init_camera is not None:
26 | env.initialize_camera(self._init_camera)
27 | self._camera_is_initialized = True
28 |
29 | return env.get_image(
30 | width=self.width,
31 | height=self.height,
32 | )
33 |
34 |
35 | class GymEnvRenderer(EnvRenderer):
36 | def _create_image(self, env):
37 | if not self._camera_is_initialized and self._init_camera is not None:
38 | env.initialize_camera(self._init_camera)
39 | self._camera_is_initialized = True
40 |
41 | return env.render(
42 | mode='rgb_array', width=self.width, height=self.height
43 | )
44 |
45 |
46 | class GymSimRenderer(EnvRenderer):
47 | def _create_image(self, env):
48 | if not self._camera_is_initialized and self._init_camera is not None:
49 | env.initialize_camera(self._init_camera)
50 | self._camera_is_initialized = True
51 |
52 | return self._get_image(env, self.width, self.height, camera_name=None)
53 |
54 | def _get_image(self, env, width=84, height=84, camera_name=None):
55 | return env.sim.render(
56 | width=width,
57 | height=height,
58 | camera_name=camera_name,
59 | )[::-1,:,:]
60 |
--------------------------------------------------------------------------------
/rlkit/envs/pearl_envs/rand_param_envs/hopper_rand_params.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from gym import utils
3 | from rlkit.envs.pearl_envs.rand_param_envs.base import RandomEnv
4 |
5 |
6 | class HopperRandParamsEnv(RandomEnv, utils.EzPickle):
7 | def __init__(self, log_scale_limit=3.0):
8 | RandomEnv.__init__(self, log_scale_limit, 'hopper.xml', 4)
9 | utils.EzPickle.__init__(self)
10 |
11 | def _step(self, a):
12 | posbefore = self.sim.data.qpos[0]
13 | self.do_simulation(a, self.frame_skip)
14 | posafter, height, ang = self.sim.data.qpos[0:3]
15 | alive_bonus = 1.0
16 | reward = (posafter - posbefore) / self.dt
17 | reward += alive_bonus
18 | reward -= 1e-3 * np.square(a).sum()
19 | s = self.state_vector()
20 | done = not (np.isfinite(s).all() and (np.abs(s[2:]) < 100).all() and
21 | (height > .7) and (abs(ang) < .2))
22 | ob = self._get_obs()
23 | return ob, reward, done, {}
24 |
25 | def _get_obs(self):
26 | return np.concatenate([
27 | self.sim.data.qpos.flat[1:],
28 | np.clip(self.sim.data.qvel.flat, -10, 10)
29 | ])
30 |
31 | def reset_model(self):
32 | qpos = self.init_qpos + self.np_random.uniform(low=-.005, high=.005, size=self.model.nq)
33 | qvel = self.init_qvel + self.np_random.uniform(low=-.005, high=.005, size=self.model.nv)
34 | self.set_state(qpos, qvel)
35 | return self._get_obs()
36 |
37 | def viewer_setup(self):
38 | self.viewer.cam.trackbodyid = 2
39 | self.viewer.cam.distance = self.model.stat.extent * 0.75
40 | self.viewer.cam.lookat[2] += .8
41 | self.viewer.cam.elevation = -20
42 |
43 | if __name__ == "__main__":
44 |
45 | env = HopperRandParamsEnv()
46 | tasks = env.sample_tasks(40)
47 | while True:
48 | env.reset()
49 | env.set_task(np.random.choice(tasks))
50 | print(env.model.body_mass)
51 | for _ in range(100):
52 | env.render()
53 | env.step(env.action_space.sample()) # take a random action
54 |
55 |
--------------------------------------------------------------------------------
/rlkit/envs/assets/water_maze.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
--------------------------------------------------------------------------------
/rlkit/torch/networks/image_state.py:
--------------------------------------------------------------------------------
1 | from rlkit.policies.base import Policy
2 | from rlkit.torch.core import PyTorchModule, eval_np
3 |
4 |
5 | class ImageStatePolicy(PyTorchModule, Policy):
6 | """Switches between image or state inputs"""
7 |
8 | def __init__(
9 | self,
10 | image_conv_net,
11 | state_fc_net,
12 | ):
13 | super().__init__()
14 |
15 | assert image_conv_net is None or state_fc_net is None
16 | self.image_conv_net = image_conv_net
17 | self.state_fc_net = state_fc_net
18 |
19 | def forward(self, input, return_preactivations=False):
20 | if self.image_conv_net is not None:
21 | image = input[:, :21168]
22 | return self.image_conv_net(image)
23 | if self.state_fc_net is not None:
24 | state = input[:, 21168:]
25 | return self.state_fc_net(state)
26 |
27 | def get_action(self, obs_np):
28 | actions = self.get_actions(obs_np[None])
29 | return actions[0, :], {}
30 |
31 | def get_actions(self, obs):
32 | return eval_np(self, obs)
33 |
34 |
35 | class ImageStateQ(PyTorchModule):
36 | """Switches between image or state inputs"""
37 |
38 | def __init__(
39 | self,
40 | # obs_dim,
41 | # action_dim,
42 | # goal_dim,
43 | image_conv_net, # assumed to be a MergedCNN
44 | state_fc_net,
45 | ):
46 | super().__init__()
47 |
48 | assert image_conv_net is None or state_fc_net is None
49 | # self.obs_dim = obs_dim
50 | # self.action_dim = action_dim
51 | # self.goal_dim = goal_dim
52 | self.image_conv_net = image_conv_net
53 | self.state_fc_net = state_fc_net
54 |
55 | def forward(self, input, action, return_preactivations=False):
56 | if self.image_conv_net is not None:
57 | image = input[:, :21168]
58 | return self.image_conv_net(image, action)
59 | if self.state_fc_net is not None:
60 | state = input[:, 21168:] # action + state
61 | return self.state_fc_net(state, action)
62 |
63 |
64 |
--------------------------------------------------------------------------------
/rlkit/envs/pearl_envs/mujoco_env.py:
--------------------------------------------------------------------------------
1 | import os
2 | from os import path
3 |
4 | import mujoco_py
5 | import numpy as np
6 | from gym.envs.mujoco import mujoco_env
7 |
8 | from rlkit.core.serializeable import Serializable
9 |
10 | ENV_ASSET_DIR = os.path.join(os.path.dirname(__file__), 'assets')
11 |
12 |
13 | class MujocoEnv(mujoco_env.MujocoEnv, Serializable):
14 | """
15 | My own wrapper around MujocoEnv.
16 |
17 | The caller needs to declare
18 | """
19 | def __init__(
20 | self,
21 | model_path,
22 | frame_skip=1,
23 | model_path_is_local=True,
24 | automatically_set_obs_and_action_space=False,
25 | ):
26 | if model_path_is_local:
27 | model_path = get_asset_xml(model_path)
28 | if automatically_set_obs_and_action_space:
29 | mujoco_env.MujocoEnv.__init__(self, model_path, frame_skip)
30 | else:
31 | """
32 | Code below is copy/pasted from MujocoEnv's __init__ function.
33 | """
34 | if model_path.startswith("/"):
35 | fullpath = model_path
36 | else:
37 | fullpath = os.path.join(os.path.dirname(__file__), "assets", model_path)
38 | if not path.exists(fullpath):
39 | raise IOError("File %s does not exist" % fullpath)
40 | self.frame_skip = frame_skip
41 | self.model = mujoco_py.MjModel(fullpath)
42 | self.data = self.model.data
43 | self.viewer = None
44 |
45 | self.metadata = {
46 | 'render.modes': ['human', 'rgb_array'],
47 | 'video.frames_per_second': int(np.round(1.0 / self.dt))
48 | }
49 |
50 | self.init_qpos = self.model.data.qpos.ravel().copy()
51 | self.init_qvel = self.model.data.qvel.ravel().copy()
52 | self._seed()
53 |
54 | def init_serialization(self, locals):
55 | Serializable.quick_init(self, locals)
56 |
57 | def log_diagnostics(self, *args, **kwargs):
58 | pass
59 |
60 |
61 | def get_asset_xml(xml_name):
62 | return os.path.join(ENV_ASSET_DIR, xml_name)
63 |
--------------------------------------------------------------------------------
/rlkit/data_management/tau_replay_buffer.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | from rlkit.data_management.her_replay_buffer import HerReplayBuffer
4 | from rlkit.util.np_util import truncated_geometric
5 |
6 |
7 | class TauReplayBuffer(HerReplayBuffer):
8 | def random_batch_random_tau(self, batch_size, max_tau):
9 | indices = np.random.randint(0, self._size, batch_size)
10 | next_obs_idxs = []
11 | for i in indices:
12 | possible_next_obs_idxs = self._idx_to_future_obs_idx[i]
13 | # This is generally faster than random.choice. Makes you wonder what
14 | # random.choice is doing
15 | num_options = len(possible_next_obs_idxs)
16 | tau = np.random.randint(0, min(max_tau+1, num_options))
17 | if num_options == 1:
18 | next_obs_i = 0
19 | else:
20 | if self.resampling_strategy == 'uniform':
21 | next_obs_i = int(np.random.randint(0, tau+1))
22 | elif self.resampling_strategy == 'truncated_geometric':
23 | next_obs_i = int(truncated_geometric(
24 | p=self.truncated_geom_factor/tau,
25 | truncate_threshold=num_options-1,
26 | size=1,
27 | new_value=0
28 | ))
29 | else:
30 | raise ValueError("Invalid resampling strategy: {}".format(
31 | self.resampling_strategy
32 | ))
33 | next_obs_idxs.append(possible_next_obs_idxs[next_obs_i])
34 | next_obs_idxs = np.array(next_obs_idxs)
35 | training_goals = self.env.convert_obs_to_goals(
36 | self._next_obs[next_obs_idxs]
37 | )
38 | return dict(
39 | observations=self._observations[indices],
40 | actions=self._actions[indices],
41 | rewards=self._rewards[indices],
42 | terminals=self._terminals[indices],
43 | next_observations=self._next_obs[indices],
44 | training_goals=training_goals,
45 | num_steps_left=self._num_steps_left[indices],
46 | )
47 |
48 |
--------------------------------------------------------------------------------
/rlkit/data_management/ocm_subtraj_replay_buffer.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | from rlkit.data_management.subtraj_replay_buffer import SubtrajReplayBuffer
4 | from rlkit.data_management.updatable_subtraj_replay_buffer import \
5 | UpdatableSubtrajReplayBuffer
6 | from rlkit.util.np_util import subsequences
7 |
8 |
9 | class OcmSubtrajReplayBuffer(UpdatableSubtrajReplayBuffer):
10 | """
11 | A replay buffer desired specifically for OneCharMem
12 | sub-trajectories
13 | """
14 |
15 | def __init__(
16 | self,
17 | max_replay_buffer_size,
18 | env,
19 | subtraj_length,
20 | *args,
21 | **kwargs
22 | ):
23 | # TODO(vitchyr): Move this logic to environment
24 | self._target_numbers = np.zeros(max_replay_buffer_size, dtype='uint8')
25 | self._times = np.zeros(max_replay_buffer_size, dtype='uint8')
26 | super().__init__(
27 | max_replay_buffer_size,
28 | env,
29 | subtraj_length,
30 | *args,
31 | only_sample_at_start_of_episode=False,
32 | **kwargs
33 | )
34 |
35 | def _add_sample(self, observation, action, reward, terminal,
36 | final_state, agent_info=None, env_info=None):
37 | if env_info is not None:
38 | if 'target_number' in env_info:
39 | self._target_numbers[self._top] = env_info['target_number']
40 | if 'time' in env_info:
41 | self._times[self._top] = env_info['time']
42 | super()._add_sample(
43 | observation,
44 | action,
45 | reward,
46 | terminal,
47 | final_state
48 | )
49 |
50 | def _get_trajectories(self, start_indices):
51 | trajs = super()._get_trajectories(start_indices)
52 | trajs['target_numbers'] = subsequences(
53 | self._target_numbers,
54 | start_indices,
55 | self._subtraj_length,
56 | )
57 | trajs['times'] = subsequences(
58 | self._times,
59 | start_indices,
60 | self._subtraj_length,
61 | )
62 | return trajs
63 |
--------------------------------------------------------------------------------
/rlkit/torch/sac/policies/base.py:
--------------------------------------------------------------------------------
1 | import abc
2 | import logging
3 |
4 | import numpy as np
5 | import torch
6 | import torch.nn.functional as F
7 | from torch import nn as nn
8 |
9 | import rlkit.torch.pytorch_util as ptu
10 | from rlkit.policies.base import ExplorationPolicy
11 | from rlkit.torch.core import torch_ify, elem_or_tuple_to_numpy
12 | from rlkit.torch.distributions import (
13 | Delta, TanhNormal, MultivariateDiagonalNormal, GaussianMixture, GaussianMixtureFull,
14 | )
15 | from rlkit.torch.networks import Mlp, CNN
16 | from rlkit.torch.networks.basic import MultiInputSequential
17 | from rlkit.torch.networks.stochastic.distribution_generator import (
18 | DistributionGenerator
19 | )
20 |
21 |
22 | class TorchStochasticPolicy(
23 | DistributionGenerator,
24 | ExplorationPolicy, metaclass=abc.ABCMeta
25 | ):
26 | def get_action(self, obs_np, ):
27 | actions = self.get_actions(obs_np[None])
28 | return actions[0, :], {}
29 |
30 | def get_actions(self, obs_np, ):
31 | dist = self._get_dist_from_np(obs_np)
32 | actions = dist.sample()
33 | return elem_or_tuple_to_numpy(actions)
34 |
35 | def _get_dist_from_np(self, *args, **kwargs):
36 | torch_args = tuple(torch_ify(x) for x in args)
37 | torch_kwargs = {k: torch_ify(v) for k, v in kwargs.items()}
38 | dist = self(*torch_args, **torch_kwargs)
39 | return dist
40 |
41 |
42 | class PolicyFromDistributionGenerator(
43 | MultiInputSequential,
44 | TorchStochasticPolicy,
45 | ):
46 | """
47 | Usage:
48 | ```
49 | distribution_generator = FancyGenerativeModel()
50 | policy = PolicyFromBatchDistributionModule(distribution_generator)
51 | ```
52 | """
53 | pass
54 |
55 |
56 | class MakeDeterministic(TorchStochasticPolicy):
57 | def __init__(
58 | self,
59 | action_distribution_generator: DistributionGenerator,
60 | ):
61 | super().__init__()
62 | self._action_distribution_generator = action_distribution_generator
63 |
64 | def forward(self, *args, **kwargs):
65 | dist = self._action_distribution_generator.forward(*args, **kwargs)
66 | return Delta(dist.mle_estimate())
67 |
--------------------------------------------------------------------------------
/rlkit/envs/mujoco/ant.py:
--------------------------------------------------------------------------------
1 | """
2 | Exact same as gym env, except that the gear ratio is 30 rather than 150.
3 | """
4 | import numpy as np
5 |
6 | from gym.envs.mujoco import MujocoEnv
7 |
8 | from rlkit.envs.env_utils import get_asset_full_path
9 |
10 |
11 | class AntEnv(MujocoEnv):
12 | def __init__(self, use_low_gear_ratio=True):
13 | if use_low_gear_ratio:
14 | xml_path = 'low_gear_ratio_ant.xml'
15 | else:
16 | xml_path = 'normal_gear_ratio_ant.xml'
17 | super().__init__(
18 | get_asset_full_path(xml_path),
19 | frame_skip=5,
20 | )
21 |
22 | def step(self, a):
23 | torso_xyz_before = self.get_body_com("torso")
24 | self.do_simulation(a, self.frame_skip)
25 | torso_xyz_after = self.get_body_com("torso")
26 | torso_velocity = torso_xyz_after - torso_xyz_before
27 | forward_reward = torso_velocity[0]/self.dt
28 | ctrl_cost = .5 * np.square(a).sum()
29 | contact_cost = 0.5 * 1e-3 * np.sum(
30 | np.square(np.clip(self.sim.data.cfrc_ext, -1, 1)))
31 | survive_reward = 1.0
32 | reward = forward_reward - ctrl_cost - contact_cost + survive_reward
33 | state = self.state_vector()
34 | notdone = np.isfinite(state).all() \
35 | and state[2] >= 0.2 and state[2] <= 1.0
36 | done = not notdone
37 | ob = self._get_obs()
38 | return ob, reward, done, dict(
39 | reward_forward=forward_reward,
40 | reward_ctrl=-ctrl_cost,
41 | reward_contact=-contact_cost,
42 | reward_survive=survive_reward,
43 | torso_velocity=torso_velocity,
44 | )
45 |
46 | def _get_obs(self):
47 | return np.concatenate([
48 | self.sim.data.qpos.flat[2:],
49 | self.sim.data.qvel.flat,
50 | ])
51 |
52 | def reset_model(self):
53 | qpos = self.init_qpos + self.np_random.uniform(size=self.model.nq, low=-.1, high=.1)
54 | qvel = self.init_qvel + self.np_random.randn(self.model.nv) * .1
55 | self.set_state(qpos, qvel)
56 | return self._get_obs()
57 |
58 | def viewer_setup(self):
59 | self.viewer.cam.distance = self.model.stat.extent * 0.5
60 |
--------------------------------------------------------------------------------
/rlkit/envs/assets/small_water_maze.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
22 |
23 |
24 |
25 |
26 |
27 |
29 |
31 |
33 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
--------------------------------------------------------------------------------
/rlkit/envs/memory/hidden_cartpole.py:
--------------------------------------------------------------------------------
1 | from collections import OrderedDict
2 |
3 | import numpy as np
4 | from cached_property import cached_property
5 |
6 | from rlkit.core.eval_util import create_stats_ordered_dict
7 | from rlkit.samplers.util import split_paths
8 | from rllab.envs.box2d.cartpole_env import CartpoleEnv
9 | from rlkit.core import logger
10 | from sandbox.rocky.tf.spaces import Box
11 |
12 |
13 | class HiddenCartpoleEnv(CartpoleEnv):
14 | def __init__(self, num_steps=100, position_only=True):
15 | assert position_only, "I only added position_only due to some weird " \
16 | "serialization bug"
17 | CartpoleEnv.__init__(self, position_only=position_only)
18 | self.num_steps = num_steps
19 |
20 | @cached_property
21 | def action_space(self):
22 | return Box(super().action_space.low,
23 | super().action_space.high)
24 |
25 | @cached_property
26 | def observation_space(self):
27 | return Box(super().observation_space.low,
28 | super().observation_space.high)
29 |
30 | @property
31 | def horizon(self):
32 | return self.num_steps
33 |
34 | @staticmethod
35 | def get_extra_info_dict_from_batch(batch):
36 | return dict()
37 |
38 | @staticmethod
39 | def get_flattened_extra_info_dict_from_subsequence_batch(batch):
40 | return dict()
41 |
42 | @staticmethod
43 | def get_last_extra_info_dict_from_subsequence_batch(batch):
44 | return dict()
45 |
46 | def log_diagnostics(self, paths, **kwargs):
47 | list_of_rewards, terminals, obs, actions, next_obs = split_paths(paths)
48 |
49 | returns = []
50 | for rewards in list_of_rewards:
51 | returns.append(np.sum(rewards))
52 | last_statistics = OrderedDict()
53 | last_statistics.update(create_stats_ordered_dict(
54 | 'UndiscountedReturns',
55 | returns,
56 | ))
57 | last_statistics.update(create_stats_ordered_dict(
58 | 'Actions',
59 | actions,
60 | ))
61 |
62 | for key, value in last_statistics.items():
63 | logger.record_tabular(key, value)
64 | return returns
65 |
66 | def is_current_done(self):
67 | return False
68 |
--------------------------------------------------------------------------------
/rlkit/envs/mujoco/mujoco_env_murtaza.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | from gym import error, spaces
4 | from gym.utils import seeding
5 | import numpy as np
6 | from os import path
7 | import gym
8 | import six
9 |
10 | try:
11 | import mujoco_py
12 | from mujoco_py.mjlib import mjlib
13 | except ImportError as e:
14 | raise error.DependencyNotInstalled("{}. (HINT: you need to install mujoco_py, and also perform the setup instructions here: https://github.com/openai/mujoco-py/.)".format(e))
15 |
16 | class MujocoEnv(gym.Env):
17 | """Superclass for all MuJoCo environments.
18 | """
19 |
20 | def __init__(self, model_path, frame_skip, init_viewer=False):
21 | if model_path.startswith("/"):
22 | fullpath = model_path
23 | else:
24 | fullpath = os.path.join(os.path.dirname(__file__), "assets", model_path)
25 | if not path.exists(fullpath):
26 | raise IOError("File %s does not exist" % fullpath)
27 | self.frame_skip = frame_skip
28 | self.model = mujoco_py.MjModel(fullpath)
29 | self.data = self.model.data
30 | self.viewer = None
31 |
32 | if init_viewer:
33 | self._viewer_bot.set_model(self.model)
34 | self._set_cam_position(self._viewer_bot, self.cam_pos)
35 | if init_viewer > 1:
36 | print("viewer 2 bring set up")
37 | self._viewer_bot2.set_model(self.model)
38 | self._set_cam_position(self._viewer_bot2, self.cam_pos2)
39 |
40 | self.metadata = {
41 | 'render.modes': ['human', 'rgb_array'],
42 | 'video.frames_per_second': int(np.round(1.0 / self.dt))
43 | }
44 |
45 | self.init_qpos = self.model.data.qpos.ravel().copy()
46 | self.init_qvel = self.model.data.qvel.ravel().copy()
47 | observation, _reward, done, _info = self._step(np.zeros(self.model.nu))
48 | assert not done
49 | self.obs_dim = observation.size
50 |
51 | bounds = self.model.actuator_ctrlrange.copy()
52 | low = bounds[:, 0]
53 | high = bounds[:, 1]
54 | self.action_space = spaces.Box(low, high)
55 |
56 | high = np.inf*np.ones(self.obs_dim)
57 | low = -high
58 | self.observation_space = spaces.Box(low, high)
59 |
60 | self._seed()
--------------------------------------------------------------------------------
/rlkit/envs/pearl_envs/rand_param_envs/walker2d_rand_params.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from gym import utils
3 | from rlkit.envs.pearl_envs.rand_param_envs.base import RandomEnv
4 |
5 |
6 | class Walker2DRandParamsEnv(RandomEnv, utils.EzPickle):
7 | def __init__(self, log_scale_limit=3.0):
8 | RandomEnv.__init__(self, log_scale_limit, 'walker2d.xml', 5)
9 | utils.EzPickle.__init__(self)
10 |
11 | def _step(self, a):
12 | # import ipdb; ipdb.set_trace()
13 | # posbefore = self.model.data.qpos[0, 0]
14 | posbefore = self.sim.data.qpos[0]
15 | self.do_simulation(a, self.frame_skip)
16 | # posafter, height, ang = self.model.data.qpos[0:3, 0]
17 | posafter, height, ang = self.sim.data.qpos[0:3]
18 | alive_bonus = 1.0
19 | reward = ((posafter - posbefore) / self.dt)
20 | reward += alive_bonus
21 | reward -= 1e-3 * np.square(a).sum()
22 | done = not (height > 0.8 and height < 2.0 and
23 | ang > -1.0 and ang < 1.0)
24 | ob = self._get_obs()
25 | return ob, reward, done, {}
26 |
27 | def _get_obs(self):
28 | # qpos = self.model.data.qpos
29 | # qvel = self.model.data.qvel
30 | qpos = self.sim.data.qpos
31 | qvel = self.sim.data.qvel
32 | return np.concatenate([qpos[1:], np.clip(qvel, -10, 10)]).ravel()
33 |
34 | def reset_model(self):
35 | self.set_state(
36 | self.init_qpos + self.np_random.uniform(low=-.005, high=.005, size=self.model.nq),
37 | self.init_qvel + self.np_random.uniform(low=-.005, high=.005, size=self.model.nv)
38 | )
39 | return self._get_obs()
40 |
41 | def viewer_setup(self):
42 | self.viewer.cam.trackbodyid = 2
43 | self.viewer.cam.distance = self.model.stat.extent * 0.5
44 | self.viewer.cam.lookat[2] += .8
45 | self.viewer.cam.elevation = -20
46 |
47 | if __name__ == "__main__":
48 |
49 | env = Walker2DRandParamsEnv()
50 | tasks = env.sample_tasks(40)
51 | while True:
52 | env.reset()
53 | env.set_task(np.random.choice(tasks))
54 | print(env.model.body_mass)
55 | for _ in range(100):
56 | env.render()
57 | env.step(env.action_space.sample()) # take a random action
58 |
59 |
--------------------------------------------------------------------------------
/rlkit/core/serializeable.py:
--------------------------------------------------------------------------------
1 | """
2 | Based on rllab's serializable.py file
3 | https://github.com/rll/rllab
4 | """
5 |
6 | import inspect
7 | import sys
8 |
9 |
10 | class Serializable(object):
11 |
12 | def __init__(self, *args, **kwargs):
13 | self.__args = args
14 | self.__kwargs = kwargs
15 |
16 | def quick_init(self, locals_):
17 | if getattr(self, "_serializable_initialized", False):
18 | return
19 | if sys.version_info >= (3, 0):
20 | spec = inspect.getfullargspec(self.__init__)
21 | # Exclude the first "self" parameter
22 | if spec.varkw:
23 | kwargs = locals_[spec.varkw].copy()
24 | else:
25 | kwargs = dict()
26 | if spec.kwonlyargs:
27 | for key in spec.kwonlyargs:
28 | kwargs[key] = locals_[key]
29 | else:
30 | spec = inspect.getargspec(self.__init__)
31 | if spec.keywords:
32 | kwargs = locals_[spec.keywords]
33 | else:
34 | kwargs = dict()
35 | if spec.varargs:
36 | varargs = locals_[spec.varargs]
37 | else:
38 | varargs = tuple()
39 | in_order_args = [locals_[arg] for arg in spec.args][1:]
40 | self.__args = tuple(in_order_args) + varargs
41 | self.__kwargs = kwargs
42 | setattr(self, "_serializable_initialized", True)
43 |
44 | def __getstate__(self):
45 | return {"__args": self.__args, "__kwargs": self.__kwargs}
46 |
47 | def __setstate__(self, d):
48 | # convert all __args to keyword-based arguments
49 | if sys.version_info >= (3, 0):
50 | spec = inspect.getfullargspec(self.__init__)
51 | else:
52 | spec = inspect.getargspec(self.__init__)
53 | in_order_args = spec.args[1:]
54 | out = type(self)(**dict(zip(in_order_args, d["__args"]), **d["__kwargs"]))
55 | self.__dict__.update(out.__dict__)
56 |
57 | @classmethod
58 | def clone(cls, obj, **kwargs):
59 | assert isinstance(obj, Serializable)
60 | d = obj.__getstate__()
61 | d["__kwargs"] = dict(d["__kwargs"], **kwargs)
62 | out = type(obj).__new__(type(obj))
63 | out.__setstate__(d)
64 | return out
65 |
--------------------------------------------------------------------------------
/rlkit/utils/image_util.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from copy import deepcopy
3 | import torch
4 | from PIL import Image
5 |
6 | from rlkit.util.augment_util import create_aug_stack
7 |
8 | class ImageAugment:
9 |
10 | def __init__(self,
11 | image_size=48,
12 | augment_order = ['RandomResizedCrop', 'ColorJitter'],
13 | augment_probability=0.95,
14 | augment_params={
15 | 'RandomResizedCrop': dict(
16 | scale=(0.9, 1.0),
17 | ratio=(0.9, 1.1),
18 | ),
19 | 'ColorJitter': dict(
20 | brightness=(0.75, 1.25),
21 | contrast=(0.9, 1.1),
22 | saturation=(0.9, 1.1),
23 | hue=(-0.1, 0.1),
24 | ),
25 | 'RandomCrop': dict(
26 | padding=4,
27 | padding_mode='edge'
28 | ),
29 | },
30 | ):
31 | self._image_size = image_size
32 |
33 | self.augment_stack = create_aug_stack(
34 | augment_order, augment_params, size=(self._image_size, self._image_size)
35 | )
36 | self.augment_probability = augment_probability
37 |
38 | def set_augment_params(self, img):
39 | if torch.rand(1) < self.augment_probability:
40 | self.augment_stack.set_params(img)
41 | else:
42 | self.augment_stack.set_default_params(img)
43 |
44 | def augment(self, img):
45 | img = self.augment_stack(img)
46 | return img
47 |
48 | def __call__(self, images, already_tranformed=True):
49 | if len(images.shape) == 4:
50 | batched = True
51 | elif len(images.shape) == 3:
52 | batched = False
53 | else:
54 | raise ValueError
55 |
56 | if already_tranformed:
57 | images += .5
58 |
59 | if self.augment_probability > 0:
60 | self.set_augment_params(images)
61 | images = self.augment(images)
62 |
63 | if already_tranformed:
64 | images -= 0.5
65 |
66 | if not batched:
67 | images = images[0]
68 |
69 | return images
70 |
--------------------------------------------------------------------------------
/rlkit/torch/core.py:
--------------------------------------------------------------------------------
1 | import abc
2 |
3 | import numpy as np
4 | import torch
5 | from torch import nn as nn
6 |
7 | from rlkit.torch import pytorch_util as ptu
8 |
9 |
10 | class PyTorchModule(nn.Module, metaclass=abc.ABCMeta):
11 | """
12 | Keeping wrapper around to be a bit more future-proof.
13 | """
14 | pass
15 |
16 |
17 | def eval_np(module, *args, **kwargs):
18 | """
19 | Eval this module with a numpy interface
20 |
21 | Same as a call to __call__ except all Variable input/outputs are
22 | replaced with numpy equivalents.
23 |
24 | Assumes the output is either a single object or a tuple of objects.
25 | """
26 | torch_args = tuple(torch_ify(x) for x in args)
27 | torch_kwargs = {k: torch_ify(v) for k, v in kwargs.items()}
28 | outputs = module(*torch_args, **torch_kwargs)
29 | return elem_or_tuple_to_numpy(outputs)
30 |
31 |
32 | def torch_ify(np_array_or_other):
33 | if isinstance(np_array_or_other, np.ndarray):
34 | return ptu.from_numpy(np_array_or_other)
35 | else:
36 | return np_array_or_other
37 |
38 |
39 | def np_ify(tensor_or_other):
40 | if isinstance(tensor_or_other, torch.autograd.Variable):
41 | return ptu.get_numpy(tensor_or_other)
42 | else:
43 | return tensor_or_other
44 |
45 |
46 | def _elem_or_tuple_to_variable(elem_or_tuple):
47 | if isinstance(elem_or_tuple, tuple):
48 | return tuple(
49 | _elem_or_tuple_to_variable(e) for e in elem_or_tuple
50 | )
51 | return ptu.from_numpy(elem_or_tuple).float()
52 |
53 |
54 | def elem_or_tuple_to_numpy(elem_or_tuple):
55 | if isinstance(elem_or_tuple, tuple):
56 | return tuple(np_ify(x) for x in elem_or_tuple)
57 | else:
58 | return np_ify(elem_or_tuple)
59 |
60 |
61 | def _filter_batch(np_batch):
62 | for k, v in np_batch.items():
63 | if v.dtype == np.bool:
64 | yield k, v.astype(int)
65 | else:
66 | yield k, v
67 |
68 |
69 | def np_to_pytorch_batch(np_batch):
70 | if isinstance(np_batch, dict):
71 | return {
72 | k: _elem_or_tuple_to_variable(x)
73 | for k, x in _filter_batch(np_batch)
74 | if x.dtype != np.dtype('O') # ignore object (e.g. dictionaries)
75 | }
76 | else:
77 | _elem_or_tuple_to_variable(np_batch)
78 |
--------------------------------------------------------------------------------
/rlkit/exploration_strategies/ou_strategy.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import numpy.random as nr
3 |
4 | from rlkit.exploration_strategies.base import RawExplorationStrategy
5 |
6 |
7 | class OUStrategy(RawExplorationStrategy):
8 | """
9 | This strategy implements the Ornstein-Uhlenbeck process, which adds
10 | time-correlated noise to the actions taken by the deterministic policy.
11 | The OU process satisfies the following stochastic differential equation:
12 | dxt = theta*(mu - xt)*dt + sigma*dWt
13 | where Wt denotes the Wiener process
14 |
15 | Based on the rllab implementation.
16 | """
17 |
18 | def __init__(
19 | self,
20 | action_space,
21 | mu=0,
22 | theta=0.15,
23 | max_sigma=0.3,
24 | min_sigma=None,
25 | decay_period=100000,
26 | ):
27 | if min_sigma is None:
28 | min_sigma = max_sigma
29 | self.mu = mu
30 | self.theta = theta
31 | self.sigma = max_sigma
32 | self._max_sigma = max_sigma
33 | if min_sigma is None:
34 | min_sigma = max_sigma
35 | self._min_sigma = min_sigma
36 | self._decay_period = decay_period
37 | self.dim = np.prod(action_space.low.shape)
38 | self.low = action_space.low
39 | self.high = action_space.high
40 | self.state = np.ones(self.dim) * self.mu
41 | self.reset()
42 |
43 | def reset(self):
44 | self.state = np.ones(self.dim) * self.mu
45 |
46 | def evolve_state(self):
47 | x = self.state
48 | dx = self.theta * (self.mu - x) + self.sigma * nr.randn(len(x))
49 | self.state = x + dx
50 | return self.state
51 |
52 | def get_action_from_raw_action(self, action, t=0, **kwargs):
53 | ou_state = self.evolve_state()
54 | self.sigma = (
55 | self._max_sigma
56 | - (self._max_sigma - self._min_sigma)
57 | * min(1.0, t * 1.0 / self._decay_period)
58 | )
59 | return np.clip(action + ou_state, self.low, self.high)
60 |
61 | def get_actions_from_raw_actions(self, actions, t=0, **kwargs):
62 | noise = (
63 | self.state + self.theta * (self.mu - self.state)
64 | + self.sigma * nr.randn(*actions.shape)
65 | )
66 | return np.clip(actions + noise, self.low, self.high)
67 |
--------------------------------------------------------------------------------
/rlkit/envs/wrappers/normalized_box_env.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from gym.spaces import Box
3 |
4 | from rlkit.envs.proxy_env import ProxyEnv
5 |
6 |
7 | class NormalizedBoxEnv(ProxyEnv):
8 | """
9 | Normalize action to in [-1, 1].
10 |
11 | Optionally normalize observations and scale reward.
12 | """
13 |
14 | def __init__(
15 | self,
16 | env,
17 | reward_scale=1.,
18 | obs_mean=None,
19 | obs_std=None,
20 | ):
21 | ProxyEnv.__init__(self, env)
22 | self._should_normalize = not (obs_mean is None and obs_std is None)
23 | if self._should_normalize:
24 | if obs_mean is None:
25 | obs_mean = np.zeros_like(env.observation_space.low)
26 | else:
27 | obs_mean = np.array(obs_mean)
28 | if obs_std is None:
29 | obs_std = np.ones_like(env.observation_space.low)
30 | else:
31 | obs_std = np.array(obs_std)
32 | self._reward_scale = reward_scale
33 | self._obs_mean = obs_mean
34 | self._obs_std = obs_std
35 | ub = np.ones(self._wrapped_env.action_space.shape)
36 | self.action_space = Box(-1 * ub, ub)
37 |
38 | def estimate_obs_stats(self, obs_batch, override_values=False):
39 | if self._obs_mean is not None and not override_values:
40 | raise Exception("Observation mean and std already set. To "
41 | "override, set override_values to True.")
42 | self._obs_mean = np.mean(obs_batch, axis=0)
43 | self._obs_std = np.std(obs_batch, axis=0)
44 |
45 | def _apply_normalize_obs(self, obs):
46 | return (obs - self._obs_mean) / (self._obs_std + 1e-8)
47 |
48 | def step(self, action):
49 | lb = self._wrapped_env.action_space.low
50 | ub = self._wrapped_env.action_space.high
51 | scaled_action = lb + (action + 1.) * 0.5 * (ub - lb)
52 | scaled_action = np.clip(scaled_action, lb, ub)
53 |
54 | wrapped_step = self._wrapped_env.step(scaled_action)
55 | next_obs, reward, done, info = wrapped_step
56 | if self._should_normalize:
57 | next_obs = self._apply_normalize_obs(next_obs)
58 | return next_obs, reward * self._reward_scale, done, info
59 |
60 | def __str__(self):
61 | return "Normalized: %s" % self._wrapped_env
62 |
63 |
64 |
--------------------------------------------------------------------------------
/rlkit/exploration_strategies/base.py:
--------------------------------------------------------------------------------
1 | import abc
2 |
3 | from rlkit.policies.base import ExplorationPolicy
4 |
5 |
6 | class ExplorationStrategy(object, metaclass=abc.ABCMeta):
7 | @abc.abstractmethod
8 | def get_action(self, t, observation, policy, **kwargs):
9 | pass
10 |
11 | @abc.abstractmethod
12 | def get_actions(self, t, observation, policy, **kwargs):
13 | pass
14 |
15 | def reset(self):
16 | pass
17 |
18 |
19 | class RawExplorationStrategy(ExplorationStrategy, metaclass=abc.ABCMeta):
20 | @abc.abstractmethod
21 | def get_action_from_raw_action(self, action, **kwargs):
22 | pass
23 |
24 | def get_actions_from_raw_actions(self, actions, **kwargs):
25 | raise NotImplementedError()
26 |
27 | def get_action(self, t, policy, *args, **kwargs):
28 | action, agent_info = policy.get_action(*args, **kwargs)
29 | return self.get_action_from_raw_action(action, t=t), agent_info
30 |
31 | def get_actions(self, t, observation, policy, **kwargs):
32 | actions = policy.get_actions(observation)
33 | return self.get_actions_from_raw_actions(actions, **kwargs)
34 |
35 | def reset(self):
36 | pass
37 |
38 |
39 | class PolicyWrappedWithExplorationStrategy(ExplorationPolicy):
40 | def __init__(
41 | self,
42 | exploration_strategy: ExplorationStrategy,
43 | policy,
44 | ):
45 | self.es = exploration_strategy
46 | self.policy = policy
47 | self.t = 0
48 |
49 | def set_num_steps_total(self, t):
50 | self.t = t
51 |
52 | def get_action(self, *args, **kwargs):
53 | return self.es.get_action(self.t, self.policy, *args, **kwargs)
54 |
55 | def get_actions(self, *args, **kwargs):
56 | return self.es.get_actions(self.t, self.policy, *args, **kwargs)
57 |
58 | def reset(self):
59 | self.es.reset()
60 | self.policy.reset()
61 |
62 | def get_param_values(self):
63 | return self.policy.get_param_values()
64 |
65 | def set_param_values(self, param_values):
66 | self.policy.set_param_values(param_values)
67 |
68 | def get_param_values_np(self):
69 | return self.policy.get_param_values_np()
70 |
71 | def set_param_values_np(self, param_values):
72 | self.policy.set_param_values_np(param_values)
73 |
74 | def to(self, device):
75 | self.policy.to(device)
76 |
--------------------------------------------------------------------------------
/rlkit/envs/mujoco/mujoco_env.py:
--------------------------------------------------------------------------------
1 | import os
2 | from os import path
3 |
4 | import mujoco_py
5 | import numpy as np
6 | from gym.envs.mujoco import mujoco_env
7 |
8 | from rlkit.envs.env_utils import get_asset_full_path
9 |
10 |
11 | class MujocoEnv(mujoco_env.MujocoEnv):
12 | """
13 | My own wrapper around MujocoEnv.
14 |
15 | The caller needs to declare
16 | """
17 | def __init__(
18 | self,
19 | model_path,
20 | frame_skip=1,
21 | model_path_is_local=True,
22 | automatically_set_obs_and_action_space=False,
23 | ):
24 | if model_path_is_local:
25 | model_path = get_asset_full_path(model_path)
26 | if automatically_set_obs_and_action_space:
27 | mujoco_env.MujocoEnv.__init__(self, model_path, frame_skip)
28 | else:
29 | """
30 | Code below is copy/pasted from MujocoEnv's __init__ function.
31 | """
32 | if model_path.startswith("/"):
33 | fullpath = model_path
34 | else:
35 | fullpath = os.path.join(os.path.dirname(__file__), "assets", model_path)
36 | if not path.exists(fullpath):
37 | raise IOError("File %s does not exist" % fullpath)
38 | self.frame_skip = frame_skip
39 | self.model = mujoco_py.load_model_from_path(fullpath)
40 | self.sim = mujoco_py.MjSim(self.model)
41 | self.data = self.sim.data
42 | self.viewer = None
43 |
44 | self.metadata = {
45 | 'render.modes': ['human', 'rgb_array'],
46 | 'video.frames_per_second': int(np.round(1.0 / self.dt))
47 | }
48 |
49 | self.init_qpos = self.sim.data.qpos.ravel().copy()
50 | self.init_qvel = self.sim.data.qvel.ravel().copy()
51 | observation, _reward, done, _info = self.step(np.zeros(self.model.nu))
52 | assert not done
53 | self.obs_dim = observation.size
54 |
55 | bounds = self.model.actuator_ctrlrange.copy()
56 | low = bounds[:, 0]
57 | high = bounds[:, 1]
58 | self.action_space = spaces.Box(low=low, high=high)
59 |
60 | high = np.inf*np.ones(self.obs_dim)
61 | low = -high
62 | self.observation_space = spaces.Box(low, high)
63 |
64 | self.seed()
65 |
66 | def log_diagnostics(self, paths):
67 | pass
68 |
--------------------------------------------------------------------------------
/rlkit/envs/pearl_envs/__init__.py:
--------------------------------------------------------------------------------
1 | from rlkit.envs.pearl_envs.ant_normal import AntNormal
2 | from rlkit.envs.pearl_envs.ant_dir import AntDirEnv
3 | from rlkit.envs.pearl_envs.ant_goal import AntGoalEnv
4 | from rlkit.envs.pearl_envs.half_cheetah_dir import HalfCheetahDirEnv
5 | from rlkit.envs.pearl_envs.half_cheetah_vel import HalfCheetahVelEnv
6 | from rlkit.envs.pearl_envs.hopper_rand_params_wrapper import \
7 | HopperRandParamsWrappedEnv
8 | from rlkit.envs.pearl_envs.humanoid_dir import HumanoidDirEnv
9 | from rlkit.envs.pearl_envs.point_robot import PointEnv, SparsePointEnv
10 | from rlkit.envs.pearl_envs.rand_param_envs.walker2d_rand_params import \
11 | Walker2DRandParamsEnv
12 | from rlkit.envs.pearl_envs.walker_rand_params_wrapper import \
13 | WalkerRandParamsWrappedEnv
14 |
15 | ENVS = {}
16 |
17 |
18 | def register_env(name):
19 | """Registers a env by name for instantiation in rlkit."""
20 |
21 | def register_env_fn(fn):
22 | if name in ENVS:
23 | raise ValueError("Cannot register duplicate env {}".format(name))
24 | if not callable(fn):
25 | raise TypeError("env {} must be callable".format(name))
26 | ENVS[name] = fn
27 | return fn
28 |
29 | return register_env_fn
30 |
31 | def _register_env(name, fn):
32 | """Registers a env by name for instantiation in rlkit."""
33 | if name in ENVS:
34 | raise ValueError("Cannot register duplicate env {}".format(name))
35 | if not callable(fn):
36 | raise TypeError("env {} must be callable".format(name))
37 | ENVS[name] = fn
38 |
39 |
40 | def register_pearl_envs():
41 | _register_env('sparse-point-robot', SparsePointEnv)
42 | _register_env('ant-normal', AntNormal)
43 | _register_env('ant-dir', AntDirEnv)
44 | _register_env('ant-goal', AntGoalEnv)
45 | _register_env('cheetah-dir', HalfCheetahDirEnv)
46 | _register_env('cheetah-vel', HalfCheetahVelEnv)
47 | _register_env('humanoid-dir', HumanoidDirEnv)
48 | _register_env('point-robot', PointEnv)
49 | _register_env('walker-rand-params', WalkerRandParamsWrappedEnv)
50 | _register_env('hopper-rand-params', HopperRandParamsWrappedEnv)
51 |
52 | # automatically import any envs in the envs/ directory
53 | # for file in os.listdir(os.path.dirname(__file__)):
54 | # if file.endswith('.py') and not file.startswith('_'):
55 | # module = file[:file.find('.py')]
56 | # importlib.import_module('rlkit.envs.pearl_envs.' + module)
57 |
--------------------------------------------------------------------------------
/rlkit/envs/make_env.py:
--------------------------------------------------------------------------------
1 | """
2 | This file provides a more uniform interface to gym.make(env_id) that handles
3 | imports and normalization
4 | """
5 |
6 | import gym
7 |
8 | from rlkit.envs.wrappers.normalized_box_env import NormalizedBoxEnv
9 |
10 | DAPG_ENVS = [
11 | 'pen-v0', 'pen-sparse-v0', 'pen-notermination-v0', 'pen-binary-v0', 'pen-binary-old-v0',
12 | 'door-v0', 'door-sparse-v0', 'door-binary-v0', 'door-binary-old-v0',
13 | 'relocate-v0', 'relocate-sparse-v0', 'relocate-binary-v0', 'relocate-binary-old-v0',
14 | 'hammer-v0', 'hammer-sparse-v0', 'hammer-binary-v0',
15 | ]
16 |
17 | D4RL_ENVS = [
18 | "maze2d-open-v0", "maze2d-umaze-v0", "maze2d-medium-v0", "maze2d-large-v0",
19 | "maze2d-open-dense-v0", "maze2d-umaze-dense-v0", "maze2d-medium-dense-v0", "maze2d-large-dense-v0",
20 | "antmaze-umaze-v0", "antmaze-umaze-diverse-v0", "antmaze-medium-diverse-v0",
21 | "antmaze-medium-play-v0", "antmaze-large-diverse-v0", "antmaze-large-play-v0",
22 | "pen-human-v0", "pen-cloned-v0", "pen-expert-v0", "hammer-human-v0", "hammer-cloned-v0", "hammer-expert-v0",
23 | "door-human-v0", "door-cloned-v0", "door-expert-v0", "relocate-human-v0", "relocate-cloned-v0", "relocate-expert-v0",
24 | "halfcheetah-random-v0", "halfcheetah-medium-v0", "halfcheetah-expert-v0", "halfcheetah-mixed-v0", "halfcheetah-medium-expert-v0",
25 | "walker2d-random-v0", "walker2d-medium-v0", "walker2d-expert-v0", "walker2d-mixed-v0", "walker2d-medium-expert-v0",
26 | "hopper-random-v0", "hopper-medium-v0", "hopper-expert-v0", "hopper-mixed-v0", "hopper-medium-expert-v0",
27 | # 'halfcheetah-expert-v2',
28 | 'halfcheetah-medium-v2',
29 | 'halfcheetah-medium-replay-v2',
30 | 'halfcheetah-medium-expert-v2',
31 | # 'hopper-expert-v2',
32 | 'hopper-medium-v2',
33 | 'hopper-medium-replay-v2',
34 | 'hopper-medium-expert-v2',
35 | # 'walker2d-expert-v2',
36 | 'walker2d-medium-v2',
37 | 'walker2d-medium-replay-v2',
38 | 'walker2d-medium-expert-v2',
39 | ]
40 |
41 | def make(env_id=None, env_class=None, env_kwargs=None, normalize_env=True):
42 | assert env_id or env_class
43 |
44 | if env_class:
45 | env = env_class(**env_kwargs)
46 | elif env_id in DAPG_ENVS:
47 | import mj_envs
48 | assert normalize_env == False
49 | env = gym.make(env_id)
50 | elif env_id in D4RL_ENVS:
51 | import d4rl
52 | assert normalize_env == False
53 | env = gym.make(env_id)
54 | elif env_id:
55 | env = gym.make(env_id)
56 |
57 | if normalize_env:
58 | env = NormalizedBoxEnv(env)
59 |
60 | return env
61 |
--------------------------------------------------------------------------------
/rlkit/learning/online_offline_split_replay_buffer.py:
--------------------------------------------------------------------------------
1 | import random
2 | import numpy as np
3 | import torch
4 |
5 | import rlkit.data_management.images as image_np
6 | from rlkit import pythonplusplus as ppp
7 |
8 | # import time
9 |
10 | from rlkit.data_management.online_offline_split_replay_buffer import (
11 | OnlineOfflineSplitReplayBuffer,
12 | )
13 |
14 |
15 | def concat(*x):
16 | return np.concatenate(x, axis=0)
17 |
18 |
19 | class OnlineOfflineSplitReplayBuffer(OnlineOfflineSplitReplayBuffer):
20 | def __init__(
21 | self,
22 | offline_replay_buffer,
23 | online_replay_buffer,
24 | sample_online_fraction=None,
25 | online_mode=False,
26 | **kwargs
27 | ):
28 | super().__init__(offline_replay_buffer=offline_replay_buffer,
29 | online_replay_buffer=online_replay_buffer,
30 | sample_online_fraction=sample_online_fraction,
31 | online_mode=online_mode,
32 | **kwargs)
33 |
34 | def random_batch(self, batch_size):
35 | online_batch_size = min(int(self.sample_online_fraction * batch_size),
36 | self.online_replay_buffer.num_steps_can_sample())
37 | offline_batch_size = batch_size - online_batch_size
38 | online_batch = self.online_replay_buffer.random_batch(online_batch_size)
39 | offline_batch = self.offline_replay_buffer.random_batch(offline_batch_size)
40 |
41 | # torch.cuda.synchronize()
42 | # start_time = time.time()
43 | # batch = dict()
44 | # for (key, online_batch_value) in online_batch.items():
45 | # assert key in offline_batch
46 | # offline_batch_value = offline_batch[key]
47 | # if key == 'indices':
48 | # batch['online_indices'] = online_batch_value
49 | # batch['offline_indicies'] = offline_batch_value
50 | # else:
51 | # batch[key] = np.concatenate((online_batch_value, offline_batch_value), axis=0)
52 | #
53 | # if batch[key].dtype == np.uint8:
54 | # batch[key] = image_np.normalize_image(batch[key], dtype=np.float32)
55 | batch = ppp.treemap(
56 | concat,
57 | online_batch,
58 | offline_batch,
59 | atomic_type=np.ndarray)
60 |
61 | # torch.cuda.synchronize()
62 | # end_time = time.time()
63 | # print("Time to concatenate offline and online data: {} secs".format(end_time - start_time))
64 |
65 | return batch
66 |
--------------------------------------------------------------------------------
/rlkit/envs/pearl_envs/ant.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | from multiworld.envs.mujoco.cameras import create_camera_init
4 | from .mujoco_env import MujocoEnv
5 |
6 |
7 | class AntEnv(MujocoEnv):
8 | def __init__(self, use_low_gear_ratio=False):
9 | # self.init_serialization(locals())
10 | if use_low_gear_ratio:
11 | xml_path = 'low_gear_ratio_ant.xml'
12 | else:
13 | xml_path = 'ant.xml'
14 | super().__init__(
15 | xml_path,
16 | frame_skip=5,
17 | automatically_set_obs_and_action_space=True,
18 | )
19 |
20 | def step(self, a):
21 | torso_xyz_before = self.get_body_com("torso")
22 | self.do_simulation(a, self.frame_skip)
23 | torso_xyz_after = self.get_body_com("torso")
24 | torso_velocity = torso_xyz_after - torso_xyz_before
25 | forward_reward = torso_velocity[0]/self.dt
26 | ctrl_cost = 0. # .5 * np.square(a).sum()
27 | contact_cost = 0.5 * 1e-3 * np.sum(
28 | np.square(np.clip(self.sim.data.cfrc_ext, -1, 1)))
29 | survive_reward = 0. # 1.0
30 | reward = forward_reward - ctrl_cost - contact_cost + survive_reward
31 | state = self.state_vector()
32 | notdone = np.isfinite(state).all() \
33 | and state[2] >= 0.2 and state[2] <= 1.0
34 | done = not notdone
35 | ob = self._get_obs()
36 | return ob, reward, done, dict(
37 | reward_forward=forward_reward,
38 | reward_ctrl=-ctrl_cost,
39 | reward_contact=-contact_cost,
40 | reward_survive=survive_reward,
41 | torso_velocity=torso_velocity,
42 | )
43 |
44 | def _get_obs(self):
45 | # this is gym ant obs, should use rllab?
46 | # if position is needed, override this in subclasses
47 | return np.concatenate([
48 | self.sim.data.qpos.flat[2:],
49 | self.sim.data.qvel.flat,
50 | ])
51 |
52 | def reset_model(self):
53 | qpos = self.init_qpos + self.np_random.uniform(size=self.model.nq, low=-.1, high=.1)
54 | qvel = self.init_qvel + self.np_random.randn(self.model.nv) * .1
55 | self.set_state(qpos, qvel)
56 | return self._get_obs()
57 |
58 | def viewer_setup(self):
59 | self.camera_init = create_camera_init(
60 | lookat=(0, 0, 0),
61 | distance=10,
62 | elevation=-45,
63 | trackbodyid=self.sim.model.body_name2id('torso'),
64 | )
65 | self.camera_init(self.viewer.cam)
66 | # self.viewer.cam.distance = self.model.stat.extent * 0.5
67 | #
--------------------------------------------------------------------------------
/rlkit/envs/pearl_envs/humanoid_dir.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from gym.envs.mujoco import HumanoidEnv as HumanoidEnv
3 |
4 |
5 | def mass_center(model, sim):
6 | mass = np.expand_dims(model.body_mass, 1)
7 | xpos = sim.data.xipos
8 | return (np.sum(mass * xpos, 0) / np.sum(mass))
9 |
10 |
11 | class HumanoidDirEnv(HumanoidEnv):
12 |
13 | def __init__(self, task={}, n_tasks=2, randomize_tasks=True):
14 | self.tasks = self.sample_tasks(n_tasks)
15 | self.reset_task(0)
16 | super(HumanoidDirEnv, self).__init__()
17 |
18 | def step(self, action):
19 | pos_before = np.copy(mass_center(self.model, self.sim)[:2])
20 | self.do_simulation(action, self.frame_skip)
21 | pos_after = mass_center(self.model, self.sim)[:2]
22 |
23 | alive_bonus = 5.0
24 | data = self.sim.data
25 | goal_direction = (np.cos(self._goal), np.sin(self._goal))
26 | lin_vel_cost = 0.25 * np.sum(goal_direction * (pos_after - pos_before)) / self.model.opt.timestep
27 | quad_ctrl_cost = 0.1 * np.square(data.ctrl).sum()
28 | quad_impact_cost = .5e-6 * np.square(data.cfrc_ext).sum()
29 | quad_impact_cost = min(quad_impact_cost, 10)
30 | reward = lin_vel_cost - quad_ctrl_cost - quad_impact_cost + alive_bonus
31 | qpos = self.sim.data.qpos
32 | done = bool((qpos[2] < 1.0) or (qpos[2] > 2.0))
33 |
34 | return self._get_obs(), reward, done, dict(reward_linvel=lin_vel_cost,
35 | reward_quadctrl=-quad_ctrl_cost,
36 | reward_alive=alive_bonus,
37 | reward_impact=-quad_impact_cost)
38 |
39 | def _get_obs(self):
40 | data = self.sim.data
41 | return np.concatenate([data.qpos.flat[2:],
42 | data.qvel.flat,
43 | data.cinert.flat,
44 | data.cvel.flat,
45 | data.qfrc_actuator.flat,
46 | data.cfrc_ext.flat])
47 |
48 | def get_all_task_idx(self):
49 | return range(len(self.tasks))
50 |
51 | def reset_task(self, idx):
52 | self._task = self.tasks[idx]
53 | self._goal = self._task['goal'] # assume parameterization of task by single vector
54 |
55 | def sample_tasks(self, num_tasks):
56 | # velocities = np.random.uniform(0., 1.0 * np.pi, size=(num_tasks,))
57 | directions = np.random.uniform(0., 2.0 * np.pi, size=(num_tasks,))
58 | tasks = [{'goal': d} for d in directions]
59 | return tasks
60 |
--------------------------------------------------------------------------------
/rlkit/envs/pearl_envs/half_cheetah_vel.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | from .half_cheetah import HalfCheetahEnv
4 |
5 |
6 | class HalfCheetahVelEnv(HalfCheetahEnv):
7 | """Half-cheetah environment with target velocity, as described in [1]. The
8 | code is adapted from
9 | https://github.com/cbfinn/maml_rl/blob/9c8e2ebd741cb0c7b8bf2d040c4caeeb8e06cc95/rllab/envs/mujoco/half_cheetah_env_rand.py
10 |
11 | The half-cheetah follows the dynamics from MuJoCo [2], and receives at each
12 | time step a reward composed of a control cost and a penalty equal to the
13 | difference between its current velocity and the target velocity. The tasks
14 | are generated by sampling the target velocities from the uniform
15 | distribution on [0, 2].
16 |
17 | [1] Chelsea Finn, Pieter Abbeel, Sergey Levine, "Model-Agnostic
18 | Meta-Learning for Fast Adaptation of Deep Networks", 2017
19 | (https://arxiv.org/abs/1703.03400)
20 | [2] Emanuel Todorov, Tom Erez, Yuval Tassa, "MuJoCo: A physics engine for
21 | model-based control", 2012
22 | (https://homes.cs.washington.edu/~todorov/papers/TodorovIROS12.pdf)
23 | """
24 | def __init__(self, task={}, n_tasks=2, randomize_tasks=True):
25 | self._task = task
26 | self.tasks = self.sample_tasks(n_tasks)
27 | self._goal_vel = self.tasks[0].get('velocity', 0.0)
28 | self._goal = self._goal_vel
29 | super(HalfCheetahVelEnv, self).__init__()
30 |
31 | def step(self, action):
32 | xposbefore = self.sim.data.qpos[0]
33 | self.do_simulation(action, self.frame_skip)
34 | xposafter = self.sim.data.qpos[0]
35 |
36 | forward_vel = (xposafter - xposbefore) / self.dt
37 | forward_reward = -1.0 * abs(forward_vel - self._goal_vel)
38 | ctrl_cost = 0.5 * 1e-1 * np.sum(np.square(action))
39 |
40 | observation = self._get_obs()
41 | reward = forward_reward - ctrl_cost
42 | done = False
43 | infos = dict(reward_forward=forward_reward,
44 | reward_ctrl=-ctrl_cost, task=self._task)
45 | return (observation, reward, done, infos)
46 |
47 | def sample_tasks(self, num_tasks):
48 | np.random.seed(1337)
49 | velocities = np.random.uniform(0.0, 3.0, size=(num_tasks,))
50 | tasks = [{'velocity': velocity} for velocity in velocities]
51 | return tasks
52 |
53 | def get_all_task_idx(self):
54 | return range(len(self.tasks))
55 |
56 | def reset_task(self, idx):
57 | self._task = self.tasks[idx]
58 | self._goal_vel = self._task['velocity']
59 | self._goal = self._goal_vel
60 | self.reset()
61 |
--------------------------------------------------------------------------------
/rlkit/util/inspect_q_util.py:
--------------------------------------------------------------------------------
1 | import matplotlib.pyplot as plt
2 | import rlkit.visualization.visualization_util as vu
3 | import numpy as np
4 | from rlkit.torch.core import np_ify, torch_ify
5 |
6 | fig_v_mean = None
7 | fig_v_std = None
8 | axes_v_mean = None
9 | axes_v_std = None
10 |
11 | def debug_q(ensemble_qs, policy, show_mean=True, show_std=True):
12 | global fig_v_mean, fig_v_std, axes_v_mean, axes_v_std
13 |
14 | if fig_v_mean is None:
15 | if show_mean:
16 | fig_v_mean, axes_v_mean = plt.subplots(3, 3, sharex='all', sharey='all', figsize=(9, 9))
17 | fig_v_mean.canvas.set_window_title('Q Mean')
18 | # plt.suptitle("V Mean")
19 | if show_std:
20 | fig_v_std, axes_v_std = plt.subplots(3, 3, sharex='all', sharey='all', figsize=(9, 9))
21 | fig_v_std.canvas.set_window_title('Q Std')
22 | # plt.suptitle("V Std")
23 |
24 | # obss = [(0, 0), (0, 0.75), (0, 1.75), (0, 1.25), (0, 4), (-4, 4), (4, -4), (4, 4), (-4, 4)]
25 | obss = []
26 | for x in [-3, 0, 3]:
27 | for y in [3, 0, -3]:
28 | obss.append((x, y))
29 |
30 | def create_eval_function(q, obs, ):
31 | def beta_eval(goals):
32 | # goals = np.array([[
33 | # *goal
34 | # ]])
35 | N = len(goals)
36 | observations = np.tile(obs, (N, 1))
37 | new_obs = np.hstack((observations, goals))
38 | actions = torch_ify(policy.get_action(new_obs)[0])
39 | return np_ify(q(torch_ify(new_obs), actions)).flatten()
40 | return beta_eval
41 |
42 | for o in range(9):
43 | i = o % 3
44 | j = o // 3
45 | H = []
46 | for b in range(5):
47 | q = ensemble_qs[b]
48 |
49 | rng = [-4, 4]
50 | resolution = 30
51 |
52 | obs = obss[o]
53 |
54 | heatmap = vu.make_heat_map(create_eval_function(q, obs, ), rng, rng, resolution=resolution, batch=True)
55 | H.append(heatmap.values)
56 | p, x, y, _ = heatmap
57 | if show_mean:
58 | h1 = vu.HeatMap(np.mean(H, axis=0), x, y, _)
59 | vu.plot_heatmap(h1, ax=axes_v_mean[i, j])
60 | axes_v_mean[i, j].set_title("pos " + str(obs))
61 |
62 | if show_std:
63 | h2 = vu.HeatMap(np.std(H, axis=0), x, y, _)
64 | vu.plot_heatmap(h2, ax=axes_v_std[i, j])
65 | axes_v_std[i, j].set_title("pos " + str(obs))
66 |
67 | # axes_v_mean[i, j].plot(range(100), range(100))
68 | # axes_v_std[i, j].plot(range(100), range(100))
69 |
70 | plt.draw()
71 | plt.pause(0.01)
72 |
--------------------------------------------------------------------------------
/rlkit/envs/pearl_envs/half_cheetah_dir.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | from .half_cheetah import HalfCheetahEnv
4 |
5 |
6 | class HalfCheetahDirEnv(HalfCheetahEnv):
7 | """Half-cheetah environment with target direction, as described in [1]. The
8 | code is adapted from
9 | https://github.com/cbfinn/maml_rl/blob/9c8e2ebd741cb0c7b8bf2d040c4caeeb8e06cc95/rllab/envs/mujoco/half_cheetah_env_rand_direc.py
10 |
11 | The half-cheetah follows the dynamics from MuJoCo [2], and receives at each
12 | time step a reward composed of a control cost and a reward equal to its
13 | velocity in the target direction. The tasks are generated by sampling the
14 | target directions from a Bernoulli distribution on {-1, 1} with parameter
15 | 0.5 (-1: backward, +1: forward).
16 |
17 | [1] Chelsea Finn, Pieter Abbeel, Sergey Levine, "Model-Agnostic
18 | Meta-Learning for Fast Adaptation of Deep Networks", 2017
19 | (https://arxiv.org/abs/1703.03400)
20 | [2] Emanuel Todorov, Tom Erez, Yuval Tassa, "MuJoCo: A physics engine for
21 | model-based control", 2012
22 | (https://homes.cs.washington.edu/~todorov/papers/TodorovIROS12.pdf)
23 | """
24 | def __init__(self, task={}, n_tasks=2, randomize_tasks=False):
25 | directions = [-1, 1]
26 | self.tasks = [{'direction': direction} for direction in directions]
27 | self._task = task
28 | self._goal_dir = task.get('direction', 1)
29 | self._goal = self._goal_dir
30 | super(HalfCheetahDirEnv, self).__init__()
31 |
32 | def step(self, action):
33 | xposbefore = self.sim.data.qpos[0]
34 | self.do_simulation(action, self.frame_skip)
35 | xposafter = self.sim.data.qpos[0]
36 |
37 | forward_vel = (xposafter - xposbefore) / self.dt
38 | forward_reward = self._goal_dir * forward_vel
39 | ctrl_cost = 0.5 * 1e-1 * np.sum(np.square(action))
40 |
41 | observation = self._get_obs()
42 | reward = forward_reward - ctrl_cost
43 | done = False
44 | infos = dict(reward_forward=forward_reward,
45 | reward_ctrl=-ctrl_cost, task=self._task)
46 | return (observation, reward, done, infos)
47 |
48 | def sample_tasks(self, num_tasks):
49 | directions = 2 * self.np_random.binomial(1, p=0.5, size=(num_tasks,)) - 1
50 | tasks = [{'direction': direction} for direction in directions]
51 | return tasks
52 |
53 | def get_all_task_idx(self):
54 | return list(range(len(self.tasks)))
55 |
56 | def reset_task(self, idx):
57 | self._task = self.tasks[idx]
58 | self._goal_dir = self._task['direction']
59 | self._goal = self._goal_dir
60 | self.reset()
61 |
--------------------------------------------------------------------------------
/rlkit/torch/networks/two_headed_mlp.py:
--------------------------------------------------------------------------------
1 | from torch import nn as nn
2 | from torch.nn import functional as F
3 |
4 | from rlkit.pythonplusplus import identity
5 | from rlkit.torch import pytorch_util as ptu
6 | from rlkit.torch.core import PyTorchModule
7 | from rlkit.torch.networks.experimental import LayerNorm
8 |
9 |
10 | class TwoHeadMlp(PyTorchModule):
11 | def __init__(
12 | self,
13 | hidden_sizes,
14 | first_head_size,
15 | second_head_size,
16 | input_size,
17 | init_w=3e-3,
18 | hidden_activation=F.relu,
19 | output_activation=identity,
20 | hidden_init=ptu.fanin_init,
21 | b_init_value=0.,
22 | layer_norm=False,
23 | layer_norm_kwargs=None,
24 | ):
25 | super().__init__()
26 |
27 | if layer_norm_kwargs is None:
28 | layer_norm_kwargs = dict()
29 |
30 | self.input_size = input_size
31 | self.first_head_size = first_head_size
32 | self.second_head_size = second_head_size
33 | self.hidden_activation = hidden_activation
34 | self.output_activation = output_activation
35 | self.layer_norm = layer_norm
36 | self.fcs = []
37 | self.layer_norms = []
38 | in_size = input_size
39 |
40 | for i, next_size in enumerate(hidden_sizes):
41 | fc = nn.Linear(in_size, next_size)
42 | in_size = next_size
43 | hidden_init(fc.weight)
44 | fc.bias.data.fill_(b_init_value)
45 | self.__setattr__("fc{}".format(i), fc)
46 | self.fcs.append(fc)
47 |
48 | if self.layer_norm:
49 | ln = LayerNorm(next_size)
50 | self.__setattr__("layer_norm{}".format(i), ln)
51 | self.layer_norms.append(ln)
52 |
53 | self.first_head = nn.Linear(in_size, self.first_head_size)
54 | self.first_head.weight.data.uniform_(-init_w, init_w)
55 |
56 | self.second_head = nn.Linear(in_size, self.second_head_size)
57 | self.second_head.weight.data.uniform_(-init_w, init_w)
58 |
59 | def forward(self, input, return_preactivations=False):
60 | h = input
61 | for i, fc in enumerate(self.fcs):
62 | h = fc(h)
63 | if self.layer_norm and i < len(self.fcs) - 1:
64 | h = self.layer_norms[i](h)
65 | h = self.hidden_activation(h)
66 | preactivation = self.first_head(h)
67 | first_output = self.output_activation(preactivation)
68 | preactivation = self.second_head(h)
69 | second_output = self.output_activation(preactivation)
70 |
71 | return first_output, second_output
72 |
--------------------------------------------------------------------------------
/rlkit/data_management/split_buffer.py:
--------------------------------------------------------------------------------
1 | import random
2 |
3 | from rlkit.data_management.replay_buffer import ReplayBuffer
4 |
5 |
6 | class SplitReplayBuffer(ReplayBuffer):
7 | """
8 | Split the data into a training and validation set.
9 | """
10 | def __init__(
11 | self,
12 | train_replay_buffer: ReplayBuffer,
13 | validation_replay_buffer: ReplayBuffer,
14 | fraction_paths_in_train,
15 | ):
16 | self.train_replay_buffer = train_replay_buffer
17 | self.validation_replay_buffer = validation_replay_buffer
18 | self.fraction_paths_in_train = fraction_paths_in_train
19 | self.replay_buffer = self.train_replay_buffer
20 |
21 | def add_sample(self, *args, **kwargs):
22 | self.replay_buffer.add_sample(*args, **kwargs)
23 |
24 | def add_path(self, path):
25 | self.replay_buffer.add_path(path)
26 | self._randomly_set_replay_buffer()
27 |
28 | def num_steps_can_sample(self):
29 | return min(
30 | self.train_replay_buffer.num_steps_can_sample(),
31 | self.validation_replay_buffer.num_steps_can_sample(),
32 | )
33 |
34 | def terminate_episode(self, *args, **kwargs):
35 | self.replay_buffer.terminate_episode(*args, **kwargs)
36 | self._randomly_set_replay_buffer()
37 |
38 | def _randomly_set_replay_buffer(self):
39 | if random.random() <= self.fraction_paths_in_train:
40 | self.replay_buffer = self.train_replay_buffer
41 | else:
42 | self.replay_buffer = self.validation_replay_buffer
43 |
44 | def get_replay_buffer(self, training=True):
45 | if training:
46 | return self.train_replay_buffer
47 | else:
48 | return self.validation_replay_buffer
49 |
50 | def random_batch(self, batch_size):
51 | return self.train_replay_buffer.random_batch(batch_size)
52 |
53 | def __getattr__(self, attrname):
54 | return getattr(self.replay_buffer, attrname)
55 |
56 | def __getstate__(self):
57 | # Do not save self.replay_buffer since it's a duplicate and seems to
58 | # cause joblib recursion issues.
59 | return dict(
60 | train_replay_buffer=self.train_replay_buffer,
61 | validation_replay_buffer=self.validation_replay_buffer,
62 | fraction_paths_in_train=self.fraction_paths_in_train,
63 | )
64 |
65 | def __setstate__(self, d):
66 | self.train_replay_buffer = d['train_replay_buffer']
67 | self.validation_replay_buffer = d['validation_replay_buffer']
68 | self.fraction_paths_in_train = d['fraction_paths_in_train']
69 | self.replay_buffer = self.train_replay_buffer
70 |
--------------------------------------------------------------------------------
/rlkit/visualization/plotter.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import matplotlib.pyplot as plt
3 |
4 | from rlkit.torch.core import eval_np
5 |
6 |
7 | class QFPolicyPlotter(object):
8 | def __init__(self, qf, policy, obs_lst, default_action, n_samples):
9 | self._qf = qf
10 | self._policy = policy
11 | self._obs_lst = obs_lst
12 | self._default_action = default_action
13 | self._n_samples = n_samples
14 |
15 | self._var_inds = np.where(np.isnan(default_action))[0]
16 | assert len(self._var_inds) == 2
17 |
18 | n_plots = len(obs_lst)
19 |
20 | x_size = 5 * n_plots
21 | y_size = 5
22 |
23 | fig = plt.figure(figsize=(x_size, y_size))
24 | self._ax_lst = []
25 | for i in range(n_plots):
26 | ax = fig.add_subplot(100 + n_plots * 10 + i + 1)
27 | ax.set_xlim((-1, 1))
28 | ax.set_ylim((-1, 1))
29 | ax.grid(True)
30 | self._ax_lst.append(ax)
31 |
32 | self._line_objects = list()
33 |
34 | def draw(self):
35 | # noinspection PyArgumentList
36 | [h.remove() for h in self._line_objects]
37 | self._line_objects = list()
38 |
39 | self._plot_level_curves()
40 | self._plot_action_samples()
41 |
42 | plt.draw()
43 | plt.pause(0.001)
44 |
45 | def _plot_level_curves(self):
46 | # Create mesh grid.
47 | xs = np.linspace(-1, 1, 50)
48 | ys = np.linspace(-1, 1, 50)
49 | xgrid, ygrid = np.meshgrid(xs, ys)
50 | N = len(xs)*len(ys)
51 |
52 | # Copy default values along the first axis and replace nans with
53 | # the mesh grid points.
54 | actions = np.tile(self._default_action, (N, 1))
55 | actions[:, self._var_inds[0]] = xgrid.ravel()
56 | actions[:, self._var_inds[1]] = ygrid.ravel()
57 |
58 | for ax, obs in zip(self._ax_lst, self._obs_lst):
59 | repeated_obs = np.repeat(
60 | obs[None],
61 | actions.shape[0],
62 | axis=0,
63 | )
64 | qs = eval_np(self._qf, repeated_obs, actions)
65 | qs = qs.reshape(xgrid.shape)
66 |
67 | cs = ax.contour(xgrid, ygrid, qs, 20)
68 | self._line_objects += cs.collections
69 | self._line_objects += ax.clabel(
70 | cs, inline=1, fontsize=10, fmt='%.2f')
71 |
72 | def _plot_action_samples(self):
73 | for ax, obs in zip(self._ax_lst, self._obs_lst):
74 | actions = self._policy.get_actions(
75 | np.ones((self._n_samples, 1)) * obs[None, :])
76 |
77 | x, y = actions[:, 0], actions[:, 1]
78 | self._line_objects += ax.plot(x, y, 'b*')
79 |
--------------------------------------------------------------------------------
/rlkit/torch/data_management/normalizer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import rlkit.torch.pytorch_util as ptu
3 | import numpy as np
4 |
5 | from rlkit.data_management.normalizer import Normalizer, FixedNormalizer
6 |
7 |
8 | class TorchNormalizer(Normalizer):
9 | """
10 | Update with np array, but de/normalize pytorch Tensors.
11 | """
12 | def normalize(self, v, clip_range=None):
13 | if not self.synchronized:
14 | self.synchronize()
15 | if clip_range is None:
16 | clip_range = self.default_clip_range
17 | mean = ptu.np_to_var(self.mean, requires_grad=False)
18 | std = ptu.np_to_var(self.std, requires_grad=False)
19 | if v.dim() == 2:
20 | # Unsqueeze along the batch use automatic broadcasting
21 | mean = mean.unsqueeze(0)
22 | std = std.unsqueeze(0)
23 | return torch.clamp((v - mean) / std, -clip_range, clip_range)
24 |
25 | def denormalize(self, v):
26 | if not self.synchronized:
27 | self.synchronize()
28 | mean = ptu.np_to_var(self.mean, requires_grad=False)
29 | std = ptu.np_to_var(self.std, requires_grad=False)
30 | if v.dim() == 2:
31 | mean = mean.unsqueeze(0)
32 | std = std.unsqueeze(0)
33 | return mean + v * std
34 |
35 |
36 | class TorchFixedNormalizer(FixedNormalizer):
37 | def normalize(self, v, clip_range=None):
38 | if clip_range is None:
39 | clip_range = self.default_clip_range
40 | mean = ptu.np_to_var(self.mean, requires_grad=False)
41 | std = ptu.np_to_var(self.std, requires_grad=False)
42 | if v.dim() == 2:
43 | # Unsqueeze along the batch use automatic broadcasting
44 | mean = mean.unsqueeze(0)
45 | std = std.unsqueeze(0)
46 | return torch.clamp((v - mean) / std, -clip_range, clip_range)
47 |
48 | def normalize_scale(self, v):
49 | """
50 | Only normalize the scale. Do not subtract the mean.
51 | """
52 | std = ptu.np_to_var(self.std, requires_grad=False)
53 | if v.dim() == 2:
54 | std = std.unsqueeze(0)
55 | return v / std
56 |
57 | def denormalize(self, v):
58 | mean = ptu.np_to_var(self.mean, requires_grad=False)
59 | std = ptu.np_to_var(self.std, requires_grad=False)
60 | if v.dim() == 2:
61 | mean = mean.unsqueeze(0)
62 | std = std.unsqueeze(0)
63 | return mean + v * std
64 |
65 | def denormalize_scale(self, v):
66 | """
67 | Only denormalize the scale. Do not add the mean.
68 | """
69 | std = ptu.np_to_var(self.std, requires_grad=False)
70 | if v.dim() == 2:
71 | std = std.unsqueeze(0)
72 | return v * std
73 |
--------------------------------------------------------------------------------
/rlkit/data_management/replay_buffer.py:
--------------------------------------------------------------------------------
1 | import abc
2 |
3 |
4 | class ReplayBuffer(object, metaclass=abc.ABCMeta):
5 | """
6 | A class used to save and replay data.
7 | """
8 |
9 | @abc.abstractmethod
10 | def add_sample(self, observation, action, reward, next_observation,
11 | terminal, **kwargs):
12 | """
13 | Add a transition tuple.
14 | """
15 | pass
16 |
17 | @abc.abstractmethod
18 | def terminate_episode(self):
19 | """
20 | Let the replay buffer know that the episode has terminated in case some
21 | special book-keeping has to happen.
22 | :return:
23 | """
24 | pass
25 |
26 | @abc.abstractmethod
27 | def num_steps_can_sample(self, **kwargs):
28 | """
29 | :return: # of unique items that can be sampled.
30 | """
31 | pass
32 |
33 | def add_path(self, path):
34 | """
35 | Add a path to the replay buffer.
36 |
37 | This default implementation naively goes through every step, but you
38 | may want to optimize this.
39 |
40 | NOTE: You should NOT call "terminate_episode" after calling add_path.
41 | It's assumed that this function handles the episode termination.
42 |
43 | :param path: Dict like one outputted by railrl.samplers.util.rollout
44 | """
45 | for i, (
46 | obs,
47 | action,
48 | reward,
49 | next_obs,
50 | terminal,
51 | agent_info,
52 | env_info
53 | ) in enumerate(zip(
54 | path["observations"],
55 | path["actions"],
56 | path["rewards"],
57 | path["next_observations"],
58 | path["terminals"],
59 | path["agent_infos"],
60 | path["env_infos"],
61 | )):
62 | self.add_sample(
63 | observation=obs,
64 | action=action,
65 | reward=reward,
66 | next_observation=next_obs,
67 | terminal=terminal,
68 | agent_info=agent_info,
69 | env_info=env_info,
70 | )
71 | self.terminate_episode()
72 |
73 | def add_paths(self, paths):
74 | for path in paths:
75 | self.add_path(path)
76 |
77 | @abc.abstractmethod
78 | def random_batch(self, batch_size):
79 | """
80 | Return a batch of size `batch_size`.
81 | :param batch_size:
82 | :return:
83 | """
84 | pass
85 |
86 | def get_diagnostics(self):
87 | return {}
88 |
89 | def get_snapshot(self):
90 | return {}
91 |
92 | def end_epoch(self, epoch):
93 | return
94 |
--------------------------------------------------------------------------------
/rlkit/torch/networks/experimental.py:
--------------------------------------------------------------------------------
1 | """
2 | Contain some self-contained modules. Maybe depend on pytorch_util.
3 | """
4 | import torch
5 | import torch.nn as nn
6 | from rlkit.torch import pytorch_util as ptu
7 |
8 |
9 | class OuterProductLinear(nn.Module):
10 | def __init__(self, in_features1, in_features2, out_features, bias=True):
11 | super().__init__()
12 | self.fc = nn.Linear(
13 | (in_features1 + 1) * (in_features2 + 1),
14 | out_features,
15 | bias=bias,
16 | )
17 |
18 | def forward(self, in1, in2):
19 | out_product_flat = ptu.double_moments(in1, in2)
20 | return self.fc(out_product_flat)
21 |
22 |
23 | class SelfOuterProductLinear(OuterProductLinear):
24 | def __init__(self, in_features, out_features, bias=True):
25 | super().__init__(in_features, in_features, out_features, bias=bias)
26 |
27 | def forward(self, input):
28 | return super().forward(input, input)
29 |
30 |
31 | class BatchSquareDiagonal(nn.Module):
32 | """
33 | Compute x^T diag(`diag_values`) x
34 | """
35 | def __init__(self, vector_size):
36 | super().__init__()
37 | self.vector_size = vector_size
38 | self.diag_mask = ptu.Variable(torch.diag(torch.ones(vector_size)),
39 | requires_grad=False)
40 |
41 | def forward(self, vector, diag_values):
42 | M = ptu.batch_diag(diag_values=diag_values, diag_mask=self.diag_mask)
43 | return ptu.batch_square_vector(vector=vector, M=M)
44 |
45 |
46 | class HuberLoss(nn.Module):
47 | def __init__(self, delta=1):
48 | super().__init__()
49 | self.huber_loss_delta1 = nn.SmoothL1Loss()
50 | self.delta = delta
51 |
52 | def forward(self, x, x_hat):
53 | loss = self.huber_loss_delta1(x / self.delta, x_hat / self.delta)
54 | return loss * self.delta * self.delta
55 |
56 |
57 | class LayerNorm(nn.Module):
58 | """
59 | Simple 1D LayerNorm.
60 | """
61 | def __init__(self, features, center=True, scale=False, eps=1e-6):
62 | super().__init__()
63 | self.center = center
64 | self.scale = scale
65 | self.eps = eps
66 | if self.scale:
67 | self.scale_param = nn.Parameter(torch.ones(features))
68 | else:
69 | self.scale_param = None
70 | if self.center:
71 | self.center_param = nn.Parameter(torch.zeros(features))
72 | else:
73 | self.center_param = None
74 |
75 | def forward(self, x):
76 | mean = x.mean(-1, keepdim=True)
77 | std = x.std(-1, keepdim=True)
78 | output = (x - mean) / (std + self.eps)
79 | if self.scale:
80 | output = output * self.scale_param
81 | if self.center:
82 | output = output + self.center_param
83 | return output
84 |
--------------------------------------------------------------------------------
/rlkit/envs/images/insert_image_env.py:
--------------------------------------------------------------------------------
1 | import typing
2 |
3 | import gym
4 | import numpy as np
5 | from gym.spaces import Box, Dict
6 |
7 | from rlkit.envs.images.env_renderer import EnvRenderer
8 |
9 | # getting Product
10 |
11 |
12 | def prod(val):
13 | res = 1
14 | for ele in val:
15 | res *= ele
16 | return res
17 |
18 |
19 | class InsertImagesEnv(gym.Wrapper):
20 | """
21 | Add an image to the observation. Usage:
22 |
23 | ```
24 | obs = env.reset()
25 | print(obs.keys()) # ['observations']
26 |
27 | new_env = InsertImageEnv(
28 | env,
29 | {
30 | 'image_observation': renderer_one,
31 | 'debugging_img': renderer_two,
32 | },
33 | )
34 | obs = new_env.reset()
35 | print(obs.keys()) # ['observations', 'image_observation', 'debugging_img']
36 | ```
37 | """
38 |
39 | def __init__(
40 | self,
41 | wrapped_env: gym.Env,
42 | renderers: typing.Dict[str, EnvRenderer],
43 | ):
44 | super().__init__(wrapped_env)
45 | spaces = self.env.observation_space.spaces.copy()
46 | for image_key, renderer in renderers.items():
47 | if renderer.image_is_normalized:
48 | img_space = Box(
49 | 0, 1, (prod(renderer.image_shape), ), dtype=np.float32)
50 | else:
51 | img_space = Box(
52 | 0, 255, (prod(renderer.image_shape), ), dtype=np.uint8)
53 | spaces[image_key] = img_space
54 | self.renderers = renderers
55 | self.observation_space = Dict(spaces)
56 | self.action_space = self.env.action_space
57 |
58 | def step(self, action):
59 | obs, reward, done, info = self.env.step(action)
60 | self._update_obs(obs)
61 | return obs, reward, done, info
62 |
63 | def reset(self):
64 | obs = self.env.reset()
65 | self._update_obs(obs)
66 | return obs
67 |
68 | def get_observation(self):
69 | obs = self.env.get_observation()
70 | self._update_obs(obs)
71 | return obs
72 |
73 | def _update_obs(self, obs):
74 | for image_key, renderer in self.renderers.items():
75 | obs[image_key] = renderer(self.env)
76 |
77 |
78 | class InsertImageEnv(InsertImagesEnv):
79 | """
80 | Add an image to the observation. Usage:
81 | ```
82 | obs = env.reset()
83 | print(obs.keys()) # ['observations']
84 |
85 | new_env = InsertImageEnv(env, renderer, image_key='pretty_picture')
86 | obs = new_env.reset()
87 | print(obs.keys()) # ['observations', 'pretty_picture']
88 | ```
89 | """
90 |
91 | def __init__(
92 | self,
93 | wrapped_env: gym.Env,
94 | renderer: EnvRenderer,
95 | image_key='image_observation',
96 | ):
97 | super().__init__(wrapped_env, {image_key: renderer})
98 |
--------------------------------------------------------------------------------
/rlkit/envs/pearl_envs/rand_param_envs/pr2_env_reach.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from gym import utils
3 | from rlkit.envs.pearl_envs.rand_param_envs.base import RandomEnv
4 | import os
5 |
6 | class PR2Env(RandomEnv, utils.EzPickle):
7 |
8 | FILE = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'assets/pr2.xml')
9 |
10 | def __init__(self, log_scale_limit=1.):
11 | self.viewer = None
12 | RandomEnv.__init__(self, log_scale_limit, 'pr2.xml', 4)
13 | utils.EzPickle.__init__(self)
14 |
15 | def _get_obs(self):
16 | return np.concatenate([
17 | self.model.data.qpos.flat[:7],
18 | self.model.data.qvel.flat[:7], # Do not include the velocity of the target (should be 0).
19 | self.get_tip_position().flat,
20 | self.get_vec_tip_to_goal().flat,
21 | ])
22 |
23 | def get_tip_position(self):
24 | return self.model.data.site_xpos[0]
25 |
26 | def get_vec_tip_to_goal(self):
27 | tip_position = self.get_tip_position()
28 | goal_position = self.goal
29 | vec_tip_to_goal = goal_position - tip_position
30 | return vec_tip_to_goal
31 |
32 | @property
33 | def goal(self):
34 | return self.model.data.qpos.flat[-3:]
35 |
36 | def _step(self, action):
37 |
38 | self.do_simulation(action, self.frame_skip)
39 |
40 | vec_tip_to_goal = self.get_vec_tip_to_goal()
41 | distance_tip_to_goal = np.linalg.norm(vec_tip_to_goal)
42 |
43 | reward = - distance_tip_to_goal
44 |
45 | state = self.state_vector()
46 | notdone = np.isfinite(state).all()
47 | done = not notdone
48 |
49 | ob = self._get_obs()
50 |
51 | return ob, reward, done, {}
52 |
53 | def reset_model(self):
54 | qpos = self.init_qpos
55 | qvel = self.init_qvel
56 | goal = np.random.uniform((0.2, -0.4, 0.5), (0.5, 0.4, 1.5))
57 | qpos[-3:] = goal
58 | qpos[:7] += self.np_random.uniform(low=-.005, high=.005, size=7)
59 | qvel[:7] += self.np_random.uniform(low=-.005, high=.005, size=7)
60 | self.set_state(qpos, qvel)
61 | return self._get_obs()
62 |
63 | def viewer_setup(self):
64 | self.viewer.cam.distance = self.model.stat.extent * 2
65 | # self.viewer.cam.lookat[2] += .8
66 | self.viewer.cam.elevation = -50
67 | # self.viewer.cam.lookat[0] = self.model.stat.center[0]
68 | # self.viewer.cam.lookat[1] = self.model.stat.center[1]
69 | # self.viewer.cam.lookat[2] = self.model.stat.center[2]
70 |
71 |
72 | if __name__ == "__main__":
73 |
74 | env = PR2Env()
75 | tasks = env.sample_tasks(40)
76 | while True:
77 | env.reset()
78 | env.set_task(np.random.choice(tasks))
79 | print(env.model.body_mass)
80 | for _ in range(100):
81 | env.render()
82 | env.step(env.action_space.sample())
83 |
--------------------------------------------------------------------------------
/rlkit/envs/mujoco/pusher.py:
--------------------------------------------------------------------------------
1 | from collections import OrderedDict
2 |
3 | import numpy as np
4 | from gym.envs.mujoco import PusherEnv as GymPusherEnv
5 |
6 | from rlkit.core.eval_util import create_stats_ordered_dict, get_stat_in_paths
7 | from rlkit.core import logger
8 |
9 |
10 | class PusherEnv(GymPusherEnv):
11 | def __init__(self):
12 | self.goal_cylinder_relative_x = 0
13 | self.goal_cylinder_relative_y = 0
14 | super().__init__()
15 |
16 | def reset_model(self):
17 | qpos = self.init_qpos
18 | goal_xy = np.array([
19 | self.goal_cylinder_relative_x,
20 | self.goal_cylinder_relative_y,
21 | ])
22 |
23 | while True:
24 | self.cylinder_pos = np.concatenate([
25 | self.np_random.uniform(low=-0.3, high=0, size=1),
26 | self.np_random.uniform(low=-0.2, high=0.2, size=1)])
27 | if np.linalg.norm(self.cylinder_pos - goal_xy) > 0.17:
28 | break
29 |
30 | qpos[-4:-2] = self.cylinder_pos
31 | # y-axis comes first in the xml
32 | qpos[-2] = self.goal_cylinder_relative_y
33 | qpos[-1] = self.goal_cylinder_relative_x
34 | qvel = self.init_qvel + self.np_random.uniform(low=-0.005,
35 | high=0.005,
36 | size=self.model.nv)
37 | qvel[-4:] = 0
38 | self.set_state(qpos, qvel)
39 | return self._get_obs()
40 |
41 | def _step(self, a):
42 | arm_to_object = (
43 | self.get_body_com("tips_arm") - self.get_body_com("object")
44 | )
45 | object_to_goal = (
46 | self.get_body_com("object") - self.get_body_com("goal")
47 | )
48 | arm_to_goal = (
49 | self.get_body_com("tips_arm") - self.get_body_com("goal")
50 | )
51 | obs, reward, done, info_dict = super()._step(a)
52 | info_dict['arm to object distance'] = np.linalg.norm(arm_to_object)
53 | info_dict['object to goal distance'] = np.linalg.norm(object_to_goal)
54 | info_dict['arm to goal distance'] = np.linalg.norm(arm_to_goal)
55 | return obs, reward, done, info_dict
56 |
57 | def log_diagnostics(self, paths):
58 | statistics = OrderedDict()
59 |
60 | for stat_name in [
61 | 'arm to object distance',
62 | 'object to goal distance',
63 | 'arm to goal distance',
64 | ]:
65 | stat = get_stat_in_paths(
66 | paths, 'env_infos', stat_name
67 | )
68 | statistics.update(create_stats_ordered_dict(
69 | stat_name, stat
70 | ))
71 |
72 | for key, value in statistics.items():
73 | logger.record_tabular(key, value)
74 |
75 | def _set_goal_xy(self, xy):
76 | # Based on XML
77 | self.goal_cylinder_relative_x = xy[0] - 0.45
78 | self.goal_cylinder_relative_y = xy[1] + 0.05
79 |
--------------------------------------------------------------------------------
/rlkit/envs/contextual/task_conditioned.py:
--------------------------------------------------------------------------------
1 | from functools import partial
2 |
3 | import numpy as np
4 |
5 | from rlkit.envs.contextual.goal_conditioned import (
6 | GoalDictDistributionFromMultitaskEnv,
7 | )
8 | from rlkit.samplers.data_collector.contextual_path_collector import (
9 | ContextualPathCollector
10 | )
11 |
12 | from gym.spaces import Box
13 | from rlkit.samplers.rollout_functions import contextual_rollout
14 |
15 | class TaskGoalDictDistributionFromMultitaskEnv(
16 | GoalDictDistributionFromMultitaskEnv):
17 | def __init__(
18 | self,
19 | *args,
20 | task_key='task_id',
21 | task_ids=None,
22 | **kwargs
23 | ):
24 | super().__init__(*args, **kwargs)
25 | self.task_key = task_key
26 | self._spaces[task_key] = Box(
27 | low=np.zeros(1),
28 | high=np.ones(1))
29 | self.task_ids = np.array(task_ids)
30 |
31 | def sample(self, batch_size: int):
32 | goals = super().sample(batch_size)
33 | idxs = np.random.choice(len(self.task_ids), batch_size)
34 | goals[self.task_key] = self.task_ids[idxs].reshape(-1, 1)
35 | return goals
36 |
37 | class TaskPathCollector(ContextualPathCollector):
38 | def __init__(
39 | self,
40 | *args,
41 | task_key=None,
42 | max_path_length=100,
43 | task_ids=None,
44 | rotate_freq=0.0,
45 | **kwargs
46 | ):
47 | super().__init__(*args, **kwargs)
48 | self.rotate_freq = rotate_freq
49 | self.rollout_tasks = []
50 |
51 | def obs_processor(o):
52 | if len(self.rollout_tasks) > 0:
53 | task = self.rollout_tasks[0]
54 | self.rollout_tasks = self.rollout_tasks[1:]
55 | o[task_key] = task
56 | self._env._rollout_context_batch[task_key] = task[None]
57 |
58 | combined_obs = [o[self._observation_key]]
59 | for k in self._context_keys_for_policy:
60 | combined_obs.append(o[k])
61 | return np.concatenate(combined_obs, axis=0)
62 |
63 | def reset_postprocess_func():
64 | rotate = (np.random.uniform() < self.rotate_freq)
65 | self.rollout_tasks = []
66 | if rotate:
67 | num_steps_per_task = max_path_length // len(task_ids)
68 | self.rollout_tasks = np.ones((max_path_length, 1)) * (len(task_ids) - 1)
69 | for (idx, id) in enumerate(task_ids):
70 | start = idx * num_steps_per_task
71 | end = start + num_steps_per_task
72 | self.rollout_tasks[start:end] = id
73 |
74 | self._rollout_fn = partial(
75 | contextual_rollout,
76 | context_keys_for_policy=self._context_keys_for_policy,
77 | observation_key=self._observation_key,
78 | obs_processor=obs_processor,
79 | reset_postprocess_func=reset_postprocess_func,
80 | )
--------------------------------------------------------------------------------
/rlkit/networks/gaussian_policy.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | from torch import nn as nn
4 |
5 | import rlkit.torch.pytorch_util as ptu
6 | from rlkit.torch.sac.policies.base import TorchStochasticPolicy
7 | from rlkit.torch.sac.policies.gaussian_policy import (
8 | LOG_SIG_MAX,
9 | LOG_SIG_MIN,
10 | )
11 | from rlkit.torch.distributions import MultivariateDiagonalNormal
12 | from rlkit.torch.networks import CNN
13 |
14 |
15 | class GaussianCNNPolicy(CNN, TorchStochasticPolicy):
16 | def __init__(
17 | self,
18 | hidden_sizes,
19 | obs_dim,
20 | action_dim,
21 | std=None,
22 | init_w=1e-3,
23 | min_log_std=None,
24 | max_log_std=None,
25 | std_architecture="shared",
26 | output_activation=None,
27 | **kwargs
28 | ):
29 | super().__init__(
30 | hidden_sizes=hidden_sizes,
31 | output_size=action_dim,
32 | init_w=init_w,
33 | output_activation=output_activation,
34 | **kwargs
35 | )
36 | self.min_log_std = min_log_std
37 | self.max_log_std = max_log_std
38 | self.log_std = None
39 | self.std = std
40 | self.std_architecture = std_architecture
41 | if std is None:
42 | if self.std_architecture == "shared":
43 | last_hidden_size = obs_dim
44 | if len(hidden_sizes) > 0:
45 | last_hidden_size = hidden_sizes[-1]
46 | self.last_fc_log_std = nn.Linear(last_hidden_size, action_dim)
47 | self.last_fc_log_std.weight.data.uniform_(-init_w, init_w)
48 | self.last_fc_log_std.bias.data.uniform_(-init_w, init_w)
49 | elif self.std_architecture == "values":
50 | self.log_std_logits = nn.Parameter(
51 | ptu.zeros(action_dim, requires_grad=True))
52 | else:
53 | raise ValueError(self.std_architecture)
54 | else:
55 | self.log_std = np.log(std)
56 | assert LOG_SIG_MIN <= self.log_std <= LOG_SIG_MAX
57 |
58 | def forward(self, obs):
59 | h = super().forward(obs, return_last_activations=True)
60 | mean = self.last_fc(h)
61 | if self.output_activation is not None:
62 | mean = self.output_activation(mean)
63 | if self.std is None:
64 | if self.std_architecture == "shared":
65 | log_std = torch.sigmoid(self.last_fc_log_std(h))
66 | elif self.std_architecture == "values":
67 | log_std = torch.sigmoid(self.log_std_logits)
68 | else:
69 | raise ValueError(self.std_architecture)
70 | log_std = self.min_log_std + log_std * (
71 | self.max_log_std - self.min_log_std)
72 | std = torch.exp(log_std)
73 | else:
74 | std = torch.from_numpy(np.array([self.std, ])).float().to(
75 | ptu.device)
76 |
77 | return MultivariateDiagonalNormal(mean, std)
78 |
--------------------------------------------------------------------------------
/rlkit/envs/dual_encoder_wrapper.py:
--------------------------------------------------------------------------------
1 | import abc
2 | import numpy as np
3 | import rlkit.torch.pytorch_util as ptu
4 | from gym.spaces import Box, Dict
5 | from rlkit.envs.vae_wrappers import VAEWrappedEnv
6 | from rlkit.envs.wrappers import ProxyEnv
7 |
8 | from rlkit.envs.encoder_wrappers import Encoder
9 |
10 | class DualEncoderWrappedEnv(ProxyEnv):
11 | def __init__(self,
12 | wrapped_env,
13 | model: Encoder,
14 | input_model,
15 | step_keys_map=None,
16 | reset_keys_map=None,
17 | conditional_input_model=False,
18 | ):
19 | super().__init__(wrapped_env)
20 | self.model = model
21 | self.input_model = input_model
22 | self.conditional_input_model = conditional_input_model
23 | self.representation_size = self.model.representation_size
24 | self.input_representation_size = self.input_model.representation_size
25 | latent_space = Box(
26 | -10 * np.ones(self.representation_size),
27 | 10 * np.ones(self.representation_size),
28 | dtype=np.float32,
29 | )
30 | input_latent_space = Box(
31 | -10 * np.ones(self.input_representation_size),
32 | 10 * np.ones(self.input_representation_size),
33 | dtype=np.float32,
34 | )
35 |
36 | if step_keys_map is None:
37 | step_keys_map = {}
38 | if reset_keys_map is None:
39 | reset_keys_map = {}
40 | self.step_keys_map = step_keys_map
41 | self.reset_keys_map = reset_keys_map
42 | spaces = self.wrapped_env.observation_space.spaces
43 | for value in self.step_keys_map.values():
44 | spaces[value] = latent_space
45 | for value in self.reset_keys_map.values():
46 | spaces[value] = latent_space
47 |
48 | spaces['input_latent'] = input_latent_space
49 | self.observation_space = Dict(spaces)
50 | self.reset_obs = {}
51 |
52 | def step(self, action):
53 | self.model.eval()
54 | obs, reward, done, info = self.wrapped_env.step(action)
55 | new_obs = self._update_obs(obs)
56 | return new_obs, reward, done, info
57 |
58 | def _update_obs(self, obs):
59 | self.model.eval()
60 | for key in self.step_keys_map:
61 | value = self.step_keys_map[key]
62 | obs[value] = self.model.encode_one_np(obs[key])
63 | obs = {**obs, **self.reset_obs}
64 |
65 | if self.conditional_input_model:
66 | obs['input_latent'] = self.input_model.encode_one_np(obs['image_observation'], self._initial_img)
67 | else:
68 | obs['input_latent'] = self.input_model.encode_one_np(obs['image_observation'])
69 |
70 | return obs
71 |
72 | def reset(self):
73 | self.model.eval()
74 | obs = self.wrapped_env.reset()
75 | self._initial_img = obs["image_observation"]
76 | for key in self.reset_keys_map:
77 | value = self.reset_keys_map[key]
78 | self.reset_obs[value] = self.model.encode_one_np(obs[key])
79 | obs = self._update_obs(obs)
80 | return obs
81 |
--------------------------------------------------------------------------------
/rlkit/launchers/doodad_wrapper.py:
--------------------------------------------------------------------------------
1 | import os
2 | import time
3 | from typing import NamedTuple
4 | import random
5 |
6 | import __main__ as main
7 | import numpy as np
8 |
9 | import rlkit.torch.pytorch_util as ptu
10 | from rlkit.core import logger, setup_logger
11 | from rlkit.launchers import config
12 |
13 | from doodad.wrappers.easy_launch import save_doodad_config, sweep_function, DoodadConfig
14 | from rlkit.launchers.launcher_util import set_seed
15 |
16 |
17 | class AutoSetup:
18 | """
19 | Automatically set up:
20 | 1. the logger
21 | 2. the GPU mode
22 | 3. the seed
23 | :param exp_function: some function that should not depend on `logger_config`
24 | nor `seed`.
25 | :param unpack_variant: do you call exp_function with `**variant`?
26 | :return: function output
27 | """
28 | def __init__(self, exp_function, unpack_variant=True):
29 | self.exp_function = exp_function
30 | self.unpack_variant = unpack_variant
31 |
32 | def __call__(self, doodad_config: DoodadConfig, variant):
33 | save_doodad_config(doodad_config)
34 | variant_to_save = variant.copy()
35 | variant_to_save['doodad_info'] = doodad_config.extra_launch_info
36 | exp_name = doodad_config.extra_launch_info['exp_name']
37 | seed = variant.pop('seed', 0)
38 | set_seed(seed)
39 | ptu.set_gpu_mode(doodad_config.use_gpu)
40 | # Reopening the files is nececessary because blobfuse only syncs files
41 | # when they're closed. For details, see
42 | # https://github.com/Azure/azure-storage-fuse#if-your-workload-is-not-read-only
43 | reopen_files_on_flush = True
44 | # might as well always have it on, but if I didn't want to, you could:
45 | # mode = doodad_config.extra_launch_info['mode']
46 | # reopen_files_on_flush = mode == 'azure'
47 | setup_logger(
48 | logger,
49 | exp_name=exp_name,
50 | base_log_dir=None,
51 | log_dir=doodad_config.output_directory,
52 | seed=seed,
53 | variant=variant,
54 | reopen_files_on_flush=reopen_files_on_flush,
55 | )
56 | variant.pop('logger_config', None)
57 | variant.pop('exp_id', None)
58 | variant.pop('run_id', None)
59 | if self.unpack_variant:
60 | self.exp_function(**variant)
61 | else:
62 | self.exp_function(variant)
63 |
64 |
65 | def run_experiment(
66 | method_call,
67 | params,
68 | default_params,
69 | exp_name='default',
70 | mode='local',
71 | wrap_fn_with_auto_setup=True,
72 | unpack_variant=True,
73 | **kwargs
74 | ):
75 | if wrap_fn_with_auto_setup:
76 | method_call = AutoSetup(
77 | method_call,
78 | unpack_variant=unpack_variant,
79 | )
80 | sweep_function(
81 | method_call,
82 | params,
83 | default_params=default_params,
84 | mode=mode,
85 | log_path=exp_name,
86 | add_time_to_run_id='in_front',
87 | extra_launch_info={'exp_name': exp_name, 'mode': mode},
88 | **kwargs
89 | )
90 |
--------------------------------------------------------------------------------
/rlkit/samplers/data_collector/joint_path_collector.py:
--------------------------------------------------------------------------------
1 | from collections import OrderedDict
2 | from typing import Dict
3 |
4 | from rlkit.core.logging import add_prefix
5 | from rlkit.samplers.data_collector import PathCollector
6 |
7 |
8 | class JointPathCollector(PathCollector):
9 | EVENLY = 'evenly'
10 |
11 | def __init__(
12 | self,
13 | path_collectors: Dict[str, PathCollector],
14 | divide_num_steps_strategy=EVENLY,
15 | ):
16 | """
17 | :param path_collectors: Dictionary of path collectors
18 | :param divide_num_steps_strategy: How the steps are divided among the
19 | path collectors.
20 | Valid values:
21 | - 'evenly': divide `num_steps' evenly among the collectors
22 | """
23 | sorted_collectors = OrderedDict()
24 | # Sort the path collectors to have a canonical ordering
25 | for k in sorted(path_collectors):
26 | sorted_collectors[k] = path_collectors[k]
27 | self.path_collectors = sorted_collectors
28 | self.divide_num_steps_strategy = divide_num_steps_strategy
29 | if divide_num_steps_strategy not in {self.EVENLY}:
30 | raise ValueError(divide_num_steps_strategy)
31 |
32 | def collect_new_paths(self, max_path_length, num_steps,
33 | discard_incomplete_paths,
34 | **kwargs):
35 | paths = []
36 | if self.divide_num_steps_strategy == self.EVENLY:
37 | num_steps_per_collector = num_steps // len(self.path_collectors)
38 | else:
39 | raise ValueError(self.divide_num_steps_strategy)
40 | for name, collector in self.path_collectors.items():
41 | paths += collector.collect_new_paths(
42 | max_path_length=max_path_length,
43 | num_steps=num_steps_per_collector,
44 | discard_incomplete_paths=discard_incomplete_paths,
45 | **kwargs
46 | )
47 | return paths
48 |
49 | def end_epoch(self, epoch):
50 | for collector in self.path_collectors.values():
51 | collector.end_epoch(epoch)
52 |
53 | def get_diagnostics(self):
54 | diagnostics = OrderedDict()
55 | num_steps = 0
56 | num_paths = 0
57 | for name, collector in self.path_collectors.items():
58 | stats = collector.get_diagnostics()
59 | num_steps += stats['num steps total']
60 | num_paths += stats['num paths total']
61 | diagnostics.update(
62 | add_prefix(stats, name, divider='/'),
63 | )
64 | diagnostics['num steps total'] = num_steps
65 | diagnostics['num paths total'] = num_paths
66 | return diagnostics
67 |
68 | def get_snapshot(self):
69 | snapshot = {}
70 | for name, collector in self.path_collectors.items():
71 | snapshot.update(
72 | add_prefix(collector.get_snapshot(), name, divider='/'),
73 | )
74 | return snapshot
75 |
76 | def get_epoch_paths(self):
77 | paths = {}
78 | for name, collector in self.path_collectors.items():
79 | paths[name] = collector.get_epoch_paths()
80 | return paths
81 |
--------------------------------------------------------------------------------
/rlkit/util/tensorboard_logger.py:
--------------------------------------------------------------------------------
1 | # Code referenced from https://gist.github.com/gyglim/1f8dfb1b5c82627ae3efcfbbadb9f514
2 | import tensorflow as tf
3 | import numpy as np
4 | import scipy.misc
5 | try:
6 | from StringIO import StringIO # Python 2.7
7 | except ImportError:
8 | from io import BytesIO # Python 3.x
9 |
10 |
11 | class TensorboardLogger(object):
12 | def __init__(self, log_dir):
13 | """Create a summary writer logging to log_dir."""
14 | self.writer = tf.summary.FileWriter(log_dir)
15 |
16 | def scalar_summary(self, tag, value, step):
17 | """Log a scalar variable."""
18 | summary = tf.Summary(value=[tf.Summary.Value(tag=tag, simple_value=value)])
19 | self.writer.add_summary(summary, step)
20 |
21 | def image_summary(self, tag, images, step):
22 | """Log a list of images."""
23 |
24 | img_summaries = []
25 | for i, img in enumerate(images):
26 | # Write the image to a string
27 | try:
28 | s = StringIO()
29 | except:
30 | s = BytesIO()
31 | scipy.misc.toimage(img).save(s, format="png")
32 |
33 | # Create an Image object
34 | img_sum = tf.Summary.Image(encoded_image_string=s.getvalue(),
35 | height=img.shape[0],
36 | width=img.shape[1])
37 | # Create a Summary value
38 | img_summaries.append(tf.Summary.Value(tag='%s/%d' % (tag, i), image=img_sum))
39 |
40 | # Create and write Summary
41 | summary = tf.Summary(value=img_summaries)
42 | self.writer.add_summary(summary, step)
43 |
44 |
45 | def histo_summary(self, tag, values, step, bins=1000):
46 | """Log a histogram of the tensor of values."""
47 |
48 | # Create a histogram using numpy
49 | # Original code:
50 | # counts, bin_edges = np.histogram(values, bins=bins)
51 | #
52 | # Now I use the following to get more detailed data
53 | # https://stackoverflow.com/questions/39418380/histogram-with-equal-number-of-points-in-each-bin
54 | counts, bin_edges = np.histogram(
55 | values,
56 | bins=histedges_equalN(values.flatten(), bins)
57 | )
58 |
59 | # Fill the fields of the histogram proto
60 | hist = tf.HistogramProto()
61 | hist.min = float(np.min(values))
62 | hist.max = float(np.max(values))
63 | hist.num = int(np.prod(values.shape))
64 | hist.sum = float(np.sum(values))
65 | hist.sum_squares = float(np.sum(values**2))
66 |
67 | # Drop the start of the first bin
68 | bin_edges = bin_edges[1:]
69 |
70 | # Add bin edges and counts
71 | for edge in bin_edges:
72 | hist.bucket_limit.append(edge)
73 | for c in counts:
74 | hist.bucket.append(c)
75 |
76 | # Create and write Summary
77 | summary = tf.Summary(value=[tf.Summary.Value(tag=tag, histo=hist)])
78 | self.writer.add_summary(summary, step)
79 | self.writer.flush()
80 |
81 |
82 | def histedges_equalN(x, nbin):
83 | npt = len(x)
84 | return np.interp(np.linspace(0, npt, nbin + 1),
85 | np.arange(npt),
86 | np.sort(x))
87 |
--------------------------------------------------------------------------------