├── rlkit ├── __init__.py ├── demos │ ├── __init__.py │ ├── source │ │ ├── __init__.py │ │ ├── demo_source.py │ │ ├── path_loader.py │ │ └── hand_demo_source.py │ ├── spacemouse │ │ ├── __init__.py │ │ ├── input_client.py │ │ └── README.md │ ├── her_td3bc.py │ ├── play_demo.py │ └── her_bc.py ├── torch │ ├── __init__.py │ ├── sac │ │ ├── __init__.py │ │ └── policies │ │ │ └── base.py │ ├── data_management │ │ ├── __init__.py │ │ └── normalizer.py │ ├── networks │ │ ├── stochastic │ │ │ └── __init__.py │ │ ├── custom.py │ │ ├── linear_transform.py │ │ ├── ae_tanh_policy.py │ │ ├── __init__.py │ │ ├── image_state.py │ │ ├── two_headed_mlp.py │ │ └── experimental.py │ ├── transforms │ │ ├── __init__.py │ │ └── _pil_constants.py │ ├── her │ │ └── her.py │ └── core.py ├── utils │ ├── __init__.py │ ├── val_util.py │ ├── experiment_util.py │ ├── device_util.py │ ├── logging.config │ ├── logging.py │ ├── timer.py │ ├── path_builder.py │ └── image_util.py ├── learning │ ├── __init__.py │ └── online_offline_split_replay_buffer.py ├── networks │ ├── __init__.py │ ├── utils.py │ └── gaussian_policy.py ├── samplers │ ├── __init__.py │ ├── data_collector │ │ ├── __init__.py │ │ ├── base.py │ │ ├── contextual_path_collector.py │ │ └── joint_path_collector.py │ ├── in_place.py │ └── util.py ├── data_management │ ├── __init__.py │ ├── external │ │ ├── __init__.py │ │ └── bair_dataset │ │ │ └── __init__.py │ ├── wrappers │ │ ├── concat_to_obs_wrapper.py │ │ ├── proxy_buffer.py │ │ └── replay_buffer_wrapper.py │ ├── path_builder.py │ ├── dataset_logger_fn.py │ ├── images.py │ ├── tau_replay_buffer.py │ ├── ocm_subtraj_replay_buffer.py │ ├── split_buffer.py │ └── replay_buffer.py ├── envs │ ├── memory │ │ ├── __init__.py │ │ └── hidden_cartpole.py │ ├── mujoco │ │ ├── __init__.py │ │ ├── point_env.py │ │ ├── twod_point_random_init.py │ │ ├── reacher_env.py │ │ ├── twod_point.py │ │ ├── twod_maze.py │ │ ├── oned_point.py │ │ ├── hopper_env.py │ │ ├── ant.py │ │ ├── mujoco_env_murtaza.py │ │ ├── mujoco_env.py │ │ └── pusher.py │ ├── pygame │ │ └── __init__.py │ ├── simple │ │ ├── __init__.py │ │ └── point.py │ ├── pearl_envs │ │ ├── rand_param_envs │ │ │ ├── __init__.py │ │ │ ├── hopper_rand_params.py │ │ │ ├── walker2d_rand_params.py │ │ │ └── pr2_env_reach.py │ │ ├── hopper_rand_params_wrapper.py │ │ ├── walker_rand_params_wrapper.py │ │ ├── ant_normal.py │ │ ├── half_cheetah.py │ │ ├── ant_multitask_base.py │ │ ├── ant_goal.py │ │ ├── ant_dir.py │ │ ├── mujoco_env.py │ │ ├── __init__.py │ │ ├── ant.py │ │ ├── humanoid_dir.py │ │ ├── half_cheetah_vel.py │ │ └── half_cheetah_dir.py │ ├── base.py │ ├── contextual │ │ ├── __init__.py │ │ └── task_conditioned.py │ ├── sawyer_parallel_hack.py │ ├── images │ │ ├── __init__.py │ │ ├── text_renderer.py │ │ ├── env_renderer.py │ │ └── insert_image_env.py │ ├── __init__.py │ ├── wrappers │ │ ├── image_mujoco_env_with_obs.py │ │ ├── reward_wrapper_env.py │ │ ├── discretize_env.py │ │ ├── __init__.py │ │ ├── modular_env.py │ │ ├── flat_to_dict.py │ │ ├── stack_observation_env.py │ │ ├── history_env.py │ │ └── normalized_box_env.py │ ├── assets │ │ ├── oned_point.xml │ │ ├── twod_point.xml │ │ ├── twod_point_random_init.xml │ │ ├── twod_maze.xml │ │ ├── water_maze.xml │ │ └── small_water_maze.xml │ ├── gridcraft │ │ ├── utils.py │ │ └── custom_test.py │ ├── env_utils.py │ ├── time_limited_env.py │ ├── supervised_learning_env.py │ ├── robosuite_wrapper.py │ ├── gripper_state_wrapper.py │ ├── proxy_env.py │ ├── make_env.py │ └── dual_encoder_wrapper.py ├── visualization │ ├── __init__.py │ └── plotter.py ├── core │ ├── __init__.py │ ├── trainer.py │ ├── loss.py │ ├── distribution.py │ ├── util.py │ ├── ray_csv_logger.py │ ├── timer.py │ └── serializeable.py ├── util │ ├── gym_util.py │ ├── random_util.py │ ├── __init__.py │ ├── ray_util.py │ ├── inspect_q_util.py │ └── tensorboard_logger.py ├── launchers │ ├── __init__.py │ ├── contextual │ │ └── util.py │ └── doodad_wrapper.py ├── policies │ ├── base.py │ └── action_repeat.py └── exploration_strategies │ ├── noop.py │ ├── epsilon_greedy.py │ ├── gaussian_strategy.py │ ├── gaussian_and_epsilon.py │ ├── ou_strategy.py │ └── base.py ├── stable_contrastive_rl.png └── LICENSE /rlkit/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /rlkit/demos/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /rlkit/torch/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /rlkit/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /rlkit/learning/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /rlkit/networks/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /rlkit/samplers/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /rlkit/torch/sac/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /rlkit/data_management/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /rlkit/demos/source/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /rlkit/envs/memory/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /rlkit/envs/mujoco/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /rlkit/envs/pygame/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /rlkit/envs/simple/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /rlkit/visualization/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /rlkit/demos/spacemouse/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /rlkit/data_management/external/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /rlkit/torch/data_management/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /rlkit/torch/networks/stochastic/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /rlkit/envs/pearl_envs/rand_param_envs/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /rlkit/data_management/external/bair_dataset/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /rlkit/torch/transforms/__init__.py: -------------------------------------------------------------------------------- 1 | from .transforms import * -------------------------------------------------------------------------------- /rlkit/torch/networks/custom.py: -------------------------------------------------------------------------------- 1 | """ 2 | Random networks 3 | """ 4 | -------------------------------------------------------------------------------- /stable_contrastive_rl.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chongyi-zheng/stable_contrastive_rl/HEAD/stable_contrastive_rl.png -------------------------------------------------------------------------------- /rlkit/demos/source/demo_source.py: -------------------------------------------------------------------------------- 1 | class DemoSource: 2 | def load_paths(self): 3 | """Should return a list of paths in PathBuilder format""" 4 | return [{}, ] 5 | -------------------------------------------------------------------------------- /rlkit/core/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | General classes, functions, utilities that are used throughout railrl. 3 | """ 4 | from rlkit.core.logging import logger, setup_logger 5 | 6 | __all__ = ["logger", "setup_logger"] 7 | -------------------------------------------------------------------------------- /rlkit/envs/base.py: -------------------------------------------------------------------------------- 1 | import abc 2 | 3 | 4 | class RolloutEnv(object): 5 | """ Environment that supports full rollouts.""" 6 | 7 | @abc.abstractmethod 8 | def rollout(self, *args, **kwargs): 9 | pass 10 | -------------------------------------------------------------------------------- /rlkit/util/gym_util.py: -------------------------------------------------------------------------------- 1 | from gym.envs import registration 2 | 3 | 4 | def get_class_and_kwargs(spec_or_id): 5 | if isinstance(spec_or_id, registration.EnvSpec): 6 | spec = spec_or_id 7 | else: 8 | spec = registration.spec(spec_or_id) 9 | return registration.load(spec._entry_point), spec._kwargs -------------------------------------------------------------------------------- /rlkit/envs/contextual/__init__.py: -------------------------------------------------------------------------------- 1 | from rlkit.envs.contextual.contextual_env import ( 2 | ContextualEnv, 3 | ContextualRewardFn, 4 | delete_info, 5 | insert_reward, 6 | ) 7 | 8 | __all__ = [ 9 | 'ContextualEnv', 10 | 'ContextualRewardFn', 11 | 'delete_info', 12 | 'insert_reward', 13 | ] -------------------------------------------------------------------------------- /rlkit/torch/networks/linear_transform.py: -------------------------------------------------------------------------------- 1 | from rlkit.torch.core import PyTorchModule 2 | 3 | 4 | class LinearTransform(PyTorchModule): 5 | def __init__(self, m, b): 6 | super().__init__() 7 | self.m = m 8 | self.b = b 9 | 10 | def __call__(self, t): 11 | return self.m * t + self.b 12 | -------------------------------------------------------------------------------- /rlkit/utils/val_util.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def preprocess_image(data): 5 | _shape = list(data.shape[:-3]) 6 | data = np.reshape(data, [-1, 3, 48, 48]) 7 | data = np.transpose(data, [0, 1, 3, 2]) 8 | data = np.reshape(data, _shape + [3, 48, 48]) 9 | data = data - 0.5 10 | return data 11 | -------------------------------------------------------------------------------- /rlkit/envs/sawyer_parallel_hack.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import warnings 3 | 4 | idx = 0 5 | while idx < len(sys.path): 6 | if 'sawyer_control' in sys.path[idx]: 7 | warnings.warn("Undoing ros generated __init__ for parallel until a better fix is found") 8 | del sys.path[idx] 9 | else: 10 | idx += 1 11 | 12 | 13 | -------------------------------------------------------------------------------- /rlkit/core/trainer.py: -------------------------------------------------------------------------------- 1 | import abc 2 | 3 | 4 | class Trainer(object, metaclass=abc.ABCMeta): 5 | @abc.abstractmethod 6 | def train(self, data): 7 | pass 8 | 9 | def end_epoch(self, epoch): 10 | pass 11 | 12 | def get_snapshot(self): 13 | return {} 14 | 15 | def get_diagnostics(self): 16 | return {} 17 | -------------------------------------------------------------------------------- /rlkit/launchers/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | This module contains 'launchers', which are self-contained functions that take 3 | one dictionary and run a full experiment. The dictionary configures the 4 | experiment. 5 | 6 | It is important that the functions are completely self-contained (i.e. they 7 | import their own modules) so that they can be serialized. 8 | """ 9 | -------------------------------------------------------------------------------- /rlkit/util/random_util.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def random_point_in_circle(angle_range=(0, 2*np.pi), radius=(0, 25)): 5 | angle = np.random.uniform(*angle_range) 6 | radius = radius if np.isscalar(radius) else np.random.uniform(*radius) 7 | x, y = np.cos(angle) * radius, np.sin(angle) * radius 8 | point = np.array([x, y]) 9 | return point 10 | -------------------------------------------------------------------------------- /rlkit/utils/experiment_util.py: -------------------------------------------------------------------------------- 1 | # import sys 2 | # import os 3 | # import stat 4 | 5 | 6 | from rlkit.launchers import launcher_util as lu 7 | 8 | 9 | def run_variant(experiment, variant): 10 | launcher_config = variant.get("launcher_config") 11 | lu.run_experiment( 12 | experiment, 13 | variant=variant, 14 | **launcher_config, 15 | ) 16 | -------------------------------------------------------------------------------- /rlkit/envs/images/__init__.py: -------------------------------------------------------------------------------- 1 | from rlkit.envs.images.renderer import Renderer 2 | from rlkit.envs.images.env_renderer import EnvRenderer, GymEnvRenderer 3 | from rlkit.envs.images.insert_image_env import InsertImageEnv, InsertImagesEnv 4 | 5 | __all__ = [ 6 | 'Renderer', 7 | 'EnvRenderer', 8 | 'GymEnvRenderer', 9 | 'InsertImageEnv', 10 | 'InsertImagesEnv', 11 | ] 12 | 13 | 14 | -------------------------------------------------------------------------------- /rlkit/util/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Modules in here are self-contained modules, that depend only core libraries 3 | like numpy and pythonplusplus. However, they should NOT depend on things that 4 | are specific to railrl or rllab. 5 | 6 | The only exception is when an external dependency is explicit in the name of 7 | the module, e.g. rllab_util can depend on rllab. hyperopt can depend on 8 | hyperopt, etc. 9 | """ -------------------------------------------------------------------------------- /rlkit/utils/device_util.py: -------------------------------------------------------------------------------- 1 | from absl import logging # NOQA 2 | 3 | import torch 4 | 5 | import rlkit.torch.pytorch_util as ptu 6 | 7 | 8 | def set_device(use_gpu, gpu_id=None): 9 | if use_gpu: 10 | assert torch.cuda.is_available() 11 | 12 | if gpu_id is None: 13 | gpu_id = 0 14 | 15 | else: 16 | gpu_id = -1 17 | 18 | ptu.set_gpu_mode(mode=use_gpu, gpu_id=gpu_id) 19 | logging.info('Device: %r', ptu.device) 20 | -------------------------------------------------------------------------------- /rlkit/samplers/data_collector/__init__.py: -------------------------------------------------------------------------------- 1 | from rlkit.samplers.data_collector.base import ( 2 | DataCollector, 3 | PathCollector, 4 | StepCollector, 5 | ) 6 | from rlkit.samplers.data_collector.path_collector import ( 7 | MdpPathCollector, 8 | ObsDictPathCollector, 9 | GoalConditionedPathCollector, 10 | VAEWrappedEnvPathCollector, 11 | ) 12 | from rlkit.samplers.data_collector.step_collector import ( 13 | GoalConditionedStepCollector 14 | ) 15 | -------------------------------------------------------------------------------- /rlkit/utils/logging.config: -------------------------------------------------------------------------------- 1 | [loggers] 2 | keys=root 3 | 4 | [handlers] 5 | keys=consoleHandler 6 | 7 | [formatters] 8 | keys=simpleFormatter 9 | 10 | [logger_root] 11 | level=DEBUG 12 | handlers=consoleHandler 13 | 14 | [handler_consoleHandler] 15 | class=StreamHandler 16 | level=DEBUG 17 | formatter=simpleFormatter 18 | args=(sys.stdout,) 19 | 20 | [formatter_simpleFormatter] 21 | format=[%(asctime)s.%(msecs)03d %(filename)s:%(lineno)d %(levelname)s] %(message)s 22 | datefmt=%m-%d %H:%M:%S 23 | -------------------------------------------------------------------------------- /rlkit/policies/base.py: -------------------------------------------------------------------------------- 1 | import abc 2 | 3 | class Policy(object, metaclass=abc.ABCMeta): 4 | """ 5 | General policy interface. 6 | """ 7 | @abc.abstractmethod 8 | def get_action(self, observation): 9 | """ 10 | 11 | :param observation: 12 | :return: action, debug_dictionary 13 | """ 14 | pass 15 | 16 | def reset(self): 17 | pass 18 | 19 | 20 | class ExplorationPolicy(Policy, metaclass=abc.ABCMeta): 21 | def set_num_steps_total(self, t): 22 | pass 23 | -------------------------------------------------------------------------------- /rlkit/utils/logging.py: -------------------------------------------------------------------------------- 1 | """Logging utilites. 2 | """ 3 | 4 | from __future__ import absolute_import 5 | from __future__ import division 6 | from __future__ import print_function 7 | 8 | import sys # NOQA 9 | import os.path 10 | import logging 11 | import logging.config 12 | 13 | 14 | try: 15 | # if 'absl' not in sys.modules: 16 | config_path = os.path.join(os.path.dirname(__file__), 'logging.config') 17 | logging.config.fileConfig(config_path) 18 | except Exception: 19 | print('Unable to set the formatters for logging.') 20 | 21 | 22 | logger = logging.getLogger('root') 23 | -------------------------------------------------------------------------------- /rlkit/core/loss.py: -------------------------------------------------------------------------------- 1 | import abc 2 | from collections import OrderedDict 3 | 4 | 5 | LossStatistics = OrderedDict 6 | 7 | 8 | class LossFunction(object, metaclass=abc.ABCMeta): 9 | @abc.abstractmethod 10 | def compute_loss(self, batch, skip_statistics=False, **kwargs): 11 | """Returns loss and statistics given a batch of data. 12 | batch : Data to compute loss of 13 | skip_statistics: Whether statistics should be calculated. If True, then 14 | an empty dict is returned for the statistics. 15 | 16 | Returns: (loss, stats) tuple. 17 | """ 18 | pass 19 | -------------------------------------------------------------------------------- /rlkit/envs/__init__.py: -------------------------------------------------------------------------------- 1 | # from envs.env_utils import register_environments 2 | from gym.envs.registration import register 3 | 4 | register(id='OneDPoint-v0', entry_point='railrl.envs.oned_point:OneDPoint') 5 | register(id='TwoDPoint-v0', entry_point='railrl.envs.twod_point:TwoDPoint') 6 | register(id='TwoDPointRandomInit-v0', 7 | entry_point='railrl.envs.twod_point_random_init:TwoDPointRandomInit') 8 | register(id='TwoDMaze-v0', entry_point='railrl.envs.twod_maze:TwoDMaze') 9 | register(id='WaterMaze-v0', entry_point='railrl.envs.water_maze:WaterMaze') 10 | register(id='Sawyer-v0', entry_point='railrl.envs.sawyer_env:SawyerEnv') 11 | -------------------------------------------------------------------------------- /rlkit/envs/wrappers/image_mujoco_env_with_obs.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from gym.spaces import Box 3 | 4 | from rlkit.envs.wrappers.image_mujoco_env import ImageMujocoEnv 5 | 6 | 7 | class ImageMujocoWithObsEnv(ImageMujocoEnv): 8 | def __init__(self, env, **kwargs): 9 | super().__init__(env, **kwargs) 10 | self.observation_space = Box( 11 | low=0.0, 12 | high=1.0, 13 | shape=(self.image_length * self.history_length 14 | + self.wrapped_env.obs_dim,)) 15 | 16 | def _get_obs(self, history_flat, true_state): 17 | return np.concatenate([history_flat, true_state]) 18 | -------------------------------------------------------------------------------- /rlkit/util/ray_util.py: -------------------------------------------------------------------------------- 1 | from ray import serialization, utils 2 | 3 | 4 | def set_serialization_mode_to_pickle(cls): 5 | """ 6 | Whenever this class is serialized by ray, it will default to using pickle 7 | serialization (__setstate__ and __getstate__) 8 | 9 | WARNING: This will only work if the driver is serializing and workers 10 | are de-serializing. 11 | 12 | :param cls: class instance or the Class itself 13 | """ 14 | if cls not in serialization.type_to_class_id: 15 | serialization.add_class_to_whitelist( 16 | cls, 17 | utils.random_string(), 18 | pickle=True, 19 | ) 20 | -------------------------------------------------------------------------------- /rlkit/exploration_strategies/noop.py: -------------------------------------------------------------------------------- 1 | from rlkit.exploration_strategies.base import RawExplorationStrategy 2 | import numpy as np 3 | 4 | 5 | class NoopStrategy(RawExplorationStrategy): 6 | """ 7 | Exploration strategy that does nothing other than clip the action. 8 | """ 9 | 10 | def __init__(self, action_space, **kwargs): 11 | self.action_space = action_space 12 | 13 | def get_action(self, t, observation, policy, **kwargs): 14 | return policy.get_action(observation) 15 | 16 | def get_action_from_raw_action(self, action, **kwargs): 17 | return np.clip(action, self.action_space.low, self.action_space.high) 18 | -------------------------------------------------------------------------------- /rlkit/envs/pearl_envs/hopper_rand_params_wrapper.py: -------------------------------------------------------------------------------- 1 | from rlkit.envs.pearl_envs.rand_param_envs.hopper_rand_params import HopperRandParamsEnv 2 | 3 | 4 | class HopperRandParamsWrappedEnv(HopperRandParamsEnv): 5 | def __init__(self, n_tasks=2, randomize_tasks=True): 6 | super(HopperRandParamsWrappedEnv, self).__init__() 7 | self.tasks = self.sample_tasks(n_tasks) 8 | self.reset_task(0) 9 | 10 | def get_all_task_idx(self): 11 | return range(len(self.tasks)) 12 | 13 | def reset_task(self, idx): 14 | self._task = self.tasks[idx] 15 | self._goal = idx 16 | self.set_task(self._task) 17 | self.reset() 18 | -------------------------------------------------------------------------------- /rlkit/envs/pearl_envs/walker_rand_params_wrapper.py: -------------------------------------------------------------------------------- 1 | from rlkit.envs.pearl_envs.rand_param_envs.walker2d_rand_params import Walker2DRandParamsEnv 2 | 3 | 4 | class WalkerRandParamsWrappedEnv(Walker2DRandParamsEnv): 5 | def __init__(self, n_tasks=2, randomize_tasks=True): 6 | super(WalkerRandParamsWrappedEnv, self).__init__() 7 | self.tasks = self.sample_tasks(n_tasks) 8 | self.reset_task(0) 9 | 10 | def get_all_task_idx(self): 11 | return range(len(self.tasks)) 12 | 13 | def reset_task(self, idx): 14 | self._task = self.tasks[idx] 15 | self._goal = idx 16 | self.set_task(self._task) 17 | self.reset() 18 | -------------------------------------------------------------------------------- /rlkit/demos/her_td3bc.py: -------------------------------------------------------------------------------- 1 | from rlkit.data_management.obs_dict_replay_buffer import \ 2 | ObsDictRelabelingBuffer 3 | from rlkit.torch.her.her import HER 4 | from rlkit.demos.td3_bc import TD3BC 5 | 6 | 7 | class HerTD3BC(HER, TD3BC): 8 | def __init__( 9 | self, 10 | *args, 11 | td3_kwargs, 12 | her_kwargs, 13 | base_kwargs, 14 | **kwargs 15 | ): 16 | HER.__init__( 17 | self, 18 | **her_kwargs, 19 | ) 20 | TD3BC.__init__(self, *args, **kwargs, **td3_kwargs, **base_kwargs) 21 | assert isinstance( 22 | self.replay_buffer, ObsDictRelabelingBuffer 23 | ) 24 | -------------------------------------------------------------------------------- /rlkit/networks/utils.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | 7 | @torch.no_grad() 8 | def variance_scaling_init_(tensor, scale=1, mode="fan_avg", distribution="uniform"): 9 | fan_in, fan_out = nn.init._calculate_fan_in_and_fan_out(tensor) 10 | 11 | if mode == "fan_in": 12 | scale /= fan_in 13 | 14 | elif mode == "fan_out": 15 | scale /= fan_out 16 | 17 | else: 18 | scale /= (fan_in + fan_out) / 2 19 | 20 | if distribution == "normal": 21 | std = math.sqrt(scale) 22 | 23 | return tensor.normal_(0, std) 24 | 25 | else: 26 | bound = math.sqrt(3 * scale) 27 | 28 | return tensor.uniform_(-bound, bound) 29 | -------------------------------------------------------------------------------- /rlkit/envs/wrappers/reward_wrapper_env.py: -------------------------------------------------------------------------------- 1 | from rlkit.envs.proxy_env import ProxyEnv 2 | 3 | 4 | class RewardWrapperEnv(ProxyEnv): 5 | """Substitute a different reward function""" 6 | 7 | def __init__( 8 | self, 9 | env, 10 | compute_reward_fn, 11 | ): 12 | ProxyEnv.__init__(self, env) 13 | self.spec = env.spec # hack for hand envs 14 | self.compute_reward_fn = compute_reward_fn 15 | 16 | def step(self, action): 17 | next_obs, reward, done, info = self._wrapped_env.step(action) 18 | info["env_reward"] = reward 19 | reward = self.compute_reward_fn(next_obs, reward, done, info) 20 | return next_obs, reward, done, info 21 | -------------------------------------------------------------------------------- /rlkit/demos/play_demo.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import sys 4 | # print(sys.path) 5 | sys.path.remove("/opt/ros/kinetic/lib/python2.7/dist-packages") 6 | 7 | import cv2 8 | import sys 9 | import pickle 10 | 11 | def play_demos(path): 12 | data = pickle.load(open(path, "rb")) 13 | # data = np.load(path, allow_pickle=True) 14 | 15 | for traj in data: 16 | obs = traj["observations"] 17 | 18 | for o in obs: 19 | img = o["image_observation"].reshape(3, 500, 300)[:, 60:, :240].transpose() 20 | img = img[:, :, ::-1] 21 | cv2.imshow('window', img) 22 | cv2.waitKey(100) 23 | 24 | if __name__ == '__main__': 25 | demo_path = sys.argv[1] 26 | play_demos(demo_path) 27 | -------------------------------------------------------------------------------- /rlkit/demos/spacemouse/input_client.py: -------------------------------------------------------------------------------- 1 | """ 2 | Should be run on a machine connected to a spacemouse 3 | """ 4 | 5 | from robosuite.devices import SpaceMouse 6 | import time 7 | import Pyro4 8 | from rlkit.launchers import config 9 | # HOSTNAME = config.SPACEMOUSE_HOSTNAME 10 | HOSTNAME = "192.168.1.4" 11 | 12 | Pyro4.config.SERIALIZERS_ACCEPTED = set( 13 | ['pickle', 'json', 'marshal', 'serpent']) 14 | Pyro4.config.SERIALIZER = 'pickle' 15 | 16 | nameserver = Pyro4.locateNS(host=HOSTNAME) 17 | uri = nameserver.lookup("example.greeting") 18 | device_state = Pyro4.Proxy(uri) 19 | device = SpaceMouse() 20 | while True: 21 | state = device.get_controller_state() 22 | print(state) 23 | time.sleep(0.1) 24 | device_state.set_state(state) 25 | -------------------------------------------------------------------------------- /rlkit/envs/pearl_envs/ant_normal.py: -------------------------------------------------------------------------------- 1 | from gym.envs.mujoco import AntEnv 2 | 3 | 4 | class AntNormal(AntEnv): 5 | def __init__( 6 | self, 7 | *args, 8 | n_tasks=2, # number of distinct tasks in this domain, shoudl equal sum of train and eval tasks 9 | randomize_tasks=True, # shuffle the tasks after creating them 10 | **kwargs 11 | ): 12 | self.tasks = [0 for _ in range(n_tasks)] 13 | self._goal = 0 14 | super().__init__(*args, **kwargs) 15 | 16 | def get_all_task_idx(self): 17 | return self.tasks 18 | 19 | def reset_task(self, idx): 20 | # not tasks. just give the same reward every time step. 21 | pass 22 | 23 | def sample_tasks(self, num_tasks): 24 | return [0 for _ in range(num_tasks)] 25 | -------------------------------------------------------------------------------- /rlkit/data_management/wrappers/concat_to_obs_wrapper.py: -------------------------------------------------------------------------------- 1 | from rlkit.data_management.wrappers.proxy_buffer import ProxyBuffer 2 | import numpy as np 3 | 4 | class ConcatToObsWrapper(ProxyBuffer): 5 | def __init__(self, replay_buffer, keys_to_concat): 6 | """keys_to_concat: list of strings""" 7 | super().__init__(replay_buffer) 8 | self.keys_to_concat = keys_to_concat 9 | 10 | def random_batch(self, batch_size): 11 | batch = self._wrapped_buffer.random_batch(batch_size) 12 | obs = batch['observations'] 13 | next_obs = batch['next_observations'] 14 | to_concat = [batch[key] for key in self.keys_to_concat] 15 | batch['observations'] = np.concatenate([obs] + to_concat, axis=1) 16 | batch['next_observations'] = np.concatenate([next_obs] + to_concat, axis=1) 17 | return batch 18 | -------------------------------------------------------------------------------- /rlkit/exploration_strategies/epsilon_greedy.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | from rlkit.exploration_strategies.base import RawExplorationStrategy 4 | 5 | 6 | class EpsilonGreedy(RawExplorationStrategy): 7 | """ 8 | Take a random discrete action with some probability. 9 | """ 10 | def __init__(self, action_space, prob_random_action=0.1): 11 | self.prob_random_action = prob_random_action 12 | self.action_space = action_space 13 | 14 | def get_action(self, t, policy, *args, **kwargs): 15 | action, agent_info = policy.get_action(*args, **kwargs) 16 | return self.get_action_from_raw_action(action), agent_info 17 | 18 | def get_action_from_raw_action(self, action, **kwargs): 19 | if random.random() <= self.prob_random_action: 20 | return self.action_space.sample() 21 | return action 22 | -------------------------------------------------------------------------------- /rlkit/demos/source/path_loader.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | 3 | import numpy as np 4 | import torch 5 | import torch.optim as optim 6 | from torch import nn as nn 7 | import torch.nn.functional as F 8 | 9 | import rlkit.torch.pytorch_util as ptu 10 | from rlkit.core.eval_util import create_stats_ordered_dict 11 | from rlkit.torch.torch_rl_algorithm import TorchTrainer 12 | 13 | from rlkit.util.io import load_local_or_remote_file 14 | 15 | import random 16 | from rlkit.torch.core import np_to_pytorch_batch 17 | from rlkit.data_management.path_builder import PathBuilder 18 | 19 | # import matplotlib 20 | # matplotlib.use('TkAgg') 21 | # import matplotlib.pyplot as plt 22 | 23 | from rlkit.core import logger 24 | 25 | import glob 26 | 27 | class PathLoader: 28 | """ 29 | Loads demonstrations and/or off-policy data into a Trainer 30 | """ 31 | 32 | def load_demos(self, ): 33 | pass 34 | -------------------------------------------------------------------------------- /rlkit/policies/action_repeat.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from rlkit.policies.base import Policy 4 | 5 | 6 | class ActionRepeatPolicy(Policy): 7 | """ 8 | General policy interface. 9 | """ 10 | def __init__(self, policy, repeat_prob=.5): 11 | self._policy = policy 12 | self._repeat_prob = repeat_prob 13 | self._last_action = None 14 | 15 | def get_action(self, observation): 16 | """ 17 | 18 | :param observation: 19 | :return: action, debug_dictionary 20 | """ 21 | action = self._policy.get_action(observation) 22 | if ( 23 | self._last_action is not None 24 | and np.random.uniform() <= self._repeat_prob 25 | ): 26 | action = self._last_action 27 | self._last_action = action 28 | return self._last_action 29 | 30 | def reset(self): 31 | self._last_action = None -------------------------------------------------------------------------------- /rlkit/envs/assets/oned_point.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | -------------------------------------------------------------------------------- /rlkit/envs/mujoco/point_env.py: -------------------------------------------------------------------------------- 1 | from gym import Env 2 | from gym.spaces import Box 3 | 4 | import numpy as np 5 | 6 | 7 | class PointEnv(Env): 8 | @property 9 | def observation_space(self): 10 | return Box(low=-np.inf, high=np.inf, shape=(2,)) 11 | 12 | @property 13 | def action_space(self): 14 | return Box(low=-0.1, high=0.1, shape=(2,)) 15 | 16 | def reset(self): 17 | self._state = np.random.uniform(-1, 1, size=(2,)) 18 | observation = np.copy(self._state) 19 | return observation 20 | 21 | def step(self, action): 22 | self._state = self._state + action 23 | x, y = self._state 24 | reward = - (x ** 2 + y ** 2) ** 0.5 25 | done = abs(x) < 0.01 and abs(y) < 0.01 26 | next_observation = np.copy(self._state) 27 | return next_observation, reward, done 28 | 29 | def render(self, **kwargs): 30 | print('current state:', self._state) 31 | -------------------------------------------------------------------------------- /rlkit/samplers/data_collector/base.py: -------------------------------------------------------------------------------- 1 | import abc 2 | 3 | 4 | class DataCollector(object, metaclass=abc.ABCMeta): 5 | def end_epoch(self, epoch): 6 | pass 7 | 8 | def get_diagnostics(self): 9 | return {} 10 | 11 | def get_snapshot(self): 12 | return {} 13 | 14 | @abc.abstractmethod 15 | def get_epoch_paths(self): 16 | pass 17 | 18 | 19 | class PathCollector(DataCollector, metaclass=abc.ABCMeta): 20 | @abc.abstractmethod 21 | def collect_new_paths( 22 | self, 23 | max_path_length, 24 | num_steps, 25 | discard_incomplete_paths, 26 | ): 27 | pass 28 | 29 | 30 | class StepCollector(DataCollector, metaclass=abc.ABCMeta): 31 | @abc.abstractmethod 32 | def collect_new_steps( 33 | self, 34 | max_path_length, 35 | num_steps, 36 | discard_incomplete_paths, 37 | ): 38 | pass 39 | -------------------------------------------------------------------------------- /rlkit/envs/wrappers/discretize_env.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | 3 | import numpy as np 4 | from gym import Env 5 | from gym.spaces import Discrete 6 | 7 | from rlkit.envs.proxy_env import ProxyEnv 8 | 9 | 10 | class DiscretizeEnv(ProxyEnv, Env): 11 | def __init__(self, wrapped_env, num_bins): 12 | super().__init__(wrapped_env) 13 | low = self.wrapped_env.action_space.low 14 | high = self.wrapped_env.action_space.high 15 | action_ranges = [ 16 | np.linspace(low[i], high[i], num_bins) 17 | for i in range(len(low)) 18 | ] 19 | self.idx_to_continuous_action = [ 20 | np.array(x) for x in itertools.product(*action_ranges) 21 | ] 22 | self.action_space = Discrete(len(self.idx_to_continuous_action)) 23 | 24 | def step(self, action): 25 | continuous_action = self.idx_to_continuous_action[action] 26 | return super().step(continuous_action) 27 | 28 | 29 | -------------------------------------------------------------------------------- /rlkit/torch/transforms/_pil_constants.py: -------------------------------------------------------------------------------- 1 | import PIL 2 | from PIL import Image 3 | 4 | # See https://pillow.readthedocs.io/en/stable/releasenotes/9.1.0.html#deprecations 5 | # TODO: Remove this file once PIL minimal version is >= 9.1 6 | 7 | if tuple(int(part) for part in PIL.__version__.split(".")) >= (9, 1): 8 | BICUBIC = Image.Resampling.BICUBIC 9 | BILINEAR = Image.Resampling.BILINEAR 10 | LINEAR = Image.Resampling.BILINEAR 11 | NEAREST = Image.Resampling.NEAREST 12 | 13 | AFFINE = Image.Transform.AFFINE 14 | FLIP_LEFT_RIGHT = Image.Transpose.FLIP_LEFT_RIGHT 15 | FLIP_TOP_BOTTOM = Image.Transpose.FLIP_TOP_BOTTOM 16 | PERSPECTIVE = Image.Transform.PERSPECTIVE 17 | else: 18 | BICUBIC = Image.BICUBIC 19 | BILINEAR = Image.BILINEAR 20 | NEAREST = Image.NEAREST 21 | LINEAR = Image.LINEAR 22 | 23 | AFFINE = Image.AFFINE 24 | FLIP_LEFT_RIGHT = Image.FLIP_LEFT_RIGHT 25 | FLIP_TOP_BOTTOM = Image.FLIP_TOP_BOTTOM 26 | PERSPECTIVE = Image.PERSPECTIVE -------------------------------------------------------------------------------- /rlkit/envs/wrappers/__init__.py: -------------------------------------------------------------------------------- 1 | """Imports wrappers 2 | 3 | TODO: DEPRECATE. This pattern imports unecessary modules, creating unnecessary 4 | software dependencies. Just import the specific module instead. 5 | """ 6 | 7 | from rlkit.envs.wrappers.discretize_env import DiscretizeEnv 8 | from rlkit.envs.wrappers.history_env import HistoryEnv 9 | # from rlkit.envs.wrappers.image_mujoco_env import ImageMujocoEnv 10 | # from rlkit.envs.wrappers.image_mujoco_env_with_obs import ImageMujocoWithObsEnv 11 | from rlkit.envs.wrappers.normalized_box_env import NormalizedBoxEnv 12 | from rlkit.envs.proxy_env import ProxyEnv 13 | from rlkit.envs.wrappers.reward_wrapper_env import RewardWrapperEnv 14 | from rlkit.envs.wrappers.stack_observation_env import StackObservationEnv 15 | 16 | 17 | __all__ = [ 18 | 'DiscretizeEnv', 19 | 'HistoryEnv', 20 | # 'ImageMujocoEnv', 21 | # 'ImageMujocoWithObsEnv', 22 | 'NormalizedBoxEnv', 23 | 'ProxyEnv', 24 | 'RewardWrapperEnv', 25 | 'StackObservationEnv', 26 | ] 27 | -------------------------------------------------------------------------------- /rlkit/torch/her/her.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from rlkit.torch.torch_rl_algorithm import TorchTrainer 4 | 5 | 6 | class HERTrainer(TorchTrainer): 7 | def __init__(self, base_trainer: TorchTrainer): 8 | super().__init__() 9 | self._base_trainer = base_trainer 10 | 11 | def train_from_torch(self, batch): 12 | obs = batch['observations'] 13 | next_obs = batch['next_observations'] 14 | goals = batch['resampled_goals'] 15 | batch['observations'] = torch.cat((obs, goals), dim=1) 16 | batch['next_observations'] = torch.cat((next_obs, goals), dim=1) 17 | self._base_trainer.train_from_torch(batch) 18 | 19 | def get_diagnostics(self): 20 | return self._base_trainer.get_diagnostics() 21 | 22 | def end_epoch(self, epoch): 23 | self._base_trainer.end_epoch(epoch) 24 | 25 | @property 26 | def networks(self): 27 | return self._base_trainer.networks 28 | 29 | def get_snapshot(self): 30 | return self._base_trainer.get_snapshot() 31 | -------------------------------------------------------------------------------- /rlkit/envs/gridcraft/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | def flat_to_one_hot(val, ndim): 4 | """ 5 | 6 | >>> flat_to_one_hot(2, ndim=4) 7 | array([ 0., 0., 1., 0.]) 8 | >>> flat_to_one_hot(4, ndim=5) 9 | array([ 0., 0., 0., 0., 1.]) 10 | >>> flat_to_one_hot(np.array([2, 4, 3]), ndim=5) 11 | array([[ 0., 0., 1., 0., 0.], 12 | [ 0., 0., 0., 0., 1.], 13 | [ 0., 0., 0., 1., 0.]]) 14 | """ 15 | shape =np.array(val).shape 16 | v = np.zeros(shape + (ndim,)) 17 | if len(shape) == 1: 18 | v[np.arange(shape[0]), val] = 1.0 19 | else: 20 | v[val] = 1.0 21 | return v 22 | 23 | def one_hot_to_flat(val): 24 | """ 25 | >>> one_hot_to_flat(np.array([0,0,0,0,1])) 26 | 4 27 | >>> one_hot_to_flat(np.array([0,0,1,0])) 28 | 2 29 | >>> one_hot_to_flat(np.array([[0,0,1,0], [1,0,0,0], [0,1,0,0]])) 30 | array([2, 0, 1]) 31 | """ 32 | idxs = np.array(np.where(val == 1.0))[-1] 33 | if len(val.shape) == 1: 34 | return int(idxs) 35 | return idxs -------------------------------------------------------------------------------- /rlkit/envs/mujoco/twod_point_random_init.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from rlkit.envs.mujoco.mujoco_env import MujocoEnv 4 | 5 | TARGET = np.array([0.2, 0]) 6 | 7 | 8 | class TwoDPointRandomInit(MujocoEnv): 9 | def __init__(self): 10 | self.init_serialization(locals()) 11 | super().__init__('twod_point.xml') 12 | 13 | def _step(self, a): 14 | self.do_simulation(a, self.frame_skip) 15 | ob = self._get_obs() 16 | pos = ob[0:2] 17 | dist = np.linalg.norm(pos - TARGET) 18 | reward = - (dist + 1e-2*np.linalg.norm(a)) 19 | done = False 20 | return ob, reward, done, {} 21 | 22 | def reset_model(self): 23 | qpos = self.np_random.uniform(size=self.model.nq, low=-0.25, high=0.25) 24 | qvel = self.np_random.uniform(size=self.model.nv, low=-0.25, high=0.25) 25 | self.set_state(qpos, qvel) 26 | return self._get_obs() 27 | 28 | def _get_obs(self): 29 | return np.concatenate([self.model.data.qpos]).ravel() 30 | 31 | def viewer_setup(self): 32 | pass 33 | -------------------------------------------------------------------------------- /rlkit/envs/env_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from gym.spaces import Box, Discrete, Tuple 4 | 5 | ENV_ASSET_DIR = os.path.join(os.path.dirname(__file__), 'assets') 6 | 7 | 8 | def get_asset_full_path(file_name): 9 | return os.path.join(ENV_ASSET_DIR, file_name) 10 | 11 | 12 | def get_dim(space): 13 | if isinstance(space, Box): 14 | return space.low.size 15 | elif isinstance(space, Discrete): 16 | return space.n 17 | elif isinstance(space, Tuple): 18 | return sum(get_dim(subspace) for subspace in space.spaces) 19 | elif hasattr(space, 'flat_dim'): 20 | return space.flat_dim 21 | else: 22 | raise TypeError("Unknown space: {}".format(space)) 23 | 24 | 25 | def gym_env(name): 26 | from rllab.envs.gym_env import GymEnv 27 | return GymEnv(name, 28 | record_video=False, 29 | log_dir='/tmp/gym-test', # Ignore gym log. 30 | record_log=False) 31 | 32 | def mode(env, mode_type): 33 | try: 34 | getattr(env, mode_type)() 35 | except AttributeError: 36 | pass 37 | -------------------------------------------------------------------------------- /rlkit/envs/wrappers/modular_env.py: -------------------------------------------------------------------------------- 1 | import abc 2 | from typing import Callable, Any, Dict, List 3 | 4 | import gym.spaces 5 | 6 | Path = Dict 7 | Diagnostics = Dict 8 | Context = Any 9 | ContextualDiagnosticsFn = Callable[ 10 | [List[Path], List[Context]], 11 | Diagnostics, 12 | ] 13 | 14 | 15 | class RewardFn(object, metaclass=abc.ABCMeta): 16 | """Some reward function""" 17 | @abc.abstractmethod 18 | def __call__( 19 | self, 20 | action, 21 | next_state: dict, 22 | ): 23 | pass 24 | 25 | 26 | class ModularEnv(gym.Wrapper): 27 | """An env where you can separately specify the reward and transition.""" 28 | def __init__( 29 | self, 30 | env: gym.Env, 31 | reward_fn: RewardFn, 32 | ): 33 | super().__init__(env) 34 | self.reward_fn = reward_fn 35 | 36 | def step(self, action): 37 | obs, reward, done, info = super().step(action) 38 | if self.reward_fn: 39 | reward = self.reward_fn(action, obs) 40 | return obs, reward, done, info 41 | -------------------------------------------------------------------------------- /rlkit/envs/images/text_renderer.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | from PIL import ImageDraw 3 | from PIL import ImageFont 4 | import numpy as np 5 | 6 | 7 | from rlkit.envs.images import Renderer 8 | 9 | 10 | class TextRenderer(Renderer): 11 | """I gave up! See plot_renderer.TextRenderer""" 12 | 13 | def __init__(self, text, *args, 14 | text_color='white', 15 | background_color='black', 16 | **kwargs): 17 | super().__init__(*args, 18 | create_image_format='HWC', 19 | **kwargs) 20 | 21 | font = ImageFont.truetype( 22 | '/usr/share/fonts/truetype/ubuntu-font-family/Ubuntu-R.ttf', 100) 23 | _, h, w = self.image_chw 24 | self._img = Image.new('RGB', (w, h), background_color) 25 | self._draw_interface = ImageDraw.Draw(self._img) 26 | self._draw_interface.text((0, 0), text, fill=text_color, font=font) 27 | self._np_img = np.array(self._img).copy() 28 | 29 | def _create_image(self, *args, **kwargs): 30 | return self._np_img 31 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Chongyi Zheng 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /rlkit/torch/networks/ae_tanh_policy.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from rlkit.torch import pytorch_util as ptu 4 | from rlkit.torch.networks import MlpPolicy 5 | 6 | 7 | class AETanhPolicy(MlpPolicy): 8 | """ 9 | A helper class since most policies have a tanh output activation. 10 | """ 11 | 12 | def __init__( 13 | self, 14 | ae, 15 | env, 16 | history_length, 17 | *args, 18 | **kwargs 19 | ): 20 | super().__init__(*args, **kwargs, output_activation=torch.tanh) 21 | self.ae = ae 22 | self.history_length = history_length 23 | self.env = env 24 | 25 | def get_action(self, obs_np): 26 | obs = obs_np 27 | obs = ptu.from_numpy(obs) 28 | image_obs, fc_obs = self.env.split_obs(obs) 29 | latent_obs = self.ae.history_encoder(image_obs, self.history_length) 30 | if fc_obs is not None: 31 | latent_obs = torch.cat((latent_obs, fc_obs), dim=1) 32 | obs_np = ptu.get_numpy(latent_obs)[0] 33 | actions = self.get_actions(obs_np[None]) 34 | return actions[0, :], {} 35 | 36 | -------------------------------------------------------------------------------- /rlkit/envs/pearl_envs/half_cheetah.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from gym.envs.mujoco import HalfCheetahEnv as HalfCheetahEnv_ 3 | 4 | class HalfCheetahEnv(HalfCheetahEnv_): 5 | def _get_obs(self): 6 | return np.concatenate([ 7 | self.sim.data.qpos.flat[1:], 8 | self.sim.data.qvel.flat, 9 | self.get_body_com("torso").flat, 10 | ]).astype(np.float32).flatten() 11 | 12 | def viewer_setup(self): 13 | camera_id = self.model.camera_name2id('track') 14 | self.viewer.cam.type = 2 15 | self.viewer.cam.fixedcamid = camera_id 16 | self.viewer.cam.distance = self.model.stat.extent * 0.35 17 | # Hide the overlay 18 | self.viewer._hide_overlay = True 19 | 20 | def render(self, mode='human', width=500, height=500, **kwargs): 21 | if mode == 'rgb_array': 22 | self._get_viewer(mode).render(width=width, height=height) 23 | data = self._get_viewer(mode).read_pixels(width, height, depth=False)[::-1, :, :] 24 | return data 25 | elif mode == 'human': 26 | self._get_viewer(mode).render() 27 | -------------------------------------------------------------------------------- /rlkit/envs/mujoco/reacher_env.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | 3 | from gym.envs.mujoco.reacher import ReacherEnv as GymReacherEnv 4 | 5 | from rlkit.core import logger as default_logger 6 | from rlkit.core.eval_util import get_stat_in_paths, create_stats_ordered_dict 7 | 8 | 9 | class ReacherEnv(GymReacherEnv): 10 | def log_diagnostics(self, paths, logger=default_logger): 11 | statistics = OrderedDict() 12 | for name_in_env_infos, name_to_log in [ 13 | ('reward_dist', 'Distance Reward'), 14 | ('reward_ctrl', 'Action Reward'), 15 | ]: 16 | stat = get_stat_in_paths(paths, 'env_infos', name_in_env_infos) 17 | statistics.update(create_stats_ordered_dict( 18 | name_to_log, 19 | stat, 20 | )) 21 | distances = get_stat_in_paths(paths, 'env_infos', 'reward_dist') 22 | statistics.update(create_stats_ordered_dict( 23 | "Final Distance Reward", 24 | [ds[-1] for ds in distances], 25 | )) 26 | for key, value in statistics.items(): 27 | logger.record_tabular(key, value) 28 | 29 | 30 | -------------------------------------------------------------------------------- /rlkit/core/distribution.py: -------------------------------------------------------------------------------- 1 | import abc 2 | from typing import Dict 3 | from gym import Space 4 | 5 | 6 | class DictDistribution(object, metaclass=abc.ABCMeta): 7 | 8 | @abc.abstractmethod 9 | def sample(self, batch_size: int): 10 | pass 11 | 12 | @property 13 | @abc.abstractmethod 14 | def spaces(self) -> Dict[str, Space]: 15 | pass 16 | 17 | def __call__(self, *args, **kwargs): 18 | """For backward compatibility with DictDistributionGenerator""" 19 | return self 20 | 21 | 22 | class DictDistributionGenerator(DictDistribution, metaclass=abc.ABCMeta): 23 | 24 | def __call__(self, *args, **kwargs) -> DictDistribution: 25 | raise NotImplementedError 26 | 27 | 28 | class DictDistributionClosure(DictDistributionGenerator): 29 | """Fills in args to a DictDistribution""" 30 | 31 | def __init__(self, clz, *args, **kwargs): 32 | self.clz = clz 33 | self.args = args 34 | self.kwargs = kwargs 35 | 36 | def __call__(self, **extra_kwargs) -> DictDistribution: 37 | self.kwargs.update(**extra_kwargs) 38 | return self.clz( 39 | *args, 40 | **kwargs, 41 | ) 42 | -------------------------------------------------------------------------------- /rlkit/envs/mujoco/twod_point.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from rlkit.envs.mujoco.mujoco_env import MujocoEnv 4 | 5 | TARGET = np.array([0.2, 0]) 6 | 7 | 8 | class TwoDPoint(MujocoEnv): 9 | def __init__(self): 10 | self.init_serialization(locals()) 11 | super().__init__('twod_point.xml') 12 | 13 | def _step(self, a): 14 | self.do_simulation(a, self.frame_skip) 15 | ob = self._get_obs() 16 | pos = ob[0:2] 17 | dist = np.linalg.norm(pos - TARGET) 18 | reward = - (dist + 1e-2 * np.linalg.norm(a)) 19 | done = False 20 | return ob, reward, done, {} 21 | 22 | def reset_model(self): 23 | qpos = self.init_qpos + np.random.uniform(size=self.model.nq, low=-0.01, 24 | high=0.01) 25 | qvel = self.init_qvel + np.random.uniform(size=self.model.nv, low=-0.01, 26 | high=0.01) 27 | self.set_state(qpos, qvel) 28 | return self._get_obs() 29 | 30 | def _get_obs(self): 31 | return np.concatenate([self.model.data.qpos]).ravel() 32 | 33 | def viewer_setup(self): 34 | pass 35 | 36 | 37 | -------------------------------------------------------------------------------- /rlkit/exploration_strategies/gaussian_strategy.py: -------------------------------------------------------------------------------- 1 | from rlkit.exploration_strategies.base import RawExplorationStrategy 2 | import numpy as np 3 | 4 | 5 | class GaussianStrategy(RawExplorationStrategy): 6 | """ 7 | This strategy adds Gaussian noise to the action taken by the deterministic policy. 8 | 9 | Based on the rllab implementation. 10 | """ 11 | def __init__(self, action_space, max_sigma=1.0, min_sigma=None, 12 | decay_period=1000000): 13 | assert len(action_space.shape) == 1 14 | self._max_sigma = max_sigma 15 | if min_sigma is None: 16 | min_sigma = max_sigma 17 | self._min_sigma = min_sigma 18 | self._decay_period = decay_period 19 | self._action_space = action_space 20 | 21 | def get_action_from_raw_action(self, action, t=None, **kwargs): 22 | sigma = ( 23 | self._max_sigma - (self._max_sigma - self._min_sigma) * 24 | min(1.0, t * 1.0 / self._decay_period) 25 | ) 26 | return np.clip( 27 | action + np.random.normal(size=len(action)) * sigma, 28 | self._action_space.low, 29 | self._action_space.high, 30 | ) 31 | -------------------------------------------------------------------------------- /rlkit/envs/time_limited_env.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from gym.spaces import Box 3 | 4 | from rlkit.envs.proxy_env import ProxyEnv 5 | 6 | 7 | class TimeLimitedEnv(ProxyEnv): 8 | def __init__(self, wrapped_env, horizon): 9 | self._wrapped_env = wrapped_env 10 | self._horizon = horizon 11 | 12 | self._observation_space = Box( 13 | np.hstack((self._wrapped_env.observation_space.low, [0])), 14 | np.hstack((self._wrapped_env.observation_space.high, [1])), 15 | ) 16 | self._t = 0 17 | 18 | @property 19 | def observation_space(self): 20 | return self._observation_space 21 | 22 | @property 23 | def horizon(self): 24 | return self._horizon 25 | 26 | def step(self, action): 27 | obs, reward, done, info = self._wrapped_env.step(action) 28 | self._t += 1 29 | done = done or self._t == self.horizon 30 | new_obs = np.hstack((obs, float(self._t) / self.horizon)) 31 | return new_obs, reward, done, info 32 | 33 | def reset(self, **kwargs): 34 | obs = self._wrapped_env.reset(**kwargs) 35 | self._t = 0 36 | new_obs = np.hstack((obs, self._t)) 37 | return new_obs 38 | -------------------------------------------------------------------------------- /rlkit/envs/gridcraft/custom_test.py: -------------------------------------------------------------------------------- 1 | from rlkit.envs.gridcraft import REW_ARENA_64 2 | from rlkit.envs.gridcraft.grid_env import GridEnv 3 | from rlkit.envs.gridcraft.grid_spec import * 4 | from rlkit.envs.gridcraft.mazes import MAZE_ANY_START1 5 | import gym.spaces.prng as prng 6 | import numpy as np 7 | 8 | if __name__ == "__main__": 9 | prng.seed(2) 10 | 11 | maze_spec = \ 12 | spec_from_string("SOOOO#R#OO\\"+ 13 | "OSOOO#2##O\\" + 14 | "###OO#3O#O\\" + 15 | "OOOOO#OO#O\\" + 16 | "OOOOOOOOOO\\" 17 | ) 18 | 19 | #maze_spec = spec_from_sparse_locations(50, 50, {START: [(25,25)], REWARD: [(45,45)]}) 20 | # maze_spec = REW_ARENA_64 21 | maze_spec = MAZE_ANY_START1 22 | 23 | env = GridEnv(maze_spec, one_hot=True, add_eyes=True, coordinate_wise=True) 24 | 25 | s = env.reset() 26 | #env.render() 27 | 28 | obses = [] 29 | for t in range(10): 30 | a = env.action_space.sample() 31 | obs, r, done, infos = env.step(a, verbose=True) 32 | obses.append(obs) 33 | obses = np.array(obses) 34 | 35 | paths = [{'observations': obses}] 36 | env.plot_trajs(paths) -------------------------------------------------------------------------------- /rlkit/envs/mujoco/twod_maze.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from rlkit.envs.mujoco.mujoco_env import MujocoEnv 4 | 5 | INIT_POS = np.array([0.2,0.15]) 6 | TARGET = np.array([0.2, -0.15]) + INIT_POS 7 | DIST_THRESH = 0.05 8 | 9 | 10 | class TwoDMaze(MujocoEnv): 11 | def __init__(self): 12 | self.init_serialization(locals()) 13 | super().__init__('twod_maze.xml') 14 | 15 | def _step(self, a): 16 | self.do_simulation(a, self.frame_skip) 17 | ob = self._get_obs() 18 | pos = ob[0:2] 19 | dist = np.linalg.norm(pos - TARGET) 20 | dist_cost = 1 if dist>DIST_THRESH else dist/DIST_THRESH 21 | reward = - (dist_cost)# + 1e-2*np.linalg.norm(a)) 22 | #print(reward, dist) 23 | done = False 24 | return ob, reward, done, {} 25 | 26 | def reset_model(self): 27 | qpos = self.init_qpos + self.np_random.uniform(size=self.model.nq, low=-0.01, high=0.01) 28 | qvel = self.init_qvel + self.np_random.uniform(size=self.model.nv, low=-0.01, high=0.01) 29 | self.set_state(qpos, qvel) 30 | return self._get_obs() 31 | 32 | def _get_obs(self): 33 | return np.concatenate([self.model.data.qpos]).ravel() 34 | 35 | def viewer_setup(self): 36 | pass 37 | -------------------------------------------------------------------------------- /rlkit/envs/supervised_learning_env.py: -------------------------------------------------------------------------------- 1 | import abc 2 | 3 | 4 | 5 | class RecurrentSupervisedLearningEnv(metaclass=abc.ABCMeta): 6 | """ 7 | An environment that's really just a supervised learning task. 8 | """ 9 | 10 | @abc.abstractmethod 11 | def get_batch(self, batch_size): 12 | """ 13 | 14 | :param batch_size: Size of the batch size 15 | :return: tuple (X, Y) where 16 | X is a numpy array of size ( 17 | batch_size, self.sequence_length, self.feature_dim 18 | ) 19 | Y is a numpy array of size ( 20 | batch_size, self.sequence_length, self.target_dim 21 | ) 22 | """ 23 | pass 24 | 25 | @property 26 | @abc.abstractmethod 27 | def feature_dim(self): 28 | """ 29 | :return: Integer. Dimension of the features. 30 | """ 31 | pass 32 | 33 | @property 34 | @abc.abstractmethod 35 | def target_dim(self): 36 | """ 37 | :return: Integer. Dimension of the target. 38 | """ 39 | pass 40 | 41 | @property 42 | @abc.abstractmethod 43 | def sequence_length(self): 44 | """ 45 | :return: Integer. Dimension of the target. 46 | """ 47 | pass 48 | -------------------------------------------------------------------------------- /rlkit/core/util.py: -------------------------------------------------------------------------------- 1 | import os 2 | LOG_DIR = os.getcwd() 3 | 4 | 5 | class Wrapper(object): 6 | """ 7 | Mixin for deferring attributes to a wrapped, inner object. 8 | """ 9 | 10 | def __init__(self, inner): 11 | self.inner = inner 12 | 13 | def __getattr__(self, attr): 14 | """ 15 | Dispatch attributes by their status as magic, members, or missing. 16 | - magic is handled by the standard getattr 17 | - existing attributes are returned 18 | - missing attributes are deferred to the inner object. 19 | """ 20 | # don't make magic any more magical 21 | is_magic = attr.startswith('__') and attr.endswith('__') 22 | if is_magic: 23 | return super().__getattr__(attr) 24 | try: 25 | # try to return the attribute... 26 | return self.__dict__[attr] 27 | except: 28 | # ...and defer to the inner dataset if it's not here 29 | return getattr(self.inner, attr) 30 | 31 | 32 | class SimpleWrapper(object): 33 | """ 34 | Mixin for deferring attributes to a wrapped, inner object. 35 | """ 36 | 37 | def __init__(self, inner): 38 | self._inner = inner 39 | 40 | def __getattr__(self, attr): 41 | if attr == '_inner': 42 | raise AttributeError() 43 | return getattr(self._inner, attr) 44 | -------------------------------------------------------------------------------- /rlkit/envs/mujoco/oned_point.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from gym import utils 3 | from gym.envs.mujoco import mujoco_env 4 | 5 | from rlkit.envs.env_utils import get_asset_full_path 6 | 7 | 8 | class OneDPoint(mujoco_env.MujocoEnv, utils.EzPickle): 9 | def __init__(self): 10 | utils.EzPickle.__init__(self) 11 | mujoco_env.MujocoEnv.__init__(self, get_asset_full_path('oned_point.xml'), 2) 12 | 13 | def _step(self, a): 14 | self.do_simulation(a, self.frame_skip) 15 | ob = self._get_obs() 16 | pos = ob[0] 17 | reward = 1 if pos > 0.5 else 0 #pos #(pos)**2 / 10. 18 | #reward = a[0] 19 | #notdone = np.isfinite(ob).all() and (np.abs(ob[1]) <= .2) 20 | #done = not notdone 21 | done = False 22 | return ob, reward, done, {} 23 | 24 | def reset_model(self): 25 | qpos = self.init_qpos + self.np_random.uniform(size=self.model.nq, low=-0.01, high=0.01) 26 | qvel = self.init_qvel + self.np_random.uniform(size=self.model.nv, low=-0.01, high=0.01) 27 | self.set_state(qpos, qvel) 28 | return self._get_obs() 29 | 30 | def _get_obs(self): 31 | return np.concatenate([self.model.data.qpos, self.model.data.qvel]).ravel() 32 | 33 | def viewer_setup(self): 34 | v = self.viewer 35 | v.cam.trackbodyid=0 36 | v.cam.distance = v.model.stat.extent 37 | -------------------------------------------------------------------------------- /rlkit/utils/timer.py: -------------------------------------------------------------------------------- 1 | """TensorFlow Task Generators API.""" 2 | 3 | from __future__ import absolute_import 4 | from __future__ import division 5 | from __future__ import print_function 6 | 7 | import time 8 | 9 | 10 | class Timer(object): 11 | 12 | def __init__(self, keys): 13 | assert isinstance(keys, list) 14 | self._keys = keys 15 | self._keys = keys 16 | 17 | self.reset() 18 | 19 | def reset(self): 20 | self._start_time = { 21 | key: None for key in self._keys 22 | } 23 | self._time_acc = { 24 | key: 0.0 for key in self._keys 25 | } 26 | self._count = { 27 | key: 0 for key in self._keys 28 | } 29 | 30 | @property 31 | def time_acc(self): 32 | return self._time_acc 33 | 34 | def tic(self, key): 35 | assert key in self._keys 36 | self._start_time[key] = time.time() 37 | 38 | def toc(self, key): 39 | assert self._start_time[key] is not None 40 | self._time_acc[key] += time.time() - self._start_time[key] 41 | self._count[key] += 1 42 | self._start_time[key] = None 43 | 44 | def accumulated_time(self, key): 45 | return self._time_acc[key] 46 | 47 | def average_time(self, key): 48 | assert self._count[key] > 0 49 | return self._time_acc[key] / self._count[key] 50 | -------------------------------------------------------------------------------- /rlkit/envs/wrappers/flat_to_dict.py: -------------------------------------------------------------------------------- 1 | import gym 2 | import numpy as np 3 | from gym.spaces import Dict 4 | 5 | from rlkit.core.util import SimpleWrapper 6 | from rlkit.policies.base import Policy 7 | 8 | 9 | class FlatToDictEnv(gym.Wrapper): 10 | """Wrap an environment that returns a flat obs to return a dict.""" 11 | def __init__(self, env, observation_key): 12 | super().__init__(env) 13 | self.observation_key = observation_key 14 | new_ob_space = { 15 | self.observation_key: self.observation_space 16 | } 17 | self.observation_space = Dict(new_ob_space) 18 | 19 | def step(self, action): 20 | obs, reward, done, info = super().step(action) 21 | return {self.observation_key: obs}, reward, done, info 22 | 23 | def reset(self): 24 | obs = super().reset() 25 | return {self.observation_key: obs} 26 | 27 | 28 | class FlatToDictPolicy(SimpleWrapper, Policy): 29 | """Wrap a policy that expect a flat obs so expects a dict obs.""" 30 | 31 | def __init__(self, policy, observation_key): 32 | super().__init__(policy) 33 | self.policy = policy 34 | self.observation_key = observation_key 35 | 36 | def get_action(self, observation, *args, **kwargs): 37 | flat_ob = observation[self.observation_key] 38 | return self.policy.get_action(flat_ob, *args, **kwargs) 39 | -------------------------------------------------------------------------------- /rlkit/demos/her_bc.py: -------------------------------------------------------------------------------- 1 | from rlkit.data_management.obs_dict_replay_buffer import \ 2 | ObsDictRelabelingBuffer 3 | from rlkit.torch.her.her import HER 4 | from rlkit.demos.behavior_clone import BehaviorClone 5 | 6 | 7 | class HerBC(HER, BehaviorClone): 8 | def __init__( 9 | self, 10 | *args, 11 | td3_kwargs, 12 | her_kwargs, 13 | base_kwargs, 14 | **kwargs 15 | ): 16 | HER.__init__( 17 | self, 18 | **her_kwargs, 19 | ) 20 | BehaviorClone.__init__(self, *args, **kwargs, **td3_kwargs, **base_kwargs) 21 | assert isinstance( 22 | self.replay_buffer, ObsDictRelabelingBuffer 23 | ) 24 | 25 | def _handle_rollout_ending(self): 26 | """Don't add anything to rollout buffer""" 27 | self._n_rollouts_total += 1 28 | # if len(self._current_path_builder) > 0: 29 | # path = self._current_path_builder.get_all_stacked() 30 | # self.replay_buffer.add_path(path) 31 | # self._exploration_paths.append(path) 32 | # self._current_path_builder = PathBuilder() 33 | 34 | def _handle_path(self, path): 35 | """Don't add anything to rollout buffer""" 36 | self._n_rollouts_total += 1 37 | # self.replay_buffer.add_path(path) 38 | self._exploration_paths.append(path) 39 | -------------------------------------------------------------------------------- /rlkit/envs/mujoco/hopper_env.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | 3 | from gym.envs.mujoco import HopperEnv as GymHopperEnv 4 | 5 | from rlkit.core import logger as default_logger 6 | from rlkit.core.eval_util import get_stat_in_paths, create_stats_ordered_dict 7 | 8 | 9 | class HopperEnv(GymHopperEnv): 10 | def _step(self, a): 11 | ob, reward, done, _ = super()._step(a) 12 | posafter, height, ang = self.model.data.qpos[0:3, 0] 13 | return ob, reward, done, { 14 | 'posafter': posafter, 15 | 'height': height, 16 | 'angle': ang, 17 | } 18 | 19 | def log_diagnostics(self, paths, logger=default_logger): 20 | statistics = OrderedDict() 21 | for name_in_env_infos, name_to_log in [ 22 | ('posafter', 'Position'), 23 | ('height', 'Height'), 24 | ('angle', 'Angle'), 25 | ]: 26 | stats = get_stat_in_paths(paths, 'env_infos', name_in_env_infos) 27 | statistics.update(create_stats_ordered_dict( 28 | name_to_log, 29 | stats, 30 | )) 31 | statistics.update(create_stats_ordered_dict( 32 | "Final " + name_to_log, 33 | [s[-1] for s in stats], 34 | )) 35 | for key, value in statistics.items(): 36 | logger.record_tabular(key, value) 37 | 38 | -------------------------------------------------------------------------------- /rlkit/envs/wrappers/stack_observation_env.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from gym.spaces import Box 3 | 4 | from rlkit.envs.proxy_env import ProxyEnv 5 | 6 | 7 | class StackObservationEnv(ProxyEnv): 8 | """ 9 | Env wrapper for passing history of observations as the new observation 10 | """ 11 | 12 | def __init__( 13 | self, 14 | env, 15 | stack_obs=1, 16 | ): 17 | ProxyEnv.__init__(self, env) 18 | self.stack_obs = stack_obs 19 | low = env.observation_space.low 20 | high = env.observation_space.high 21 | self.obs_dim = low.size 22 | self._last_obs = np.zeros((self.stack_obs, self.obs_dim)) 23 | self.observation_space = Box( 24 | low=np.repeat(low, stack_obs), 25 | high=np.repeat(high, stack_obs), 26 | ) 27 | 28 | def reset(self): 29 | self._last_obs = np.zeros((self.stack_obs, self.obs_dim)) 30 | next_obs = self._wrapped_env.reset() 31 | self._last_obs[-1, :] = next_obs 32 | return self._last_obs.copy().flatten() 33 | 34 | def step(self, action): 35 | next_obs, reward, done, info = self._wrapped_env.step(action) 36 | self._last_obs = np.vstack(( 37 | self._last_obs[1:, :], 38 | next_obs 39 | )) 40 | return self._last_obs.copy().flatten(), reward, done, info 41 | 42 | 43 | -------------------------------------------------------------------------------- /rlkit/samplers/in_place.py: -------------------------------------------------------------------------------- 1 | from rlkit.samplers.util import rollout 2 | 3 | 4 | class InPlacePathSampler(object): 5 | """ 6 | A sampler that does not serialization for sampling. Instead, it just uses 7 | the current policy and environment as-is. 8 | 9 | WARNING: This will affect the environment! So 10 | ``` 11 | sampler = InPlacePathSampler(env, ...) 12 | sampler.obtain_samples # this has side-effects: env will change! 13 | ``` 14 | """ 15 | def __init__(self, env, policy, max_samples, max_path_length, render=False): 16 | self.env = env 17 | self.policy = policy 18 | self.max_path_length = max_path_length 19 | self.max_samples = max_samples 20 | self.render = render 21 | assert max_samples >= max_path_length, "Need max_samples >= max_path_length" 22 | 23 | def start_worker(self): 24 | pass 25 | 26 | def shutdown_worker(self): 27 | pass 28 | 29 | def obtain_samples(self): 30 | paths = [] 31 | n_steps_total = 0 32 | while n_steps_total + self.max_path_length <= self.max_samples: 33 | path = rollout( 34 | self.env, self.policy, max_path_length=self.max_path_length, 35 | animated=self.render 36 | ) 37 | paths.append(path) 38 | n_steps_total += len(path['observations']) 39 | return paths 40 | -------------------------------------------------------------------------------- /rlkit/exploration_strategies/gaussian_and_epsilon.py: -------------------------------------------------------------------------------- 1 | import random 2 | from rlkit.exploration_strategies.base import RawExplorationStrategy 3 | import numpy as np 4 | 5 | 6 | class GaussianAndEpsilonStrategy(RawExplorationStrategy): 7 | """ 8 | With probability epsilon, take a completely random action. 9 | with probability 1-epsilon, add Gaussian noise to the action taken by a 10 | deterministic policy. 11 | """ 12 | def __init__(self, action_space, epsilon, max_sigma=1.0, min_sigma=None, 13 | decay_period=1000000): 14 | assert len(action_space.shape) == 1 15 | if min_sigma is None: 16 | min_sigma = max_sigma 17 | self._max_sigma = max_sigma 18 | self._epsilon = epsilon 19 | self._min_sigma = min_sigma 20 | self._decay_period = decay_period 21 | self._action_space = action_space 22 | 23 | def get_action_from_raw_action(self, action, t=None, **kwargs): 24 | if random.random() < self._epsilon: 25 | return self._action_space.sample() 26 | else: 27 | sigma = self._max_sigma - (self._max_sigma - self._min_sigma) * min(1.0, t * 1.0 / self._decay_period) 28 | return np.clip( 29 | action + np.random.normal(size=len(action)) * sigma, 30 | self._action_space.low, 31 | self._action_space.high, 32 | ) 33 | -------------------------------------------------------------------------------- /rlkit/core/ray_csv_logger.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | import os 3 | 4 | from ray.tune.logger import CSVLogger 5 | 6 | 7 | class SequentialCSVLogger(CSVLogger): 8 | """CSVLogger to be used with SequentialRayExperiment 9 | 10 | on receiving a log_dict with next_algo=True, a new csv progress file we be 11 | used. 12 | """ 13 | def _init(self): 14 | self.created_logger = False 15 | # We need to create an initial self._file to make Ray happy... 16 | self.setup_new_logger('temp_progress.csv') 17 | 18 | def setup_new_logger(self, csv_fname): 19 | self.csv_fname = csv_fname 20 | if self.created_logger: 21 | self._file.close() 22 | self.created_logger = True 23 | progress_file = os.path.join(self.logdir, csv_fname) 24 | self._continuing = os.path.exists(progress_file) 25 | self._file = open(progress_file, "a") 26 | self._csv_out = None 27 | 28 | def on_result(self, result): 29 | if 'log_fname' in result and result['log_fname'] != self.csv_fname: 30 | self.setup_new_logger(result['log_fname']) 31 | """ 32 | Sort the keys to enforce deterministic ordering during resuming. 33 | """ 34 | sorted_result = OrderedDict() 35 | for key in sorted(result.keys()): 36 | sorted_result[key] = result[key] 37 | super().on_result(sorted_result) 38 | -------------------------------------------------------------------------------- /rlkit/envs/pearl_envs/ant_multitask_base.py: -------------------------------------------------------------------------------- 1 | from rlkit.envs.pearl_envs.ant import AntEnv 2 | 3 | 4 | class MultitaskAntEnv(AntEnv): 5 | def __init__(self, task={}, n_tasks=2, **kwargs): 6 | self._task = task 7 | self.tasks = self.sample_tasks(n_tasks) 8 | self._goal = self.tasks[0]['goal'] 9 | super(MultitaskAntEnv, self).__init__(**kwargs) 10 | 11 | """ 12 | def step(self, action): 13 | xposbefore = self.sim.data.qpos[0] 14 | self.do_simulation(action, self.frame_skip) 15 | xposafter = self.sim.data.qpos[0] 16 | 17 | forward_vel = (xposafter - xposbefore) / self.dt 18 | forward_reward = -1.0 * abs(forward_vel - self._goal_vel) 19 | ctrl_cost = 0.5 * 1e-1 * np.sum(np.square(action)) 20 | 21 | observation = self._get_obs() 22 | reward = forward_reward - ctrl_cost 23 | done = False 24 | infos = dict(reward_forward=forward_reward, 25 | reward_ctrl=-ctrl_cost, task=self._task) 26 | return (observation, reward, done, infos) 27 | """ 28 | 29 | 30 | def get_all_task_idx(self): 31 | return range(len(self.tasks)) 32 | 33 | def reset_task(self, idx): 34 | try: 35 | self._task = self.tasks[idx] 36 | except IndexError as e: 37 | import ipdb; ipdb.set_trace() 38 | self._goal = self._task['goal'] # assume parameterization of task by single vector 39 | self.reset() 40 | -------------------------------------------------------------------------------- /rlkit/data_management/path_builder.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | class PathBuilder(dict): 5 | """ 6 | Usage: 7 | ``` 8 | path_builder = PathBuilder() 9 | path.add_sample( 10 | observations=1, 11 | actions=2, 12 | next_observations=3, 13 | ... 14 | ) 15 | path.add_sample( 16 | observations=4, 17 | actions=5, 18 | next_observations=6, 19 | ... 20 | ) 21 | 22 | path = path_builder.get_all_stacked() 23 | 24 | path['observations'] 25 | # output: [1, 4] 26 | path['actions'] 27 | # output: [2, 5] 28 | ``` 29 | 30 | Note that the key should be "actions" and not "action" since the 31 | resulting dictionary will have those keys. 32 | """ 33 | 34 | def __init__(self): 35 | super().__init__() 36 | self._path_length = 0 37 | 38 | def add_all(self, **key_to_value): 39 | for k, v in key_to_value.items(): 40 | if k not in self: 41 | self[k] = [v] 42 | else: 43 | self[k].append(v) 44 | self._path_length += 1 45 | 46 | def get_all_stacked(self): 47 | output_dict = dict() 48 | for k, v in self.items(): 49 | output_dict[k] = stack_list(v) 50 | return output_dict 51 | 52 | def __len__(self): 53 | return self._path_length 54 | 55 | 56 | def stack_list(lst): 57 | if isinstance(lst[0], dict): 58 | return lst 59 | else: 60 | return np.array(lst) 61 | -------------------------------------------------------------------------------- /rlkit/utils/path_builder.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | global paths 4 | 5 | 6 | class PathBuilder(dict): 7 | """ 8 | Usage: 9 | ``` 10 | path_builder = PathBuilder() 11 | path.add_sample( 12 | observations=1, 13 | actions=2, 14 | next_observations=3, 15 | ... 16 | ) 17 | path.add_sample( 18 | observations=4, 19 | actions=5, 20 | next_observations=6, 21 | ... 22 | ) 23 | 24 | path = path_builder.get_all_stacked() 25 | 26 | path['observations'] 27 | # output: [1, 4] 28 | path['actions'] 29 | # output: [2, 5] 30 | ``` 31 | 32 | Note that the key should be "actions" and not "action" since the 33 | resulting dictionary will have those keys. 34 | """ 35 | 36 | def __init__(self): 37 | super().__init__() 38 | self._path_length = 0 39 | 40 | def add_all(self, **key_to_value): 41 | for k, v in key_to_value.items(): 42 | if k not in self: 43 | self[k] = [v] 44 | else: 45 | self[k].append(v) 46 | self._path_length += 1 47 | 48 | def get_all_stacked(self): 49 | output_dict = dict() 50 | for k, v in self.items(): 51 | output_dict[k] = stack_list(v) 52 | return output_dict 53 | 54 | def __len__(self): 55 | return self._path_length 56 | 57 | 58 | def stack_list(lst): 59 | if isinstance(lst[0], dict): 60 | return lst 61 | else: 62 | return np.array(lst) 63 | -------------------------------------------------------------------------------- /rlkit/demos/source/hand_demo_source.py: -------------------------------------------------------------------------------- 1 | from rlkit.demos.source.demo_source import DemoSource 2 | import pickle 3 | 4 | from rlkit.data_management.path_builder import PathBuilder 5 | 6 | from rlkit.util.io import load_local_or_remote_file 7 | 8 | class HandDemoSource(DemoSource): 9 | def __init__(self, filename): 10 | self.data = load_local_or_remote_file(filename) 11 | 12 | def load_paths(self): 13 | paths = [] 14 | for i in range(len(self.data)): 15 | p = self.data[i] 16 | H = len(p["observations"]) - 1 17 | 18 | path_builder = PathBuilder() 19 | 20 | for t in range(H): 21 | p["observations"][t] 22 | 23 | ob = path["observations"][t, :] 24 | action = path["actions"][t, :] 25 | reward = path["rewards"][t] 26 | next_ob = path["observations"][t+1, :] 27 | terminal = 0 28 | agent_info = {} # todo (need to unwrap each key) 29 | env_info = {} # todo (need to unwrap each key) 30 | 31 | path_builder.add_all( 32 | observations=ob, 33 | actions=action, 34 | rewards=reward, 35 | next_observations=next_ob, 36 | terminals=terminal, 37 | agent_infos=agent_info, 38 | env_infos=env_info, 39 | ) 40 | 41 | path = path_builder.get_all_stacked() 42 | paths.append(path) 43 | return paths 44 | -------------------------------------------------------------------------------- /rlkit/envs/simple/point.py: -------------------------------------------------------------------------------- 1 | import os 2 | from gym import spaces 3 | import numpy as np 4 | import gym 5 | 6 | 7 | class Point(gym.Env): 8 | """Superclass for all MuJoCo environments. 9 | """ 10 | 11 | def __init__(self, n=2, action_scale=0.2, fixed_goal=None): 12 | self.fixed_goal = fixed_goal 13 | self.n = n 14 | self.action_scale = action_scale 15 | self.goal = np.zeros((n,)) 16 | self.state = np.zeros((n,)) 17 | 18 | @property 19 | def action_space(self): 20 | return spaces.Box( 21 | low=-1*np.ones((self.n,)), 22 | high=1*np.ones((self.n,)) 23 | ) 24 | 25 | @property 26 | def observation_space(self): 27 | return spaces.Box( 28 | low=-5*np.ones((2 * self.n,)), 29 | high=5*np.ones((2 * self.n,)) 30 | ) 31 | 32 | def reset(self): 33 | self.state = np.zeros((self.n,)) 34 | if self.fixed_goal is None: 35 | self.goal = np.random.uniform(-5, 5, size=(self.n,)) 36 | else: 37 | self.goal = np.array(self.fixed_goal) 38 | return self._get_obs() 39 | 40 | def step(self, action): 41 | action = np.clip(action, -1, 1) * self.action_scale 42 | new_state = self.state + action 43 | new_state = np.clip(new_state, -5, 5) 44 | self.state = new_state 45 | reward = -np.linalg.norm(new_state - self.goal) 46 | 47 | return self._get_obs(), reward, False, {} 48 | 49 | def _get_obs(self): 50 | return np.concatenate([self.state, self.goal]) 51 | -------------------------------------------------------------------------------- /rlkit/data_management/wrappers/proxy_buffer.py: -------------------------------------------------------------------------------- 1 | from rlkit.data_management.replay_buffer import ReplayBuffer 2 | 3 | class ProxyBuffer(ReplayBuffer): 4 | def __init__(self, replay_buffer): 5 | self._wrapped_buffer = replay_buffer 6 | 7 | def add_sample(self, *args, **kwargs): 8 | self._wrapped_buffer.add_sample(*args, **kwargs) 9 | 10 | def terminate_episode(self, *args, **kwargs): 11 | self._wrapped_buffer.terminate_episode(*args, **kwargs) 12 | 13 | def num_steps_can_sample(self, *args, **kwargs): 14 | self._wrapped_buffer.num_steps_can_sample(*args, **kwargs) 15 | 16 | def add_path(self, *args, **kwargs): 17 | self._wrapped_buffer.add_path(*args, **kwargs) 18 | 19 | def add_paths(self, *args, **kwargs): 20 | self._wrapped_buffer.add_paths(*args, **kwargs) 21 | 22 | def random_batch(self, *args, **kwargs): 23 | self._wrapped_buffer.random_batch(*args, **kwargs) 24 | 25 | def get_diagnostics(self, *args, **kwargs): 26 | return self._wrapped_buffer.get_diagnostics(*args, **kwargs) 27 | 28 | def get_snapshot(self, *args, **kwargs): 29 | return self._wrapped_buffer.get_snapshot(*args, **kwargs) 30 | 31 | def end_epoch(self, *args, **kwargs): 32 | return self._wrapped_buffer.end_epoch(*args, **kwargs) 33 | 34 | @property 35 | def wrapped_buffer(self): 36 | return self._wrapped_buffer 37 | 38 | def __getattr__(self, attr): 39 | if attr == '_wrapped_buffer': 40 | raise AttributeError() 41 | return getattr(self._wrapped_buffer, attr) 42 | -------------------------------------------------------------------------------- /rlkit/envs/wrappers/history_env.py: -------------------------------------------------------------------------------- 1 | from collections import deque 2 | 3 | import numpy as np 4 | from gym import Env 5 | from gym.spaces import Box 6 | 7 | from rlkit.envs.proxy_env import ProxyEnv 8 | 9 | 10 | class HistoryEnv(ProxyEnv, Env): 11 | def __init__(self, wrapped_env, history_len): 12 | super().__init__(wrapped_env) 13 | self.history_len = history_len 14 | 15 | high = np.inf * np.ones( 16 | self.history_len * self.observation_space.low.size) 17 | low = -high 18 | self.observation_space = Box(low=low, 19 | high=high, 20 | ) 21 | self.history = deque(maxlen=self.history_len) 22 | 23 | def step(self, action): 24 | state, reward, done, info = super().step(action) 25 | self.history.append(state) 26 | flattened_history = self._get_history().flatten() 27 | return flattened_history, reward, done, info 28 | 29 | def reset(self, **kwargs): 30 | state = super().reset() 31 | self.history = deque(maxlen=self.history_len) 32 | self.history.append(state) 33 | flattened_history = self._get_history().flatten() 34 | return flattened_history 35 | 36 | def _get_history(self): 37 | observations = list(self.history) 38 | 39 | obs_count = len(observations) 40 | for _ in range(self.history_len - obs_count): 41 | dummy = np.zeros(self._wrapped_env.observation_space.low.size) 42 | observations.append(dummy) 43 | return np.c_[observations] 44 | 45 | 46 | -------------------------------------------------------------------------------- /rlkit/core/timer.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | from collections import defaultdict 4 | 5 | 6 | class Timer: 7 | def __init__(self, return_global_times=False): 8 | self.stamps = None 9 | self.epoch_start_time = None 10 | self.global_start_time = time.time() 11 | self._return_global_times = return_global_times 12 | 13 | self.reset() 14 | 15 | def reset(self): 16 | self.stamps = defaultdict(lambda: 0) 17 | self.start_times = {} 18 | self.epoch_start_time = time.time() 19 | 20 | def start_timer(self, name, unique=True): 21 | if unique: 22 | assert name not in self.start_times.keys() 23 | self.start_times[name] = time.time() 24 | 25 | def stop_timer(self, name): 26 | assert name in self.start_times.keys() 27 | start_time = self.start_times[name] 28 | end_time = time.time() 29 | self.stamps[name] += (end_time - start_time) 30 | 31 | def get_times(self): 32 | global_times = {} 33 | cur_time = time.time() 34 | global_times['epoch_time'] = (cur_time - self.epoch_start_time) 35 | if self._return_global_times: 36 | global_times['global_time'] = (cur_time - self.global_start_time) 37 | return { 38 | **self.stamps.copy(), 39 | **global_times, 40 | } 41 | 42 | @property 43 | def return_global_times(self): 44 | return self._return_global_times 45 | 46 | @return_global_times.setter 47 | def return_global_times(self, value): 48 | self._return_global_times = value 49 | 50 | 51 | timer = Timer() 52 | -------------------------------------------------------------------------------- /rlkit/data_management/dataset_logger_fn.py: -------------------------------------------------------------------------------- 1 | """Creates a callable that can be passed as a post_epoch_fn to an algorithm 2 | for evaluating on specific data.""" 3 | 4 | from rlkit.core.logging import add_prefix 5 | 6 | from rlkit.torch.core import np_to_pytorch_batch 7 | import rlkit.torch.pytorch_util as ptu 8 | 9 | import numpy as np 10 | 11 | class DatasetLoggerFn: 12 | def __init__(self, dataset, fn, prefix="", batch_size=64, *args, **kwargs): 13 | self.dataset = dataset 14 | self.fn = fn 15 | self.prefix = prefix 16 | self.batch_size = batch_size 17 | self.args = args 18 | self.kwargs = kwargs 19 | 20 | def __call__(self, algo): 21 | batch = self.dataset.random_batch(self.batch_size) 22 | batch = np_to_pytorch_batch(batch) 23 | log_dict = self.fn(batch, *self.args, **self.kwargs) 24 | return add_prefix(log_dict, self.prefix) 25 | 26 | def run_bc_batch(batch, policy): 27 | o = batch["observations"] 28 | u = batch["actions"] 29 | # g = batch["resampled_goals"] 30 | # og = torch.cat((o, g), dim=1) 31 | og = o 32 | # pred_u, *_ = self.policy(og) 33 | dist = policy(og) 34 | pred_u, log_pi = dist.rsample_and_logprob() 35 | stats = dist.get_diagnostics() 36 | 37 | mse = (pred_u - u) ** 2 38 | mse_loss = np.mean(ptu.get_numpy(mse.mean())) 39 | 40 | policy_logpp = dist.log_prob(u, ) 41 | logp_loss = -policy_logpp.mean() 42 | policy_loss = np.mean(ptu.get_numpy(logp_loss)) 43 | 44 | return dict( 45 | bc_loss=policy_loss, 46 | mse_loss=mse_loss, 47 | **stats 48 | ) -------------------------------------------------------------------------------- /rlkit/data_management/wrappers/replay_buffer_wrapper.py: -------------------------------------------------------------------------------- 1 | from rlkit.data_management.replay_buffer import ReplayBuffer 2 | 3 | class ProxyBuffer(ReplayBuffer): 4 | def __init__(self, replay_buffer): 5 | self._wrapped_buffer = wrapped_buffer 6 | 7 | def add_sample(self, *args, **kwargs): 8 | self._wrapped_buffer.add_sample(*args, **kwargs) 9 | 10 | def terminate_episode(self, *args, **kwargs): 11 | self._wrapped_buffer.terminate_episode(*args, **kwargs) 12 | 13 | def num_steps_can_sample(self, *args, **kwargs): 14 | self._wrapped_buffer.num_steps_can_sample(*args, **kwargs) 15 | 16 | def add_path(self, *args, **kwargs): 17 | self._wrapped_buffer.add_path(*args, **kwargs) 18 | 19 | def add_paths(self, *args, **kwargs): 20 | self._wrapped_buffer.add_paths(*args, **kwargs) 21 | 22 | @abc.abstractmethod 23 | def random_batch(self, *args, **kwargs): 24 | self._wrapped_buffer.random_batch(*args, **kwargs) 25 | 26 | def get_diagnostics(self, *args, **kwargs): 27 | return self._wrapped_buffer.get_diagnostics(*args, **kwargs) 28 | 29 | def get_snapshot(self, *args, **kwargs): 30 | return self._wrapped_buffer.get_snapshot(*args, **kwargs) 31 | 32 | def end_epoch(self, *args, **kwargs): 33 | return self._wrapped_buffer.end_epoch(*args, **kwargs) 34 | 35 | @property 36 | def wrapped_buffer(self): 37 | return self._wrapped_buffer 38 | 39 | def __getattr__(self, attr): 40 | if attr == '_wrapped_buffer': 41 | raise AttributeError() 42 | return getattr(self._wrapped_buffer, attr) 43 | -------------------------------------------------------------------------------- /rlkit/samplers/util.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from rlkit.samplers.rollout_functions import deprecated_rollout as normal_rollout 3 | 4 | def rollout(*args, **kwargs): 5 | # TODO Steven: remove pointer 6 | return normal_rollout(*args, **kwargs) 7 | 8 | def split_paths(paths): 9 | """ 10 | Stack multiples obs/actions/etc. from different paths 11 | :param paths: List of paths, where one path is something returned from 12 | the rollout functino above. 13 | :return: Tuple. Every element will have shape batch_size X DIM, including 14 | the rewards and terminal flags. 15 | """ 16 | rewards = [path["rewards"].reshape(-1, 1) for path in paths] 17 | terminals = [path["terminals"].reshape(-1, 1) for path in paths] 18 | actions = [path["actions"] for path in paths] 19 | obs = [path["observations"] for path in paths] 20 | next_obs = [path["next_observations"] for path in paths] 21 | rewards = np.vstack(rewards) 22 | terminals = np.vstack(terminals) 23 | obs = np.vstack(obs) 24 | actions = np.vstack(actions) 25 | next_obs = np.vstack(next_obs) 26 | assert len(rewards.shape) == 2 27 | assert len(terminals.shape) == 2 28 | assert len(obs.shape) == 2 29 | assert len(actions.shape) == 2 30 | assert len(next_obs.shape) == 2 31 | return rewards, terminals, obs, actions, next_obs 32 | 33 | 34 | def split_paths_to_dict(paths): 35 | rewards, terminals, obs, actions, next_obs = split_paths(paths) 36 | return dict( 37 | rewards=rewards, 38 | terminals=terminals, 39 | observations=obs, 40 | actions=actions, 41 | next_observations=next_obs, 42 | ) 43 | -------------------------------------------------------------------------------- /rlkit/demos/spacemouse/README.md: -------------------------------------------------------------------------------- 1 | # Spacemouse Demonstrations 2 | 3 | This code allows an agent (including physical robots) to be controlled by a 7 degree-of-freedom 3Dconnexion Spacemouse device. 4 | 5 | ## Setup Instructions 6 | 7 | ### Spacemouse (on a Mac) 8 | 9 | 1. Clone robosuite ([https://github.com/anair13/robosuite](https://github.com/anair13/robosuite)) and add it to the python path 10 | 2. Run `pip install hidapi` and install Spacemouse drivers 11 | 2. Ensure you can run the following file (and 12 | see input values from the spacemouse): `robosuite/devices/spacemouse.py` 13 | 4. Follow the example in `railrl/demos/collect_demo.py` to collect demonstrations 14 | 15 | ### Server 16 | We haven't been able to install Spacemouse drivers for Linux but instead we use a Spacemouse on a Mac ("client") and send messages over a network to a Linux machine ("server"). 17 | 18 | #### Setup 19 | On the client, run the setup above. On the server, run: 20 | 1. Run `pip install Pyro4` 21 | 2. Make sure the hostname in `railrl/demos/spacemouse/config.py` is correct (in the example I use gauss1.banatao.berkeley.edu). This hostname needs to be visible (eg. you can ping it) from both the client and server 22 | 23 | #### Run 24 | 25 | 1. On the server, start the nameserver: 26 | ```export PYRO_SERIALIZERS_ACCEPTED=serpent,json,marshal,pickle 27 | python -m Pyro4.naming -n euler1.dyn.berkeley.edu 28 | ``` 29 | 2. On the server, run a script that uses the `SpaceMouseExpert` imported from `railrl/demos/spacemouse/input_server.py` such as ```python experiments/ashvin/iros2019/collect_demos_spacemouse.py``` 30 | 2. On the client, run ```python railrl/demos/spacemouse/input_client.py``` 31 | -------------------------------------------------------------------------------- /rlkit/launchers/contextual/util.py: -------------------------------------------------------------------------------- 1 | from os import path as osp 2 | 3 | from gym.wrappers import TimeLimit 4 | 5 | from rlkit.core import logger 6 | from rlkit.visualization.video import dump_video 7 | 8 | 9 | def get_save_video_function( 10 | rollout_function, 11 | env, 12 | policy, 13 | save_video_period=10, 14 | imsize=48, 15 | tag="", 16 | **dump_video_kwargs 17 | ): 18 | logdir = logger.get_snapshot_dir() 19 | 20 | def save_video(algo, epoch): 21 | if epoch % save_video_period == 0 or epoch >= algo.num_epochs - 1: 22 | if tag is not None and len(tag) > 0: 23 | filename = 'video_{}_{epoch}_env.mp4'.format(tag, epoch=epoch) 24 | else: 25 | filename = 'video_{epoch}_env.mp4'.format(epoch=epoch) 26 | filename = osp.join( 27 | logdir, 28 | filename, 29 | ) 30 | dump_video(env, policy, filename, rollout_function, 31 | imsize=imsize, **dump_video_kwargs) 32 | return save_video 33 | 34 | 35 | def get_gym_env( 36 | env_id, 37 | env_class=None, 38 | env_kwargs=None, 39 | unwrap_timed_envs=False, 40 | ): 41 | if env_kwargs is None: 42 | env_kwargs = {} 43 | 44 | assert env_id or env_class 45 | 46 | if env_id: 47 | import gym 48 | import multiworld 49 | multiworld.register_all_envs() 50 | env = gym.make(env_id) 51 | else: 52 | env = env_class(**env_kwargs) 53 | 54 | if isinstance(env, TimeLimit) and unwrap_timed_envs: 55 | env = env.env 56 | 57 | return env 58 | -------------------------------------------------------------------------------- /rlkit/data_management/images.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | class _ImageNumpyArr: 5 | """ 6 | Wrapper for a numpy array. This code automatically normalizes/unormalizes 7 | the image internally. This process should be completely hidden from the 8 | user of this class. 9 | 10 | im_arr = image_numpy_wrapper.zeros(10) 11 | # img is normalized. ImageNumpyArr automatically stores as np.uint8 12 | im_arr[2] = img 13 | # ImageNumpyArr automatically normalizes the np.uint8 and returns as np.float 14 | img_2 = im_arr[2] 15 | """ 16 | 17 | def __init__(self, np_array): 18 | assert np_array.dtype == np.uint8 19 | self.np_array = np_array 20 | self.shape = self.np_array.shape 21 | self.size = self.np_array.size 22 | self.dtype = np.uint8 23 | 24 | def __getitem__(self, idxs): 25 | return normalize_image(self.np_array[idxs], dtype=np.float32) 26 | 27 | def __setitem__(self, idxs, value): 28 | if value.dtype != np.uint8: 29 | self.np_array[idxs] = unnormalize_image(value) 30 | else: 31 | self.np_array[idxs] = value 32 | 33 | 34 | def zeros(shape, *args, **kwargs): 35 | arr = np.zeros(shape, dtype=np.uint8) 36 | return _ImageNumpyArr(arr) 37 | 38 | 39 | def ones(shape, *args, **kwargs): 40 | arr = np.ones(shape, dtype=np.uint8) 41 | return _ImageNumpyArr(arr) 42 | 43 | 44 | def from_np(np_arr): 45 | return _ImageNumpyArr(np_arr) 46 | 47 | 48 | def normalize_image(image, dtype=np.float64): 49 | assert image.dtype == np.uint8 50 | return dtype(image) / 255.0 51 | 52 | 53 | def unnormalize_image(image): 54 | assert image.dtype != np.uint8 55 | return np.uint8(image * 255.0) 56 | -------------------------------------------------------------------------------- /rlkit/torch/networks/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | General networks for pytorch. 3 | 4 | Algorithm-specific networks should go else-where. 5 | """ 6 | from rlkit.torch.networks.basic import ( 7 | Sigmoid, Clamp, SigmoidClamp, ConcatTuple, Detach, Flatten, FlattenEach, Split, Reshape, 8 | ) 9 | from rlkit.torch.networks.cnn import BasicCNN, CNN, MergedCNN, CNNPolicy, TwoChannelCNN, ConcatTwoChannelCNN 10 | from rlkit.torch.networks.dcnn import DCNN, TwoHeadDCNN 11 | from rlkit.torch.networks.deprecated_feedforward import ( 12 | FeedForwardPolicy, FeedForwardQFunction 13 | ) 14 | from rlkit.torch.networks.feat_point_mlp import FeatPointMlp 15 | from rlkit.torch.networks.image_state import ImageStatePolicy, ImageStateQ 16 | from rlkit.torch.networks.linear_transform import LinearTransform 17 | from rlkit.torch.networks.mlp import ( 18 | Mlp, ConcatMlp, MlpPolicy, TanhMlpPolicy, 19 | MlpQf, 20 | MlpQfWithObsProcessor, 21 | ConcatMultiHeadedMlp, 22 | ) 23 | from rlkit.torch.networks.pretrained_cnn import PretrainedCNN 24 | from rlkit.torch.networks.two_headed_mlp import TwoHeadMlp 25 | 26 | __all__ = [ 27 | 'Sigmoid', 28 | 'Clamp', 29 | 'SigmoidClamp', 30 | 'ConcatMlp', 31 | 'ConcatMultiHeadedMlp', 32 | 'ConcatTuple', 33 | 'BasicCNN', 34 | 'CNN', 35 | 'TwoChannelCNN', 36 | "ConcatTwoChannelCNN", 37 | 'CNNPolicy', 38 | 'DCNN', 39 | 'Detach', 40 | 'FeedForwardPolicy', 41 | 'FeedForwardQFunction', 42 | 'FeatPointMlp', 43 | 'Flatten', 44 | 'FlattenEach', 45 | 'LinearTransform', 46 | 'ImageStatePolicy', 47 | 'ImageStateQ', 48 | 'MergedCNN', 49 | 'Mlp', 50 | 'PretrainedCNN', 51 | 'Reshape', 52 | 'Split', 53 | 'TwoHeadDCNN', 54 | 'TwoHeadMlp', 55 | ] 56 | 57 | -------------------------------------------------------------------------------- /rlkit/envs/assets/twod_point.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | -------------------------------------------------------------------------------- /rlkit/envs/assets/twod_point_random_init.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | -------------------------------------------------------------------------------- /rlkit/samplers/data_collector/contextual_path_collector.py: -------------------------------------------------------------------------------- 1 | from collections import deque 2 | from functools import partial 3 | 4 | from rlkit.envs.contextual import ContextualEnv 5 | from rlkit.policies.base import Policy 6 | from rlkit.samplers.data_collector import MdpPathCollector 7 | from rlkit.samplers.rollout_functions import contextual_rollout 8 | 9 | 10 | class ContextualPathCollector(MdpPathCollector): 11 | def __init__( 12 | self, 13 | env: ContextualEnv, 14 | policy: Policy, 15 | max_num_epoch_paths_saved=None, 16 | observation_keys=('observation',), 17 | context_keys_for_policy='context', 18 | render=False, 19 | render_kwargs=None, 20 | obs_processor=None, 21 | rollout=contextual_rollout, 22 | **kwargs 23 | ): 24 | rollout_fn = partial( 25 | rollout, 26 | context_keys_for_policy=context_keys_for_policy, 27 | observation_keys=observation_keys, 28 | obs_processor=obs_processor, 29 | ) 30 | super().__init__( 31 | env, policy, 32 | max_num_epoch_paths_saved=max_num_epoch_paths_saved, 33 | render=render, 34 | render_kwargs=render_kwargs, 35 | rollout_fn=rollout_fn, 36 | **kwargs 37 | ) 38 | self._observation_keys = observation_keys 39 | self._context_keys_for_policy = context_keys_for_policy 40 | 41 | def end_epoch(self, epoch): 42 | self._epoch_paths = deque(maxlen=self._max_num_epoch_paths_saved) 43 | 44 | def get_snapshot(self): 45 | snapshot = super().get_snapshot() 46 | snapshot.update( 47 | observation_keys=self._observation_keys, 48 | context_keys_for_policy=self._context_keys_for_policy, 49 | ) 50 | return snapshot 51 | -------------------------------------------------------------------------------- /rlkit/envs/pearl_envs/ant_goal.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from rlkit.envs.pearl_envs.ant_multitask_base import MultitaskAntEnv 4 | 5 | 6 | # Copy task structure from https://github.com/jonasrothfuss/ProMP/blob/master/meta_policy_search/envs/mujoco_envs/ant_rand_goal.py 7 | class AntGoalEnv(MultitaskAntEnv): 8 | def __init__(self, task={}, n_tasks=2, randomize_tasks=True, **kwargs): 9 | self.quick_init(locals()) 10 | super(AntGoalEnv, self).__init__(task, n_tasks, **kwargs) 11 | 12 | def step(self, action): 13 | self.do_simulation(action, self.frame_skip) 14 | xposafter = np.array(self.get_body_com("torso")) 15 | 16 | goal_reward = -np.sum(np.abs(xposafter[:2] - self._goal)) # make it happy, not suicidal 17 | 18 | ctrl_cost = .1 * np.square(action).sum() 19 | contact_cost = 0.5 * 1e-3 * np.sum( 20 | np.square(np.clip(self.sim.data.cfrc_ext, -1, 1))) 21 | survive_reward = 0.0 22 | reward = goal_reward - ctrl_cost - contact_cost + survive_reward 23 | state = self.state_vector() 24 | done = False 25 | ob = self._get_obs() 26 | return ob, reward, done, dict( 27 | goal_forward=goal_reward, 28 | reward_ctrl=-ctrl_cost, 29 | reward_contact=-contact_cost, 30 | reward_survive=survive_reward, 31 | ) 32 | 33 | def sample_tasks(self, num_tasks): 34 | a = np.random.random(num_tasks) * 2 * np.pi 35 | r = 3 * np.random.random(num_tasks) ** 0.5 36 | goals = np.stack((r * np.cos(a), r * np.sin(a)), axis=-1) 37 | tasks = [{'goal': goal} for goal in goals] 38 | return tasks 39 | 40 | def _get_obs(self): 41 | return np.concatenate([ 42 | self.sim.data.qpos.flat, 43 | self.sim.data.qvel.flat, 44 | np.clip(self.sim.data.cfrc_ext, -1, 1).flat, 45 | ]) 46 | -------------------------------------------------------------------------------- /rlkit/envs/robosuite_wrapper.py: -------------------------------------------------------------------------------- 1 | from gym import Env 2 | from gym.spaces import Box 3 | import numpy as np 4 | import robosuite as suite 5 | from rlkit.core.serializeable import Serializable 6 | 7 | 8 | class RobosuiteStateWrapperEnv(Serializable, Env): 9 | def __init__(self, wrapped_env_id, observation_keys=('robot-state', 'object-state'), **wrapped_env_kwargs): 10 | Serializable.quick_init(self, locals()) 11 | self._wrapped_env = suite.make( 12 | wrapped_env_id, 13 | **wrapped_env_kwargs 14 | ) 15 | self.action_space = Box(self._wrapped_env.action_spec[0], self._wrapped_env.action_spec[1], dtype=np.float32) 16 | observation_dim = self._wrapped_env.observation_spec()['robot-state'].shape[0] \ 17 | + self._wrapped_env.observation_spec()['object-state'].shape[0] 18 | self.observation_space = Box( 19 | -np.inf * np.ones(observation_dim), 20 | np.inf * np.ones(observation_dim), 21 | dtype=np.float32, 22 | ) 23 | self._observation_keys = observation_keys 24 | 25 | def step(self, action): 26 | obs, reward, done, info = self._wrapped_env.step(action) 27 | obs = self.flatten_dict_obs(obs) 28 | return obs, reward, done, info 29 | 30 | def flatten_dict_obs(self, obs): 31 | obs = np.concatenate(tuple(obs[k] for k in self._observation_keys)) 32 | return obs 33 | 34 | def reset(self): 35 | obs = self._wrapped_env.reset() 36 | obs = self.flatten_dict_obs(obs) 37 | return obs 38 | 39 | def render(self): 40 | self._wrapped_env.render() 41 | 42 | def __getattr__(self, attr): 43 | if attr == '_wrapped_env': 44 | raise AttributeError() 45 | return getattr(self._wrapped_env, attr) 46 | 47 | def __str__(self): 48 | return '{}({})'.format(type(self).__name__, self.wrapped_env) 49 | 50 | -------------------------------------------------------------------------------- /rlkit/envs/assets/twod_maze.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | -------------------------------------------------------------------------------- /rlkit/envs/gripper_state_wrapper.py: -------------------------------------------------------------------------------- 1 | # import Dict 2 | 3 | import numpy as np 4 | from gym.spaces import Box, Dict 5 | 6 | from rlkit.envs.wrappers import ProxyEnv 7 | from roboverse.bullet.misc import quat_to_deg 8 | 9 | 10 | def process_gripper_state(gripper_state): 11 | gripper_pos = gripper_state[:3] 12 | gripper_ori = quat_to_deg(gripper_state[3:7]) / 360.0 13 | gripper_tips_distance = gripper_state[7:8] 14 | return np.concatenate([gripper_pos, gripper_ori, gripper_tips_distance], axis=0) 15 | 16 | 17 | class GripperStateWrappedEnv(ProxyEnv): 18 | def __init__(self, 19 | wrapped_env, 20 | state_key, 21 | step_keys_map=None, 22 | ): 23 | super().__init__(wrapped_env) 24 | self.state_key = state_key 25 | self.gripper_state_size = 7 26 | gripper_state_space = Box( 27 | -1 * np.ones(self.gripper_state_size), 28 | 1 * np.ones(self.gripper_state_size), 29 | dtype=np.float32, 30 | ) 31 | 32 | if step_keys_map is None: 33 | step_keys_map = {} 34 | self.step_keys_map = step_keys_map 35 | spaces = self.wrapped_env.observation_space.spaces 36 | for value in self.step_keys_map.values(): 37 | spaces[value] = gripper_state_space 38 | self.observation_space = Dict(spaces) 39 | 40 | def step(self, action): 41 | obs, reward, done, info = self.wrapped_env.step(action) 42 | new_obs = self._update_obs(obs) 43 | return new_obs, reward, done, info 44 | 45 | def _update_obs(self, obs): 46 | for key in self.step_keys_map: 47 | value = self.step_keys_map[key] 48 | gripper_state = obs[self.state_key] 49 | process_gripper_state(gripper_state) 50 | obs[value] = process_gripper_state(gripper_state) 51 | obs = {**obs, **self.reset_obs} 52 | return obs 53 | -------------------------------------------------------------------------------- /rlkit/envs/proxy_env.py: -------------------------------------------------------------------------------- 1 | from gym import Env 2 | 3 | 4 | class ProxyEnv(Env): 5 | def __init__(self, wrapped_env): 6 | self._wrapped_env = wrapped_env 7 | self.action_space = self._wrapped_env.action_space 8 | self.observation_space = self._wrapped_env.observation_space 9 | 10 | @property 11 | def wrapped_env(self): 12 | return self._wrapped_env 13 | 14 | def reset(self, **kwargs): 15 | return self._wrapped_env.reset(**kwargs) 16 | 17 | def step(self, action): 18 | return self._wrapped_env.step(action) 19 | 20 | def render(self, *args, **kwargs): 21 | return self._wrapped_env.render(*args, **kwargs) 22 | 23 | @property 24 | def horizon(self): 25 | return self._wrapped_env.horizon 26 | 27 | def terminate(self): 28 | if hasattr(self.wrapped_env, "terminate"): 29 | self.wrapped_env.terminate() 30 | 31 | def seed(self, _seed): 32 | return self.wrapped_env.seed(_seed) 33 | 34 | def __getattr__(self, attr): 35 | if attr == '_wrapped_env': 36 | raise AttributeError() 37 | if attr == 'planner': 38 | return self._planner 39 | if attr == 'set_vf': 40 | return self.set_vf 41 | return getattr(self._wrapped_env, attr) 42 | # try: 43 | # getattr(self, attr) 44 | # except Exception: 45 | # return getattr(self._wrapped_env, attr) 46 | 47 | def __getstate__(self): 48 | """ 49 | This is useful to override in case the wrapped env has some funky 50 | __getstate__ that doesn't play well with overriding __getattr__. 51 | 52 | The main problematic case is/was gym's EzPickle serialization scheme. 53 | :return: 54 | """ 55 | return self.__dict__ 56 | 57 | def __setstate__(self, state): 58 | self.__dict__.update(state) 59 | 60 | def __str__(self): 61 | return '{}({})'.format(type(self).__name__, self.wrapped_env) 62 | -------------------------------------------------------------------------------- /rlkit/envs/pearl_envs/ant_dir.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from rlkit.envs.pearl_envs.ant_multitask_base import MultitaskAntEnv 4 | 5 | 6 | class AntDirEnv(MultitaskAntEnv): 7 | 8 | def __init__(self, task={}, n_tasks=2, forward_backward=False, randomize_tasks=True, **kwargs): 9 | self.quick_init(locals()) 10 | self.forward_backward = forward_backward 11 | super(AntDirEnv, self).__init__(task, n_tasks, **kwargs) 12 | 13 | def step(self, action): 14 | torso_xyz_before = np.array(self.get_body_com("torso")) 15 | 16 | direct = (np.cos(self._goal), np.sin(self._goal)) 17 | 18 | self.do_simulation(action, self.frame_skip) 19 | torso_xyz_after = np.array(self.get_body_com("torso")) 20 | torso_velocity = torso_xyz_after - torso_xyz_before 21 | forward_reward = np.dot((torso_velocity[:2]/self.dt), direct) 22 | 23 | ctrl_cost = .5 * np.square(action).sum() 24 | contact_cost = 0.5 * 1e-3 * np.sum( 25 | np.square(np.clip(self.sim.data.cfrc_ext, -1, 1))) 26 | survive_reward = 1.0 27 | reward = forward_reward - ctrl_cost - contact_cost + survive_reward 28 | state = self.state_vector() 29 | notdone = np.isfinite(state).all() \ 30 | and state[2] >= 0.2 and state[2] <= 1.0 31 | done = not notdone 32 | ob = self._get_obs() 33 | return ob, reward, done, dict( 34 | reward_forward=forward_reward, 35 | reward_ctrl=-ctrl_cost, 36 | reward_contact=-contact_cost, 37 | reward_survive=survive_reward, 38 | torso_velocity=torso_velocity, 39 | ) 40 | 41 | def sample_tasks(self, num_tasks): 42 | if self.forward_backward: 43 | assert num_tasks == 2 44 | velocities = np.array([0., np.pi]) 45 | else: 46 | velocities = np.random.uniform(0., 2.0 * np.pi, size=(num_tasks,)) 47 | tasks = [{'goal': velocity} for velocity in velocities] 48 | return tasks 49 | -------------------------------------------------------------------------------- /rlkit/envs/images/env_renderer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from PIL import Image 3 | 4 | from rlkit.envs.images import Renderer 5 | 6 | 7 | class EnvRenderer(Renderer): 8 | # TODO: switch to env.render interface 9 | def __init__( 10 | self, 11 | init_camera=None, 12 | normalize_image=True, # most gym envs output uint8 13 | create_image_format='HWC', 14 | **kwargs 15 | ): 16 | """Render an image.""" 17 | super().__init__( 18 | normalize_image=normalize_image, 19 | create_image_format=create_image_format, 20 | **kwargs) 21 | self._init_camera = init_camera 22 | self._camera_is_initialized = False 23 | 24 | def _create_image(self, env): 25 | if not self._camera_is_initialized and self._init_camera is not None: 26 | env.initialize_camera(self._init_camera) 27 | self._camera_is_initialized = True 28 | 29 | return env.get_image( 30 | width=self.width, 31 | height=self.height, 32 | ) 33 | 34 | 35 | class GymEnvRenderer(EnvRenderer): 36 | def _create_image(self, env): 37 | if not self._camera_is_initialized and self._init_camera is not None: 38 | env.initialize_camera(self._init_camera) 39 | self._camera_is_initialized = True 40 | 41 | return env.render( 42 | mode='rgb_array', width=self.width, height=self.height 43 | ) 44 | 45 | 46 | class GymSimRenderer(EnvRenderer): 47 | def _create_image(self, env): 48 | if not self._camera_is_initialized and self._init_camera is not None: 49 | env.initialize_camera(self._init_camera) 50 | self._camera_is_initialized = True 51 | 52 | return self._get_image(env, self.width, self.height, camera_name=None) 53 | 54 | def _get_image(self, env, width=84, height=84, camera_name=None): 55 | return env.sim.render( 56 | width=width, 57 | height=height, 58 | camera_name=camera_name, 59 | )[::-1,:,:] 60 | -------------------------------------------------------------------------------- /rlkit/envs/pearl_envs/rand_param_envs/hopper_rand_params.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from gym import utils 3 | from rlkit.envs.pearl_envs.rand_param_envs.base import RandomEnv 4 | 5 | 6 | class HopperRandParamsEnv(RandomEnv, utils.EzPickle): 7 | def __init__(self, log_scale_limit=3.0): 8 | RandomEnv.__init__(self, log_scale_limit, 'hopper.xml', 4) 9 | utils.EzPickle.__init__(self) 10 | 11 | def _step(self, a): 12 | posbefore = self.sim.data.qpos[0] 13 | self.do_simulation(a, self.frame_skip) 14 | posafter, height, ang = self.sim.data.qpos[0:3] 15 | alive_bonus = 1.0 16 | reward = (posafter - posbefore) / self.dt 17 | reward += alive_bonus 18 | reward -= 1e-3 * np.square(a).sum() 19 | s = self.state_vector() 20 | done = not (np.isfinite(s).all() and (np.abs(s[2:]) < 100).all() and 21 | (height > .7) and (abs(ang) < .2)) 22 | ob = self._get_obs() 23 | return ob, reward, done, {} 24 | 25 | def _get_obs(self): 26 | return np.concatenate([ 27 | self.sim.data.qpos.flat[1:], 28 | np.clip(self.sim.data.qvel.flat, -10, 10) 29 | ]) 30 | 31 | def reset_model(self): 32 | qpos = self.init_qpos + self.np_random.uniform(low=-.005, high=.005, size=self.model.nq) 33 | qvel = self.init_qvel + self.np_random.uniform(low=-.005, high=.005, size=self.model.nv) 34 | self.set_state(qpos, qvel) 35 | return self._get_obs() 36 | 37 | def viewer_setup(self): 38 | self.viewer.cam.trackbodyid = 2 39 | self.viewer.cam.distance = self.model.stat.extent * 0.75 40 | self.viewer.cam.lookat[2] += .8 41 | self.viewer.cam.elevation = -20 42 | 43 | if __name__ == "__main__": 44 | 45 | env = HopperRandParamsEnv() 46 | tasks = env.sample_tasks(40) 47 | while True: 48 | env.reset() 49 | env.set_task(np.random.choice(tasks)) 50 | print(env.model.body_mass) 51 | for _ in range(100): 52 | env.render() 53 | env.step(env.action_space.sample()) # take a random action 54 | 55 | -------------------------------------------------------------------------------- /rlkit/envs/assets/water_maze.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | -------------------------------------------------------------------------------- /rlkit/torch/networks/image_state.py: -------------------------------------------------------------------------------- 1 | from rlkit.policies.base import Policy 2 | from rlkit.torch.core import PyTorchModule, eval_np 3 | 4 | 5 | class ImageStatePolicy(PyTorchModule, Policy): 6 | """Switches between image or state inputs""" 7 | 8 | def __init__( 9 | self, 10 | image_conv_net, 11 | state_fc_net, 12 | ): 13 | super().__init__() 14 | 15 | assert image_conv_net is None or state_fc_net is None 16 | self.image_conv_net = image_conv_net 17 | self.state_fc_net = state_fc_net 18 | 19 | def forward(self, input, return_preactivations=False): 20 | if self.image_conv_net is not None: 21 | image = input[:, :21168] 22 | return self.image_conv_net(image) 23 | if self.state_fc_net is not None: 24 | state = input[:, 21168:] 25 | return self.state_fc_net(state) 26 | 27 | def get_action(self, obs_np): 28 | actions = self.get_actions(obs_np[None]) 29 | return actions[0, :], {} 30 | 31 | def get_actions(self, obs): 32 | return eval_np(self, obs) 33 | 34 | 35 | class ImageStateQ(PyTorchModule): 36 | """Switches between image or state inputs""" 37 | 38 | def __init__( 39 | self, 40 | # obs_dim, 41 | # action_dim, 42 | # goal_dim, 43 | image_conv_net, # assumed to be a MergedCNN 44 | state_fc_net, 45 | ): 46 | super().__init__() 47 | 48 | assert image_conv_net is None or state_fc_net is None 49 | # self.obs_dim = obs_dim 50 | # self.action_dim = action_dim 51 | # self.goal_dim = goal_dim 52 | self.image_conv_net = image_conv_net 53 | self.state_fc_net = state_fc_net 54 | 55 | def forward(self, input, action, return_preactivations=False): 56 | if self.image_conv_net is not None: 57 | image = input[:, :21168] 58 | return self.image_conv_net(image, action) 59 | if self.state_fc_net is not None: 60 | state = input[:, 21168:] # action + state 61 | return self.state_fc_net(state, action) 62 | 63 | 64 | -------------------------------------------------------------------------------- /rlkit/envs/pearl_envs/mujoco_env.py: -------------------------------------------------------------------------------- 1 | import os 2 | from os import path 3 | 4 | import mujoco_py 5 | import numpy as np 6 | from gym.envs.mujoco import mujoco_env 7 | 8 | from rlkit.core.serializeable import Serializable 9 | 10 | ENV_ASSET_DIR = os.path.join(os.path.dirname(__file__), 'assets') 11 | 12 | 13 | class MujocoEnv(mujoco_env.MujocoEnv, Serializable): 14 | """ 15 | My own wrapper around MujocoEnv. 16 | 17 | The caller needs to declare 18 | """ 19 | def __init__( 20 | self, 21 | model_path, 22 | frame_skip=1, 23 | model_path_is_local=True, 24 | automatically_set_obs_and_action_space=False, 25 | ): 26 | if model_path_is_local: 27 | model_path = get_asset_xml(model_path) 28 | if automatically_set_obs_and_action_space: 29 | mujoco_env.MujocoEnv.__init__(self, model_path, frame_skip) 30 | else: 31 | """ 32 | Code below is copy/pasted from MujocoEnv's __init__ function. 33 | """ 34 | if model_path.startswith("/"): 35 | fullpath = model_path 36 | else: 37 | fullpath = os.path.join(os.path.dirname(__file__), "assets", model_path) 38 | if not path.exists(fullpath): 39 | raise IOError("File %s does not exist" % fullpath) 40 | self.frame_skip = frame_skip 41 | self.model = mujoco_py.MjModel(fullpath) 42 | self.data = self.model.data 43 | self.viewer = None 44 | 45 | self.metadata = { 46 | 'render.modes': ['human', 'rgb_array'], 47 | 'video.frames_per_second': int(np.round(1.0 / self.dt)) 48 | } 49 | 50 | self.init_qpos = self.model.data.qpos.ravel().copy() 51 | self.init_qvel = self.model.data.qvel.ravel().copy() 52 | self._seed() 53 | 54 | def init_serialization(self, locals): 55 | Serializable.quick_init(self, locals) 56 | 57 | def log_diagnostics(self, *args, **kwargs): 58 | pass 59 | 60 | 61 | def get_asset_xml(xml_name): 62 | return os.path.join(ENV_ASSET_DIR, xml_name) 63 | -------------------------------------------------------------------------------- /rlkit/data_management/tau_replay_buffer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from rlkit.data_management.her_replay_buffer import HerReplayBuffer 4 | from rlkit.util.np_util import truncated_geometric 5 | 6 | 7 | class TauReplayBuffer(HerReplayBuffer): 8 | def random_batch_random_tau(self, batch_size, max_tau): 9 | indices = np.random.randint(0, self._size, batch_size) 10 | next_obs_idxs = [] 11 | for i in indices: 12 | possible_next_obs_idxs = self._idx_to_future_obs_idx[i] 13 | # This is generally faster than random.choice. Makes you wonder what 14 | # random.choice is doing 15 | num_options = len(possible_next_obs_idxs) 16 | tau = np.random.randint(0, min(max_tau+1, num_options)) 17 | if num_options == 1: 18 | next_obs_i = 0 19 | else: 20 | if self.resampling_strategy == 'uniform': 21 | next_obs_i = int(np.random.randint(0, tau+1)) 22 | elif self.resampling_strategy == 'truncated_geometric': 23 | next_obs_i = int(truncated_geometric( 24 | p=self.truncated_geom_factor/tau, 25 | truncate_threshold=num_options-1, 26 | size=1, 27 | new_value=0 28 | )) 29 | else: 30 | raise ValueError("Invalid resampling strategy: {}".format( 31 | self.resampling_strategy 32 | )) 33 | next_obs_idxs.append(possible_next_obs_idxs[next_obs_i]) 34 | next_obs_idxs = np.array(next_obs_idxs) 35 | training_goals = self.env.convert_obs_to_goals( 36 | self._next_obs[next_obs_idxs] 37 | ) 38 | return dict( 39 | observations=self._observations[indices], 40 | actions=self._actions[indices], 41 | rewards=self._rewards[indices], 42 | terminals=self._terminals[indices], 43 | next_observations=self._next_obs[indices], 44 | training_goals=training_goals, 45 | num_steps_left=self._num_steps_left[indices], 46 | ) 47 | 48 | -------------------------------------------------------------------------------- /rlkit/data_management/ocm_subtraj_replay_buffer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from rlkit.data_management.subtraj_replay_buffer import SubtrajReplayBuffer 4 | from rlkit.data_management.updatable_subtraj_replay_buffer import \ 5 | UpdatableSubtrajReplayBuffer 6 | from rlkit.util.np_util import subsequences 7 | 8 | 9 | class OcmSubtrajReplayBuffer(UpdatableSubtrajReplayBuffer): 10 | """ 11 | A replay buffer desired specifically for OneCharMem 12 | sub-trajectories 13 | """ 14 | 15 | def __init__( 16 | self, 17 | max_replay_buffer_size, 18 | env, 19 | subtraj_length, 20 | *args, 21 | **kwargs 22 | ): 23 | # TODO(vitchyr): Move this logic to environment 24 | self._target_numbers = np.zeros(max_replay_buffer_size, dtype='uint8') 25 | self._times = np.zeros(max_replay_buffer_size, dtype='uint8') 26 | super().__init__( 27 | max_replay_buffer_size, 28 | env, 29 | subtraj_length, 30 | *args, 31 | only_sample_at_start_of_episode=False, 32 | **kwargs 33 | ) 34 | 35 | def _add_sample(self, observation, action, reward, terminal, 36 | final_state, agent_info=None, env_info=None): 37 | if env_info is not None: 38 | if 'target_number' in env_info: 39 | self._target_numbers[self._top] = env_info['target_number'] 40 | if 'time' in env_info: 41 | self._times[self._top] = env_info['time'] 42 | super()._add_sample( 43 | observation, 44 | action, 45 | reward, 46 | terminal, 47 | final_state 48 | ) 49 | 50 | def _get_trajectories(self, start_indices): 51 | trajs = super()._get_trajectories(start_indices) 52 | trajs['target_numbers'] = subsequences( 53 | self._target_numbers, 54 | start_indices, 55 | self._subtraj_length, 56 | ) 57 | trajs['times'] = subsequences( 58 | self._times, 59 | start_indices, 60 | self._subtraj_length, 61 | ) 62 | return trajs 63 | -------------------------------------------------------------------------------- /rlkit/torch/sac/policies/base.py: -------------------------------------------------------------------------------- 1 | import abc 2 | import logging 3 | 4 | import numpy as np 5 | import torch 6 | import torch.nn.functional as F 7 | from torch import nn as nn 8 | 9 | import rlkit.torch.pytorch_util as ptu 10 | from rlkit.policies.base import ExplorationPolicy 11 | from rlkit.torch.core import torch_ify, elem_or_tuple_to_numpy 12 | from rlkit.torch.distributions import ( 13 | Delta, TanhNormal, MultivariateDiagonalNormal, GaussianMixture, GaussianMixtureFull, 14 | ) 15 | from rlkit.torch.networks import Mlp, CNN 16 | from rlkit.torch.networks.basic import MultiInputSequential 17 | from rlkit.torch.networks.stochastic.distribution_generator import ( 18 | DistributionGenerator 19 | ) 20 | 21 | 22 | class TorchStochasticPolicy( 23 | DistributionGenerator, 24 | ExplorationPolicy, metaclass=abc.ABCMeta 25 | ): 26 | def get_action(self, obs_np, ): 27 | actions = self.get_actions(obs_np[None]) 28 | return actions[0, :], {} 29 | 30 | def get_actions(self, obs_np, ): 31 | dist = self._get_dist_from_np(obs_np) 32 | actions = dist.sample() 33 | return elem_or_tuple_to_numpy(actions) 34 | 35 | def _get_dist_from_np(self, *args, **kwargs): 36 | torch_args = tuple(torch_ify(x) for x in args) 37 | torch_kwargs = {k: torch_ify(v) for k, v in kwargs.items()} 38 | dist = self(*torch_args, **torch_kwargs) 39 | return dist 40 | 41 | 42 | class PolicyFromDistributionGenerator( 43 | MultiInputSequential, 44 | TorchStochasticPolicy, 45 | ): 46 | """ 47 | Usage: 48 | ``` 49 | distribution_generator = FancyGenerativeModel() 50 | policy = PolicyFromBatchDistributionModule(distribution_generator) 51 | ``` 52 | """ 53 | pass 54 | 55 | 56 | class MakeDeterministic(TorchStochasticPolicy): 57 | def __init__( 58 | self, 59 | action_distribution_generator: DistributionGenerator, 60 | ): 61 | super().__init__() 62 | self._action_distribution_generator = action_distribution_generator 63 | 64 | def forward(self, *args, **kwargs): 65 | dist = self._action_distribution_generator.forward(*args, **kwargs) 66 | return Delta(dist.mle_estimate()) 67 | -------------------------------------------------------------------------------- /rlkit/envs/mujoco/ant.py: -------------------------------------------------------------------------------- 1 | """ 2 | Exact same as gym env, except that the gear ratio is 30 rather than 150. 3 | """ 4 | import numpy as np 5 | 6 | from gym.envs.mujoco import MujocoEnv 7 | 8 | from rlkit.envs.env_utils import get_asset_full_path 9 | 10 | 11 | class AntEnv(MujocoEnv): 12 | def __init__(self, use_low_gear_ratio=True): 13 | if use_low_gear_ratio: 14 | xml_path = 'low_gear_ratio_ant.xml' 15 | else: 16 | xml_path = 'normal_gear_ratio_ant.xml' 17 | super().__init__( 18 | get_asset_full_path(xml_path), 19 | frame_skip=5, 20 | ) 21 | 22 | def step(self, a): 23 | torso_xyz_before = self.get_body_com("torso") 24 | self.do_simulation(a, self.frame_skip) 25 | torso_xyz_after = self.get_body_com("torso") 26 | torso_velocity = torso_xyz_after - torso_xyz_before 27 | forward_reward = torso_velocity[0]/self.dt 28 | ctrl_cost = .5 * np.square(a).sum() 29 | contact_cost = 0.5 * 1e-3 * np.sum( 30 | np.square(np.clip(self.sim.data.cfrc_ext, -1, 1))) 31 | survive_reward = 1.0 32 | reward = forward_reward - ctrl_cost - contact_cost + survive_reward 33 | state = self.state_vector() 34 | notdone = np.isfinite(state).all() \ 35 | and state[2] >= 0.2 and state[2] <= 1.0 36 | done = not notdone 37 | ob = self._get_obs() 38 | return ob, reward, done, dict( 39 | reward_forward=forward_reward, 40 | reward_ctrl=-ctrl_cost, 41 | reward_contact=-contact_cost, 42 | reward_survive=survive_reward, 43 | torso_velocity=torso_velocity, 44 | ) 45 | 46 | def _get_obs(self): 47 | return np.concatenate([ 48 | self.sim.data.qpos.flat[2:], 49 | self.sim.data.qvel.flat, 50 | ]) 51 | 52 | def reset_model(self): 53 | qpos = self.init_qpos + self.np_random.uniform(size=self.model.nq, low=-.1, high=.1) 54 | qvel = self.init_qvel + self.np_random.randn(self.model.nv) * .1 55 | self.set_state(qpos, qvel) 56 | return self._get_obs() 57 | 58 | def viewer_setup(self): 59 | self.viewer.cam.distance = self.model.stat.extent * 0.5 60 | -------------------------------------------------------------------------------- /rlkit/envs/assets/small_water_maze.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 44 | -------------------------------------------------------------------------------- /rlkit/envs/memory/hidden_cartpole.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | 3 | import numpy as np 4 | from cached_property import cached_property 5 | 6 | from rlkit.core.eval_util import create_stats_ordered_dict 7 | from rlkit.samplers.util import split_paths 8 | from rllab.envs.box2d.cartpole_env import CartpoleEnv 9 | from rlkit.core import logger 10 | from sandbox.rocky.tf.spaces import Box 11 | 12 | 13 | class HiddenCartpoleEnv(CartpoleEnv): 14 | def __init__(self, num_steps=100, position_only=True): 15 | assert position_only, "I only added position_only due to some weird " \ 16 | "serialization bug" 17 | CartpoleEnv.__init__(self, position_only=position_only) 18 | self.num_steps = num_steps 19 | 20 | @cached_property 21 | def action_space(self): 22 | return Box(super().action_space.low, 23 | super().action_space.high) 24 | 25 | @cached_property 26 | def observation_space(self): 27 | return Box(super().observation_space.low, 28 | super().observation_space.high) 29 | 30 | @property 31 | def horizon(self): 32 | return self.num_steps 33 | 34 | @staticmethod 35 | def get_extra_info_dict_from_batch(batch): 36 | return dict() 37 | 38 | @staticmethod 39 | def get_flattened_extra_info_dict_from_subsequence_batch(batch): 40 | return dict() 41 | 42 | @staticmethod 43 | def get_last_extra_info_dict_from_subsequence_batch(batch): 44 | return dict() 45 | 46 | def log_diagnostics(self, paths, **kwargs): 47 | list_of_rewards, terminals, obs, actions, next_obs = split_paths(paths) 48 | 49 | returns = [] 50 | for rewards in list_of_rewards: 51 | returns.append(np.sum(rewards)) 52 | last_statistics = OrderedDict() 53 | last_statistics.update(create_stats_ordered_dict( 54 | 'UndiscountedReturns', 55 | returns, 56 | )) 57 | last_statistics.update(create_stats_ordered_dict( 58 | 'Actions', 59 | actions, 60 | )) 61 | 62 | for key, value in last_statistics.items(): 63 | logger.record_tabular(key, value) 64 | return returns 65 | 66 | def is_current_done(self): 67 | return False 68 | -------------------------------------------------------------------------------- /rlkit/envs/mujoco/mujoco_env_murtaza.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from gym import error, spaces 4 | from gym.utils import seeding 5 | import numpy as np 6 | from os import path 7 | import gym 8 | import six 9 | 10 | try: 11 | import mujoco_py 12 | from mujoco_py.mjlib import mjlib 13 | except ImportError as e: 14 | raise error.DependencyNotInstalled("{}. (HINT: you need to install mujoco_py, and also perform the setup instructions here: https://github.com/openai/mujoco-py/.)".format(e)) 15 | 16 | class MujocoEnv(gym.Env): 17 | """Superclass for all MuJoCo environments. 18 | """ 19 | 20 | def __init__(self, model_path, frame_skip, init_viewer=False): 21 | if model_path.startswith("/"): 22 | fullpath = model_path 23 | else: 24 | fullpath = os.path.join(os.path.dirname(__file__), "assets", model_path) 25 | if not path.exists(fullpath): 26 | raise IOError("File %s does not exist" % fullpath) 27 | self.frame_skip = frame_skip 28 | self.model = mujoco_py.MjModel(fullpath) 29 | self.data = self.model.data 30 | self.viewer = None 31 | 32 | if init_viewer: 33 | self._viewer_bot.set_model(self.model) 34 | self._set_cam_position(self._viewer_bot, self.cam_pos) 35 | if init_viewer > 1: 36 | print("viewer 2 bring set up") 37 | self._viewer_bot2.set_model(self.model) 38 | self._set_cam_position(self._viewer_bot2, self.cam_pos2) 39 | 40 | self.metadata = { 41 | 'render.modes': ['human', 'rgb_array'], 42 | 'video.frames_per_second': int(np.round(1.0 / self.dt)) 43 | } 44 | 45 | self.init_qpos = self.model.data.qpos.ravel().copy() 46 | self.init_qvel = self.model.data.qvel.ravel().copy() 47 | observation, _reward, done, _info = self._step(np.zeros(self.model.nu)) 48 | assert not done 49 | self.obs_dim = observation.size 50 | 51 | bounds = self.model.actuator_ctrlrange.copy() 52 | low = bounds[:, 0] 53 | high = bounds[:, 1] 54 | self.action_space = spaces.Box(low, high) 55 | 56 | high = np.inf*np.ones(self.obs_dim) 57 | low = -high 58 | self.observation_space = spaces.Box(low, high) 59 | 60 | self._seed() -------------------------------------------------------------------------------- /rlkit/envs/pearl_envs/rand_param_envs/walker2d_rand_params.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from gym import utils 3 | from rlkit.envs.pearl_envs.rand_param_envs.base import RandomEnv 4 | 5 | 6 | class Walker2DRandParamsEnv(RandomEnv, utils.EzPickle): 7 | def __init__(self, log_scale_limit=3.0): 8 | RandomEnv.__init__(self, log_scale_limit, 'walker2d.xml', 5) 9 | utils.EzPickle.__init__(self) 10 | 11 | def _step(self, a): 12 | # import ipdb; ipdb.set_trace() 13 | # posbefore = self.model.data.qpos[0, 0] 14 | posbefore = self.sim.data.qpos[0] 15 | self.do_simulation(a, self.frame_skip) 16 | # posafter, height, ang = self.model.data.qpos[0:3, 0] 17 | posafter, height, ang = self.sim.data.qpos[0:3] 18 | alive_bonus = 1.0 19 | reward = ((posafter - posbefore) / self.dt) 20 | reward += alive_bonus 21 | reward -= 1e-3 * np.square(a).sum() 22 | done = not (height > 0.8 and height < 2.0 and 23 | ang > -1.0 and ang < 1.0) 24 | ob = self._get_obs() 25 | return ob, reward, done, {} 26 | 27 | def _get_obs(self): 28 | # qpos = self.model.data.qpos 29 | # qvel = self.model.data.qvel 30 | qpos = self.sim.data.qpos 31 | qvel = self.sim.data.qvel 32 | return np.concatenate([qpos[1:], np.clip(qvel, -10, 10)]).ravel() 33 | 34 | def reset_model(self): 35 | self.set_state( 36 | self.init_qpos + self.np_random.uniform(low=-.005, high=.005, size=self.model.nq), 37 | self.init_qvel + self.np_random.uniform(low=-.005, high=.005, size=self.model.nv) 38 | ) 39 | return self._get_obs() 40 | 41 | def viewer_setup(self): 42 | self.viewer.cam.trackbodyid = 2 43 | self.viewer.cam.distance = self.model.stat.extent * 0.5 44 | self.viewer.cam.lookat[2] += .8 45 | self.viewer.cam.elevation = -20 46 | 47 | if __name__ == "__main__": 48 | 49 | env = Walker2DRandParamsEnv() 50 | tasks = env.sample_tasks(40) 51 | while True: 52 | env.reset() 53 | env.set_task(np.random.choice(tasks)) 54 | print(env.model.body_mass) 55 | for _ in range(100): 56 | env.render() 57 | env.step(env.action_space.sample()) # take a random action 58 | 59 | -------------------------------------------------------------------------------- /rlkit/core/serializeable.py: -------------------------------------------------------------------------------- 1 | """ 2 | Based on rllab's serializable.py file 3 | https://github.com/rll/rllab 4 | """ 5 | 6 | import inspect 7 | import sys 8 | 9 | 10 | class Serializable(object): 11 | 12 | def __init__(self, *args, **kwargs): 13 | self.__args = args 14 | self.__kwargs = kwargs 15 | 16 | def quick_init(self, locals_): 17 | if getattr(self, "_serializable_initialized", False): 18 | return 19 | if sys.version_info >= (3, 0): 20 | spec = inspect.getfullargspec(self.__init__) 21 | # Exclude the first "self" parameter 22 | if spec.varkw: 23 | kwargs = locals_[spec.varkw].copy() 24 | else: 25 | kwargs = dict() 26 | if spec.kwonlyargs: 27 | for key in spec.kwonlyargs: 28 | kwargs[key] = locals_[key] 29 | else: 30 | spec = inspect.getargspec(self.__init__) 31 | if spec.keywords: 32 | kwargs = locals_[spec.keywords] 33 | else: 34 | kwargs = dict() 35 | if spec.varargs: 36 | varargs = locals_[spec.varargs] 37 | else: 38 | varargs = tuple() 39 | in_order_args = [locals_[arg] for arg in spec.args][1:] 40 | self.__args = tuple(in_order_args) + varargs 41 | self.__kwargs = kwargs 42 | setattr(self, "_serializable_initialized", True) 43 | 44 | def __getstate__(self): 45 | return {"__args": self.__args, "__kwargs": self.__kwargs} 46 | 47 | def __setstate__(self, d): 48 | # convert all __args to keyword-based arguments 49 | if sys.version_info >= (3, 0): 50 | spec = inspect.getfullargspec(self.__init__) 51 | else: 52 | spec = inspect.getargspec(self.__init__) 53 | in_order_args = spec.args[1:] 54 | out = type(self)(**dict(zip(in_order_args, d["__args"]), **d["__kwargs"])) 55 | self.__dict__.update(out.__dict__) 56 | 57 | @classmethod 58 | def clone(cls, obj, **kwargs): 59 | assert isinstance(obj, Serializable) 60 | d = obj.__getstate__() 61 | d["__kwargs"] = dict(d["__kwargs"], **kwargs) 62 | out = type(obj).__new__(type(obj)) 63 | out.__setstate__(d) 64 | return out 65 | -------------------------------------------------------------------------------- /rlkit/utils/image_util.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from copy import deepcopy 3 | import torch 4 | from PIL import Image 5 | 6 | from rlkit.util.augment_util import create_aug_stack 7 | 8 | class ImageAugment: 9 | 10 | def __init__(self, 11 | image_size=48, 12 | augment_order = ['RandomResizedCrop', 'ColorJitter'], 13 | augment_probability=0.95, 14 | augment_params={ 15 | 'RandomResizedCrop': dict( 16 | scale=(0.9, 1.0), 17 | ratio=(0.9, 1.1), 18 | ), 19 | 'ColorJitter': dict( 20 | brightness=(0.75, 1.25), 21 | contrast=(0.9, 1.1), 22 | saturation=(0.9, 1.1), 23 | hue=(-0.1, 0.1), 24 | ), 25 | 'RandomCrop': dict( 26 | padding=4, 27 | padding_mode='edge' 28 | ), 29 | }, 30 | ): 31 | self._image_size = image_size 32 | 33 | self.augment_stack = create_aug_stack( 34 | augment_order, augment_params, size=(self._image_size, self._image_size) 35 | ) 36 | self.augment_probability = augment_probability 37 | 38 | def set_augment_params(self, img): 39 | if torch.rand(1) < self.augment_probability: 40 | self.augment_stack.set_params(img) 41 | else: 42 | self.augment_stack.set_default_params(img) 43 | 44 | def augment(self, img): 45 | img = self.augment_stack(img) 46 | return img 47 | 48 | def __call__(self, images, already_tranformed=True): 49 | if len(images.shape) == 4: 50 | batched = True 51 | elif len(images.shape) == 3: 52 | batched = False 53 | else: 54 | raise ValueError 55 | 56 | if already_tranformed: 57 | images += .5 58 | 59 | if self.augment_probability > 0: 60 | self.set_augment_params(images) 61 | images = self.augment(images) 62 | 63 | if already_tranformed: 64 | images -= 0.5 65 | 66 | if not batched: 67 | images = images[0] 68 | 69 | return images 70 | -------------------------------------------------------------------------------- /rlkit/torch/core.py: -------------------------------------------------------------------------------- 1 | import abc 2 | 3 | import numpy as np 4 | import torch 5 | from torch import nn as nn 6 | 7 | from rlkit.torch import pytorch_util as ptu 8 | 9 | 10 | class PyTorchModule(nn.Module, metaclass=abc.ABCMeta): 11 | """ 12 | Keeping wrapper around to be a bit more future-proof. 13 | """ 14 | pass 15 | 16 | 17 | def eval_np(module, *args, **kwargs): 18 | """ 19 | Eval this module with a numpy interface 20 | 21 | Same as a call to __call__ except all Variable input/outputs are 22 | replaced with numpy equivalents. 23 | 24 | Assumes the output is either a single object or a tuple of objects. 25 | """ 26 | torch_args = tuple(torch_ify(x) for x in args) 27 | torch_kwargs = {k: torch_ify(v) for k, v in kwargs.items()} 28 | outputs = module(*torch_args, **torch_kwargs) 29 | return elem_or_tuple_to_numpy(outputs) 30 | 31 | 32 | def torch_ify(np_array_or_other): 33 | if isinstance(np_array_or_other, np.ndarray): 34 | return ptu.from_numpy(np_array_or_other) 35 | else: 36 | return np_array_or_other 37 | 38 | 39 | def np_ify(tensor_or_other): 40 | if isinstance(tensor_or_other, torch.autograd.Variable): 41 | return ptu.get_numpy(tensor_or_other) 42 | else: 43 | return tensor_or_other 44 | 45 | 46 | def _elem_or_tuple_to_variable(elem_or_tuple): 47 | if isinstance(elem_or_tuple, tuple): 48 | return tuple( 49 | _elem_or_tuple_to_variable(e) for e in elem_or_tuple 50 | ) 51 | return ptu.from_numpy(elem_or_tuple).float() 52 | 53 | 54 | def elem_or_tuple_to_numpy(elem_or_tuple): 55 | if isinstance(elem_or_tuple, tuple): 56 | return tuple(np_ify(x) for x in elem_or_tuple) 57 | else: 58 | return np_ify(elem_or_tuple) 59 | 60 | 61 | def _filter_batch(np_batch): 62 | for k, v in np_batch.items(): 63 | if v.dtype == np.bool: 64 | yield k, v.astype(int) 65 | else: 66 | yield k, v 67 | 68 | 69 | def np_to_pytorch_batch(np_batch): 70 | if isinstance(np_batch, dict): 71 | return { 72 | k: _elem_or_tuple_to_variable(x) 73 | for k, x in _filter_batch(np_batch) 74 | if x.dtype != np.dtype('O') # ignore object (e.g. dictionaries) 75 | } 76 | else: 77 | _elem_or_tuple_to_variable(np_batch) 78 | -------------------------------------------------------------------------------- /rlkit/exploration_strategies/ou_strategy.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import numpy.random as nr 3 | 4 | from rlkit.exploration_strategies.base import RawExplorationStrategy 5 | 6 | 7 | class OUStrategy(RawExplorationStrategy): 8 | """ 9 | This strategy implements the Ornstein-Uhlenbeck process, which adds 10 | time-correlated noise to the actions taken by the deterministic policy. 11 | The OU process satisfies the following stochastic differential equation: 12 | dxt = theta*(mu - xt)*dt + sigma*dWt 13 | where Wt denotes the Wiener process 14 | 15 | Based on the rllab implementation. 16 | """ 17 | 18 | def __init__( 19 | self, 20 | action_space, 21 | mu=0, 22 | theta=0.15, 23 | max_sigma=0.3, 24 | min_sigma=None, 25 | decay_period=100000, 26 | ): 27 | if min_sigma is None: 28 | min_sigma = max_sigma 29 | self.mu = mu 30 | self.theta = theta 31 | self.sigma = max_sigma 32 | self._max_sigma = max_sigma 33 | if min_sigma is None: 34 | min_sigma = max_sigma 35 | self._min_sigma = min_sigma 36 | self._decay_period = decay_period 37 | self.dim = np.prod(action_space.low.shape) 38 | self.low = action_space.low 39 | self.high = action_space.high 40 | self.state = np.ones(self.dim) * self.mu 41 | self.reset() 42 | 43 | def reset(self): 44 | self.state = np.ones(self.dim) * self.mu 45 | 46 | def evolve_state(self): 47 | x = self.state 48 | dx = self.theta * (self.mu - x) + self.sigma * nr.randn(len(x)) 49 | self.state = x + dx 50 | return self.state 51 | 52 | def get_action_from_raw_action(self, action, t=0, **kwargs): 53 | ou_state = self.evolve_state() 54 | self.sigma = ( 55 | self._max_sigma 56 | - (self._max_sigma - self._min_sigma) 57 | * min(1.0, t * 1.0 / self._decay_period) 58 | ) 59 | return np.clip(action + ou_state, self.low, self.high) 60 | 61 | def get_actions_from_raw_actions(self, actions, t=0, **kwargs): 62 | noise = ( 63 | self.state + self.theta * (self.mu - self.state) 64 | + self.sigma * nr.randn(*actions.shape) 65 | ) 66 | return np.clip(actions + noise, self.low, self.high) 67 | -------------------------------------------------------------------------------- /rlkit/envs/wrappers/normalized_box_env.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from gym.spaces import Box 3 | 4 | from rlkit.envs.proxy_env import ProxyEnv 5 | 6 | 7 | class NormalizedBoxEnv(ProxyEnv): 8 | """ 9 | Normalize action to in [-1, 1]. 10 | 11 | Optionally normalize observations and scale reward. 12 | """ 13 | 14 | def __init__( 15 | self, 16 | env, 17 | reward_scale=1., 18 | obs_mean=None, 19 | obs_std=None, 20 | ): 21 | ProxyEnv.__init__(self, env) 22 | self._should_normalize = not (obs_mean is None and obs_std is None) 23 | if self._should_normalize: 24 | if obs_mean is None: 25 | obs_mean = np.zeros_like(env.observation_space.low) 26 | else: 27 | obs_mean = np.array(obs_mean) 28 | if obs_std is None: 29 | obs_std = np.ones_like(env.observation_space.low) 30 | else: 31 | obs_std = np.array(obs_std) 32 | self._reward_scale = reward_scale 33 | self._obs_mean = obs_mean 34 | self._obs_std = obs_std 35 | ub = np.ones(self._wrapped_env.action_space.shape) 36 | self.action_space = Box(-1 * ub, ub) 37 | 38 | def estimate_obs_stats(self, obs_batch, override_values=False): 39 | if self._obs_mean is not None and not override_values: 40 | raise Exception("Observation mean and std already set. To " 41 | "override, set override_values to True.") 42 | self._obs_mean = np.mean(obs_batch, axis=0) 43 | self._obs_std = np.std(obs_batch, axis=0) 44 | 45 | def _apply_normalize_obs(self, obs): 46 | return (obs - self._obs_mean) / (self._obs_std + 1e-8) 47 | 48 | def step(self, action): 49 | lb = self._wrapped_env.action_space.low 50 | ub = self._wrapped_env.action_space.high 51 | scaled_action = lb + (action + 1.) * 0.5 * (ub - lb) 52 | scaled_action = np.clip(scaled_action, lb, ub) 53 | 54 | wrapped_step = self._wrapped_env.step(scaled_action) 55 | next_obs, reward, done, info = wrapped_step 56 | if self._should_normalize: 57 | next_obs = self._apply_normalize_obs(next_obs) 58 | return next_obs, reward * self._reward_scale, done, info 59 | 60 | def __str__(self): 61 | return "Normalized: %s" % self._wrapped_env 62 | 63 | 64 | -------------------------------------------------------------------------------- /rlkit/exploration_strategies/base.py: -------------------------------------------------------------------------------- 1 | import abc 2 | 3 | from rlkit.policies.base import ExplorationPolicy 4 | 5 | 6 | class ExplorationStrategy(object, metaclass=abc.ABCMeta): 7 | @abc.abstractmethod 8 | def get_action(self, t, observation, policy, **kwargs): 9 | pass 10 | 11 | @abc.abstractmethod 12 | def get_actions(self, t, observation, policy, **kwargs): 13 | pass 14 | 15 | def reset(self): 16 | pass 17 | 18 | 19 | class RawExplorationStrategy(ExplorationStrategy, metaclass=abc.ABCMeta): 20 | @abc.abstractmethod 21 | def get_action_from_raw_action(self, action, **kwargs): 22 | pass 23 | 24 | def get_actions_from_raw_actions(self, actions, **kwargs): 25 | raise NotImplementedError() 26 | 27 | def get_action(self, t, policy, *args, **kwargs): 28 | action, agent_info = policy.get_action(*args, **kwargs) 29 | return self.get_action_from_raw_action(action, t=t), agent_info 30 | 31 | def get_actions(self, t, observation, policy, **kwargs): 32 | actions = policy.get_actions(observation) 33 | return self.get_actions_from_raw_actions(actions, **kwargs) 34 | 35 | def reset(self): 36 | pass 37 | 38 | 39 | class PolicyWrappedWithExplorationStrategy(ExplorationPolicy): 40 | def __init__( 41 | self, 42 | exploration_strategy: ExplorationStrategy, 43 | policy, 44 | ): 45 | self.es = exploration_strategy 46 | self.policy = policy 47 | self.t = 0 48 | 49 | def set_num_steps_total(self, t): 50 | self.t = t 51 | 52 | def get_action(self, *args, **kwargs): 53 | return self.es.get_action(self.t, self.policy, *args, **kwargs) 54 | 55 | def get_actions(self, *args, **kwargs): 56 | return self.es.get_actions(self.t, self.policy, *args, **kwargs) 57 | 58 | def reset(self): 59 | self.es.reset() 60 | self.policy.reset() 61 | 62 | def get_param_values(self): 63 | return self.policy.get_param_values() 64 | 65 | def set_param_values(self, param_values): 66 | self.policy.set_param_values(param_values) 67 | 68 | def get_param_values_np(self): 69 | return self.policy.get_param_values_np() 70 | 71 | def set_param_values_np(self, param_values): 72 | self.policy.set_param_values_np(param_values) 73 | 74 | def to(self, device): 75 | self.policy.to(device) 76 | -------------------------------------------------------------------------------- /rlkit/envs/mujoco/mujoco_env.py: -------------------------------------------------------------------------------- 1 | import os 2 | from os import path 3 | 4 | import mujoco_py 5 | import numpy as np 6 | from gym.envs.mujoco import mujoco_env 7 | 8 | from rlkit.envs.env_utils import get_asset_full_path 9 | 10 | 11 | class MujocoEnv(mujoco_env.MujocoEnv): 12 | """ 13 | My own wrapper around MujocoEnv. 14 | 15 | The caller needs to declare 16 | """ 17 | def __init__( 18 | self, 19 | model_path, 20 | frame_skip=1, 21 | model_path_is_local=True, 22 | automatically_set_obs_and_action_space=False, 23 | ): 24 | if model_path_is_local: 25 | model_path = get_asset_full_path(model_path) 26 | if automatically_set_obs_and_action_space: 27 | mujoco_env.MujocoEnv.__init__(self, model_path, frame_skip) 28 | else: 29 | """ 30 | Code below is copy/pasted from MujocoEnv's __init__ function. 31 | """ 32 | if model_path.startswith("/"): 33 | fullpath = model_path 34 | else: 35 | fullpath = os.path.join(os.path.dirname(__file__), "assets", model_path) 36 | if not path.exists(fullpath): 37 | raise IOError("File %s does not exist" % fullpath) 38 | self.frame_skip = frame_skip 39 | self.model = mujoco_py.load_model_from_path(fullpath) 40 | self.sim = mujoco_py.MjSim(self.model) 41 | self.data = self.sim.data 42 | self.viewer = None 43 | 44 | self.metadata = { 45 | 'render.modes': ['human', 'rgb_array'], 46 | 'video.frames_per_second': int(np.round(1.0 / self.dt)) 47 | } 48 | 49 | self.init_qpos = self.sim.data.qpos.ravel().copy() 50 | self.init_qvel = self.sim.data.qvel.ravel().copy() 51 | observation, _reward, done, _info = self.step(np.zeros(self.model.nu)) 52 | assert not done 53 | self.obs_dim = observation.size 54 | 55 | bounds = self.model.actuator_ctrlrange.copy() 56 | low = bounds[:, 0] 57 | high = bounds[:, 1] 58 | self.action_space = spaces.Box(low=low, high=high) 59 | 60 | high = np.inf*np.ones(self.obs_dim) 61 | low = -high 62 | self.observation_space = spaces.Box(low, high) 63 | 64 | self.seed() 65 | 66 | def log_diagnostics(self, paths): 67 | pass 68 | -------------------------------------------------------------------------------- /rlkit/envs/pearl_envs/__init__.py: -------------------------------------------------------------------------------- 1 | from rlkit.envs.pearl_envs.ant_normal import AntNormal 2 | from rlkit.envs.pearl_envs.ant_dir import AntDirEnv 3 | from rlkit.envs.pearl_envs.ant_goal import AntGoalEnv 4 | from rlkit.envs.pearl_envs.half_cheetah_dir import HalfCheetahDirEnv 5 | from rlkit.envs.pearl_envs.half_cheetah_vel import HalfCheetahVelEnv 6 | from rlkit.envs.pearl_envs.hopper_rand_params_wrapper import \ 7 | HopperRandParamsWrappedEnv 8 | from rlkit.envs.pearl_envs.humanoid_dir import HumanoidDirEnv 9 | from rlkit.envs.pearl_envs.point_robot import PointEnv, SparsePointEnv 10 | from rlkit.envs.pearl_envs.rand_param_envs.walker2d_rand_params import \ 11 | Walker2DRandParamsEnv 12 | from rlkit.envs.pearl_envs.walker_rand_params_wrapper import \ 13 | WalkerRandParamsWrappedEnv 14 | 15 | ENVS = {} 16 | 17 | 18 | def register_env(name): 19 | """Registers a env by name for instantiation in rlkit.""" 20 | 21 | def register_env_fn(fn): 22 | if name in ENVS: 23 | raise ValueError("Cannot register duplicate env {}".format(name)) 24 | if not callable(fn): 25 | raise TypeError("env {} must be callable".format(name)) 26 | ENVS[name] = fn 27 | return fn 28 | 29 | return register_env_fn 30 | 31 | def _register_env(name, fn): 32 | """Registers a env by name for instantiation in rlkit.""" 33 | if name in ENVS: 34 | raise ValueError("Cannot register duplicate env {}".format(name)) 35 | if not callable(fn): 36 | raise TypeError("env {} must be callable".format(name)) 37 | ENVS[name] = fn 38 | 39 | 40 | def register_pearl_envs(): 41 | _register_env('sparse-point-robot', SparsePointEnv) 42 | _register_env('ant-normal', AntNormal) 43 | _register_env('ant-dir', AntDirEnv) 44 | _register_env('ant-goal', AntGoalEnv) 45 | _register_env('cheetah-dir', HalfCheetahDirEnv) 46 | _register_env('cheetah-vel', HalfCheetahVelEnv) 47 | _register_env('humanoid-dir', HumanoidDirEnv) 48 | _register_env('point-robot', PointEnv) 49 | _register_env('walker-rand-params', WalkerRandParamsWrappedEnv) 50 | _register_env('hopper-rand-params', HopperRandParamsWrappedEnv) 51 | 52 | # automatically import any envs in the envs/ directory 53 | # for file in os.listdir(os.path.dirname(__file__)): 54 | # if file.endswith('.py') and not file.startswith('_'): 55 | # module = file[:file.find('.py')] 56 | # importlib.import_module('rlkit.envs.pearl_envs.' + module) 57 | -------------------------------------------------------------------------------- /rlkit/envs/make_env.py: -------------------------------------------------------------------------------- 1 | """ 2 | This file provides a more uniform interface to gym.make(env_id) that handles 3 | imports and normalization 4 | """ 5 | 6 | import gym 7 | 8 | from rlkit.envs.wrappers.normalized_box_env import NormalizedBoxEnv 9 | 10 | DAPG_ENVS = [ 11 | 'pen-v0', 'pen-sparse-v0', 'pen-notermination-v0', 'pen-binary-v0', 'pen-binary-old-v0', 12 | 'door-v0', 'door-sparse-v0', 'door-binary-v0', 'door-binary-old-v0', 13 | 'relocate-v0', 'relocate-sparse-v0', 'relocate-binary-v0', 'relocate-binary-old-v0', 14 | 'hammer-v0', 'hammer-sparse-v0', 'hammer-binary-v0', 15 | ] 16 | 17 | D4RL_ENVS = [ 18 | "maze2d-open-v0", "maze2d-umaze-v0", "maze2d-medium-v0", "maze2d-large-v0", 19 | "maze2d-open-dense-v0", "maze2d-umaze-dense-v0", "maze2d-medium-dense-v0", "maze2d-large-dense-v0", 20 | "antmaze-umaze-v0", "antmaze-umaze-diverse-v0", "antmaze-medium-diverse-v0", 21 | "antmaze-medium-play-v0", "antmaze-large-diverse-v0", "antmaze-large-play-v0", 22 | "pen-human-v0", "pen-cloned-v0", "pen-expert-v0", "hammer-human-v0", "hammer-cloned-v0", "hammer-expert-v0", 23 | "door-human-v0", "door-cloned-v0", "door-expert-v0", "relocate-human-v0", "relocate-cloned-v0", "relocate-expert-v0", 24 | "halfcheetah-random-v0", "halfcheetah-medium-v0", "halfcheetah-expert-v0", "halfcheetah-mixed-v0", "halfcheetah-medium-expert-v0", 25 | "walker2d-random-v0", "walker2d-medium-v0", "walker2d-expert-v0", "walker2d-mixed-v0", "walker2d-medium-expert-v0", 26 | "hopper-random-v0", "hopper-medium-v0", "hopper-expert-v0", "hopper-mixed-v0", "hopper-medium-expert-v0", 27 | # 'halfcheetah-expert-v2', 28 | 'halfcheetah-medium-v2', 29 | 'halfcheetah-medium-replay-v2', 30 | 'halfcheetah-medium-expert-v2', 31 | # 'hopper-expert-v2', 32 | 'hopper-medium-v2', 33 | 'hopper-medium-replay-v2', 34 | 'hopper-medium-expert-v2', 35 | # 'walker2d-expert-v2', 36 | 'walker2d-medium-v2', 37 | 'walker2d-medium-replay-v2', 38 | 'walker2d-medium-expert-v2', 39 | ] 40 | 41 | def make(env_id=None, env_class=None, env_kwargs=None, normalize_env=True): 42 | assert env_id or env_class 43 | 44 | if env_class: 45 | env = env_class(**env_kwargs) 46 | elif env_id in DAPG_ENVS: 47 | import mj_envs 48 | assert normalize_env == False 49 | env = gym.make(env_id) 50 | elif env_id in D4RL_ENVS: 51 | import d4rl 52 | assert normalize_env == False 53 | env = gym.make(env_id) 54 | elif env_id: 55 | env = gym.make(env_id) 56 | 57 | if normalize_env: 58 | env = NormalizedBoxEnv(env) 59 | 60 | return env 61 | -------------------------------------------------------------------------------- /rlkit/learning/online_offline_split_replay_buffer.py: -------------------------------------------------------------------------------- 1 | import random 2 | import numpy as np 3 | import torch 4 | 5 | import rlkit.data_management.images as image_np 6 | from rlkit import pythonplusplus as ppp 7 | 8 | # import time 9 | 10 | from rlkit.data_management.online_offline_split_replay_buffer import ( 11 | OnlineOfflineSplitReplayBuffer, 12 | ) 13 | 14 | 15 | def concat(*x): 16 | return np.concatenate(x, axis=0) 17 | 18 | 19 | class OnlineOfflineSplitReplayBuffer(OnlineOfflineSplitReplayBuffer): 20 | def __init__( 21 | self, 22 | offline_replay_buffer, 23 | online_replay_buffer, 24 | sample_online_fraction=None, 25 | online_mode=False, 26 | **kwargs 27 | ): 28 | super().__init__(offline_replay_buffer=offline_replay_buffer, 29 | online_replay_buffer=online_replay_buffer, 30 | sample_online_fraction=sample_online_fraction, 31 | online_mode=online_mode, 32 | **kwargs) 33 | 34 | def random_batch(self, batch_size): 35 | online_batch_size = min(int(self.sample_online_fraction * batch_size), 36 | self.online_replay_buffer.num_steps_can_sample()) 37 | offline_batch_size = batch_size - online_batch_size 38 | online_batch = self.online_replay_buffer.random_batch(online_batch_size) 39 | offline_batch = self.offline_replay_buffer.random_batch(offline_batch_size) 40 | 41 | # torch.cuda.synchronize() 42 | # start_time = time.time() 43 | # batch = dict() 44 | # for (key, online_batch_value) in online_batch.items(): 45 | # assert key in offline_batch 46 | # offline_batch_value = offline_batch[key] 47 | # if key == 'indices': 48 | # batch['online_indices'] = online_batch_value 49 | # batch['offline_indicies'] = offline_batch_value 50 | # else: 51 | # batch[key] = np.concatenate((online_batch_value, offline_batch_value), axis=0) 52 | # 53 | # if batch[key].dtype == np.uint8: 54 | # batch[key] = image_np.normalize_image(batch[key], dtype=np.float32) 55 | batch = ppp.treemap( 56 | concat, 57 | online_batch, 58 | offline_batch, 59 | atomic_type=np.ndarray) 60 | 61 | # torch.cuda.synchronize() 62 | # end_time = time.time() 63 | # print("Time to concatenate offline and online data: {} secs".format(end_time - start_time)) 64 | 65 | return batch 66 | -------------------------------------------------------------------------------- /rlkit/envs/pearl_envs/ant.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from multiworld.envs.mujoco.cameras import create_camera_init 4 | from .mujoco_env import MujocoEnv 5 | 6 | 7 | class AntEnv(MujocoEnv): 8 | def __init__(self, use_low_gear_ratio=False): 9 | # self.init_serialization(locals()) 10 | if use_low_gear_ratio: 11 | xml_path = 'low_gear_ratio_ant.xml' 12 | else: 13 | xml_path = 'ant.xml' 14 | super().__init__( 15 | xml_path, 16 | frame_skip=5, 17 | automatically_set_obs_and_action_space=True, 18 | ) 19 | 20 | def step(self, a): 21 | torso_xyz_before = self.get_body_com("torso") 22 | self.do_simulation(a, self.frame_skip) 23 | torso_xyz_after = self.get_body_com("torso") 24 | torso_velocity = torso_xyz_after - torso_xyz_before 25 | forward_reward = torso_velocity[0]/self.dt 26 | ctrl_cost = 0. # .5 * np.square(a).sum() 27 | contact_cost = 0.5 * 1e-3 * np.sum( 28 | np.square(np.clip(self.sim.data.cfrc_ext, -1, 1))) 29 | survive_reward = 0. # 1.0 30 | reward = forward_reward - ctrl_cost - contact_cost + survive_reward 31 | state = self.state_vector() 32 | notdone = np.isfinite(state).all() \ 33 | and state[2] >= 0.2 and state[2] <= 1.0 34 | done = not notdone 35 | ob = self._get_obs() 36 | return ob, reward, done, dict( 37 | reward_forward=forward_reward, 38 | reward_ctrl=-ctrl_cost, 39 | reward_contact=-contact_cost, 40 | reward_survive=survive_reward, 41 | torso_velocity=torso_velocity, 42 | ) 43 | 44 | def _get_obs(self): 45 | # this is gym ant obs, should use rllab? 46 | # if position is needed, override this in subclasses 47 | return np.concatenate([ 48 | self.sim.data.qpos.flat[2:], 49 | self.sim.data.qvel.flat, 50 | ]) 51 | 52 | def reset_model(self): 53 | qpos = self.init_qpos + self.np_random.uniform(size=self.model.nq, low=-.1, high=.1) 54 | qvel = self.init_qvel + self.np_random.randn(self.model.nv) * .1 55 | self.set_state(qpos, qvel) 56 | return self._get_obs() 57 | 58 | def viewer_setup(self): 59 | self.camera_init = create_camera_init( 60 | lookat=(0, 0, 0), 61 | distance=10, 62 | elevation=-45, 63 | trackbodyid=self.sim.model.body_name2id('torso'), 64 | ) 65 | self.camera_init(self.viewer.cam) 66 | # self.viewer.cam.distance = self.model.stat.extent * 0.5 67 | # -------------------------------------------------------------------------------- /rlkit/envs/pearl_envs/humanoid_dir.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from gym.envs.mujoco import HumanoidEnv as HumanoidEnv 3 | 4 | 5 | def mass_center(model, sim): 6 | mass = np.expand_dims(model.body_mass, 1) 7 | xpos = sim.data.xipos 8 | return (np.sum(mass * xpos, 0) / np.sum(mass)) 9 | 10 | 11 | class HumanoidDirEnv(HumanoidEnv): 12 | 13 | def __init__(self, task={}, n_tasks=2, randomize_tasks=True): 14 | self.tasks = self.sample_tasks(n_tasks) 15 | self.reset_task(0) 16 | super(HumanoidDirEnv, self).__init__() 17 | 18 | def step(self, action): 19 | pos_before = np.copy(mass_center(self.model, self.sim)[:2]) 20 | self.do_simulation(action, self.frame_skip) 21 | pos_after = mass_center(self.model, self.sim)[:2] 22 | 23 | alive_bonus = 5.0 24 | data = self.sim.data 25 | goal_direction = (np.cos(self._goal), np.sin(self._goal)) 26 | lin_vel_cost = 0.25 * np.sum(goal_direction * (pos_after - pos_before)) / self.model.opt.timestep 27 | quad_ctrl_cost = 0.1 * np.square(data.ctrl).sum() 28 | quad_impact_cost = .5e-6 * np.square(data.cfrc_ext).sum() 29 | quad_impact_cost = min(quad_impact_cost, 10) 30 | reward = lin_vel_cost - quad_ctrl_cost - quad_impact_cost + alive_bonus 31 | qpos = self.sim.data.qpos 32 | done = bool((qpos[2] < 1.0) or (qpos[2] > 2.0)) 33 | 34 | return self._get_obs(), reward, done, dict(reward_linvel=lin_vel_cost, 35 | reward_quadctrl=-quad_ctrl_cost, 36 | reward_alive=alive_bonus, 37 | reward_impact=-quad_impact_cost) 38 | 39 | def _get_obs(self): 40 | data = self.sim.data 41 | return np.concatenate([data.qpos.flat[2:], 42 | data.qvel.flat, 43 | data.cinert.flat, 44 | data.cvel.flat, 45 | data.qfrc_actuator.flat, 46 | data.cfrc_ext.flat]) 47 | 48 | def get_all_task_idx(self): 49 | return range(len(self.tasks)) 50 | 51 | def reset_task(self, idx): 52 | self._task = self.tasks[idx] 53 | self._goal = self._task['goal'] # assume parameterization of task by single vector 54 | 55 | def sample_tasks(self, num_tasks): 56 | # velocities = np.random.uniform(0., 1.0 * np.pi, size=(num_tasks,)) 57 | directions = np.random.uniform(0., 2.0 * np.pi, size=(num_tasks,)) 58 | tasks = [{'goal': d} for d in directions] 59 | return tasks 60 | -------------------------------------------------------------------------------- /rlkit/envs/pearl_envs/half_cheetah_vel.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from .half_cheetah import HalfCheetahEnv 4 | 5 | 6 | class HalfCheetahVelEnv(HalfCheetahEnv): 7 | """Half-cheetah environment with target velocity, as described in [1]. The 8 | code is adapted from 9 | https://github.com/cbfinn/maml_rl/blob/9c8e2ebd741cb0c7b8bf2d040c4caeeb8e06cc95/rllab/envs/mujoco/half_cheetah_env_rand.py 10 | 11 | The half-cheetah follows the dynamics from MuJoCo [2], and receives at each 12 | time step a reward composed of a control cost and a penalty equal to the 13 | difference between its current velocity and the target velocity. The tasks 14 | are generated by sampling the target velocities from the uniform 15 | distribution on [0, 2]. 16 | 17 | [1] Chelsea Finn, Pieter Abbeel, Sergey Levine, "Model-Agnostic 18 | Meta-Learning for Fast Adaptation of Deep Networks", 2017 19 | (https://arxiv.org/abs/1703.03400) 20 | [2] Emanuel Todorov, Tom Erez, Yuval Tassa, "MuJoCo: A physics engine for 21 | model-based control", 2012 22 | (https://homes.cs.washington.edu/~todorov/papers/TodorovIROS12.pdf) 23 | """ 24 | def __init__(self, task={}, n_tasks=2, randomize_tasks=True): 25 | self._task = task 26 | self.tasks = self.sample_tasks(n_tasks) 27 | self._goal_vel = self.tasks[0].get('velocity', 0.0) 28 | self._goal = self._goal_vel 29 | super(HalfCheetahVelEnv, self).__init__() 30 | 31 | def step(self, action): 32 | xposbefore = self.sim.data.qpos[0] 33 | self.do_simulation(action, self.frame_skip) 34 | xposafter = self.sim.data.qpos[0] 35 | 36 | forward_vel = (xposafter - xposbefore) / self.dt 37 | forward_reward = -1.0 * abs(forward_vel - self._goal_vel) 38 | ctrl_cost = 0.5 * 1e-1 * np.sum(np.square(action)) 39 | 40 | observation = self._get_obs() 41 | reward = forward_reward - ctrl_cost 42 | done = False 43 | infos = dict(reward_forward=forward_reward, 44 | reward_ctrl=-ctrl_cost, task=self._task) 45 | return (observation, reward, done, infos) 46 | 47 | def sample_tasks(self, num_tasks): 48 | np.random.seed(1337) 49 | velocities = np.random.uniform(0.0, 3.0, size=(num_tasks,)) 50 | tasks = [{'velocity': velocity} for velocity in velocities] 51 | return tasks 52 | 53 | def get_all_task_idx(self): 54 | return range(len(self.tasks)) 55 | 56 | def reset_task(self, idx): 57 | self._task = self.tasks[idx] 58 | self._goal_vel = self._task['velocity'] 59 | self._goal = self._goal_vel 60 | self.reset() 61 | -------------------------------------------------------------------------------- /rlkit/util/inspect_q_util.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import rlkit.visualization.visualization_util as vu 3 | import numpy as np 4 | from rlkit.torch.core import np_ify, torch_ify 5 | 6 | fig_v_mean = None 7 | fig_v_std = None 8 | axes_v_mean = None 9 | axes_v_std = None 10 | 11 | def debug_q(ensemble_qs, policy, show_mean=True, show_std=True): 12 | global fig_v_mean, fig_v_std, axes_v_mean, axes_v_std 13 | 14 | if fig_v_mean is None: 15 | if show_mean: 16 | fig_v_mean, axes_v_mean = plt.subplots(3, 3, sharex='all', sharey='all', figsize=(9, 9)) 17 | fig_v_mean.canvas.set_window_title('Q Mean') 18 | # plt.suptitle("V Mean") 19 | if show_std: 20 | fig_v_std, axes_v_std = plt.subplots(3, 3, sharex='all', sharey='all', figsize=(9, 9)) 21 | fig_v_std.canvas.set_window_title('Q Std') 22 | # plt.suptitle("V Std") 23 | 24 | # obss = [(0, 0), (0, 0.75), (0, 1.75), (0, 1.25), (0, 4), (-4, 4), (4, -4), (4, 4), (-4, 4)] 25 | obss = [] 26 | for x in [-3, 0, 3]: 27 | for y in [3, 0, -3]: 28 | obss.append((x, y)) 29 | 30 | def create_eval_function(q, obs, ): 31 | def beta_eval(goals): 32 | # goals = np.array([[ 33 | # *goal 34 | # ]]) 35 | N = len(goals) 36 | observations = np.tile(obs, (N, 1)) 37 | new_obs = np.hstack((observations, goals)) 38 | actions = torch_ify(policy.get_action(new_obs)[0]) 39 | return np_ify(q(torch_ify(new_obs), actions)).flatten() 40 | return beta_eval 41 | 42 | for o in range(9): 43 | i = o % 3 44 | j = o // 3 45 | H = [] 46 | for b in range(5): 47 | q = ensemble_qs[b] 48 | 49 | rng = [-4, 4] 50 | resolution = 30 51 | 52 | obs = obss[o] 53 | 54 | heatmap = vu.make_heat_map(create_eval_function(q, obs, ), rng, rng, resolution=resolution, batch=True) 55 | H.append(heatmap.values) 56 | p, x, y, _ = heatmap 57 | if show_mean: 58 | h1 = vu.HeatMap(np.mean(H, axis=0), x, y, _) 59 | vu.plot_heatmap(h1, ax=axes_v_mean[i, j]) 60 | axes_v_mean[i, j].set_title("pos " + str(obs)) 61 | 62 | if show_std: 63 | h2 = vu.HeatMap(np.std(H, axis=0), x, y, _) 64 | vu.plot_heatmap(h2, ax=axes_v_std[i, j]) 65 | axes_v_std[i, j].set_title("pos " + str(obs)) 66 | 67 | # axes_v_mean[i, j].plot(range(100), range(100)) 68 | # axes_v_std[i, j].plot(range(100), range(100)) 69 | 70 | plt.draw() 71 | plt.pause(0.01) 72 | -------------------------------------------------------------------------------- /rlkit/envs/pearl_envs/half_cheetah_dir.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from .half_cheetah import HalfCheetahEnv 4 | 5 | 6 | class HalfCheetahDirEnv(HalfCheetahEnv): 7 | """Half-cheetah environment with target direction, as described in [1]. The 8 | code is adapted from 9 | https://github.com/cbfinn/maml_rl/blob/9c8e2ebd741cb0c7b8bf2d040c4caeeb8e06cc95/rllab/envs/mujoco/half_cheetah_env_rand_direc.py 10 | 11 | The half-cheetah follows the dynamics from MuJoCo [2], and receives at each 12 | time step a reward composed of a control cost and a reward equal to its 13 | velocity in the target direction. The tasks are generated by sampling the 14 | target directions from a Bernoulli distribution on {-1, 1} with parameter 15 | 0.5 (-1: backward, +1: forward). 16 | 17 | [1] Chelsea Finn, Pieter Abbeel, Sergey Levine, "Model-Agnostic 18 | Meta-Learning for Fast Adaptation of Deep Networks", 2017 19 | (https://arxiv.org/abs/1703.03400) 20 | [2] Emanuel Todorov, Tom Erez, Yuval Tassa, "MuJoCo: A physics engine for 21 | model-based control", 2012 22 | (https://homes.cs.washington.edu/~todorov/papers/TodorovIROS12.pdf) 23 | """ 24 | def __init__(self, task={}, n_tasks=2, randomize_tasks=False): 25 | directions = [-1, 1] 26 | self.tasks = [{'direction': direction} for direction in directions] 27 | self._task = task 28 | self._goal_dir = task.get('direction', 1) 29 | self._goal = self._goal_dir 30 | super(HalfCheetahDirEnv, self).__init__() 31 | 32 | def step(self, action): 33 | xposbefore = self.sim.data.qpos[0] 34 | self.do_simulation(action, self.frame_skip) 35 | xposafter = self.sim.data.qpos[0] 36 | 37 | forward_vel = (xposafter - xposbefore) / self.dt 38 | forward_reward = self._goal_dir * forward_vel 39 | ctrl_cost = 0.5 * 1e-1 * np.sum(np.square(action)) 40 | 41 | observation = self._get_obs() 42 | reward = forward_reward - ctrl_cost 43 | done = False 44 | infos = dict(reward_forward=forward_reward, 45 | reward_ctrl=-ctrl_cost, task=self._task) 46 | return (observation, reward, done, infos) 47 | 48 | def sample_tasks(self, num_tasks): 49 | directions = 2 * self.np_random.binomial(1, p=0.5, size=(num_tasks,)) - 1 50 | tasks = [{'direction': direction} for direction in directions] 51 | return tasks 52 | 53 | def get_all_task_idx(self): 54 | return list(range(len(self.tasks))) 55 | 56 | def reset_task(self, idx): 57 | self._task = self.tasks[idx] 58 | self._goal_dir = self._task['direction'] 59 | self._goal = self._goal_dir 60 | self.reset() 61 | -------------------------------------------------------------------------------- /rlkit/torch/networks/two_headed_mlp.py: -------------------------------------------------------------------------------- 1 | from torch import nn as nn 2 | from torch.nn import functional as F 3 | 4 | from rlkit.pythonplusplus import identity 5 | from rlkit.torch import pytorch_util as ptu 6 | from rlkit.torch.core import PyTorchModule 7 | from rlkit.torch.networks.experimental import LayerNorm 8 | 9 | 10 | class TwoHeadMlp(PyTorchModule): 11 | def __init__( 12 | self, 13 | hidden_sizes, 14 | first_head_size, 15 | second_head_size, 16 | input_size, 17 | init_w=3e-3, 18 | hidden_activation=F.relu, 19 | output_activation=identity, 20 | hidden_init=ptu.fanin_init, 21 | b_init_value=0., 22 | layer_norm=False, 23 | layer_norm_kwargs=None, 24 | ): 25 | super().__init__() 26 | 27 | if layer_norm_kwargs is None: 28 | layer_norm_kwargs = dict() 29 | 30 | self.input_size = input_size 31 | self.first_head_size = first_head_size 32 | self.second_head_size = second_head_size 33 | self.hidden_activation = hidden_activation 34 | self.output_activation = output_activation 35 | self.layer_norm = layer_norm 36 | self.fcs = [] 37 | self.layer_norms = [] 38 | in_size = input_size 39 | 40 | for i, next_size in enumerate(hidden_sizes): 41 | fc = nn.Linear(in_size, next_size) 42 | in_size = next_size 43 | hidden_init(fc.weight) 44 | fc.bias.data.fill_(b_init_value) 45 | self.__setattr__("fc{}".format(i), fc) 46 | self.fcs.append(fc) 47 | 48 | if self.layer_norm: 49 | ln = LayerNorm(next_size) 50 | self.__setattr__("layer_norm{}".format(i), ln) 51 | self.layer_norms.append(ln) 52 | 53 | self.first_head = nn.Linear(in_size, self.first_head_size) 54 | self.first_head.weight.data.uniform_(-init_w, init_w) 55 | 56 | self.second_head = nn.Linear(in_size, self.second_head_size) 57 | self.second_head.weight.data.uniform_(-init_w, init_w) 58 | 59 | def forward(self, input, return_preactivations=False): 60 | h = input 61 | for i, fc in enumerate(self.fcs): 62 | h = fc(h) 63 | if self.layer_norm and i < len(self.fcs) - 1: 64 | h = self.layer_norms[i](h) 65 | h = self.hidden_activation(h) 66 | preactivation = self.first_head(h) 67 | first_output = self.output_activation(preactivation) 68 | preactivation = self.second_head(h) 69 | second_output = self.output_activation(preactivation) 70 | 71 | return first_output, second_output 72 | -------------------------------------------------------------------------------- /rlkit/data_management/split_buffer.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | from rlkit.data_management.replay_buffer import ReplayBuffer 4 | 5 | 6 | class SplitReplayBuffer(ReplayBuffer): 7 | """ 8 | Split the data into a training and validation set. 9 | """ 10 | def __init__( 11 | self, 12 | train_replay_buffer: ReplayBuffer, 13 | validation_replay_buffer: ReplayBuffer, 14 | fraction_paths_in_train, 15 | ): 16 | self.train_replay_buffer = train_replay_buffer 17 | self.validation_replay_buffer = validation_replay_buffer 18 | self.fraction_paths_in_train = fraction_paths_in_train 19 | self.replay_buffer = self.train_replay_buffer 20 | 21 | def add_sample(self, *args, **kwargs): 22 | self.replay_buffer.add_sample(*args, **kwargs) 23 | 24 | def add_path(self, path): 25 | self.replay_buffer.add_path(path) 26 | self._randomly_set_replay_buffer() 27 | 28 | def num_steps_can_sample(self): 29 | return min( 30 | self.train_replay_buffer.num_steps_can_sample(), 31 | self.validation_replay_buffer.num_steps_can_sample(), 32 | ) 33 | 34 | def terminate_episode(self, *args, **kwargs): 35 | self.replay_buffer.terminate_episode(*args, **kwargs) 36 | self._randomly_set_replay_buffer() 37 | 38 | def _randomly_set_replay_buffer(self): 39 | if random.random() <= self.fraction_paths_in_train: 40 | self.replay_buffer = self.train_replay_buffer 41 | else: 42 | self.replay_buffer = self.validation_replay_buffer 43 | 44 | def get_replay_buffer(self, training=True): 45 | if training: 46 | return self.train_replay_buffer 47 | else: 48 | return self.validation_replay_buffer 49 | 50 | def random_batch(self, batch_size): 51 | return self.train_replay_buffer.random_batch(batch_size) 52 | 53 | def __getattr__(self, attrname): 54 | return getattr(self.replay_buffer, attrname) 55 | 56 | def __getstate__(self): 57 | # Do not save self.replay_buffer since it's a duplicate and seems to 58 | # cause joblib recursion issues. 59 | return dict( 60 | train_replay_buffer=self.train_replay_buffer, 61 | validation_replay_buffer=self.validation_replay_buffer, 62 | fraction_paths_in_train=self.fraction_paths_in_train, 63 | ) 64 | 65 | def __setstate__(self, d): 66 | self.train_replay_buffer = d['train_replay_buffer'] 67 | self.validation_replay_buffer = d['validation_replay_buffer'] 68 | self.fraction_paths_in_train = d['fraction_paths_in_train'] 69 | self.replay_buffer = self.train_replay_buffer 70 | -------------------------------------------------------------------------------- /rlkit/visualization/plotter.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | 4 | from rlkit.torch.core import eval_np 5 | 6 | 7 | class QFPolicyPlotter(object): 8 | def __init__(self, qf, policy, obs_lst, default_action, n_samples): 9 | self._qf = qf 10 | self._policy = policy 11 | self._obs_lst = obs_lst 12 | self._default_action = default_action 13 | self._n_samples = n_samples 14 | 15 | self._var_inds = np.where(np.isnan(default_action))[0] 16 | assert len(self._var_inds) == 2 17 | 18 | n_plots = len(obs_lst) 19 | 20 | x_size = 5 * n_plots 21 | y_size = 5 22 | 23 | fig = plt.figure(figsize=(x_size, y_size)) 24 | self._ax_lst = [] 25 | for i in range(n_plots): 26 | ax = fig.add_subplot(100 + n_plots * 10 + i + 1) 27 | ax.set_xlim((-1, 1)) 28 | ax.set_ylim((-1, 1)) 29 | ax.grid(True) 30 | self._ax_lst.append(ax) 31 | 32 | self._line_objects = list() 33 | 34 | def draw(self): 35 | # noinspection PyArgumentList 36 | [h.remove() for h in self._line_objects] 37 | self._line_objects = list() 38 | 39 | self._plot_level_curves() 40 | self._plot_action_samples() 41 | 42 | plt.draw() 43 | plt.pause(0.001) 44 | 45 | def _plot_level_curves(self): 46 | # Create mesh grid. 47 | xs = np.linspace(-1, 1, 50) 48 | ys = np.linspace(-1, 1, 50) 49 | xgrid, ygrid = np.meshgrid(xs, ys) 50 | N = len(xs)*len(ys) 51 | 52 | # Copy default values along the first axis and replace nans with 53 | # the mesh grid points. 54 | actions = np.tile(self._default_action, (N, 1)) 55 | actions[:, self._var_inds[0]] = xgrid.ravel() 56 | actions[:, self._var_inds[1]] = ygrid.ravel() 57 | 58 | for ax, obs in zip(self._ax_lst, self._obs_lst): 59 | repeated_obs = np.repeat( 60 | obs[None], 61 | actions.shape[0], 62 | axis=0, 63 | ) 64 | qs = eval_np(self._qf, repeated_obs, actions) 65 | qs = qs.reshape(xgrid.shape) 66 | 67 | cs = ax.contour(xgrid, ygrid, qs, 20) 68 | self._line_objects += cs.collections 69 | self._line_objects += ax.clabel( 70 | cs, inline=1, fontsize=10, fmt='%.2f') 71 | 72 | def _plot_action_samples(self): 73 | for ax, obs in zip(self._ax_lst, self._obs_lst): 74 | actions = self._policy.get_actions( 75 | np.ones((self._n_samples, 1)) * obs[None, :]) 76 | 77 | x, y = actions[:, 0], actions[:, 1] 78 | self._line_objects += ax.plot(x, y, 'b*') 79 | -------------------------------------------------------------------------------- /rlkit/torch/data_management/normalizer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import rlkit.torch.pytorch_util as ptu 3 | import numpy as np 4 | 5 | from rlkit.data_management.normalizer import Normalizer, FixedNormalizer 6 | 7 | 8 | class TorchNormalizer(Normalizer): 9 | """ 10 | Update with np array, but de/normalize pytorch Tensors. 11 | """ 12 | def normalize(self, v, clip_range=None): 13 | if not self.synchronized: 14 | self.synchronize() 15 | if clip_range is None: 16 | clip_range = self.default_clip_range 17 | mean = ptu.np_to_var(self.mean, requires_grad=False) 18 | std = ptu.np_to_var(self.std, requires_grad=False) 19 | if v.dim() == 2: 20 | # Unsqueeze along the batch use automatic broadcasting 21 | mean = mean.unsqueeze(0) 22 | std = std.unsqueeze(0) 23 | return torch.clamp((v - mean) / std, -clip_range, clip_range) 24 | 25 | def denormalize(self, v): 26 | if not self.synchronized: 27 | self.synchronize() 28 | mean = ptu.np_to_var(self.mean, requires_grad=False) 29 | std = ptu.np_to_var(self.std, requires_grad=False) 30 | if v.dim() == 2: 31 | mean = mean.unsqueeze(0) 32 | std = std.unsqueeze(0) 33 | return mean + v * std 34 | 35 | 36 | class TorchFixedNormalizer(FixedNormalizer): 37 | def normalize(self, v, clip_range=None): 38 | if clip_range is None: 39 | clip_range = self.default_clip_range 40 | mean = ptu.np_to_var(self.mean, requires_grad=False) 41 | std = ptu.np_to_var(self.std, requires_grad=False) 42 | if v.dim() == 2: 43 | # Unsqueeze along the batch use automatic broadcasting 44 | mean = mean.unsqueeze(0) 45 | std = std.unsqueeze(0) 46 | return torch.clamp((v - mean) / std, -clip_range, clip_range) 47 | 48 | def normalize_scale(self, v): 49 | """ 50 | Only normalize the scale. Do not subtract the mean. 51 | """ 52 | std = ptu.np_to_var(self.std, requires_grad=False) 53 | if v.dim() == 2: 54 | std = std.unsqueeze(0) 55 | return v / std 56 | 57 | def denormalize(self, v): 58 | mean = ptu.np_to_var(self.mean, requires_grad=False) 59 | std = ptu.np_to_var(self.std, requires_grad=False) 60 | if v.dim() == 2: 61 | mean = mean.unsqueeze(0) 62 | std = std.unsqueeze(0) 63 | return mean + v * std 64 | 65 | def denormalize_scale(self, v): 66 | """ 67 | Only denormalize the scale. Do not add the mean. 68 | """ 69 | std = ptu.np_to_var(self.std, requires_grad=False) 70 | if v.dim() == 2: 71 | std = std.unsqueeze(0) 72 | return v * std 73 | -------------------------------------------------------------------------------- /rlkit/data_management/replay_buffer.py: -------------------------------------------------------------------------------- 1 | import abc 2 | 3 | 4 | class ReplayBuffer(object, metaclass=abc.ABCMeta): 5 | """ 6 | A class used to save and replay data. 7 | """ 8 | 9 | @abc.abstractmethod 10 | def add_sample(self, observation, action, reward, next_observation, 11 | terminal, **kwargs): 12 | """ 13 | Add a transition tuple. 14 | """ 15 | pass 16 | 17 | @abc.abstractmethod 18 | def terminate_episode(self): 19 | """ 20 | Let the replay buffer know that the episode has terminated in case some 21 | special book-keeping has to happen. 22 | :return: 23 | """ 24 | pass 25 | 26 | @abc.abstractmethod 27 | def num_steps_can_sample(self, **kwargs): 28 | """ 29 | :return: # of unique items that can be sampled. 30 | """ 31 | pass 32 | 33 | def add_path(self, path): 34 | """ 35 | Add a path to the replay buffer. 36 | 37 | This default implementation naively goes through every step, but you 38 | may want to optimize this. 39 | 40 | NOTE: You should NOT call "terminate_episode" after calling add_path. 41 | It's assumed that this function handles the episode termination. 42 | 43 | :param path: Dict like one outputted by railrl.samplers.util.rollout 44 | """ 45 | for i, ( 46 | obs, 47 | action, 48 | reward, 49 | next_obs, 50 | terminal, 51 | agent_info, 52 | env_info 53 | ) in enumerate(zip( 54 | path["observations"], 55 | path["actions"], 56 | path["rewards"], 57 | path["next_observations"], 58 | path["terminals"], 59 | path["agent_infos"], 60 | path["env_infos"], 61 | )): 62 | self.add_sample( 63 | observation=obs, 64 | action=action, 65 | reward=reward, 66 | next_observation=next_obs, 67 | terminal=terminal, 68 | agent_info=agent_info, 69 | env_info=env_info, 70 | ) 71 | self.terminate_episode() 72 | 73 | def add_paths(self, paths): 74 | for path in paths: 75 | self.add_path(path) 76 | 77 | @abc.abstractmethod 78 | def random_batch(self, batch_size): 79 | """ 80 | Return a batch of size `batch_size`. 81 | :param batch_size: 82 | :return: 83 | """ 84 | pass 85 | 86 | def get_diagnostics(self): 87 | return {} 88 | 89 | def get_snapshot(self): 90 | return {} 91 | 92 | def end_epoch(self, epoch): 93 | return 94 | -------------------------------------------------------------------------------- /rlkit/torch/networks/experimental.py: -------------------------------------------------------------------------------- 1 | """ 2 | Contain some self-contained modules. Maybe depend on pytorch_util. 3 | """ 4 | import torch 5 | import torch.nn as nn 6 | from rlkit.torch import pytorch_util as ptu 7 | 8 | 9 | class OuterProductLinear(nn.Module): 10 | def __init__(self, in_features1, in_features2, out_features, bias=True): 11 | super().__init__() 12 | self.fc = nn.Linear( 13 | (in_features1 + 1) * (in_features2 + 1), 14 | out_features, 15 | bias=bias, 16 | ) 17 | 18 | def forward(self, in1, in2): 19 | out_product_flat = ptu.double_moments(in1, in2) 20 | return self.fc(out_product_flat) 21 | 22 | 23 | class SelfOuterProductLinear(OuterProductLinear): 24 | def __init__(self, in_features, out_features, bias=True): 25 | super().__init__(in_features, in_features, out_features, bias=bias) 26 | 27 | def forward(self, input): 28 | return super().forward(input, input) 29 | 30 | 31 | class BatchSquareDiagonal(nn.Module): 32 | """ 33 | Compute x^T diag(`diag_values`) x 34 | """ 35 | def __init__(self, vector_size): 36 | super().__init__() 37 | self.vector_size = vector_size 38 | self.diag_mask = ptu.Variable(torch.diag(torch.ones(vector_size)), 39 | requires_grad=False) 40 | 41 | def forward(self, vector, diag_values): 42 | M = ptu.batch_diag(diag_values=diag_values, diag_mask=self.diag_mask) 43 | return ptu.batch_square_vector(vector=vector, M=M) 44 | 45 | 46 | class HuberLoss(nn.Module): 47 | def __init__(self, delta=1): 48 | super().__init__() 49 | self.huber_loss_delta1 = nn.SmoothL1Loss() 50 | self.delta = delta 51 | 52 | def forward(self, x, x_hat): 53 | loss = self.huber_loss_delta1(x / self.delta, x_hat / self.delta) 54 | return loss * self.delta * self.delta 55 | 56 | 57 | class LayerNorm(nn.Module): 58 | """ 59 | Simple 1D LayerNorm. 60 | """ 61 | def __init__(self, features, center=True, scale=False, eps=1e-6): 62 | super().__init__() 63 | self.center = center 64 | self.scale = scale 65 | self.eps = eps 66 | if self.scale: 67 | self.scale_param = nn.Parameter(torch.ones(features)) 68 | else: 69 | self.scale_param = None 70 | if self.center: 71 | self.center_param = nn.Parameter(torch.zeros(features)) 72 | else: 73 | self.center_param = None 74 | 75 | def forward(self, x): 76 | mean = x.mean(-1, keepdim=True) 77 | std = x.std(-1, keepdim=True) 78 | output = (x - mean) / (std + self.eps) 79 | if self.scale: 80 | output = output * self.scale_param 81 | if self.center: 82 | output = output + self.center_param 83 | return output 84 | -------------------------------------------------------------------------------- /rlkit/envs/images/insert_image_env.py: -------------------------------------------------------------------------------- 1 | import typing 2 | 3 | import gym 4 | import numpy as np 5 | from gym.spaces import Box, Dict 6 | 7 | from rlkit.envs.images.env_renderer import EnvRenderer 8 | 9 | # getting Product 10 | 11 | 12 | def prod(val): 13 | res = 1 14 | for ele in val: 15 | res *= ele 16 | return res 17 | 18 | 19 | class InsertImagesEnv(gym.Wrapper): 20 | """ 21 | Add an image to the observation. Usage: 22 | 23 | ``` 24 | obs = env.reset() 25 | print(obs.keys()) # ['observations'] 26 | 27 | new_env = InsertImageEnv( 28 | env, 29 | { 30 | 'image_observation': renderer_one, 31 | 'debugging_img': renderer_two, 32 | }, 33 | ) 34 | obs = new_env.reset() 35 | print(obs.keys()) # ['observations', 'image_observation', 'debugging_img'] 36 | ``` 37 | """ 38 | 39 | def __init__( 40 | self, 41 | wrapped_env: gym.Env, 42 | renderers: typing.Dict[str, EnvRenderer], 43 | ): 44 | super().__init__(wrapped_env) 45 | spaces = self.env.observation_space.spaces.copy() 46 | for image_key, renderer in renderers.items(): 47 | if renderer.image_is_normalized: 48 | img_space = Box( 49 | 0, 1, (prod(renderer.image_shape), ), dtype=np.float32) 50 | else: 51 | img_space = Box( 52 | 0, 255, (prod(renderer.image_shape), ), dtype=np.uint8) 53 | spaces[image_key] = img_space 54 | self.renderers = renderers 55 | self.observation_space = Dict(spaces) 56 | self.action_space = self.env.action_space 57 | 58 | def step(self, action): 59 | obs, reward, done, info = self.env.step(action) 60 | self._update_obs(obs) 61 | return obs, reward, done, info 62 | 63 | def reset(self): 64 | obs = self.env.reset() 65 | self._update_obs(obs) 66 | return obs 67 | 68 | def get_observation(self): 69 | obs = self.env.get_observation() 70 | self._update_obs(obs) 71 | return obs 72 | 73 | def _update_obs(self, obs): 74 | for image_key, renderer in self.renderers.items(): 75 | obs[image_key] = renderer(self.env) 76 | 77 | 78 | class InsertImageEnv(InsertImagesEnv): 79 | """ 80 | Add an image to the observation. Usage: 81 | ``` 82 | obs = env.reset() 83 | print(obs.keys()) # ['observations'] 84 | 85 | new_env = InsertImageEnv(env, renderer, image_key='pretty_picture') 86 | obs = new_env.reset() 87 | print(obs.keys()) # ['observations', 'pretty_picture'] 88 | ``` 89 | """ 90 | 91 | def __init__( 92 | self, 93 | wrapped_env: gym.Env, 94 | renderer: EnvRenderer, 95 | image_key='image_observation', 96 | ): 97 | super().__init__(wrapped_env, {image_key: renderer}) 98 | -------------------------------------------------------------------------------- /rlkit/envs/pearl_envs/rand_param_envs/pr2_env_reach.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from gym import utils 3 | from rlkit.envs.pearl_envs.rand_param_envs.base import RandomEnv 4 | import os 5 | 6 | class PR2Env(RandomEnv, utils.EzPickle): 7 | 8 | FILE = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'assets/pr2.xml') 9 | 10 | def __init__(self, log_scale_limit=1.): 11 | self.viewer = None 12 | RandomEnv.__init__(self, log_scale_limit, 'pr2.xml', 4) 13 | utils.EzPickle.__init__(self) 14 | 15 | def _get_obs(self): 16 | return np.concatenate([ 17 | self.model.data.qpos.flat[:7], 18 | self.model.data.qvel.flat[:7], # Do not include the velocity of the target (should be 0). 19 | self.get_tip_position().flat, 20 | self.get_vec_tip_to_goal().flat, 21 | ]) 22 | 23 | def get_tip_position(self): 24 | return self.model.data.site_xpos[0] 25 | 26 | def get_vec_tip_to_goal(self): 27 | tip_position = self.get_tip_position() 28 | goal_position = self.goal 29 | vec_tip_to_goal = goal_position - tip_position 30 | return vec_tip_to_goal 31 | 32 | @property 33 | def goal(self): 34 | return self.model.data.qpos.flat[-3:] 35 | 36 | def _step(self, action): 37 | 38 | self.do_simulation(action, self.frame_skip) 39 | 40 | vec_tip_to_goal = self.get_vec_tip_to_goal() 41 | distance_tip_to_goal = np.linalg.norm(vec_tip_to_goal) 42 | 43 | reward = - distance_tip_to_goal 44 | 45 | state = self.state_vector() 46 | notdone = np.isfinite(state).all() 47 | done = not notdone 48 | 49 | ob = self._get_obs() 50 | 51 | return ob, reward, done, {} 52 | 53 | def reset_model(self): 54 | qpos = self.init_qpos 55 | qvel = self.init_qvel 56 | goal = np.random.uniform((0.2, -0.4, 0.5), (0.5, 0.4, 1.5)) 57 | qpos[-3:] = goal 58 | qpos[:7] += self.np_random.uniform(low=-.005, high=.005, size=7) 59 | qvel[:7] += self.np_random.uniform(low=-.005, high=.005, size=7) 60 | self.set_state(qpos, qvel) 61 | return self._get_obs() 62 | 63 | def viewer_setup(self): 64 | self.viewer.cam.distance = self.model.stat.extent * 2 65 | # self.viewer.cam.lookat[2] += .8 66 | self.viewer.cam.elevation = -50 67 | # self.viewer.cam.lookat[0] = self.model.stat.center[0] 68 | # self.viewer.cam.lookat[1] = self.model.stat.center[1] 69 | # self.viewer.cam.lookat[2] = self.model.stat.center[2] 70 | 71 | 72 | if __name__ == "__main__": 73 | 74 | env = PR2Env() 75 | tasks = env.sample_tasks(40) 76 | while True: 77 | env.reset() 78 | env.set_task(np.random.choice(tasks)) 79 | print(env.model.body_mass) 80 | for _ in range(100): 81 | env.render() 82 | env.step(env.action_space.sample()) 83 | -------------------------------------------------------------------------------- /rlkit/envs/mujoco/pusher.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | 3 | import numpy as np 4 | from gym.envs.mujoco import PusherEnv as GymPusherEnv 5 | 6 | from rlkit.core.eval_util import create_stats_ordered_dict, get_stat_in_paths 7 | from rlkit.core import logger 8 | 9 | 10 | class PusherEnv(GymPusherEnv): 11 | def __init__(self): 12 | self.goal_cylinder_relative_x = 0 13 | self.goal_cylinder_relative_y = 0 14 | super().__init__() 15 | 16 | def reset_model(self): 17 | qpos = self.init_qpos 18 | goal_xy = np.array([ 19 | self.goal_cylinder_relative_x, 20 | self.goal_cylinder_relative_y, 21 | ]) 22 | 23 | while True: 24 | self.cylinder_pos = np.concatenate([ 25 | self.np_random.uniform(low=-0.3, high=0, size=1), 26 | self.np_random.uniform(low=-0.2, high=0.2, size=1)]) 27 | if np.linalg.norm(self.cylinder_pos - goal_xy) > 0.17: 28 | break 29 | 30 | qpos[-4:-2] = self.cylinder_pos 31 | # y-axis comes first in the xml 32 | qpos[-2] = self.goal_cylinder_relative_y 33 | qpos[-1] = self.goal_cylinder_relative_x 34 | qvel = self.init_qvel + self.np_random.uniform(low=-0.005, 35 | high=0.005, 36 | size=self.model.nv) 37 | qvel[-4:] = 0 38 | self.set_state(qpos, qvel) 39 | return self._get_obs() 40 | 41 | def _step(self, a): 42 | arm_to_object = ( 43 | self.get_body_com("tips_arm") - self.get_body_com("object") 44 | ) 45 | object_to_goal = ( 46 | self.get_body_com("object") - self.get_body_com("goal") 47 | ) 48 | arm_to_goal = ( 49 | self.get_body_com("tips_arm") - self.get_body_com("goal") 50 | ) 51 | obs, reward, done, info_dict = super()._step(a) 52 | info_dict['arm to object distance'] = np.linalg.norm(arm_to_object) 53 | info_dict['object to goal distance'] = np.linalg.norm(object_to_goal) 54 | info_dict['arm to goal distance'] = np.linalg.norm(arm_to_goal) 55 | return obs, reward, done, info_dict 56 | 57 | def log_diagnostics(self, paths): 58 | statistics = OrderedDict() 59 | 60 | for stat_name in [ 61 | 'arm to object distance', 62 | 'object to goal distance', 63 | 'arm to goal distance', 64 | ]: 65 | stat = get_stat_in_paths( 66 | paths, 'env_infos', stat_name 67 | ) 68 | statistics.update(create_stats_ordered_dict( 69 | stat_name, stat 70 | )) 71 | 72 | for key, value in statistics.items(): 73 | logger.record_tabular(key, value) 74 | 75 | def _set_goal_xy(self, xy): 76 | # Based on XML 77 | self.goal_cylinder_relative_x = xy[0] - 0.45 78 | self.goal_cylinder_relative_y = xy[1] + 0.05 79 | -------------------------------------------------------------------------------- /rlkit/envs/contextual/task_conditioned.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | 3 | import numpy as np 4 | 5 | from rlkit.envs.contextual.goal_conditioned import ( 6 | GoalDictDistributionFromMultitaskEnv, 7 | ) 8 | from rlkit.samplers.data_collector.contextual_path_collector import ( 9 | ContextualPathCollector 10 | ) 11 | 12 | from gym.spaces import Box 13 | from rlkit.samplers.rollout_functions import contextual_rollout 14 | 15 | class TaskGoalDictDistributionFromMultitaskEnv( 16 | GoalDictDistributionFromMultitaskEnv): 17 | def __init__( 18 | self, 19 | *args, 20 | task_key='task_id', 21 | task_ids=None, 22 | **kwargs 23 | ): 24 | super().__init__(*args, **kwargs) 25 | self.task_key = task_key 26 | self._spaces[task_key] = Box( 27 | low=np.zeros(1), 28 | high=np.ones(1)) 29 | self.task_ids = np.array(task_ids) 30 | 31 | def sample(self, batch_size: int): 32 | goals = super().sample(batch_size) 33 | idxs = np.random.choice(len(self.task_ids), batch_size) 34 | goals[self.task_key] = self.task_ids[idxs].reshape(-1, 1) 35 | return goals 36 | 37 | class TaskPathCollector(ContextualPathCollector): 38 | def __init__( 39 | self, 40 | *args, 41 | task_key=None, 42 | max_path_length=100, 43 | task_ids=None, 44 | rotate_freq=0.0, 45 | **kwargs 46 | ): 47 | super().__init__(*args, **kwargs) 48 | self.rotate_freq = rotate_freq 49 | self.rollout_tasks = [] 50 | 51 | def obs_processor(o): 52 | if len(self.rollout_tasks) > 0: 53 | task = self.rollout_tasks[0] 54 | self.rollout_tasks = self.rollout_tasks[1:] 55 | o[task_key] = task 56 | self._env._rollout_context_batch[task_key] = task[None] 57 | 58 | combined_obs = [o[self._observation_key]] 59 | for k in self._context_keys_for_policy: 60 | combined_obs.append(o[k]) 61 | return np.concatenate(combined_obs, axis=0) 62 | 63 | def reset_postprocess_func(): 64 | rotate = (np.random.uniform() < self.rotate_freq) 65 | self.rollout_tasks = [] 66 | if rotate: 67 | num_steps_per_task = max_path_length // len(task_ids) 68 | self.rollout_tasks = np.ones((max_path_length, 1)) * (len(task_ids) - 1) 69 | for (idx, id) in enumerate(task_ids): 70 | start = idx * num_steps_per_task 71 | end = start + num_steps_per_task 72 | self.rollout_tasks[start:end] = id 73 | 74 | self._rollout_fn = partial( 75 | contextual_rollout, 76 | context_keys_for_policy=self._context_keys_for_policy, 77 | observation_key=self._observation_key, 78 | obs_processor=obs_processor, 79 | reset_postprocess_func=reset_postprocess_func, 80 | ) -------------------------------------------------------------------------------- /rlkit/networks/gaussian_policy.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch import nn as nn 4 | 5 | import rlkit.torch.pytorch_util as ptu 6 | from rlkit.torch.sac.policies.base import TorchStochasticPolicy 7 | from rlkit.torch.sac.policies.gaussian_policy import ( 8 | LOG_SIG_MAX, 9 | LOG_SIG_MIN, 10 | ) 11 | from rlkit.torch.distributions import MultivariateDiagonalNormal 12 | from rlkit.torch.networks import CNN 13 | 14 | 15 | class GaussianCNNPolicy(CNN, TorchStochasticPolicy): 16 | def __init__( 17 | self, 18 | hidden_sizes, 19 | obs_dim, 20 | action_dim, 21 | std=None, 22 | init_w=1e-3, 23 | min_log_std=None, 24 | max_log_std=None, 25 | std_architecture="shared", 26 | output_activation=None, 27 | **kwargs 28 | ): 29 | super().__init__( 30 | hidden_sizes=hidden_sizes, 31 | output_size=action_dim, 32 | init_w=init_w, 33 | output_activation=output_activation, 34 | **kwargs 35 | ) 36 | self.min_log_std = min_log_std 37 | self.max_log_std = max_log_std 38 | self.log_std = None 39 | self.std = std 40 | self.std_architecture = std_architecture 41 | if std is None: 42 | if self.std_architecture == "shared": 43 | last_hidden_size = obs_dim 44 | if len(hidden_sizes) > 0: 45 | last_hidden_size = hidden_sizes[-1] 46 | self.last_fc_log_std = nn.Linear(last_hidden_size, action_dim) 47 | self.last_fc_log_std.weight.data.uniform_(-init_w, init_w) 48 | self.last_fc_log_std.bias.data.uniform_(-init_w, init_w) 49 | elif self.std_architecture == "values": 50 | self.log_std_logits = nn.Parameter( 51 | ptu.zeros(action_dim, requires_grad=True)) 52 | else: 53 | raise ValueError(self.std_architecture) 54 | else: 55 | self.log_std = np.log(std) 56 | assert LOG_SIG_MIN <= self.log_std <= LOG_SIG_MAX 57 | 58 | def forward(self, obs): 59 | h = super().forward(obs, return_last_activations=True) 60 | mean = self.last_fc(h) 61 | if self.output_activation is not None: 62 | mean = self.output_activation(mean) 63 | if self.std is None: 64 | if self.std_architecture == "shared": 65 | log_std = torch.sigmoid(self.last_fc_log_std(h)) 66 | elif self.std_architecture == "values": 67 | log_std = torch.sigmoid(self.log_std_logits) 68 | else: 69 | raise ValueError(self.std_architecture) 70 | log_std = self.min_log_std + log_std * ( 71 | self.max_log_std - self.min_log_std) 72 | std = torch.exp(log_std) 73 | else: 74 | std = torch.from_numpy(np.array([self.std, ])).float().to( 75 | ptu.device) 76 | 77 | return MultivariateDiagonalNormal(mean, std) 78 | -------------------------------------------------------------------------------- /rlkit/envs/dual_encoder_wrapper.py: -------------------------------------------------------------------------------- 1 | import abc 2 | import numpy as np 3 | import rlkit.torch.pytorch_util as ptu 4 | from gym.spaces import Box, Dict 5 | from rlkit.envs.vae_wrappers import VAEWrappedEnv 6 | from rlkit.envs.wrappers import ProxyEnv 7 | 8 | from rlkit.envs.encoder_wrappers import Encoder 9 | 10 | class DualEncoderWrappedEnv(ProxyEnv): 11 | def __init__(self, 12 | wrapped_env, 13 | model: Encoder, 14 | input_model, 15 | step_keys_map=None, 16 | reset_keys_map=None, 17 | conditional_input_model=False, 18 | ): 19 | super().__init__(wrapped_env) 20 | self.model = model 21 | self.input_model = input_model 22 | self.conditional_input_model = conditional_input_model 23 | self.representation_size = self.model.representation_size 24 | self.input_representation_size = self.input_model.representation_size 25 | latent_space = Box( 26 | -10 * np.ones(self.representation_size), 27 | 10 * np.ones(self.representation_size), 28 | dtype=np.float32, 29 | ) 30 | input_latent_space = Box( 31 | -10 * np.ones(self.input_representation_size), 32 | 10 * np.ones(self.input_representation_size), 33 | dtype=np.float32, 34 | ) 35 | 36 | if step_keys_map is None: 37 | step_keys_map = {} 38 | if reset_keys_map is None: 39 | reset_keys_map = {} 40 | self.step_keys_map = step_keys_map 41 | self.reset_keys_map = reset_keys_map 42 | spaces = self.wrapped_env.observation_space.spaces 43 | for value in self.step_keys_map.values(): 44 | spaces[value] = latent_space 45 | for value in self.reset_keys_map.values(): 46 | spaces[value] = latent_space 47 | 48 | spaces['input_latent'] = input_latent_space 49 | self.observation_space = Dict(spaces) 50 | self.reset_obs = {} 51 | 52 | def step(self, action): 53 | self.model.eval() 54 | obs, reward, done, info = self.wrapped_env.step(action) 55 | new_obs = self._update_obs(obs) 56 | return new_obs, reward, done, info 57 | 58 | def _update_obs(self, obs): 59 | self.model.eval() 60 | for key in self.step_keys_map: 61 | value = self.step_keys_map[key] 62 | obs[value] = self.model.encode_one_np(obs[key]) 63 | obs = {**obs, **self.reset_obs} 64 | 65 | if self.conditional_input_model: 66 | obs['input_latent'] = self.input_model.encode_one_np(obs['image_observation'], self._initial_img) 67 | else: 68 | obs['input_latent'] = self.input_model.encode_one_np(obs['image_observation']) 69 | 70 | return obs 71 | 72 | def reset(self): 73 | self.model.eval() 74 | obs = self.wrapped_env.reset() 75 | self._initial_img = obs["image_observation"] 76 | for key in self.reset_keys_map: 77 | value = self.reset_keys_map[key] 78 | self.reset_obs[value] = self.model.encode_one_np(obs[key]) 79 | obs = self._update_obs(obs) 80 | return obs 81 | -------------------------------------------------------------------------------- /rlkit/launchers/doodad_wrapper.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | from typing import NamedTuple 4 | import random 5 | 6 | import __main__ as main 7 | import numpy as np 8 | 9 | import rlkit.torch.pytorch_util as ptu 10 | from rlkit.core import logger, setup_logger 11 | from rlkit.launchers import config 12 | 13 | from doodad.wrappers.easy_launch import save_doodad_config, sweep_function, DoodadConfig 14 | from rlkit.launchers.launcher_util import set_seed 15 | 16 | 17 | class AutoSetup: 18 | """ 19 | Automatically set up: 20 | 1. the logger 21 | 2. the GPU mode 22 | 3. the seed 23 | :param exp_function: some function that should not depend on `logger_config` 24 | nor `seed`. 25 | :param unpack_variant: do you call exp_function with `**variant`? 26 | :return: function output 27 | """ 28 | def __init__(self, exp_function, unpack_variant=True): 29 | self.exp_function = exp_function 30 | self.unpack_variant = unpack_variant 31 | 32 | def __call__(self, doodad_config: DoodadConfig, variant): 33 | save_doodad_config(doodad_config) 34 | variant_to_save = variant.copy() 35 | variant_to_save['doodad_info'] = doodad_config.extra_launch_info 36 | exp_name = doodad_config.extra_launch_info['exp_name'] 37 | seed = variant.pop('seed', 0) 38 | set_seed(seed) 39 | ptu.set_gpu_mode(doodad_config.use_gpu) 40 | # Reopening the files is nececessary because blobfuse only syncs files 41 | # when they're closed. For details, see 42 | # https://github.com/Azure/azure-storage-fuse#if-your-workload-is-not-read-only 43 | reopen_files_on_flush = True 44 | # might as well always have it on, but if I didn't want to, you could: 45 | # mode = doodad_config.extra_launch_info['mode'] 46 | # reopen_files_on_flush = mode == 'azure' 47 | setup_logger( 48 | logger, 49 | exp_name=exp_name, 50 | base_log_dir=None, 51 | log_dir=doodad_config.output_directory, 52 | seed=seed, 53 | variant=variant, 54 | reopen_files_on_flush=reopen_files_on_flush, 55 | ) 56 | variant.pop('logger_config', None) 57 | variant.pop('exp_id', None) 58 | variant.pop('run_id', None) 59 | if self.unpack_variant: 60 | self.exp_function(**variant) 61 | else: 62 | self.exp_function(variant) 63 | 64 | 65 | def run_experiment( 66 | method_call, 67 | params, 68 | default_params, 69 | exp_name='default', 70 | mode='local', 71 | wrap_fn_with_auto_setup=True, 72 | unpack_variant=True, 73 | **kwargs 74 | ): 75 | if wrap_fn_with_auto_setup: 76 | method_call = AutoSetup( 77 | method_call, 78 | unpack_variant=unpack_variant, 79 | ) 80 | sweep_function( 81 | method_call, 82 | params, 83 | default_params=default_params, 84 | mode=mode, 85 | log_path=exp_name, 86 | add_time_to_run_id='in_front', 87 | extra_launch_info={'exp_name': exp_name, 'mode': mode}, 88 | **kwargs 89 | ) 90 | -------------------------------------------------------------------------------- /rlkit/samplers/data_collector/joint_path_collector.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | from typing import Dict 3 | 4 | from rlkit.core.logging import add_prefix 5 | from rlkit.samplers.data_collector import PathCollector 6 | 7 | 8 | class JointPathCollector(PathCollector): 9 | EVENLY = 'evenly' 10 | 11 | def __init__( 12 | self, 13 | path_collectors: Dict[str, PathCollector], 14 | divide_num_steps_strategy=EVENLY, 15 | ): 16 | """ 17 | :param path_collectors: Dictionary of path collectors 18 | :param divide_num_steps_strategy: How the steps are divided among the 19 | path collectors. 20 | Valid values: 21 | - 'evenly': divide `num_steps' evenly among the collectors 22 | """ 23 | sorted_collectors = OrderedDict() 24 | # Sort the path collectors to have a canonical ordering 25 | for k in sorted(path_collectors): 26 | sorted_collectors[k] = path_collectors[k] 27 | self.path_collectors = sorted_collectors 28 | self.divide_num_steps_strategy = divide_num_steps_strategy 29 | if divide_num_steps_strategy not in {self.EVENLY}: 30 | raise ValueError(divide_num_steps_strategy) 31 | 32 | def collect_new_paths(self, max_path_length, num_steps, 33 | discard_incomplete_paths, 34 | **kwargs): 35 | paths = [] 36 | if self.divide_num_steps_strategy == self.EVENLY: 37 | num_steps_per_collector = num_steps // len(self.path_collectors) 38 | else: 39 | raise ValueError(self.divide_num_steps_strategy) 40 | for name, collector in self.path_collectors.items(): 41 | paths += collector.collect_new_paths( 42 | max_path_length=max_path_length, 43 | num_steps=num_steps_per_collector, 44 | discard_incomplete_paths=discard_incomplete_paths, 45 | **kwargs 46 | ) 47 | return paths 48 | 49 | def end_epoch(self, epoch): 50 | for collector in self.path_collectors.values(): 51 | collector.end_epoch(epoch) 52 | 53 | def get_diagnostics(self): 54 | diagnostics = OrderedDict() 55 | num_steps = 0 56 | num_paths = 0 57 | for name, collector in self.path_collectors.items(): 58 | stats = collector.get_diagnostics() 59 | num_steps += stats['num steps total'] 60 | num_paths += stats['num paths total'] 61 | diagnostics.update( 62 | add_prefix(stats, name, divider='/'), 63 | ) 64 | diagnostics['num steps total'] = num_steps 65 | diagnostics['num paths total'] = num_paths 66 | return diagnostics 67 | 68 | def get_snapshot(self): 69 | snapshot = {} 70 | for name, collector in self.path_collectors.items(): 71 | snapshot.update( 72 | add_prefix(collector.get_snapshot(), name, divider='/'), 73 | ) 74 | return snapshot 75 | 76 | def get_epoch_paths(self): 77 | paths = {} 78 | for name, collector in self.path_collectors.items(): 79 | paths[name] = collector.get_epoch_paths() 80 | return paths 81 | -------------------------------------------------------------------------------- /rlkit/util/tensorboard_logger.py: -------------------------------------------------------------------------------- 1 | # Code referenced from https://gist.github.com/gyglim/1f8dfb1b5c82627ae3efcfbbadb9f514 2 | import tensorflow as tf 3 | import numpy as np 4 | import scipy.misc 5 | try: 6 | from StringIO import StringIO # Python 2.7 7 | except ImportError: 8 | from io import BytesIO # Python 3.x 9 | 10 | 11 | class TensorboardLogger(object): 12 | def __init__(self, log_dir): 13 | """Create a summary writer logging to log_dir.""" 14 | self.writer = tf.summary.FileWriter(log_dir) 15 | 16 | def scalar_summary(self, tag, value, step): 17 | """Log a scalar variable.""" 18 | summary = tf.Summary(value=[tf.Summary.Value(tag=tag, simple_value=value)]) 19 | self.writer.add_summary(summary, step) 20 | 21 | def image_summary(self, tag, images, step): 22 | """Log a list of images.""" 23 | 24 | img_summaries = [] 25 | for i, img in enumerate(images): 26 | # Write the image to a string 27 | try: 28 | s = StringIO() 29 | except: 30 | s = BytesIO() 31 | scipy.misc.toimage(img).save(s, format="png") 32 | 33 | # Create an Image object 34 | img_sum = tf.Summary.Image(encoded_image_string=s.getvalue(), 35 | height=img.shape[0], 36 | width=img.shape[1]) 37 | # Create a Summary value 38 | img_summaries.append(tf.Summary.Value(tag='%s/%d' % (tag, i), image=img_sum)) 39 | 40 | # Create and write Summary 41 | summary = tf.Summary(value=img_summaries) 42 | self.writer.add_summary(summary, step) 43 | 44 | 45 | def histo_summary(self, tag, values, step, bins=1000): 46 | """Log a histogram of the tensor of values.""" 47 | 48 | # Create a histogram using numpy 49 | # Original code: 50 | # counts, bin_edges = np.histogram(values, bins=bins) 51 | # 52 | # Now I use the following to get more detailed data 53 | # https://stackoverflow.com/questions/39418380/histogram-with-equal-number-of-points-in-each-bin 54 | counts, bin_edges = np.histogram( 55 | values, 56 | bins=histedges_equalN(values.flatten(), bins) 57 | ) 58 | 59 | # Fill the fields of the histogram proto 60 | hist = tf.HistogramProto() 61 | hist.min = float(np.min(values)) 62 | hist.max = float(np.max(values)) 63 | hist.num = int(np.prod(values.shape)) 64 | hist.sum = float(np.sum(values)) 65 | hist.sum_squares = float(np.sum(values**2)) 66 | 67 | # Drop the start of the first bin 68 | bin_edges = bin_edges[1:] 69 | 70 | # Add bin edges and counts 71 | for edge in bin_edges: 72 | hist.bucket_limit.append(edge) 73 | for c in counts: 74 | hist.bucket.append(c) 75 | 76 | # Create and write Summary 77 | summary = tf.Summary(value=[tf.Summary.Value(tag=tag, histo=hist)]) 78 | self.writer.add_summary(summary, step) 79 | self.writer.flush() 80 | 81 | 82 | def histedges_equalN(x, nbin): 83 | npt = len(x) 84 | return np.interp(np.linspace(0, npt, nbin + 1), 85 | np.arange(npt), 86 | np.sort(x)) 87 | --------------------------------------------------------------------------------