├── rlkit ├── __init__.py ├── torch │ ├── __init__.py │ ├── dsac │ │ ├── __init__.py │ │ ├── utils.py │ │ ├── risk.py │ │ ├── networks.py │ │ └── policies.py │ ├── sac │ │ ├── __init__.py │ │ ├── policies.py │ │ └── sac.py │ ├── td3 │ │ ├── __init__.py │ │ ├── networks.py │ │ └── td3.py │ ├── td4 │ │ ├── __init__.py │ │ └── utils.py │ ├── data_management │ │ ├── __init__.py │ │ └── normalizer.py │ ├── core.py │ ├── torch_rl_algorithm.py │ ├── data.py │ ├── distributions.py │ ├── networks.py │ └── pytorch_util.py ├── policies │ ├── __init__.py │ ├── simple.py │ ├── base.py │ └── argmax.py ├── samplers │ ├── __init__.py │ ├── data_collector │ │ ├── __init__.py │ │ ├── base.py │ │ ├── vec_path_collector.py │ │ ├── vec_step_collector.py │ │ ├── path_collector.py │ │ └── step_collector.py │ ├── util.py │ └── rollout_functions.py ├── data_management │ ├── __init__.py │ ├── path_builder.py │ ├── env_replay_buffer.py │ ├── replay_buffer.py │ ├── simple_replay_buffer.py │ ├── normalizer.py │ └── torch_replay_buffer.py ├── exploration_strategies │ ├── __init__.py │ ├── epsilon_greedy.py │ ├── gaussian_strategy.py │ ├── gaussian_and_epsilon_strategy.py │ ├── base.py │ └── ou_strategy.py ├── core │ ├── __init__.py │ ├── trainer.py │ ├── serializable.py │ ├── batch_rl_algorithm.py │ ├── online_rl_algorithm.py │ ├── eval_util.py │ ├── rl_algorithm.py │ └── vec_online_rl_algorithm.py ├── launchers │ ├── __init__.py │ └── conf.py ├── envs │ ├── __init__.py │ ├── env_utils.py │ ├── wrappers.py │ └── vecenv.py └── util │ ├── ml_util.py │ ├── io.py │ ├── video.py │ └── hyperparameter.py ├── .gitignore ├── utils.py ├── configs ├── td3-normal │ ├── ant.yaml │ ├── hopper.yaml │ ├── humanoid.yaml │ ├── walker2d.yaml │ ├── halfcheetah.yaml │ └── bipedalwalkerhardcore.yaml ├── td4-normal-iqn-neutral │ ├── ant.yaml │ ├── hopper.yaml │ ├── humanoid.yaml │ ├── walker2d.yaml │ ├── halfcheetah.yaml │ └── bipedalwalkerhardcore.yaml ├── sac-normal │ ├── ant.yaml │ ├── hopper.yaml │ ├── humanoid.yaml │ ├── walker2d.yaml │ ├── halfcheetah.yaml │ └── bipedalwalkerhardcore.yaml └── dsac-normal-iqn-neutral │ ├── ant.yaml │ ├── hopper.yaml │ ├── humanoid.yaml │ ├── walker2d.yaml │ ├── halfcheetah.yaml │ └── bipedalwalkerhardcore.yaml ├── LICENSE ├── README.md ├── sac.py ├── td3.py ├── dsac.py └── td4.py /rlkit/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /rlkit/torch/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /rlkit/policies/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /rlkit/samplers/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /rlkit/torch/dsac/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /rlkit/torch/sac/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /rlkit/torch/td3/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /rlkit/torch/td4/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /rlkit/data_management/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /rlkit/exploration_strategies/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /rlkit/torch/data_management/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | */*/mjkey.txt 2 | **/.DS_STORE 3 | **/*.pyc 4 | **/*.swp 5 | rlkit/launchers/config.py 6 | rlkit/launchers/conf_private.py 7 | MANIFEST 8 | *.egg-info 9 | \.idea/ 10 | -------------------------------------------------------------------------------- /rlkit/core/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | General classes, functions, utilities that are used throughout rlkit. 3 | """ 4 | from rlkit.core.logging import logger 5 | 6 | __all__ = ['logger'] 7 | 8 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import yaml 2 | 3 | 4 | def load_config(path): 5 | try: 6 | with open(path, 'r', encoding="utf-8") as f: 7 | config = yaml.load(f) 8 | except IOError: 9 | print(f"No Such File: {path}") 10 | return config 11 | -------------------------------------------------------------------------------- /rlkit/core/trainer.py: -------------------------------------------------------------------------------- 1 | import abc 2 | 3 | 4 | class Trainer(object, metaclass=abc.ABCMeta): 5 | @abc.abstractmethod 6 | def train(self, data): 7 | pass 8 | 9 | def end_epoch(self, epoch): 10 | pass 11 | 12 | def get_snapshot(self): 13 | return {} 14 | 15 | def get_diagnostics(self): 16 | return {} 17 | -------------------------------------------------------------------------------- /rlkit/launchers/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | This module contains 'launchers', which are self-contained functions that take 3 | one dictionary and run a full experiment. The dictionary configures the 4 | experiment. 5 | 6 | It is important that the functions are completely self-contained (i.e. they 7 | import their own modules) so that they can be serialized. 8 | """ 9 | -------------------------------------------------------------------------------- /rlkit/policies/simple.py: -------------------------------------------------------------------------------- 1 | from rlkit.policies.base import Policy 2 | 3 | 4 | class RandomPolicy(Policy): 5 | """ 6 | Policy that always outputs zero. 7 | """ 8 | 9 | def __init__(self, action_space): 10 | self.action_space = action_space 11 | 12 | def get_action(self, obs): 13 | return self.action_space.sample(), {} 14 | -------------------------------------------------------------------------------- /rlkit/envs/__init__.py: -------------------------------------------------------------------------------- 1 | import gym 2 | from gym.wrappers import TimeLimit 3 | 4 | from rlkit.envs.wrappers import CustomInfoEnv, NormalizedBoxEnv 5 | 6 | 7 | def make_env(name): 8 | env = gym.make(name) 9 | # Remove TimeLimit Wrapper 10 | if isinstance(env, TimeLimit): 11 | env = env.unwrapped 12 | env = CustomInfoEnv(env) 13 | env = NormalizedBoxEnv(env) 14 | return env 15 | -------------------------------------------------------------------------------- /configs/td3-normal/ant.yaml: -------------------------------------------------------------------------------- 1 | algorithm_kwargs: 2 | batch_size: 256 3 | max_path_length: 1000 4 | min_num_steps_before_training: 10000 5 | num_epochs: 1000 6 | num_eval_paths_per_epoch: 10 7 | num_expl_steps_per_train_loop: 1000 8 | num_trains_per_train_loop: 1000 9 | env: Ant-v2 10 | eval_env_num: 10 11 | expl_env_num: 10 12 | layer_size: 256 13 | replay_buffer_size: 1000000 14 | trainer_kwargs: 15 | discount: 0.99 16 | version: normal 17 | -------------------------------------------------------------------------------- /configs/td3-normal/hopper.yaml: -------------------------------------------------------------------------------- 1 | algorithm_kwargs: 2 | batch_size: 256 3 | max_path_length: 1000 4 | min_num_steps_before_training: 10000 5 | num_epochs: 1000 6 | num_eval_paths_per_epoch: 10 7 | num_expl_steps_per_train_loop: 1000 8 | num_trains_per_train_loop: 1000 9 | env: Hopper-v2 10 | eval_env_num: 10 11 | expl_env_num: 10 12 | layer_size: 256 13 | replay_buffer_size: 1000000 14 | trainer_kwargs: 15 | discount: 0.99 16 | version: normal 17 | -------------------------------------------------------------------------------- /configs/td3-normal/humanoid.yaml: -------------------------------------------------------------------------------- 1 | algorithm_kwargs: 2 | batch_size: 256 3 | max_path_length: 1000 4 | min_num_steps_before_training: 10000 5 | num_epochs: 1000 6 | num_eval_paths_per_epoch: 10 7 | num_expl_steps_per_train_loop: 1000 8 | num_trains_per_train_loop: 1000 9 | env: Humanoid-v2 10 | eval_env_num: 10 11 | expl_env_num: 10 12 | layer_size: 256 13 | replay_buffer_size: 1000000 14 | trainer_kwargs: 15 | discount: 0.99 16 | version: normal 17 | -------------------------------------------------------------------------------- /configs/td3-normal/walker2d.yaml: -------------------------------------------------------------------------------- 1 | algorithm_kwargs: 2 | batch_size: 256 3 | max_path_length: 1000 4 | min_num_steps_before_training: 10000 5 | num_epochs: 1000 6 | num_eval_paths_per_epoch: 10 7 | num_expl_steps_per_train_loop: 1000 8 | num_trains_per_train_loop: 1000 9 | env: Walker2d-v2 10 | eval_env_num: 10 11 | expl_env_num: 10 12 | layer_size: 256 13 | replay_buffer_size: 1000000 14 | trainer_kwargs: 15 | discount: 0.99 16 | version: normal 17 | -------------------------------------------------------------------------------- /configs/td3-normal/halfcheetah.yaml: -------------------------------------------------------------------------------- 1 | algorithm_kwargs: 2 | batch_size: 256 3 | max_path_length: 1000 4 | min_num_steps_before_training: 10000 5 | num_epochs: 1000 6 | num_eval_paths_per_epoch: 10 7 | num_expl_steps_per_train_loop: 1000 8 | num_trains_per_train_loop: 1000 9 | env: HalfCheetah-v2 10 | eval_env_num: 10 11 | expl_env_num: 10 12 | layer_size: 256 13 | replay_buffer_size: 1000000 14 | trainer_kwargs: 15 | discount: 0.99 16 | version: normal 17 | -------------------------------------------------------------------------------- /configs/td3-normal/bipedalwalkerhardcore.yaml: -------------------------------------------------------------------------------- 1 | algorithm_kwargs: 2 | batch_size: 256 3 | max_path_length: 2000 4 | min_num_steps_before_training: 10000 5 | num_epochs: 1000 6 | num_eval_paths_per_epoch: 10 7 | num_expl_steps_per_train_loop: 3000 8 | num_trains_per_train_loop: 3000 9 | env: BipedalWalkerHardcore-v3 10 | eval_env_num: 10 11 | expl_env_num: 10 12 | layer_size: 256 13 | replay_buffer_size: 1000000 14 | trainer_kwargs: 15 | discount: 0.99 16 | version: normal 17 | -------------------------------------------------------------------------------- /rlkit/policies/base.py: -------------------------------------------------------------------------------- 1 | import abc 2 | 3 | 4 | class Policy(object, metaclass=abc.ABCMeta): 5 | """ 6 | General policy interface. 7 | """ 8 | @abc.abstractmethod 9 | def get_action(self, observation): 10 | """ 11 | 12 | :param observation: 13 | :return: action, debug_dictionary 14 | """ 15 | pass 16 | 17 | def reset(self): 18 | pass 19 | 20 | 21 | class ExplorationPolicy(Policy, metaclass=abc.ABCMeta): 22 | def set_num_steps_total(self, t): 23 | pass 24 | -------------------------------------------------------------------------------- /configs/td4-normal-iqn-neutral/ant.yaml: -------------------------------------------------------------------------------- 1 | algorithm_kwargs: 2 | batch_size: 256 3 | max_path_length: 1000 4 | min_num_steps_before_training: 10000 5 | num_epochs: 1000 6 | num_eval_paths_per_epoch: 10 7 | num_expl_steps_per_train_loop: 1000 8 | num_trains_per_train_loop: 1000 9 | env: Ant-v2 10 | eval_env_num: 10 11 | expl_env_num: 10 12 | layer_size: 256 13 | num_quantiles: 32 14 | replay_buffer_size: 1000000 15 | trainer_kwargs: 16 | discount: 0.99 17 | policy_lr: 0.0003 18 | tau_type: iqn 19 | zf_lr: 0.0003 20 | version: normal-iqn-neutral 21 | -------------------------------------------------------------------------------- /configs/td4-normal-iqn-neutral/hopper.yaml: -------------------------------------------------------------------------------- 1 | algorithm_kwargs: 2 | batch_size: 256 3 | max_path_length: 1000 4 | min_num_steps_before_training: 10000 5 | num_epochs: 1000 6 | num_eval_paths_per_epoch: 10 7 | num_expl_steps_per_train_loop: 1000 8 | num_trains_per_train_loop: 1000 9 | env: Hopper-v2 10 | eval_env_num: 10 11 | expl_env_num: 10 12 | layer_size: 256 13 | num_quantiles: 32 14 | replay_buffer_size: 1000000 15 | trainer_kwargs: 16 | discount: 0.99 17 | policy_lr: 0.0003 18 | tau_type: iqn 19 | zf_lr: 0.0003 20 | version: normal-iqn-neutral 21 | -------------------------------------------------------------------------------- /configs/td4-normal-iqn-neutral/humanoid.yaml: -------------------------------------------------------------------------------- 1 | algorithm_kwargs: 2 | batch_size: 256 3 | max_path_length: 1000 4 | min_num_steps_before_training: 10000 5 | num_epochs: 1000 6 | num_eval_paths_per_epoch: 10 7 | num_expl_steps_per_train_loop: 1000 8 | num_trains_per_train_loop: 1000 9 | env: Humanoid-v2 10 | eval_env_num: 10 11 | expl_env_num: 10 12 | layer_size: 256 13 | num_quantiles: 32 14 | replay_buffer_size: 1000000 15 | trainer_kwargs: 16 | discount: 0.99 17 | policy_lr: 0.0003 18 | tau_type: iqn 19 | zf_lr: 0.0003 20 | version: normal-iqn-neutral 21 | -------------------------------------------------------------------------------- /configs/td4-normal-iqn-neutral/walker2d.yaml: -------------------------------------------------------------------------------- 1 | algorithm_kwargs: 2 | batch_size: 256 3 | max_path_length: 1000 4 | min_num_steps_before_training: 10000 5 | num_epochs: 1000 6 | num_eval_paths_per_epoch: 10 7 | num_expl_steps_per_train_loop: 1000 8 | num_trains_per_train_loop: 1000 9 | env: Walker2d-v2 10 | eval_env_num: 10 11 | expl_env_num: 10 12 | layer_size: 256 13 | num_quantiles: 32 14 | replay_buffer_size: 1000000 15 | trainer_kwargs: 16 | discount: 0.99 17 | policy_lr: 0.0003 18 | tau_type: iqn 19 | zf_lr: 0.0003 20 | version: normal-iqn-neutral 21 | -------------------------------------------------------------------------------- /configs/td4-normal-iqn-neutral/halfcheetah.yaml: -------------------------------------------------------------------------------- 1 | algorithm_kwargs: 2 | batch_size: 256 3 | max_path_length: 1000 4 | min_num_steps_before_training: 10000 5 | num_epochs: 1000 6 | num_eval_paths_per_epoch: 10 7 | num_expl_steps_per_train_loop: 1000 8 | num_trains_per_train_loop: 1000 9 | env: HalfCheetah-v2 10 | eval_env_num: 10 11 | expl_env_num: 10 12 | layer_size: 256 13 | num_quantiles: 32 14 | replay_buffer_size: 1000000 15 | trainer_kwargs: 16 | discount: 0.99 17 | policy_lr: 0.0003 18 | tau_type: iqn 19 | zf_lr: 0.0003 20 | version: normal-iqn-neutral 21 | -------------------------------------------------------------------------------- /configs/td4-normal-iqn-neutral/bipedalwalkerhardcore.yaml: -------------------------------------------------------------------------------- 1 | algorithm_kwargs: 2 | batch_size: 256 3 | max_path_length: 2000 4 | min_num_steps_before_training: 10000 5 | num_epochs: 1000 6 | num_eval_paths_per_epoch: 10 7 | num_expl_steps_per_train_loop: 3000 8 | num_trains_per_train_loop: 3000 9 | env: BipedalWalkerHardcore-v3 10 | eval_env_num: 10 11 | expl_env_num: 10 12 | layer_size: 256 13 | num_quantiles: 32 14 | replay_buffer_size: 1000000 15 | trainer_kwargs: 16 | discount: 0.99 17 | policy_lr: 0.0003 18 | tau_type: iqn 19 | zf_lr: 0.0003 20 | version: normal-iqn-neutral 21 | -------------------------------------------------------------------------------- /configs/sac-normal/ant.yaml: -------------------------------------------------------------------------------- 1 | algorithm_kwargs: 2 | batch_size: 256 3 | max_path_length: 1000 4 | min_num_steps_before_training: 10000 5 | num_epochs: 1000 6 | num_eval_paths_per_epoch: 10 7 | num_expl_steps_per_train_loop: 1000 8 | num_trains_per_train_loop: 1000 9 | env: Ant-v2 10 | eval_env_num: 10 11 | expl_env_num: 10 12 | layer_size: 256 13 | replay_buffer_size: 1000000 14 | trainer_kwargs: 15 | alpha: 0.2 16 | discount: 0.99 17 | policy_lr: 0.0003 18 | qf_lr: 0.0003 19 | soft_target_tau: 0.005 20 | target_update_period: 1 21 | use_automatic_entropy_tuning: false 22 | version: normal 23 | -------------------------------------------------------------------------------- /configs/sac-normal/hopper.yaml: -------------------------------------------------------------------------------- 1 | algorithm_kwargs: 2 | batch_size: 256 3 | max_path_length: 1000 4 | min_num_steps_before_training: 10000 5 | num_epochs: 1000 6 | num_eval_paths_per_epoch: 10 7 | num_expl_steps_per_train_loop: 1000 8 | num_trains_per_train_loop: 1000 9 | env: Hopper-v2 10 | eval_env_num: 10 11 | expl_env_num: 10 12 | layer_size: 256 13 | replay_buffer_size: 1000000 14 | trainer_kwargs: 15 | alpha: 0.2 16 | discount: 0.99 17 | policy_lr: 0.0003 18 | qf_lr: 0.0003 19 | soft_target_tau: 0.005 20 | target_update_period: 1 21 | use_automatic_entropy_tuning: false 22 | version: normal 23 | -------------------------------------------------------------------------------- /configs/sac-normal/humanoid.yaml: -------------------------------------------------------------------------------- 1 | algorithm_kwargs: 2 | batch_size: 256 3 | max_path_length: 1000 4 | min_num_steps_before_training: 10000 5 | num_epochs: 1000 6 | num_eval_paths_per_epoch: 10 7 | num_expl_steps_per_train_loop: 1000 8 | num_trains_per_train_loop: 1000 9 | env: Humanoid-v2 10 | eval_env_num: 10 11 | expl_env_num: 10 12 | layer_size: 256 13 | replay_buffer_size: 1000000 14 | trainer_kwargs: 15 | alpha: 0.05 16 | discount: 0.99 17 | policy_lr: 0.0003 18 | qf_lr: 0.0003 19 | soft_target_tau: 0.005 20 | target_update_period: 1 21 | use_automatic_entropy_tuning: false 22 | version: normal 23 | -------------------------------------------------------------------------------- /configs/sac-normal/walker2d.yaml: -------------------------------------------------------------------------------- 1 | algorithm_kwargs: 2 | batch_size: 256 3 | max_path_length: 1000 4 | min_num_steps_before_training: 10000 5 | num_epochs: 1000 6 | num_eval_paths_per_epoch: 10 7 | num_expl_steps_per_train_loop: 1000 8 | num_trains_per_train_loop: 1000 9 | env: Walker2d-v2 10 | eval_env_num: 10 11 | expl_env_num: 10 12 | layer_size: 256 13 | replay_buffer_size: 1000000 14 | trainer_kwargs: 15 | alpha: 0.2 16 | discount: 0.99 17 | policy_lr: 0.0003 18 | qf_lr: 0.0003 19 | soft_target_tau: 0.005 20 | target_update_period: 1 21 | use_automatic_entropy_tuning: false 22 | version: normal 23 | -------------------------------------------------------------------------------- /rlkit/policies/argmax.py: -------------------------------------------------------------------------------- 1 | """ 2 | Torch argmax policy 3 | """ 4 | import numpy as np 5 | from torch import nn 6 | 7 | import rlkit.torch.pytorch_util as ptu 8 | from rlkit.policies.base import Policy 9 | 10 | 11 | class ArgmaxDiscretePolicy(nn.Module, Policy): 12 | def __init__(self, qf): 13 | super().__init__() 14 | self.qf = qf 15 | 16 | def get_action(self, obs): 17 | obs = np.expand_dims(obs, axis=0) 18 | obs = ptu.from_numpy(obs).float() 19 | q_values = self.qf(obs).squeeze(0) 20 | q_values_np = ptu.get_numpy(q_values) 21 | return q_values_np.argmax(), {} 22 | -------------------------------------------------------------------------------- /configs/sac-normal/halfcheetah.yaml: -------------------------------------------------------------------------------- 1 | algorithm_kwargs: 2 | batch_size: 256 3 | max_path_length: 1000 4 | min_num_steps_before_training: 10000 5 | num_epochs: 1000 6 | num_eval_paths_per_epoch: 10 7 | num_expl_steps_per_train_loop: 1000 8 | num_trains_per_train_loop: 1000 9 | env: HalfCheetah-v2 10 | eval_env_num: 10 11 | expl_env_num: 10 12 | layer_size: 256 13 | replay_buffer_size: 1000000 14 | trainer_kwargs: 15 | alpha: 0.2 16 | discount: 0.99 17 | policy_lr: 0.0003 18 | qf_lr: 0.0003 19 | soft_target_tau: 0.005 20 | target_update_period: 1 21 | use_automatic_entropy_tuning: false 22 | version: normal 23 | -------------------------------------------------------------------------------- /rlkit/samplers/data_collector/__init__.py: -------------------------------------------------------------------------------- 1 | from rlkit.samplers.data_collector.base import ( 2 | DataCollector, 3 | PathCollector, 4 | StepCollector, 5 | ) 6 | from rlkit.samplers.data_collector.path_collector import ( 7 | MdpPathCollector, 8 | EvalPathCollector, 9 | GoalConditionedPathCollector, 10 | ) 11 | from rlkit.samplers.data_collector.step_collector import ( 12 | MdpStepCollector, 13 | GoalConditionedStepCollector, 14 | ) 15 | from rlkit.samplers.data_collector.vec_step_collector import VecMdpStepCollector 16 | 17 | from rlkit.samplers.data_collector.vec_path_collector import VecMdpPathCollector 18 | -------------------------------------------------------------------------------- /configs/dsac-normal-iqn-neutral/ant.yaml: -------------------------------------------------------------------------------- 1 | algorithm_kwargs: 2 | batch_size: 256 3 | max_path_length: 1000 4 | min_num_steps_before_training: 10000 5 | num_epochs: 1000 6 | num_eval_paths_per_epoch: 10 7 | num_expl_steps_per_train_loop: 1000 8 | num_trains_per_train_loop: 1000 9 | env: Ant-v2 10 | eval_env_num: 10 11 | expl_env_num: 10 12 | layer_size: 256 13 | num_quantiles: 32 14 | replay_buffer_size: 1000000 15 | trainer_kwargs: 16 | alpha: 0.2 17 | discount: 0.99 18 | policy_lr: 0.0003 19 | soft_target_tau: 0.005 20 | tau_type: iqn 21 | use_automatic_entropy_tuning: false 22 | zf_lr: 0.0003 23 | version: normal-iqn-neutral 24 | -------------------------------------------------------------------------------- /configs/sac-normal/bipedalwalkerhardcore.yaml: -------------------------------------------------------------------------------- 1 | algorithm_kwargs: 2 | batch_size: 256 3 | max_path_length: 2000 4 | min_num_steps_before_training: 10000 5 | num_epochs: 1000 6 | num_eval_paths_per_epoch: 10 7 | num_expl_steps_per_train_loop: 3000 8 | num_trains_per_train_loop: 3000 9 | env: BipedalWalkerHardcore-v3 10 | eval_env_num: 10 11 | expl_env_num: 10 12 | layer_size: 256 13 | replay_buffer_size: 1000000 14 | trainer_kwargs: 15 | alpha: 0.002 16 | discount: 0.99 17 | policy_lr: 0.0003 18 | qf_lr: 0.0003 19 | soft_target_tau: 0.005 20 | target_update_period: 1 21 | use_automatic_entropy_tuning: false 22 | version: normal 23 | -------------------------------------------------------------------------------- /rlkit/exploration_strategies/epsilon_greedy.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | from rlkit.exploration_strategies.base import RawExplorationStrategy 4 | 5 | 6 | class EpsilonGreedy(RawExplorationStrategy): 7 | """ 8 | Take a random discrete action with some probability. 9 | """ 10 | def __init__(self, action_space, prob_random_action=0.1): 11 | self.prob_random_action = prob_random_action 12 | self.action_space = action_space 13 | 14 | def get_action_from_raw_action(self, action, **kwargs): 15 | if random.random() <= self.prob_random_action: 16 | return self.action_space.sample() 17 | return action 18 | -------------------------------------------------------------------------------- /configs/dsac-normal-iqn-neutral/hopper.yaml: -------------------------------------------------------------------------------- 1 | algorithm_kwargs: 2 | batch_size: 256 3 | max_path_length: 1000 4 | min_num_steps_before_training: 10000 5 | num_epochs: 1000 6 | num_eval_paths_per_epoch: 10 7 | num_expl_steps_per_train_loop: 1000 8 | num_trains_per_train_loop: 1000 9 | env: Hopper-v2 10 | eval_env_num: 10 11 | expl_env_num: 10 12 | layer_size: 256 13 | num_quantiles: 32 14 | replay_buffer_size: 1000000 15 | trainer_kwargs: 16 | alpha: 0.2 17 | discount: 0.99 18 | policy_lr: 0.0003 19 | soft_target_tau: 0.005 20 | tau_type: iqn 21 | use_automatic_entropy_tuning: false 22 | zf_lr: 0.0003 23 | version: normal-iqn-neutral 24 | -------------------------------------------------------------------------------- /configs/dsac-normal-iqn-neutral/humanoid.yaml: -------------------------------------------------------------------------------- 1 | algorithm_kwargs: 2 | batch_size: 256 3 | max_path_length: 1000 4 | min_num_steps_before_training: 10000 5 | num_epochs: 1000 6 | num_eval_paths_per_epoch: 10 7 | num_expl_steps_per_train_loop: 1000 8 | num_trains_per_train_loop: 1000 9 | env: Humanoid-v2 10 | eval_env_num: 10 11 | expl_env_num: 10 12 | layer_size: 256 13 | num_quantiles: 32 14 | replay_buffer_size: 1000000 15 | trainer_kwargs: 16 | alpha: 0.05 17 | discount: 0.99 18 | policy_lr: 0.0003 19 | soft_target_tau: 0.005 20 | tau_type: iqn 21 | use_automatic_entropy_tuning: false 22 | zf_lr: 0.0003 23 | version: normal-iqn-neutral 24 | -------------------------------------------------------------------------------- /configs/dsac-normal-iqn-neutral/walker2d.yaml: -------------------------------------------------------------------------------- 1 | algorithm_kwargs: 2 | batch_size: 256 3 | max_path_length: 1000 4 | min_num_steps_before_training: 10000 5 | num_epochs: 1000 6 | num_eval_paths_per_epoch: 10 7 | num_expl_steps_per_train_loop: 1000 8 | num_trains_per_train_loop: 1000 9 | env: Walker2d-v2 10 | eval_env_num: 10 11 | expl_env_num: 10 12 | layer_size: 256 13 | num_quantiles: 32 14 | replay_buffer_size: 1000000 15 | trainer_kwargs: 16 | alpha: 0.2 17 | discount: 0.99 18 | policy_lr: 0.0003 19 | soft_target_tau: 0.005 20 | tau_type: iqn 21 | use_automatic_entropy_tuning: false 22 | zf_lr: 0.0003 23 | version: normal-iqn-neutral 24 | -------------------------------------------------------------------------------- /configs/dsac-normal-iqn-neutral/halfcheetah.yaml: -------------------------------------------------------------------------------- 1 | algorithm_kwargs: 2 | batch_size: 256 3 | max_path_length: 1000 4 | min_num_steps_before_training: 10000 5 | num_epochs: 1000 6 | num_eval_paths_per_epoch: 10 7 | num_expl_steps_per_train_loop: 1000 8 | num_trains_per_train_loop: 1000 9 | env: HalfCheetah-v2 10 | eval_env_num: 10 11 | expl_env_num: 10 12 | layer_size: 256 13 | num_quantiles: 32 14 | replay_buffer_size: 1000000 15 | trainer_kwargs: 16 | alpha: 0.2 17 | discount: 0.99 18 | policy_lr: 0.0003 19 | soft_target_tau: 0.005 20 | tau_type: iqn 21 | use_automatic_entropy_tuning: false 22 | zf_lr: 0.0003 23 | version: normal-iqn-neutral 24 | -------------------------------------------------------------------------------- /configs/dsac-normal-iqn-neutral/bipedalwalkerhardcore.yaml: -------------------------------------------------------------------------------- 1 | algorithm_kwargs: 2 | batch_size: 256 3 | max_path_length: 2000 4 | min_num_steps_before_training: 10000 5 | num_epochs: 1000 6 | num_eval_paths_per_epoch: 10 7 | num_expl_steps_per_train_loop: 3000 8 | num_trains_per_train_loop: 3000 9 | env: BipedalWalkerHardcore-v3 10 | eval_env_num: 10 11 | expl_env_num: 10 12 | layer_size: 256 13 | num_quantiles: 32 14 | replay_buffer_size: 1000000 15 | trainer_kwargs: 16 | alpha: 0.002 17 | discount: 0.99 18 | policy_lr: 0.0003 19 | soft_target_tau: 0.005 20 | tau_type: iqn 21 | use_automatic_entropy_tuning: false 22 | zf_lr: 0.0003 23 | version: normal-iqn-neutral 24 | -------------------------------------------------------------------------------- /rlkit/samplers/data_collector/base.py: -------------------------------------------------------------------------------- 1 | import abc 2 | 3 | 4 | class DataCollector(object, metaclass=abc.ABCMeta): 5 | def end_epoch(self, epoch): 6 | pass 7 | 8 | def get_diagnostics(self): 9 | return {} 10 | 11 | def get_snapshot(self): 12 | return {} 13 | 14 | @abc.abstractmethod 15 | def get_epoch_paths(self): 16 | pass 17 | 18 | 19 | class PathCollector(DataCollector, metaclass=abc.ABCMeta): 20 | @abc.abstractmethod 21 | def collect_new_paths( 22 | self, 23 | max_path_length, 24 | num_steps, 25 | discard_incomplete_paths, 26 | ): 27 | pass 28 | 29 | 30 | class StepCollector(DataCollector, metaclass=abc.ABCMeta): 31 | @abc.abstractmethod 32 | def collect_new_steps( 33 | self, 34 | max_path_length, 35 | num_steps, 36 | discard_incomplete_paths, 37 | ): 38 | pass 39 | -------------------------------------------------------------------------------- /rlkit/torch/dsac/utils.py: -------------------------------------------------------------------------------- 1 | class LinearSchedule(object): 2 | 3 | def __init__(self, schedule_timesteps, initial=1., final=0.): 4 | """Linear interpolation between initial_p and final_p over 5 | schedule_timesteps. After this many timesteps pass final_p is 6 | returned. 7 | 8 | Parameters 9 | ---------- 10 | schedule_timesteps: int 11 | Number of timesteps for which to linearly anneal initial_p 12 | to final_p 13 | initial_p: float 14 | initial output value 15 | final_p: float 16 | final output value 17 | """ 18 | self.schedule_timesteps = schedule_timesteps 19 | self.final = final 20 | self.initial = initial 21 | 22 | def __call__(self, t): 23 | """See Schedule.value""" 24 | fraction = min(float(t) / self.schedule_timesteps, 1.0) 25 | return self.initial + fraction * (self.final - self.initial) 26 | -------------------------------------------------------------------------------- /rlkit/torch/td4/utils.py: -------------------------------------------------------------------------------- 1 | class LinearSchedule(object): 2 | 3 | def __init__(self, schedule_timesteps, initial=1., final=0.): 4 | """Linear interpolation between initial_p and final_p over 5 | schedule_timesteps. After this many timesteps pass final_p is 6 | returned. 7 | 8 | Parameters 9 | ---------- 10 | schedule_timesteps: int 11 | Number of timesteps for which to linearly anneal initial_p 12 | to final_p 13 | initial_p: float 14 | initial output value 15 | final_p: float 16 | final output value 17 | """ 18 | self.schedule_timesteps = schedule_timesteps 19 | self.final = final 20 | self.initial = initial 21 | 22 | def __call__(self, t): 23 | """See Schedule.value""" 24 | fraction = min(float(t) / self.schedule_timesteps, 1.0) 25 | return self.initial + fraction * (self.final - self.initial) 26 | -------------------------------------------------------------------------------- /rlkit/exploration_strategies/gaussian_strategy.py: -------------------------------------------------------------------------------- 1 | from rlkit.exploration_strategies.base import RawExplorationStrategy 2 | import numpy as np 3 | 4 | 5 | class GaussianStrategy(RawExplorationStrategy): 6 | """ 7 | This strategy adds Gaussian noise to the action taken by the deterministic policy. 8 | 9 | Based on the rllab implementation. 10 | """ 11 | 12 | def __init__(self, action_space, max_sigma=1.0, min_sigma=None, decay_period=1000000): 13 | assert len(action_space.shape) == 1 14 | self._max_sigma = max_sigma 15 | if min_sigma is None: 16 | min_sigma = max_sigma 17 | self._min_sigma = min_sigma 18 | self._decay_period = decay_period 19 | self._action_space = action_space 20 | 21 | def get_action_from_raw_action(self, action, t=None, **kwargs): 22 | sigma = (self._max_sigma - (self._max_sigma - self._min_sigma) * min(1.0, t * 1.0 / self._decay_period)) 23 | return np.clip( 24 | action + np.random.normal(size=action.shape) * sigma, 25 | self._action_space.low, 26 | self._action_space.high, 27 | ) 28 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Xiaoteng Ma 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /rlkit/envs/env_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import cloudpickle 4 | from gym.spaces import Box, Discrete, Tuple 5 | 6 | ENV_ASSET_DIR = os.path.join(os.path.dirname(__file__), 'assets') 7 | 8 | 9 | def get_asset_full_path(file_name): 10 | return os.path.join(ENV_ASSET_DIR, file_name) 11 | 12 | 13 | def get_dim(space): 14 | if isinstance(space, Box): 15 | return space.low.size 16 | elif isinstance(space, Discrete): 17 | return space.n 18 | elif isinstance(space, Tuple): 19 | return sum(get_dim(subspace) for subspace in space.spaces) 20 | elif hasattr(space, 'flat_dim'): 21 | return space.flat_dim 22 | else: 23 | raise TypeError("Unknown space: {}".format(space)) 24 | 25 | 26 | def mode(env, mode_type): 27 | try: 28 | getattr(env, mode_type)() 29 | except AttributeError: 30 | pass 31 | 32 | 33 | class CloudpickleWrapper(object): 34 | """A cloudpickle wrapper used in :class:`~tianshou.env.SubprocVectorEnv`""" 35 | 36 | def __init__(self, data): 37 | self.data = data 38 | 39 | def __getstate__(self): 40 | return cloudpickle.dumps(self.data) 41 | 42 | def __setstate__(self, data): 43 | self.data = cloudpickle.loads(data) 44 | -------------------------------------------------------------------------------- /rlkit/exploration_strategies/gaussian_and_epsilon_strategy.py: -------------------------------------------------------------------------------- 1 | import random 2 | from rlkit.exploration_strategies.base import RawExplorationStrategy 3 | import numpy as np 4 | 5 | 6 | class GaussianAndEpislonStrategy(RawExplorationStrategy): 7 | """ 8 | With probability epsilon, take a completely random action. 9 | with probability 1-epsilon, add Gaussian noise to the action taken by a 10 | deterministic policy. 11 | """ 12 | def __init__(self, action_space, epsilon, max_sigma=1.0, min_sigma=None, 13 | decay_period=1000000): 14 | assert len(action_space.shape) == 1 15 | if min_sigma is None: 16 | min_sigma = max_sigma 17 | self._max_sigma = max_sigma 18 | self._epsilon = epsilon 19 | self._min_sigma = min_sigma 20 | self._decay_period = decay_period 21 | self._action_space = action_space 22 | 23 | def get_action_from_raw_action(self, action, t=None, **kwargs): 24 | if random.random() < self._epsilon: 25 | return self._action_space.sample() 26 | else: 27 | sigma = self._max_sigma - (self._max_sigma - self._min_sigma) * min(1.0, t * 1.0 / self._decay_period) 28 | return np.clip( 29 | action + np.random.normal(size=len(action)) * sigma, 30 | self._action_space.low, 31 | self._action_space.high, 32 | ) 33 | -------------------------------------------------------------------------------- /rlkit/data_management/path_builder.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | class PathBuilder(dict): 5 | """ 6 | Usage: 7 | ``` 8 | path_builder = PathBuilder() 9 | path.add_sample( 10 | observations=1, 11 | actions=2, 12 | next_observations=3, 13 | ... 14 | ) 15 | path.add_sample( 16 | observations=4, 17 | actions=5, 18 | next_observations=6, 19 | ... 20 | ) 21 | 22 | path = path_builder.get_all_stacked() 23 | 24 | path['observations'] 25 | # output: [1, 4] 26 | path['actions'] 27 | # output: [2, 5] 28 | ``` 29 | 30 | Note that the key should be "actions" and not "action" since the 31 | resulting dictionary will have those keys. 32 | """ 33 | 34 | def __init__(self): 35 | super().__init__() 36 | self._path_length = 0 37 | 38 | def add_all(self, **key_to_value): 39 | for k, v in key_to_value.items(): 40 | if k not in self: 41 | self[k] = [v] 42 | else: 43 | self[k].append(v) 44 | self._path_length += 1 45 | 46 | def get_all_stacked(self): 47 | output_dict = dict() 48 | for k, v in self.items(): 49 | output_dict[k] = stack_list(v) 50 | return output_dict 51 | 52 | def __len__(self): 53 | return self._path_length 54 | 55 | 56 | def stack_list(lst): 57 | if isinstance(lst[0], dict): 58 | return lst 59 | else: 60 | return np.array(lst) 61 | -------------------------------------------------------------------------------- /rlkit/torch/td3/networks.py: -------------------------------------------------------------------------------- 1 | """ 2 | General networks for pytorch. 3 | 4 | Algorithm-specific networks should go else-where. 5 | """ 6 | import torch 7 | from torch import nn as nn 8 | from torch.nn import functional as F 9 | 10 | from rlkit.torch import pytorch_util as ptu 11 | from rlkit.torch.networks import Mlp 12 | 13 | 14 | def identity(x): 15 | return x 16 | 17 | 18 | class TD3Mlp(nn.Module): 19 | 20 | def __init__( 21 | self, 22 | hidden_sizes, 23 | obs_dim, 24 | action_dim, 25 | init_w=3e-3, 26 | hidden_activation=F.relu, 27 | output_activation=identity, 28 | hidden_init=ptu.fanin_init, 29 | b_init_value=0.1, 30 | layer_norm=False, 31 | layer_norm_kwargs=None, 32 | ): 33 | super().__init__() 34 | self.fc1 = Mlp( 35 | input_size=obs_dim + action_dim, 36 | hidden_sizes=[], 37 | output_size=hidden_sizes[0], 38 | output_activation=hidden_activation, 39 | layer_norm=layer_norm, 40 | ) 41 | self.fc2 = Mlp( 42 | input_size=action_dim + hidden_sizes[0], 43 | hidden_sizes=hidden_sizes[1:], 44 | output_size=1, 45 | output_activation=output_activation, 46 | layer_norm=layer_norm, 47 | ) 48 | 49 | def forward(self, state, action): 50 | h = self.fc1(torch.cat([state, action], dim=1)) 51 | h = torch.cat([h, action], dim=1) 52 | output = self.fc2(h) 53 | return output 54 | -------------------------------------------------------------------------------- /rlkit/exploration_strategies/base.py: -------------------------------------------------------------------------------- 1 | import abc 2 | 3 | from rlkit.policies.base import ExplorationPolicy 4 | 5 | 6 | class ExplorationStrategy(object, metaclass=abc.ABCMeta): 7 | 8 | @abc.abstractmethod 9 | def get_action(self, t, observation, policy, **kwargs): 10 | pass 11 | 12 | def reset(self): 13 | pass 14 | 15 | 16 | class RawExplorationStrategy(ExplorationStrategy, metaclass=abc.ABCMeta): 17 | 18 | @abc.abstractmethod 19 | def get_action_from_raw_action(self, action, **kwargs): 20 | pass 21 | 22 | def get_action(self, t, policy, *args, **kwargs): 23 | action, agent_info = policy.get_action(*args, **kwargs) 24 | return self.get_action_from_raw_action(action, t=t), agent_info 25 | 26 | def get_actions(self, t, policy, *args, **kwargs): 27 | action = policy.get_actions(*args, **kwargs) 28 | return self.get_action_from_raw_action(action, t=t) 29 | 30 | def reset(self): 31 | pass 32 | 33 | 34 | class PolicyWrappedWithExplorationStrategy(ExplorationPolicy): 35 | 36 | def __init__( 37 | self, 38 | exploration_strategy: ExplorationStrategy, 39 | policy, 40 | ): 41 | self.es = exploration_strategy 42 | self.policy = policy 43 | self.t = 0 44 | 45 | def set_num_steps_total(self, t): 46 | self.t = t 47 | 48 | def get_action(self, *args, **kwargs): 49 | return self.es.get_action(self.t, self.policy, *args, **kwargs) 50 | 51 | def get_actions(self, *args, **kwargs): 52 | return self.es.get_actions(self.t, self.policy, *args, **kwargs) 53 | 54 | def reset(self): 55 | self.es.reset() 56 | self.policy.reset() 57 | -------------------------------------------------------------------------------- /rlkit/data_management/env_replay_buffer.py: -------------------------------------------------------------------------------- 1 | from gym.spaces import Discrete 2 | 3 | from rlkit.data_management.simple_replay_buffer import SimpleReplayBuffer 4 | from rlkit.envs.env_utils import get_dim 5 | import numpy as np 6 | 7 | 8 | class EnvReplayBuffer(SimpleReplayBuffer): 9 | 10 | def __init__(self, max_replay_buffer_size, env, env_info_sizes=None): 11 | """ 12 | :param max_replay_buffer_size: 13 | :param env: 14 | """ 15 | self.env = env 16 | self._ob_space = env.observation_space 17 | self._action_space = env.action_space 18 | 19 | if env_info_sizes is None: 20 | if hasattr(env, 'info_sizes'): 21 | env_info_sizes = env.info_sizes 22 | else: 23 | env_info_sizes = dict() 24 | 25 | super().__init__(max_replay_buffer_size=max_replay_buffer_size, 26 | observation_dim=get_dim(self._ob_space), 27 | action_dim=get_dim(self._action_space), 28 | env_info_sizes=env_info_sizes) 29 | 30 | def add_sample(self, observation, action, reward, terminal, next_observation, **kwargs): 31 | if isinstance(self._action_space, Discrete): 32 | new_action = np.zeros(self._action_dim) 33 | new_action[action] = 1 34 | else: 35 | new_action = action 36 | return super().add_sample(observation=observation, 37 | action=new_action, 38 | reward=reward, 39 | next_observation=next_observation, 40 | terminal=terminal, 41 | **kwargs) 42 | -------------------------------------------------------------------------------- /rlkit/torch/core.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | from rlkit.torch import pytorch_util as ptu 5 | 6 | 7 | def eval_np(module, *args, **kwargs): 8 | """ 9 | Eval this module with a numpy interface 10 | 11 | Same as a call to __call__ except all Variable input/outputs are 12 | replaced with numpy equivalents. 13 | 14 | Assumes the output is either a single object or a tuple of objects. 15 | """ 16 | torch_args = tuple(torch_ify(x) for x in args) 17 | torch_kwargs = {k: torch_ify(v) for k, v in kwargs.items()} 18 | outputs = module(*torch_args, **torch_kwargs) 19 | if isinstance(outputs, tuple): 20 | return tuple(np_ify(x) for x in outputs) 21 | else: 22 | return np_ify(outputs) 23 | 24 | 25 | def torch_ify(np_array_or_other): 26 | if isinstance(np_array_or_other, np.ndarray): 27 | return ptu.from_numpy(np_array_or_other) 28 | else: 29 | return np_array_or_other 30 | 31 | 32 | def np_ify(tensor_or_other): 33 | if isinstance(tensor_or_other, torch.autograd.Variable): 34 | return ptu.get_numpy(tensor_or_other) 35 | else: 36 | return tensor_or_other 37 | 38 | 39 | def _elem_or_tuple_to_variable(elem_or_tuple): 40 | if isinstance(elem_or_tuple, tuple): 41 | return tuple(_elem_or_tuple_to_variable(e) for e in elem_or_tuple) 42 | return ptu.from_numpy(elem_or_tuple).float() 43 | 44 | 45 | def _filter_batch(np_batch): 46 | for k, v in np_batch.items(): 47 | if v.dtype == np.bool: 48 | yield k, v.astype(int) 49 | else: 50 | yield k, v 51 | 52 | 53 | def np_to_pytorch_batch(np_batch): 54 | return { 55 | k: _elem_or_tuple_to_variable(x) 56 | for k, x in _filter_batch(np_batch) 57 | if x.dtype != np.dtype('O') # ignore object (e.g. dictionaries) 58 | } 59 | -------------------------------------------------------------------------------- /rlkit/torch/torch_rl_algorithm.py: -------------------------------------------------------------------------------- 1 | import abc 2 | from collections import OrderedDict 3 | 4 | from typing import Iterable 5 | from torch import nn as nn 6 | 7 | from rlkit.core.batch_rl_algorithm import BatchRLAlgorithm 8 | from rlkit.core.online_rl_algorithm import OnlineRLAlgorithm 9 | from rlkit.core.vec_online_rl_algorithm import VecOnlineRLAlgorithm 10 | from rlkit.core.trainer import Trainer 11 | 12 | 13 | class TorchOnlineRLAlgorithm(OnlineRLAlgorithm): 14 | 15 | def to(self, device): 16 | for net in self.trainer.networks: 17 | net.to(device) 18 | 19 | def training_mode(self, mode): 20 | for net in self.trainer.networks: 21 | net.train(mode) 22 | 23 | 24 | class TorchBatchRLAlgorithm(BatchRLAlgorithm): 25 | 26 | def to(self, device): 27 | for net in self.trainer.networks: 28 | net.to(device) 29 | 30 | def training_mode(self, mode): 31 | for net in self.trainer.networks: 32 | net.train(mode) 33 | 34 | 35 | class TorchVecOnlineRLAlgorithm(VecOnlineRLAlgorithm): 36 | 37 | def to(self, device): 38 | for net in self.trainer.networks: 39 | net.to(device) 40 | 41 | def training_mode(self, mode): 42 | for net in self.trainer.networks: 43 | net.train(mode) 44 | 45 | 46 | class TorchTrainer(Trainer, metaclass=abc.ABCMeta): 47 | 48 | def __init__(self): 49 | self._num_train_steps = 0 50 | 51 | def train(self, batch): 52 | self._num_train_steps += 1 53 | self.train_from_torch(batch) 54 | 55 | def get_diagnostics(self): 56 | return OrderedDict([ 57 | ('num train calls', self._num_train_steps), 58 | ]) 59 | 60 | @abc.abstractmethod 61 | def train_from_torch(self, batch): 62 | pass 63 | 64 | @property 65 | @abc.abstractmethod 66 | def networks(self) -> Iterable[nn.Module]: 67 | pass 68 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DSAC 2 | Implementation of Distributional Soft Actor Critic (DSAC). 3 | This repository is based on [RLkit](https://github.com/vitchyr/rlkit), a reinforcement learning framework implemented by PyTorch. 4 | The core algorithm of DSAC is in `rlkit/torch/dsac/` 5 | 6 | ## Requirements 7 | - python 3.6+ 8 | - pytorch 1.0+ 9 | - gym[all] 0.15+ 10 | - scipy 1.0+ 11 | - numpy 12 | - matplotlib 13 | - gtimer 14 | - pyyaml 15 | 16 | ## Usage 17 | You can write your experiment settings in YAML and run with 18 | ``` 19 | python dsac.py --config your_config.yaml --gpu 0 --seed 0 20 | ``` 21 | To run our implementation of SAC/TD3/TD4, please replace dsac.py with sac.py/td3.py/td4.py. Set `--gpu -1`, your program will run on CPU. 22 | 23 | The experimental configurations of the paper are in `config/`. A typical configuration in YAML is given as follow: 24 | ``` 25 | env: Hopper-v2 26 | version: normal-iqn-neutral # version for logging 27 | eval_env_num: 10 # # of paralleled environments for evaluation 28 | expl_env_num: 10 # of paralleled environments for exploration 29 | layer_size: 256 # hidden size of networks 30 | num_quantiles: 32 31 | replay_buffer_size: 1000000 32 | algorithm_kwargs: 33 | batch_size: 256 34 | max_path_length: 1000 35 | min_num_steps_before_training: 10000 36 | num_epochs: 1000 37 | num_eval_paths_per_epoch: 10 38 | num_expl_steps_per_train_loop: 1000 39 | num_trains_per_train_loop: 1000 40 | trainer_kwargs: 41 | alpha: 0.2 42 | discount: 0.99 43 | policy_lr: 0.0003 44 | zf_lr: 0.0003 45 | soft_target_tau: 0.005 46 | tau_type: iqn # quantile fraction generation method, choices: fix, iqn, fqf 47 | use_automatic_entropy_tuning: false 48 | ``` 49 | 50 | Learning under risk measures is available for DSAC and TD4. We provide 6 choices of risk metrics: `neutral`, `std`, `VaR`, `cpw`, `wang`, `cvar`. You can change the risk preference by add two additional items in your YAML config: 51 | ``` 52 | ... 53 | 54 | trainer_kwargs: 55 | ... 56 | risk_type: std 57 | risk_param: 0.1 58 | ``` 59 | 60 | 61 | -------------------------------------------------------------------------------- /rlkit/exploration_strategies/ou_strategy.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import numpy.random as nr 3 | 4 | from rlkit.exploration_strategies.base import RawExplorationStrategy 5 | 6 | 7 | class OUStrategy(RawExplorationStrategy): 8 | """ 9 | This strategy implements the Ornstein-Uhlenbeck process, which adds 10 | time-correlated noise to the actions taken by the deterministic policy. 11 | The OU process satisfies the following stochastic differential equation: 12 | dxt = theta*(mu - xt)*dt + sigma*dWt 13 | where Wt denotes the Wiener process 14 | 15 | Based on the rllab implementation. 16 | """ 17 | 18 | def __init__( 19 | self, 20 | action_space, 21 | mu=0, 22 | theta=0.15, 23 | max_sigma=0.3, 24 | min_sigma=None, 25 | decay_period=100000, 26 | ): 27 | if min_sigma is None: 28 | min_sigma = max_sigma 29 | self.mu = mu 30 | self.theta = theta 31 | self.sigma = max_sigma 32 | self._max_sigma = max_sigma 33 | if min_sigma is None: 34 | min_sigma = max_sigma 35 | self._min_sigma = min_sigma 36 | self._decay_period = decay_period 37 | self.dim = np.prod(action_space.low.shape) 38 | self.low = action_space.low 39 | self.high = action_space.high 40 | self.state = np.ones(self.dim) * self.mu 41 | self.reset() 42 | 43 | def reset(self): 44 | self.state = np.ones(self.dim) * self.mu 45 | 46 | def evolve_state(self): 47 | x = self.state 48 | dx = self.theta * (self.mu - x) + self.sigma * nr.randn(len(x)) 49 | self.state = x + dx 50 | return self.state 51 | 52 | def get_action_from_raw_action(self, action, t=0, **kwargs): 53 | ou_state = self.evolve_state() 54 | self.sigma = ( 55 | self._max_sigma 56 | - (self._max_sigma - self._min_sigma) 57 | * min(1.0, t * 1.0 / self._decay_period) 58 | ) 59 | return np.clip(action + ou_state, self.low, self.high) 60 | -------------------------------------------------------------------------------- /rlkit/torch/dsac/risk.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | import rlkit.torch.pytorch_util as ptu 5 | 6 | 7 | def normal_cdf(value, loc=0., scale=1.): 8 | return 0.5 * (1 + torch.erf((value - loc) / scale / np.sqrt(2))) 9 | 10 | 11 | def normal_icdf(value, loc=0., scale=1.): 12 | return loc + scale * torch.erfinv(2 * value - 1) * np.sqrt(2) 13 | 14 | 15 | def normal_pdf(value, loc=0., scale=1.): 16 | return torch.exp(-(value - loc)**2 / (2 * scale**2)) / scale / np.sqrt(2 * np.pi) 17 | 18 | 19 | def distortion_fn(tau, mode="neutral", param=0.): 20 | # Risk distortion function 21 | tau = tau.clamp(0., 1.) 22 | if param >= 0: 23 | if mode == "neutral": 24 | tau_ = tau 25 | elif mode == "wang": 26 | tau_ = normal_cdf(normal_icdf(tau) + param) 27 | elif mode == "cvar": 28 | tau_ = (1. / param) * tau 29 | elif mode == "cpw": 30 | tau_ = tau**param / (tau**param + (1. - tau)**param)**(1. / param) 31 | return tau_.clamp(0., 1.) 32 | else: 33 | return 1 - distortion_fn(1 - tau, mode, -param) 34 | 35 | 36 | def distortion_de(tau, mode="neutral", param=0., eps=1e-8): 37 | # Derivative of Risk distortion function 38 | tau = tau.clamp(0., 1.) 39 | if param >= 0: 40 | if mode == "neutral": 41 | tau_ = ptu.one_like(tau) 42 | elif mode == "wang": 43 | tau_ = normal_pdf(normal_icdf(tau) + param) / (normal_pdf(normal_icdf(tau)) + eps) 44 | elif mode == "cvar": 45 | tau_ = (1. / param) * (tau < param) 46 | elif mode == "cpw": 47 | g = tau**param 48 | h = (tau**param + (1 - tau)**param)**(1 / param) 49 | g_ = param * tau**(param - 1) 50 | h_ = (tau**param + (1 - tau)**param)**(1 / param - 1) * (tau**(param - 1) - (1 - tau)**(param - 1)) 51 | tau_ = (g_ * h - g * h_) / (h**2 + eps) 52 | return tau_.clamp(0., 5.) 53 | 54 | else: 55 | return distortion_de(1 - tau, mode, -param) 56 | -------------------------------------------------------------------------------- /rlkit/util/ml_util.py: -------------------------------------------------------------------------------- 1 | """ 2 | General utility functions for machine learning. 3 | """ 4 | import abc 5 | import math 6 | import numpy as np 7 | 8 | 9 | class ScalarSchedule(object, metaclass=abc.ABCMeta): 10 | @abc.abstractmethod 11 | def get_value(self, t): 12 | pass 13 | 14 | 15 | class ConstantSchedule(ScalarSchedule): 16 | def __init__(self, value): 17 | self._value = value 18 | 19 | def get_value(self, t): 20 | return self._value 21 | 22 | 23 | class LinearSchedule(ScalarSchedule): 24 | """ 25 | Linearly interpolate and then stop at a final value. 26 | """ 27 | def __init__( 28 | self, 29 | init_value, 30 | final_value, 31 | ramp_duration, 32 | ): 33 | self._init_value = init_value 34 | self._final_value = final_value 35 | self._ramp_duration = ramp_duration 36 | 37 | def get_value(self, t): 38 | return ( 39 | self._init_value 40 | + (self._final_value - self._init_value) 41 | * min(1.0, t * 1.0 / self._ramp_duration) 42 | ) 43 | 44 | 45 | class IntLinearSchedule(LinearSchedule): 46 | """ 47 | Same as RampUpSchedule but round output to an int 48 | """ 49 | def get_value(self, t): 50 | return int(super().get_value(t)) 51 | 52 | 53 | class PiecewiseLinearSchedule(ScalarSchedule): 54 | """ 55 | Given a list of (x, t) value-time pairs, return value x at time t, 56 | and linearly interpolate between the two 57 | """ 58 | def __init__( 59 | self, 60 | x_values, 61 | y_values, 62 | ): 63 | self._x_values = x_values 64 | self._y_values = y_values 65 | 66 | def get_value(self, t): 67 | return np.interp(t, self._x_values, self._y_values) 68 | 69 | 70 | class IntPiecewiseLinearSchedule(PiecewiseLinearSchedule): 71 | def get_value(self, t): 72 | return int(super().get_value(t)) 73 | 74 | 75 | def none_to_infty(bounds): 76 | if bounds is None: 77 | bounds = -math.inf, math.inf 78 | lb, ub = bounds 79 | if lb is None: 80 | lb = -math.inf 81 | if ub is None: 82 | ub = math.inf 83 | return lb, ub 84 | -------------------------------------------------------------------------------- /rlkit/core/serializable.py: -------------------------------------------------------------------------------- 1 | """ 2 | Based on rllab's serializable.py file 3 | 4 | https://github.com/rll/rllab 5 | """ 6 | 7 | import inspect 8 | import sys 9 | 10 | 11 | class Serializable(object): 12 | 13 | def __init__(self, *args, **kwargs): 14 | self.__args = args 15 | self.__kwargs = kwargs 16 | 17 | def quick_init(self, locals_): 18 | if getattr(self, "_serializable_initialized", False): 19 | return 20 | if sys.version_info >= (3, 0): 21 | spec = inspect.getfullargspec(self.__init__) 22 | # Exclude the first "self" parameter 23 | if spec.varkw: 24 | kwargs = locals_[spec.varkw].copy() 25 | else: 26 | kwargs = dict() 27 | if spec.kwonlyargs: 28 | for key in spec.kwonlyargs: 29 | kwargs[key] = locals_[key] 30 | else: 31 | spec = inspect.getargspec(self.__init__) 32 | if spec.keywords: 33 | kwargs = locals_[spec.keywords] 34 | else: 35 | kwargs = dict() 36 | if spec.varargs: 37 | varargs = locals_[spec.varargs] 38 | else: 39 | varargs = tuple() 40 | in_order_args = [locals_[arg] for arg in spec.args][1:] 41 | self.__args = tuple(in_order_args) + varargs 42 | self.__kwargs = kwargs 43 | setattr(self, "_serializable_initialized", True) 44 | 45 | def __getstate__(self): 46 | return {"__args": self.__args, "__kwargs": self.__kwargs} 47 | 48 | def __setstate__(self, d): 49 | # convert all __args to keyword-based arguments 50 | if sys.version_info >= (3, 0): 51 | spec = inspect.getfullargspec(self.__init__) 52 | else: 53 | spec = inspect.getargspec(self.__init__) 54 | in_order_args = spec.args[1:] 55 | out = type(self)(**dict(zip(in_order_args, d["__args"]), **d["__kwargs"])) 56 | self.__dict__.update(out.__dict__) 57 | 58 | @classmethod 59 | def clone(cls, obj, **kwargs): 60 | assert isinstance(obj, Serializable) 61 | d = obj.__getstate__() 62 | d["__kwargs"] = dict(d["__kwargs"], **kwargs) 63 | out = type(obj).__new__(type(obj)) 64 | out.__setstate__(d) 65 | return out 66 | -------------------------------------------------------------------------------- /rlkit/torch/data.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch.utils.data import Dataset, Sampler 4 | 5 | # TODO: move this to more reasonable place 6 | from rlkit.data_management.obs_dict_replay_buffer import normalize_image 7 | 8 | 9 | class ImageDataset(Dataset): 10 | 11 | def __init__(self, images, should_normalize=True): 12 | super().__init__() 13 | self.dataset = images 14 | self.dataset_len = len(self.dataset) 15 | assert should_normalize == (images.dtype == np.uint8) 16 | self.should_normalize = should_normalize 17 | 18 | def __len__(self): 19 | return self.dataset_len 20 | 21 | def __getitem__(self, idxs): 22 | samples = self.dataset[idxs, :] 23 | if self.should_normalize: 24 | samples = normalize_image(samples) 25 | return np.float32(samples) 26 | 27 | 28 | class InfiniteRandomSampler(Sampler): 29 | 30 | def __init__(self, data_source): 31 | self.data_source = data_source 32 | self.iter = iter(torch.randperm(len(self.data_source)).tolist()) 33 | 34 | def __iter__(self): 35 | return self 36 | 37 | def __next__(self): 38 | try: 39 | idx = next(self.iter) 40 | except StopIteration: 41 | self.iter = iter(torch.randperm(len(self.data_source)).tolist()) 42 | idx = next(self.iter) 43 | return idx 44 | 45 | def __len__(self): 46 | return 2 ** 62 47 | 48 | 49 | class InfiniteWeightedRandomSampler(Sampler): 50 | 51 | def __init__(self, data_source, weights): 52 | assert len(data_source) == len(weights) 53 | assert len(weights.shape) == 1 54 | self.data_source = data_source 55 | # Always use CPU 56 | self._weights = torch.from_numpy(weights) 57 | self.iter = self._create_iterator() 58 | 59 | def update_weights(self, weights): 60 | self._weights = weights 61 | self.iter = self._create_iterator() 62 | 63 | def _create_iterator(self): 64 | return iter( 65 | torch.multinomial( 66 | self._weights, len(self._weights), replacement=True 67 | ).tolist() 68 | ) 69 | 70 | def __iter__(self): 71 | return self 72 | 73 | def __next__(self): 74 | try: 75 | idx = next(self.iter) 76 | except StopIteration: 77 | self.iter = self._create_iterator() 78 | idx = next(self.iter) 79 | return idx 80 | 81 | def __len__(self): 82 | return 2 ** 62 83 | -------------------------------------------------------------------------------- /rlkit/torch/distributions.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.distributions import Distribution, Normal 3 | import rlkit.torch.pytorch_util as ptu 4 | 5 | 6 | class TanhNormal(Distribution): 7 | """ 8 | Represent distribution of X where 9 | X ~ tanh(Z) 10 | Z ~ N(mean, std) 11 | 12 | Note: this is not very numerically stable. 13 | """ 14 | def __init__(self, normal_mean, normal_std, epsilon=1e-6): 15 | """ 16 | :param normal_mean: Mean of the normal distribution 17 | :param normal_std: Std of the normal distribution 18 | :param epsilon: Numerical stability epsilon when computing log-prob. 19 | """ 20 | self.normal_mean = normal_mean 21 | self.normal_std = normal_std 22 | self.normal = Normal(normal_mean, normal_std) 23 | self.epsilon = epsilon 24 | 25 | def sample_n(self, n, return_pre_tanh_value=False): 26 | z = self.normal.sample_n(n) 27 | if return_pre_tanh_value: 28 | return torch.tanh(z), z 29 | else: 30 | return torch.tanh(z) 31 | 32 | def log_prob(self, value, pre_tanh_value=None): 33 | """ 34 | 35 | :param value: some value, x 36 | :param pre_tanh_value: arctanh(x) 37 | :return: 38 | """ 39 | if pre_tanh_value is None: 40 | pre_tanh_value = torch.log( 41 | (1+value) / (1-value) 42 | ) / 2 43 | return self.normal.log_prob(pre_tanh_value) - torch.log( 44 | 1 - value * value + self.epsilon 45 | ) 46 | 47 | def sample(self, return_pretanh_value=False): 48 | """ 49 | Gradients will and should *not* pass through this operation. 50 | 51 | See https://github.com/pytorch/pytorch/issues/4620 for discussion. 52 | """ 53 | z = self.normal.sample().detach() 54 | 55 | if return_pretanh_value: 56 | return torch.tanh(z), z 57 | else: 58 | return torch.tanh(z) 59 | 60 | def rsample(self, return_pretanh_value=False): 61 | """ 62 | Sampling in the reparameterization case. 63 | """ 64 | z = ( 65 | self.normal_mean + 66 | self.normal_std * 67 | Normal( 68 | ptu.zeros(self.normal_mean.size()), 69 | ptu.ones(self.normal_std.size()) 70 | ).sample() 71 | ) 72 | z.requires_grad_() 73 | 74 | if return_pretanh_value: 75 | return torch.tanh(z), z 76 | else: 77 | return torch.tanh(z) 78 | -------------------------------------------------------------------------------- /rlkit/torch/data_management/normalizer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import rlkit.torch.pytorch_util as ptu 3 | import numpy as np 4 | 5 | from rlkit.data_management.normalizer import Normalizer, FixedNormalizer 6 | 7 | 8 | class TorchNormalizer(Normalizer): 9 | """ 10 | Update with np array, but de/normalize pytorch Tensors. 11 | """ 12 | def normalize(self, v, clip_range=None): 13 | if not self.synchronized: 14 | self.synchronize() 15 | if clip_range is None: 16 | clip_range = self.default_clip_range 17 | mean = ptu.from_numpy(self.mean) 18 | std = ptu.from_numpy(self.std) 19 | if v.dim() == 2: 20 | # Unsqueeze along the batch use automatic broadcasting 21 | mean = mean.unsqueeze(0) 22 | std = std.unsqueeze(0) 23 | return torch.clamp((v - mean) / std, -clip_range, clip_range) 24 | 25 | def denormalize(self, v): 26 | if not self.synchronized: 27 | self.synchronize() 28 | mean = ptu.from_numpy(self.mean) 29 | std = ptu.from_numpy(self.std) 30 | if v.dim() == 2: 31 | mean = mean.unsqueeze(0) 32 | std = std.unsqueeze(0) 33 | return mean + v * std 34 | 35 | 36 | class TorchFixedNormalizer(FixedNormalizer): 37 | def normalize(self, v, clip_range=None): 38 | if clip_range is None: 39 | clip_range = self.default_clip_range 40 | mean = ptu.from_numpy(self.mean) 41 | std = ptu.from_numpy(self.std) 42 | if v.dim() == 2: 43 | # Unsqueeze along the batch use automatic broadcasting 44 | mean = mean.unsqueeze(0) 45 | std = std.unsqueeze(0) 46 | return torch.clamp((v - mean) / std, -clip_range, clip_range) 47 | 48 | def normalize_scale(self, v): 49 | """ 50 | Only normalize the scale. Do not subtract the mean. 51 | """ 52 | std = ptu.from_numpy(self.std) 53 | if v.dim() == 2: 54 | std = std.unsqueeze(0) 55 | return v / std 56 | 57 | def denormalize(self, v): 58 | mean = ptu.from_numpy(self.mean) 59 | std = ptu.from_numpy(self.std) 60 | if v.dim() == 2: 61 | mean = mean.unsqueeze(0) 62 | std = std.unsqueeze(0) 63 | return mean + v * std 64 | 65 | def denormalize_scale(self, v): 66 | """ 67 | Only denormalize the scale. Do not add the mean. 68 | """ 69 | std = ptu.from_numpy(self.std) 70 | if v.dim() == 2: 71 | std = std.unsqueeze(0) 72 | return v * std 73 | -------------------------------------------------------------------------------- /rlkit/torch/dsac/networks.py: -------------------------------------------------------------------------------- 1 | """ 2 | General networks for pytorch. 3 | 4 | Algorithm-specific networks should go else-where. 5 | """ 6 | import numpy as np 7 | import torch 8 | from torch import nn as nn 9 | from torch.nn import functional as F 10 | 11 | from rlkit.torch import pytorch_util as ptu 12 | 13 | 14 | def softmax(x): 15 | return F.softmax(x, dim=-1) 16 | 17 | 18 | class QuantileMlp(nn.Module): 19 | 20 | def __init__( 21 | self, 22 | hidden_sizes, 23 | output_size, 24 | input_size, 25 | embedding_size=64, 26 | num_quantiles=32, 27 | layer_norm=True, 28 | **kwargs, 29 | ): 30 | super().__init__() 31 | self.layer_norm = layer_norm 32 | # hidden_sizes[:-2] MLP base 33 | # hidden_sizes[-2] before merge 34 | # hidden_sizes[-1] before output 35 | 36 | self.base_fc = [] 37 | last_size = input_size 38 | for next_size in hidden_sizes[:-1]: 39 | self.base_fc += [ 40 | nn.Linear(last_size, next_size), 41 | nn.LayerNorm(next_size) if layer_norm else nn.Identity(), 42 | nn.ReLU(inplace=True), 43 | ] 44 | last_size = next_size 45 | self.base_fc = nn.Sequential(*self.base_fc) 46 | self.num_quantiles = num_quantiles 47 | self.embedding_size = embedding_size 48 | self.tau_fc = nn.Sequential( 49 | nn.Linear(embedding_size, last_size), 50 | nn.LayerNorm(last_size) if layer_norm else nn.Identity(), 51 | nn.Sigmoid(), 52 | ) 53 | self.merge_fc = nn.Sequential( 54 | nn.Linear(last_size, hidden_sizes[-1]), 55 | nn.LayerNorm(hidden_sizes[-1]) if layer_norm else nn.Identity(), 56 | nn.ReLU(inplace=True), 57 | ) 58 | self.last_fc = nn.Linear(hidden_sizes[-1], 1) 59 | self.const_vec = ptu.from_numpy(np.arange(1, 1 + self.embedding_size)) 60 | 61 | def forward(self, state, action, tau): 62 | """ 63 | Calculate Quantile Value in Batch 64 | tau: quantile fractions, (N, T) 65 | """ 66 | h = torch.cat([state, action], dim=1) 67 | h = self.base_fc(h) # (N, C) 68 | 69 | x = torch.cos(tau.unsqueeze(-1) * self.const_vec * np.pi) # (N, T, E) 70 | x = self.tau_fc(x) # (N, T, C) 71 | 72 | h = torch.mul(x, h.unsqueeze(-2)) # (N, T, C) 73 | h = self.merge_fc(h) # (N, T, C) 74 | output = self.last_fc(h).squeeze(-1) # (N, T) 75 | return output 76 | -------------------------------------------------------------------------------- /rlkit/data_management/replay_buffer.py: -------------------------------------------------------------------------------- 1 | import abc 2 | 3 | 4 | class ReplayBuffer(object, metaclass=abc.ABCMeta): 5 | """ 6 | A class used to save and replay data. 7 | """ 8 | 9 | @abc.abstractmethod 10 | def add_sample(self, observation, action, reward, next_observation, terminal, **kwargs): 11 | """ 12 | Add a transition tuple. 13 | """ 14 | pass 15 | 16 | @abc.abstractmethod 17 | def terminate_episode(self): 18 | """ 19 | Let the replay buffer know that the episode has terminated in case some 20 | special book-keeping has to happen. 21 | :return: 22 | """ 23 | pass 24 | 25 | @abc.abstractmethod 26 | def num_steps_can_sample(self, **kwargs): 27 | """ 28 | :return: # of unique items that can be sampled. 29 | """ 30 | pass 31 | 32 | def add_path(self, path): 33 | """ 34 | Add a path to the replay buffer. 35 | 36 | This default implementation naively goes through every step, but you 37 | may want to optimize this. 38 | 39 | NOTE: You should NOT call "terminate_episode" after calling add_path. 40 | It's assumed that this function handles the episode termination. 41 | 42 | :param path: Dict like one outputted by rlkit.samplers.util.rollout 43 | """ 44 | for i, ( 45 | obs, 46 | action, 47 | reward, 48 | next_obs, 49 | terminal, 50 | agent_info, 51 | env_info, 52 | ) in enumerate( 53 | zip( 54 | path["observations"], 55 | path["actions"], 56 | path["rewards"], 57 | path["next_observations"], 58 | path["terminals"], 59 | path["agent_infos"], 60 | path["env_infos"], 61 | )): 62 | self.add_sample( 63 | observation=obs, 64 | action=action, 65 | reward=reward, 66 | next_observation=next_obs, 67 | terminal=terminal, 68 | agent_info=agent_info, 69 | env_info=env_info, 70 | ) 71 | self.terminate_episode() 72 | 73 | def add_paths(self, paths): 74 | for path in paths: 75 | self.add_path(path) 76 | 77 | @abc.abstractmethod 78 | def random_batch(self, batch_size): 79 | """ 80 | Return a batch of size `batch_size`. 81 | :param batch_size: 82 | :return: 83 | """ 84 | pass 85 | 86 | def get_diagnostics(self): 87 | return {} 88 | 89 | def get_snapshot(self): 90 | return {} 91 | 92 | def end_epoch(self, epoch): 93 | return 94 | -------------------------------------------------------------------------------- /rlkit/util/io.py: -------------------------------------------------------------------------------- 1 | import joblib 2 | import numpy as np 3 | import pickle 4 | 5 | import boto3 6 | 7 | from rlkit.launchers.conf import LOCAL_LOG_DIR, AWS_S3_PATH 8 | import os 9 | 10 | PICKLE = 'pickle' 11 | NUMPY = 'numpy' 12 | JOBLIB = 'joblib' 13 | 14 | 15 | def local_path_from_s3_or_local_path(filename): 16 | relative_filename = os.path.join(LOCAL_LOG_DIR, filename) 17 | if os.path.isfile(filename): 18 | return filename 19 | elif os.path.isfile(relative_filename): 20 | return relative_filename 21 | else: 22 | return sync_down(filename) 23 | 24 | 25 | def sync_down(path, check_exists=True): 26 | is_docker = os.path.isfile("/.dockerenv") 27 | if is_docker: 28 | local_path = "/tmp/%s" % (path) 29 | else: 30 | local_path = "%s/%s" % (LOCAL_LOG_DIR, path) 31 | 32 | if check_exists and os.path.isfile(local_path): 33 | return local_path 34 | 35 | local_dir = os.path.dirname(local_path) 36 | os.makedirs(local_dir, exist_ok=True) 37 | 38 | if is_docker: 39 | from doodad.ec2.autoconfig import AUTOCONFIG 40 | os.environ["AWS_ACCESS_KEY_ID"] = AUTOCONFIG.aws_access_key() 41 | os.environ["AWS_SECRET_ACCESS_KEY"] = AUTOCONFIG.aws_access_secret() 42 | 43 | full_s3_path = os.path.join(AWS_S3_PATH, path) 44 | bucket_name, bucket_relative_path = split_s3_full_path(full_s3_path) 45 | try: 46 | bucket = boto3.resource('s3').Bucket(bucket_name) 47 | bucket.download_file(bucket_relative_path, local_path) 48 | except Exception as e: 49 | local_path = None 50 | print("Failed to sync! path: ", path) 51 | print("Exception: ", e) 52 | return local_path 53 | 54 | 55 | def split_s3_full_path(s3_path): 56 | """ 57 | Split "s3://foo/bar/baz" into "foo" and "bar/baz" 58 | """ 59 | bucket_name_and_directories = s3_path.split('//')[1] 60 | bucket_name, *directories = bucket_name_and_directories.split('/') 61 | directory_path = '/'.join(directories) 62 | return bucket_name, directory_path 63 | 64 | 65 | def load_local_or_remote_file(filepath, file_type=None): 66 | local_path = local_path_from_s3_or_local_path(filepath) 67 | if file_type is None: 68 | extension = local_path.split('.')[-1] 69 | if extension == 'npy': 70 | file_type = NUMPY 71 | else: 72 | file_type = PICKLE 73 | else: 74 | file_type = PICKLE 75 | if file_type == NUMPY: 76 | object = np.load(open(local_path, "rb")) 77 | elif file_type == JOBLIB: 78 | object = joblib.load(local_path) 79 | else: 80 | object = pickle.load(open(local_path, "rb")) 81 | print("loaded", local_path) 82 | return object 83 | 84 | 85 | if __name__ == "__main__": 86 | p = sync_down("ashvin/vae/new-point2d/run0/id1/params.pkl") 87 | print("got", p) -------------------------------------------------------------------------------- /rlkit/core/batch_rl_algorithm.py: -------------------------------------------------------------------------------- 1 | import abc 2 | 3 | import gtimer as gt 4 | from rlkit.core.rl_algorithm import BaseRLAlgorithm 5 | from rlkit.data_management.replay_buffer import ReplayBuffer 6 | from rlkit.samplers.data_collector import PathCollector 7 | 8 | 9 | class BatchRLAlgorithm(BaseRLAlgorithm, metaclass=abc.ABCMeta): 10 | def __init__( 11 | self, 12 | trainer, 13 | exploration_env, 14 | evaluation_env, 15 | exploration_data_collector: PathCollector, 16 | evaluation_data_collector: PathCollector, 17 | replay_buffer: ReplayBuffer, 18 | batch_size, 19 | max_path_length, 20 | num_epochs, 21 | num_eval_steps_per_epoch, 22 | num_expl_steps_per_train_loop, 23 | num_trains_per_train_loop, 24 | num_train_loops_per_epoch=1, 25 | min_num_steps_before_training=0, 26 | ): 27 | super().__init__( 28 | trainer, 29 | exploration_env, 30 | evaluation_env, 31 | exploration_data_collector, 32 | evaluation_data_collector, 33 | replay_buffer, 34 | ) 35 | self.batch_size = batch_size 36 | self.max_path_length = max_path_length 37 | self.num_epochs = num_epochs 38 | self.num_eval_steps_per_epoch = num_eval_steps_per_epoch 39 | self.num_trains_per_train_loop = num_trains_per_train_loop 40 | self.num_train_loops_per_epoch = num_train_loops_per_epoch 41 | self.num_expl_steps_per_train_loop = num_expl_steps_per_train_loop 42 | self.min_num_steps_before_training = min_num_steps_before_training 43 | 44 | def _train(self): 45 | if self.min_num_steps_before_training > 0: 46 | init_expl_paths = self.expl_data_collector.collect_new_paths( 47 | self.max_path_length, 48 | self.min_num_steps_before_training, 49 | discard_incomplete_paths=False, 50 | ) 51 | self.replay_buffer.add_paths(init_expl_paths) 52 | self.expl_data_collector.end_epoch(-1) 53 | 54 | for epoch in gt.timed_for( 55 | range(self._start_epoch, self.num_epochs), 56 | save_itrs=True, 57 | ): 58 | self.eval_data_collector.collect_new_paths( 59 | self.max_path_length, 60 | self.num_eval_steps_per_epoch, 61 | discard_incomplete_paths=True, 62 | ) 63 | gt.stamp('evaluation sampling') 64 | 65 | for _ in range(self.num_train_loops_per_epoch): 66 | new_expl_paths = self.expl_data_collector.collect_new_paths( 67 | self.max_path_length, 68 | self.num_expl_steps_per_train_loop, 69 | discard_incomplete_paths=False, 70 | ) 71 | gt.stamp('exploration sampling', unique=False) 72 | 73 | self.replay_buffer.add_paths(new_expl_paths) 74 | gt.stamp('data storing', unique=False) 75 | 76 | self.training_mode(True) 77 | for _ in range(self.num_trains_per_train_loop): 78 | train_data = self.replay_buffer.random_batch( 79 | self.batch_size) 80 | self.trainer.train(train_data) 81 | gt.stamp('training', unique=False) 82 | self.training_mode(False) 83 | 84 | self._end_epoch(epoch) 85 | -------------------------------------------------------------------------------- /rlkit/data_management/simple_replay_buffer.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | 3 | import numpy as np 4 | 5 | from rlkit.data_management.replay_buffer import ReplayBuffer 6 | 7 | 8 | class SimpleReplayBuffer(ReplayBuffer): 9 | 10 | def __init__( 11 | self, 12 | max_replay_buffer_size, 13 | observation_dim, 14 | action_dim, 15 | env_info_sizes, 16 | ): 17 | self._observation_dim = observation_dim 18 | self._action_dim = action_dim 19 | self._max_replay_buffer_size = max_replay_buffer_size 20 | self._observations = np.zeros((max_replay_buffer_size, observation_dim)) 21 | # It's a bit memory inefficient to save the observations twice, 22 | # but it makes the code *much* easier since you no longer have to 23 | # worry about termination conditions. 24 | self._next_obs = np.zeros((max_replay_buffer_size, observation_dim)) 25 | self._actions = np.zeros((max_replay_buffer_size, action_dim)) 26 | # Make everything a 2D np array to make it easier for other code to 27 | # reason about the shape of the data 28 | self._rewards = np.zeros((max_replay_buffer_size, 1)) 29 | # self._terminals[i] = a terminal was received at time i 30 | self._terminals = np.zeros((max_replay_buffer_size, 1), dtype='uint8') 31 | # Define self._env_infos[key][i] to be the return value of env_info[key] 32 | # at time i 33 | self._env_infos = {} 34 | for key, size in env_info_sizes.items(): 35 | self._env_infos[key] = np.zeros((max_replay_buffer_size, size)) 36 | self._env_info_keys = env_info_sizes.keys() 37 | 38 | self._top = 0 39 | self._size = 0 40 | 41 | def add_sample(self, observation, action, reward, next_observation, 42 | terminal, env_info, **kwargs): 43 | self._observations[self._top] = observation 44 | self._actions[self._top] = action 45 | self._rewards[self._top] = reward 46 | self._terminals[self._top] = terminal 47 | self._next_obs[self._top] = next_observation 48 | 49 | for key in self._env_info_keys: 50 | self._env_infos[key][self._top] = env_info[key] 51 | self._advance() 52 | 53 | def terminate_episode(self): 54 | pass 55 | 56 | def _advance(self): 57 | self._top = (self._top + 1) % self._max_replay_buffer_size 58 | if self._size < self._max_replay_buffer_size: 59 | self._size += 1 60 | 61 | def random_batch(self, batch_size): 62 | indices = np.random.randint(0, self._size, batch_size) 63 | batch = dict( 64 | observations=self._observations[indices], 65 | actions=self._actions[indices], 66 | rewards=self._rewards[indices], 67 | terminals=self._terminals[indices], 68 | next_observations=self._next_obs[indices], 69 | ) 70 | for key in self._env_info_keys: 71 | assert key not in batch.keys() 72 | batch[key] = self._env_infos[key][indices] 73 | return batch 74 | 75 | def rebuild_env_info_dict(self, idx): 76 | return { 77 | key: self._env_infos[key][idx] 78 | for key in self._env_info_keys 79 | } 80 | 81 | def batch_env_info_dict(self, indices): 82 | return { 83 | key: self._env_infos[key][indices] 84 | for key in self._env_info_keys 85 | } 86 | 87 | def num_steps_can_sample(self): 88 | return self._size 89 | 90 | def get_diagnostics(self): 91 | return OrderedDict([ 92 | ('size', self._size) 93 | ]) 94 | -------------------------------------------------------------------------------- /sac.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import torch 4 | 5 | import rlkit.torch.pytorch_util as ptu 6 | import yaml 7 | from rlkit.data_management.torch_replay_buffer import TorchReplayBuffer 8 | from rlkit.envs import make_env 9 | from rlkit.envs.vecenv import SubprocVectorEnv, VectorEnv 10 | from rlkit.launchers.launcher_util import set_seed, setup_logger 11 | from rlkit.samplers.data_collector import (VecMdpPathCollector, VecMdpStepCollector) 12 | from rlkit.torch.networks import FlattenMlp 13 | from rlkit.torch.sac.policies import MakeDeterministic, TanhGaussianPolicy 14 | from rlkit.torch.sac.sac import SACTrainer 15 | from rlkit.torch.torch_rl_algorithm import TorchVecOnlineRLAlgorithm 16 | 17 | torch.set_num_threads(4) 18 | torch.set_num_interop_threads(4) 19 | 20 | 21 | def experiment(variant): 22 | dummy_env = make_env(variant['env']) 23 | obs_dim = dummy_env.observation_space.low.size 24 | action_dim = dummy_env.action_space.low.size 25 | expl_env = VectorEnv([lambda: make_env(variant['env']) for _ in range(variant['expl_env_num'])]) 26 | expl_env.seed(variant["seed"]) 27 | expl_env.action_space.seed(variant["seed"]) 28 | eval_env = SubprocVectorEnv([lambda: make_env(variant['env']) for _ in range(variant['eval_env_num'])]) 29 | eval_env.seed(variant["seed"]) 30 | 31 | M = variant['layer_size'] 32 | qf1 = FlattenMlp( 33 | input_size=obs_dim + action_dim, 34 | output_size=1, 35 | hidden_sizes=[M, M], 36 | ) 37 | qf2 = FlattenMlp( 38 | input_size=obs_dim + action_dim, 39 | output_size=1, 40 | hidden_sizes=[M, M], 41 | ) 42 | target_qf1 = FlattenMlp( 43 | input_size=obs_dim + action_dim, 44 | output_size=1, 45 | hidden_sizes=[M, M], 46 | ) 47 | target_qf2 = FlattenMlp( 48 | input_size=obs_dim + action_dim, 49 | output_size=1, 50 | hidden_sizes=[M, M], 51 | ) 52 | policy = TanhGaussianPolicy( 53 | obs_dim=obs_dim, 54 | action_dim=action_dim, 55 | hidden_sizes=[M, M], 56 | ) 57 | eval_policy = MakeDeterministic(policy) 58 | eval_path_collector = VecMdpPathCollector( 59 | eval_env, 60 | eval_policy, 61 | ) 62 | expl_path_collector = VecMdpStepCollector( 63 | expl_env, 64 | policy, 65 | ) 66 | replay_buffer = TorchReplayBuffer( 67 | variant['replay_buffer_size'], 68 | dummy_env, 69 | ) 70 | trainer = SACTrainer( 71 | env=eval_env, 72 | policy=policy, 73 | qf1=qf1, 74 | qf2=qf2, 75 | target_qf1=target_qf1, 76 | target_qf2=target_qf2, 77 | **variant['trainer_kwargs'], 78 | ) 79 | algorithm = TorchVecOnlineRLAlgorithm( 80 | trainer=trainer, 81 | exploration_env=expl_env, 82 | evaluation_env=eval_env, 83 | exploration_data_collector=expl_path_collector, 84 | evaluation_data_collector=eval_path_collector, 85 | replay_buffer=replay_buffer, 86 | **variant['algorithm_kwargs'], 87 | ) 88 | algorithm.to(ptu.device) 89 | algorithm.train() 90 | 91 | 92 | if __name__ == "__main__": 93 | parser = argparse.ArgumentParser(description='Soft Actor Critic') 94 | parser.add_argument('--config', type=str, default="configs/lunarlander.yaml") 95 | parser.add_argument('--gpu', type=int, default=0, help="using cpu with -1") 96 | parser.add_argument('--seed', type=int, default=0) 97 | args = parser.parse_args() 98 | with open(args.config, 'r', encoding="utf-8") as f: 99 | variant = yaml.load(f, Loader=yaml.FullLoader) 100 | variant["seed"] = args.seed 101 | log_prefix = "_".join(["sac", variant["env"][:-3].lower(), str(variant["version"])]) 102 | setup_logger(log_prefix, variant=variant, seed=args.seed) 103 | if args.gpu >= 0: 104 | ptu.set_gpu_mode(True, args.gpu) 105 | set_seed(args.seed) 106 | experiment(variant) 107 | -------------------------------------------------------------------------------- /rlkit/torch/networks.py: -------------------------------------------------------------------------------- 1 | """ 2 | General networks for pytorch. 3 | 4 | Algorithm-specific networks should go else-where. 5 | """ 6 | import torch 7 | from torch import nn as nn 8 | from torch.nn import functional as F 9 | 10 | from rlkit.policies.base import Policy 11 | from rlkit.torch import pytorch_util as ptu 12 | from rlkit.torch.core import eval_np 13 | from rlkit.torch.data_management.normalizer import TorchFixedNormalizer 14 | 15 | 16 | def identity(x): 17 | return x 18 | 19 | 20 | class Mlp(nn.Module): 21 | 22 | def __init__( 23 | self, 24 | hidden_sizes, 25 | output_size, 26 | input_size, 27 | init_w=3e-3, 28 | hidden_activation=F.relu, 29 | output_activation=identity, 30 | hidden_init=ptu.fanin_init, 31 | b_init_value=0.1, 32 | layer_norm=False, 33 | layer_norm_kwargs=None, 34 | ): 35 | super().__init__() 36 | 37 | if layer_norm_kwargs is None: 38 | layer_norm_kwargs = dict() 39 | 40 | self.input_size = input_size 41 | self.output_size = output_size 42 | self.hidden_activation = hidden_activation 43 | self.output_activation = output_activation 44 | self.layer_norm = layer_norm 45 | self.fcs = [] 46 | self.layer_norms = [] 47 | in_size = input_size 48 | 49 | for i, next_size in enumerate(hidden_sizes): 50 | fc = nn.Linear(in_size, next_size) 51 | in_size = next_size 52 | hidden_init(fc.weight) 53 | fc.bias.data.fill_(b_init_value) 54 | self.__setattr__("fc{}".format(i), fc) 55 | self.fcs.append(fc) 56 | 57 | if self.layer_norm: 58 | ln = nn.LayerNorm(next_size) 59 | self.__setattr__("layer_norm{}".format(i), ln) 60 | self.layer_norms.append(ln) 61 | 62 | self.last_fc = nn.Linear(in_size, output_size) 63 | self.last_fc.weight.data.uniform_(-init_w, init_w) 64 | self.last_fc.bias.data.uniform_(-init_w, init_w) 65 | 66 | def forward(self, input, return_preactivations=False): 67 | h = input 68 | for i, fc in enumerate(self.fcs): 69 | h = fc(h) 70 | if self.layer_norm and i < len(self.fcs) - 1: 71 | h = self.layer_norms[i](h) 72 | h = self.hidden_activation(h) 73 | preactivation = self.last_fc(h) 74 | output = self.output_activation(preactivation) 75 | if return_preactivations: 76 | return output, preactivation 77 | else: 78 | return output 79 | 80 | 81 | class FlattenMlp(Mlp): 82 | """ 83 | Flatten inputs along dimension 1 and then pass through MLP. 84 | """ 85 | 86 | def forward(self, *inputs, **kwargs): 87 | flat_inputs = torch.cat(inputs, dim=1) 88 | return super().forward(flat_inputs, **kwargs) 89 | 90 | 91 | class MlpPolicy(Mlp, Policy): 92 | """ 93 | A simpler interface for creating policies. 94 | """ 95 | 96 | def __init__(self, *args, obs_normalizer: TorchFixedNormalizer = None, **kwargs): 97 | super().__init__(*args, **kwargs) 98 | self.obs_normalizer = obs_normalizer 99 | 100 | def forward(self, obs, **kwargs): 101 | if self.obs_normalizer: 102 | obs = self.obs_normalizer.normalize(obs) 103 | return super().forward(obs, **kwargs) 104 | 105 | def get_action(self, obs_np): 106 | actions = self.get_actions(obs_np[None]) 107 | return actions[0, :], {} 108 | 109 | def get_actions(self, obs): 110 | return eval_np(self, obs) 111 | 112 | 113 | class TanhMlpPolicy(MlpPolicy): 114 | """ 115 | A helper class since most policies have a tanh output activation. 116 | """ 117 | 118 | def __init__(self, *args, **kwargs): 119 | super().__init__(*args, output_activation=torch.tanh, **kwargs) 120 | -------------------------------------------------------------------------------- /rlkit/data_management/normalizer.py: -------------------------------------------------------------------------------- 1 | """ 2 | Based on code from Marcin Andrychowicz 3 | """ 4 | import numpy as np 5 | 6 | 7 | class Normalizer(object): 8 | def __init__( 9 | self, 10 | size, 11 | eps=1e-8, 12 | default_clip_range=np.inf, 13 | mean=0, 14 | std=1, 15 | ): 16 | self.size = size 17 | self.eps = eps 18 | self.default_clip_range = default_clip_range 19 | self.sum = np.zeros(self.size, np.float32) 20 | self.sumsq = np.zeros(self.size, np.float32) 21 | self.count = np.ones(1, np.float32) 22 | self.mean = mean + np.zeros(self.size, np.float32) 23 | self.std = std * np.ones(self.size, np.float32) 24 | self.synchronized = True 25 | 26 | def update(self, v): 27 | if v.ndim == 1: 28 | v = np.expand_dims(v, 0) 29 | assert v.ndim == 2 30 | assert v.shape[1] == self.size 31 | self.sum += v.sum(axis=0) 32 | self.sumsq += (np.square(v)).sum(axis=0) 33 | self.count[0] += v.shape[0] 34 | self.synchronized = False 35 | 36 | def normalize(self, v, clip_range=None): 37 | if not self.synchronized: 38 | self.synchronize() 39 | if clip_range is None: 40 | clip_range = self.default_clip_range 41 | mean, std = self.mean, self.std 42 | if v.ndim == 2: 43 | mean = mean.reshape(1, -1) 44 | std = std.reshape(1, -1) 45 | return np.clip((v - mean) / std, -clip_range, clip_range) 46 | 47 | def denormalize(self, v): 48 | if not self.synchronized: 49 | self.synchronize() 50 | mean, std = self.mean, self.std 51 | if v.ndim == 2: 52 | mean = mean.reshape(1, -1) 53 | std = std.reshape(1, -1) 54 | return mean + v * std 55 | 56 | def synchronize(self): 57 | self.mean[...] = self.sum / self.count[0] 58 | self.std[...] = np.sqrt( 59 | np.maximum( 60 | np.square(self.eps), 61 | self.sumsq / self.count[0] - np.square(self.mean) 62 | ) 63 | ) 64 | self.synchronized = True 65 | 66 | 67 | class IdentityNormalizer(object): 68 | def __init__(self, *args, **kwargs): 69 | pass 70 | 71 | def update(self, v): 72 | pass 73 | 74 | def normalize(self, v, clip_range=None): 75 | return v 76 | 77 | def denormalize(self, v): 78 | return v 79 | 80 | 81 | class FixedNormalizer(object): 82 | def __init__( 83 | self, 84 | size, 85 | default_clip_range=np.inf, 86 | mean=0, 87 | std=1, 88 | eps=1e-8, 89 | ): 90 | assert std > 0 91 | std = std + eps 92 | self.size = size 93 | self.default_clip_range = default_clip_range 94 | self.mean = mean + np.zeros(self.size, np.float32) 95 | self.std = std + np.zeros(self.size, np.float32) 96 | self.eps = eps 97 | 98 | def set_mean(self, mean): 99 | self.mean = mean + np.zeros(self.size, np.float32) 100 | 101 | def set_std(self, std): 102 | std = std + self.eps 103 | self.std = std + np.zeros(self.size, np.float32) 104 | 105 | def normalize(self, v, clip_range=None): 106 | if clip_range is None: 107 | clip_range = self.default_clip_range 108 | mean, std = self.mean, self.std 109 | if v.ndim == 2: 110 | mean = mean.reshape(1, -1) 111 | std = std.reshape(1, -1) 112 | return np.clip((v - mean) / std, -clip_range, clip_range) 113 | 114 | def denormalize(self, v): 115 | mean, std = self.mean, self.std 116 | if v.ndim == 2: 117 | mean = mean.reshape(1, -1) 118 | std = std.reshape(1, -1) 119 | return mean + v * std 120 | 121 | def copy_stats(self, other): 122 | self.set_mean(other.mean) 123 | self.set_std(other.std) 124 | -------------------------------------------------------------------------------- /rlkit/core/online_rl_algorithm.py: -------------------------------------------------------------------------------- 1 | import abc 2 | 3 | import gtimer as gt 4 | from rlkit.core.rl_algorithm import BaseRLAlgorithm 5 | from rlkit.data_management.replay_buffer import ReplayBuffer 6 | from rlkit.samplers.data_collector import ( 7 | PathCollector, 8 | StepCollector, 9 | ) 10 | 11 | 12 | class OnlineRLAlgorithm(BaseRLAlgorithm, metaclass=abc.ABCMeta): 13 | def __init__( 14 | self, 15 | trainer, 16 | exploration_env, 17 | evaluation_env, 18 | exploration_data_collector: StepCollector, 19 | evaluation_data_collector: PathCollector, 20 | replay_buffer: ReplayBuffer, 21 | batch_size, 22 | max_path_length, 23 | num_epochs, 24 | num_eval_paths_per_epoch, 25 | num_expl_steps_per_train_loop, 26 | num_trains_per_train_loop, 27 | num_train_loops_per_epoch=1, 28 | min_num_steps_before_training=0, 29 | ): 30 | super().__init__( 31 | trainer, 32 | exploration_env, 33 | evaluation_env, 34 | exploration_data_collector, 35 | evaluation_data_collector, 36 | replay_buffer, 37 | ) 38 | self.batch_size = batch_size 39 | self.max_path_length = max_path_length 40 | self.num_epochs = num_epochs 41 | self.num_eval_paths_per_epoch = num_eval_paths_per_epoch 42 | self.num_trains_per_train_loop = num_trains_per_train_loop 43 | self.num_train_loops_per_epoch = num_train_loops_per_epoch 44 | self.num_expl_steps_per_train_loop = num_expl_steps_per_train_loop 45 | self.min_num_steps_before_training = min_num_steps_before_training 46 | 47 | assert self.num_trains_per_train_loop >= self.num_expl_steps_per_train_loop, \ 48 | 'Online training presumes num_trains_per_train_loop >= num_expl_steps_per_train_loop' 49 | 50 | def _train(self): 51 | self.training_mode(False) 52 | if self.min_num_steps_before_training > 0: 53 | self.expl_data_collector.collect_new_steps( 54 | self.max_path_length, 55 | self.min_num_steps_before_training, 56 | discard_incomplete_paths=False, 57 | ) 58 | init_expl_paths = self.expl_data_collector.get_epoch_paths() 59 | self.replay_buffer.add_paths(init_expl_paths) 60 | self.expl_data_collector.end_epoch(-1) 61 | 62 | gt.stamp('initial exploration', unique=True) 63 | 64 | num_trains_per_expl_step = self.num_trains_per_train_loop // self.num_expl_steps_per_train_loop 65 | for epoch in gt.timed_for( 66 | range(self._start_epoch, self.num_epochs), 67 | save_itrs=True, 68 | ): 69 | self.eval_data_collector.collect_new_paths( 70 | self.max_path_length, 71 | self.num_eval_paths_per_epoch, 72 | ) 73 | gt.stamp('evaluation sampling') 74 | 75 | for _ in range(self.num_train_loops_per_epoch): 76 | for _ in range(self.num_expl_steps_per_train_loop): 77 | self.expl_data_collector.collect_new_steps( 78 | self.max_path_length, 79 | 1, # num steps 80 | discard_incomplete_paths=False, 81 | ) 82 | gt.stamp('exploration sampling', unique=False) 83 | 84 | self.training_mode(True) 85 | for _ in range(num_trains_per_expl_step): 86 | train_data = self.replay_buffer.random_batch( 87 | self.batch_size) 88 | self.trainer.train(train_data) 89 | gt.stamp('training', unique=False) 90 | self.training_mode(False) 91 | 92 | new_expl_paths = self.expl_data_collector.get_epoch_paths() 93 | self.replay_buffer.add_paths(new_expl_paths) 94 | gt.stamp('data storing', unique=False) 95 | 96 | self._end_epoch(epoch) 97 | -------------------------------------------------------------------------------- /rlkit/samplers/util.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def rollout(env, agent, max_path_length=np.inf, render=False): 5 | """ 6 | The following value for the following keys will be a 2D array, with the 7 | first dimension corresponding to the time dimension. 8 | - observations 9 | - actions 10 | - rewards 11 | - next_observations 12 | - terminals 13 | 14 | The next two elements will be lists of dictionaries, with the index into 15 | the list being the index into the time 16 | - agent_infos 17 | - env_infos 18 | 19 | :param env: 20 | :param agent: 21 | :param max_path_length: 22 | :param render: 23 | :return: 24 | """ 25 | observations = [] 26 | actions = [] 27 | rewards = [] 28 | terminals = [] 29 | agent_infos = [] 30 | env_infos = [] 31 | o = env.reset() 32 | next_o = None 33 | path_length = 0 34 | if render: 35 | env.render() 36 | while path_length < max_path_length: 37 | a, agent_info = agent.get_action(o) 38 | next_o, r, d, env_info = env.step(a) 39 | observations.append(o) 40 | rewards.append(r) 41 | terminals.append(d) 42 | actions.append(a) 43 | agent_infos.append(agent_info) 44 | env_infos.append(env_info) 45 | path_length += 1 46 | if d: 47 | break 48 | o = next_o 49 | if render: 50 | env.render() 51 | 52 | actions = np.array(actions) 53 | if len(actions.shape) == 1: 54 | actions = np.expand_dims(actions, 1) 55 | observations = np.array(observations) 56 | if len(observations.shape) == 1: 57 | observations = np.expand_dims(observations, 1) 58 | next_o = np.array([next_o]) 59 | next_observations = np.vstack( 60 | ( 61 | observations[1:, :], 62 | np.expand_dims(next_o, 0) 63 | ) 64 | ) 65 | return dict( 66 | observations=observations, 67 | actions=actions, 68 | rewards=np.array(rewards).reshape(-1, 1), 69 | next_observations=next_observations, 70 | terminals=np.array(terminals).reshape(-1, 1), 71 | agent_infos=agent_infos, 72 | env_infos=env_infos, 73 | ) 74 | 75 | 76 | def split_paths(paths): 77 | """ 78 | Stack multiples obs/actions/etc. from different paths 79 | :param paths: List of paths, where one path is something returned from 80 | the rollout functino above. 81 | :return: Tuple. Every element will have shape batch_size X DIM, including 82 | the rewards and terminal flags. 83 | """ 84 | rewards = [path["rewards"].reshape(-1, 1) for path in paths] 85 | terminals = [path["terminals"].reshape(-1, 1) for path in paths] 86 | actions = [path["actions"] for path in paths] 87 | obs = [path["observations"] for path in paths] 88 | next_obs = [path["next_observations"] for path in paths] 89 | rewards = np.vstack(rewards) 90 | terminals = np.vstack(terminals) 91 | obs = np.vstack(obs) 92 | actions = np.vstack(actions) 93 | next_obs = np.vstack(next_obs) 94 | assert len(rewards.shape) == 2 95 | assert len(terminals.shape) == 2 96 | assert len(obs.shape) == 2 97 | assert len(actions.shape) == 2 98 | assert len(next_obs.shape) == 2 99 | return rewards, terminals, obs, actions, next_obs 100 | 101 | 102 | def split_paths_to_dict(paths): 103 | rewards, terminals, obs, actions, next_obs = split_paths(paths) 104 | return dict( 105 | rewards=rewards, 106 | terminals=terminals, 107 | observations=obs, 108 | actions=actions, 109 | next_observations=next_obs, 110 | ) 111 | 112 | 113 | def get_stat_in_paths(paths, dict_name, scalar_name): 114 | if len(paths) == 0: 115 | return np.array([[]]) 116 | 117 | if type(paths[0][dict_name]) == dict: 118 | # Support rllab interface 119 | return [path[dict_name][scalar_name] for path in paths] 120 | 121 | return [ 122 | [info[scalar_name] for info in path[dict_name]] 123 | for path in paths 124 | ] -------------------------------------------------------------------------------- /rlkit/util/video.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as osp 3 | import time 4 | 5 | import numpy as np 6 | import scipy.misc 7 | import skvideo.io 8 | 9 | from rlkit.envs.vae_wrapper import VAEWrappedEnv 10 | 11 | 12 | def dump_video( 13 | env, 14 | policy, 15 | filename, 16 | rollout_function, 17 | rows=3, 18 | columns=6, 19 | pad_length=0, 20 | pad_color=255, 21 | do_timer=True, 22 | horizon=100, 23 | dirname_to_save_images=None, 24 | subdirname="rollouts", 25 | imsize=84, 26 | num_channels=3, 27 | ): 28 | frames = [] 29 | H = 3 * imsize 30 | W = imsize 31 | N = rows * columns 32 | for i in range(N): 33 | start = time.time() 34 | path = rollout_function( 35 | env, 36 | policy, 37 | max_path_length=horizon, 38 | render=False, 39 | ) 40 | is_vae_env = isinstance(env, VAEWrappedEnv) 41 | l = [] 42 | for d in path['full_observations']: 43 | if is_vae_env: 44 | recon = np.clip(env._reconstruct_img(d['image_observation']), 0, 45 | 1) 46 | else: 47 | recon = d['image_observation'] 48 | l.append( 49 | get_image( 50 | d['image_desired_goal'], 51 | d['image_observation'], 52 | recon, 53 | pad_length=pad_length, 54 | pad_color=pad_color, 55 | imsize=imsize, 56 | ) 57 | ) 58 | frames += l 59 | 60 | if dirname_to_save_images: 61 | rollout_dir = osp.join(dirname_to_save_images, subdirname, str(i)) 62 | os.makedirs(rollout_dir, exist_ok=True) 63 | rollout_frames = frames[-101:] 64 | goal_img = np.flip(rollout_frames[0][:imsize, :imsize, :], 0) 65 | scipy.misc.imsave(rollout_dir + "/goal.png", goal_img) 66 | goal_img = np.flip(rollout_frames[1][:imsize, :imsize, :], 0) 67 | scipy.misc.imsave(rollout_dir + "/z_goal.png", goal_img) 68 | for j in range(0, 101, 1): 69 | img = np.flip(rollout_frames[j][imsize:, :imsize, :], 0) 70 | scipy.misc.imsave(rollout_dir + "/" + str(j) + ".png", img) 71 | if do_timer: 72 | print(i, time.time() - start) 73 | 74 | frames = np.array(frames, dtype=np.uint8) 75 | path_length = frames.size // ( 76 | N * (H + 2 * pad_length) * (W + 2 * pad_length) * num_channels 77 | ) 78 | frames = np.array(frames, dtype=np.uint8).reshape( 79 | (N, path_length, H + 2 * pad_length, W + 2 * pad_length, num_channels) 80 | ) 81 | f1 = [] 82 | for k1 in range(columns): 83 | f2 = [] 84 | for k2 in range(rows): 85 | k = k1 * rows + k2 86 | f2.append(frames[k:k + 1, :, :, :, :].reshape( 87 | (path_length, H + 2 * pad_length, W + 2 * pad_length, 88 | num_channels) 89 | )) 90 | f1.append(np.concatenate(f2, axis=1)) 91 | outputdata = np.concatenate(f1, axis=2) 92 | skvideo.io.vwrite(filename, outputdata) 93 | print("Saved video to ", filename) 94 | 95 | 96 | def get_image(goal, obs, recon_obs, imsize=84, pad_length=1, pad_color=255): 97 | if len(goal.shape) == 1: 98 | goal = goal.reshape(-1, imsize, imsize).transpose() 99 | obs = obs.reshape(-1, imsize, imsize).transpose() 100 | recon_obs = recon_obs.reshape(-1, imsize, imsize).transpose() 101 | img = np.concatenate((goal, obs, recon_obs)) 102 | img = np.uint8(255 * img) 103 | if pad_length > 0: 104 | img = add_border(img, pad_length, pad_color) 105 | return img 106 | 107 | 108 | def add_border(img, pad_length, pad_color, imsize=84): 109 | H = 3 * imsize 110 | W = imsize 111 | img = img.reshape((3 * imsize, imsize, -1)) 112 | img2 = np.ones((H + 2 * pad_length, W + 2 * pad_length, img.shape[2]), 113 | dtype=np.uint8) * pad_color 114 | img2[pad_length:-pad_length, pad_length:-pad_length, :] = img 115 | return img2 116 | -------------------------------------------------------------------------------- /td3.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import torch 4 | 5 | import rlkit.torch.pytorch_util as ptu 6 | import yaml 7 | from rlkit.data_management.torch_replay_buffer import TorchReplayBuffer 8 | from rlkit.envs import make_env 9 | from rlkit.envs.vecenv import SubprocVectorEnv, VectorEnv 10 | from rlkit.exploration_strategies.base import \ 11 | PolicyWrappedWithExplorationStrategy 12 | from rlkit.exploration_strategies.gaussian_strategy import GaussianStrategy 13 | from rlkit.launchers.launcher_util import set_seed, setup_logger 14 | from rlkit.samplers.data_collector import (VecMdpPathCollector, VecMdpStepCollector) 15 | from rlkit.torch.networks import FlattenMlp, TanhMlpPolicy 16 | from rlkit.torch.td3.td3 import TD3Trainer 17 | from rlkit.torch.torch_rl_algorithm import TorchVecOnlineRLAlgorithm 18 | 19 | torch.set_num_threads(4) 20 | torch.set_num_interop_threads(4) 21 | 22 | 23 | def experiment(variant): 24 | dummy_env = make_env(variant['env']) 25 | obs_dim = dummy_env.observation_space.low.size 26 | action_dim = dummy_env.action_space.low.size 27 | expl_env = VectorEnv([lambda: make_env(variant['env']) for _ in range(variant['expl_env_num'])]) 28 | expl_env.seed(variant["seed"]) 29 | expl_env.action_space.seed(variant["seed"]) 30 | eval_env = SubprocVectorEnv([lambda: make_env(variant['env']) for _ in range(variant['eval_env_num'])]) 31 | eval_env.seed(variant["seed"]) 32 | 33 | M = variant['layer_size'] 34 | qf1 = FlattenMlp( 35 | input_size=obs_dim + action_dim, 36 | output_size=1, 37 | hidden_sizes=[M, M], 38 | ) 39 | qf2 = FlattenMlp( 40 | input_size=obs_dim + action_dim, 41 | output_size=1, 42 | hidden_sizes=[M, M], 43 | ) 44 | target_qf1 = FlattenMlp( 45 | input_size=obs_dim + action_dim, 46 | output_size=1, 47 | hidden_sizes=[M, M], 48 | ) 49 | target_qf2 = FlattenMlp( 50 | input_size=obs_dim + action_dim, 51 | output_size=1, 52 | hidden_sizes=[M, M], 53 | ) 54 | policy = TanhMlpPolicy( 55 | input_size=obs_dim, 56 | output_size=action_dim, 57 | hidden_sizes=[M, M], 58 | ) 59 | target_policy = TanhMlpPolicy( 60 | input_size=obs_dim, 61 | output_size=action_dim, 62 | hidden_sizes=[M, M], 63 | ) 64 | es = GaussianStrategy( 65 | action_space=dummy_env.action_space, 66 | max_sigma=0.1, 67 | min_sigma=0.1, # Constant sigma 68 | ) 69 | exploration_policy = PolicyWrappedWithExplorationStrategy( 70 | exploration_strategy=es, 71 | policy=policy, 72 | ) 73 | eval_path_collector = VecMdpPathCollector( 74 | eval_env, 75 | policy, 76 | ) 77 | expl_path_collector = VecMdpStepCollector( 78 | expl_env, 79 | exploration_policy, 80 | ) 81 | replay_buffer = TorchReplayBuffer( 82 | variant['replay_buffer_size'], 83 | dummy_env, 84 | ) 85 | trainer = TD3Trainer( 86 | policy=policy, 87 | target_policy=target_policy, 88 | qf1=qf1, 89 | qf2=qf2, 90 | target_qf1=target_qf1, 91 | target_qf2=target_qf2, 92 | **variant['trainer_kwargs'], 93 | ) 94 | algorithm = TorchVecOnlineRLAlgorithm( 95 | trainer=trainer, 96 | exploration_env=expl_env, 97 | evaluation_env=eval_env, 98 | exploration_data_collector=expl_path_collector, 99 | evaluation_data_collector=eval_path_collector, 100 | replay_buffer=replay_buffer, 101 | **variant['algorithm_kwargs'], 102 | ) 103 | algorithm.to(ptu.device) 104 | algorithm.train() 105 | 106 | 107 | if __name__ == "__main__": 108 | parser = argparse.ArgumentParser(description='TD3') 109 | parser.add_argument('--config', type=str, default="configs/lunarlander.yaml") 110 | parser.add_argument('--gpu', type=int, default=0, help="using cpu with -1") 111 | parser.add_argument('--seed', type=int, default=0) 112 | args = parser.parse_args() 113 | with open(args.config, 'r', encoding="utf-8") as f: 114 | variant = yaml.load(f, Loader=yaml.FullLoader) 115 | variant["seed"] = args.seed 116 | log_prefix = "_".join(["td3", variant["env"][:-3].lower(), str(variant["version"])]) 117 | setup_logger(log_prefix, variant=variant, seed=args.seed) 118 | if args.gpu >= 0: 119 | ptu.set_gpu_mode(True, args.gpu) 120 | set_seed(args.seed) 121 | experiment(variant) 122 | -------------------------------------------------------------------------------- /rlkit/torch/sac/policies.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch import nn as nn 4 | 5 | from rlkit.policies.base import ExplorationPolicy, Policy 6 | from rlkit.torch.core import eval_np 7 | from rlkit.torch.distributions import TanhNormal 8 | from rlkit.torch.networks import Mlp 9 | 10 | LOG_SIG_MAX = 2 11 | LOG_SIG_MIN = -20 12 | 13 | 14 | class TanhGaussianPolicy(Mlp, ExplorationPolicy): 15 | """ 16 | Usage: 17 | 18 | ``` 19 | policy = TanhGaussianPolicy(...) 20 | action, mean, log_std, _ = policy(obs) 21 | action, mean, log_std, _ = policy(obs, deterministic=True) 22 | action, mean, log_std, log_prob = policy(obs, return_log_prob=True) 23 | ``` 24 | 25 | Here, mean and log_std are the mean and log_std of the Gaussian that is 26 | sampled from. 27 | 28 | If deterministic is True, action = tanh(mean). 29 | If return_log_prob is False (default), log_prob = None 30 | This is done because computing the log_prob can be a bit expensive. 31 | """ 32 | 33 | def __init__(self, hidden_sizes, obs_dim, action_dim, std=None, init_w=1e-3, **kwargs): 34 | super().__init__(hidden_sizes, input_size=obs_dim, output_size=action_dim, init_w=init_w, **kwargs) 35 | self.log_std = None 36 | self.std = std 37 | if std is None: 38 | last_hidden_size = obs_dim 39 | if len(hidden_sizes) > 0: 40 | last_hidden_size = hidden_sizes[-1] 41 | self.last_fc_log_std = nn.Linear(last_hidden_size, action_dim) 42 | self.last_fc_log_std.weight.data.uniform_(-init_w, init_w) 43 | self.last_fc_log_std.bias.data.uniform_(-init_w, init_w) 44 | else: 45 | self.log_std = np.log(std) 46 | assert LOG_SIG_MIN <= self.log_std <= LOG_SIG_MAX 47 | 48 | def get_action(self, obs_np, deterministic=False): 49 | actions = self.get_actions(obs_np[None], deterministic=deterministic) 50 | return actions[0, :], {} 51 | 52 | def get_actions(self, obs_np, deterministic=False): 53 | return eval_np(self, obs_np, deterministic=deterministic)[0] 54 | 55 | def forward( 56 | self, 57 | obs, 58 | reparameterize=True, 59 | deterministic=False, 60 | return_log_prob=False, 61 | ): 62 | """ 63 | :param obs: Observation 64 | :param deterministic: If True, do not sample 65 | :param return_log_prob: If True, return a sample and its log probability 66 | """ 67 | h = obs 68 | for i, fc in enumerate(self.fcs): 69 | h = self.hidden_activation(fc(h)) 70 | mean = self.last_fc(h) 71 | if self.std is None: 72 | log_std = self.last_fc_log_std(h) 73 | log_std = torch.clamp(log_std, LOG_SIG_MIN, LOG_SIG_MAX) 74 | std = torch.exp(log_std) 75 | else: 76 | std = self.std 77 | log_std = self.log_std 78 | 79 | log_prob = None 80 | entropy = None 81 | mean_action_log_prob = None 82 | pre_tanh_value = None 83 | if deterministic: 84 | action = torch.tanh(mean) 85 | else: 86 | tanh_normal = TanhNormal(mean, std) 87 | if return_log_prob: 88 | if reparameterize is True: 89 | action, pre_tanh_value = tanh_normal.rsample(return_pretanh_value=True) 90 | else: 91 | action, pre_tanh_value = tanh_normal.sample(return_pretanh_value=True) 92 | log_prob = tanh_normal.log_prob(action, pre_tanh_value=pre_tanh_value) 93 | log_prob = log_prob.sum(dim=1, keepdim=True) 94 | else: 95 | if reparameterize is True: 96 | action = tanh_normal.rsample() 97 | else: 98 | action = tanh_normal.sample() 99 | 100 | return ( 101 | action, 102 | mean, 103 | log_std, 104 | log_prob, 105 | entropy, 106 | std, 107 | mean_action_log_prob, 108 | pre_tanh_value, 109 | ) 110 | 111 | 112 | class MakeDeterministic(nn.Module, Policy): 113 | 114 | def __init__(self, stochastic_policy): 115 | super().__init__() 116 | self.stochastic_policy = stochastic_policy 117 | 118 | def get_action(self, observation): 119 | return self.stochastic_policy.get_action(observation, deterministic=True) 120 | 121 | def get_actions(self, observation): 122 | return self.stochastic_policy.get_actions(observation, deterministic=True) 123 | -------------------------------------------------------------------------------- /rlkit/torch/dsac/policies.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch import nn as nn 4 | 5 | from rlkit.policies.base import ExplorationPolicy, Policy 6 | from rlkit.torch.core import eval_np 7 | from rlkit.torch.distributions import TanhNormal 8 | from rlkit.torch.networks import Mlp 9 | 10 | LOG_SIG_MAX = 2 11 | LOG_SIG_MIN = -20 12 | 13 | 14 | class TanhGaussianPolicy(Mlp, ExplorationPolicy): 15 | """ 16 | Usage: 17 | 18 | ``` 19 | policy = TanhGaussianPolicy(...) 20 | action, mean, log_std, _ = policy(obs) 21 | action, mean, log_std, _ = policy(obs, deterministic=True) 22 | action, mean, log_std, log_prob = policy(obs, return_log_prob=True) 23 | ``` 24 | 25 | Here, mean and log_std are the mean and log_std of the Gaussian that is 26 | sampled from. 27 | 28 | If deterministic is True, action = tanh(mean). 29 | If return_log_prob is False (default), log_prob = None 30 | This is done because computing the log_prob can be a bit expensive. 31 | """ 32 | 33 | def __init__(self, hidden_sizes, obs_dim, action_dim, std=None, init_w=1e-3, **kwargs): 34 | super().__init__(hidden_sizes, input_size=obs_dim, output_size=action_dim, init_w=init_w, **kwargs) 35 | self.log_std = None 36 | self.std = std 37 | if std is None: 38 | last_hidden_size = obs_dim 39 | if len(hidden_sizes) > 0: 40 | last_hidden_size = hidden_sizes[-1] 41 | self.last_fc_log_std = nn.Linear(last_hidden_size, action_dim) 42 | self.last_fc_log_std.weight.data.uniform_(-init_w, init_w) 43 | self.last_fc_log_std.bias.data.uniform_(-init_w, init_w) 44 | else: 45 | self.log_std = np.log(std) 46 | assert LOG_SIG_MIN <= self.log_std <= LOG_SIG_MAX 47 | 48 | @torch.no_grad() 49 | def get_action(self, obs_np, deterministic=False): 50 | actions = self.get_actions(obs_np[None], deterministic=deterministic) 51 | return actions[0, :], {} 52 | 53 | @torch.no_grad() 54 | def get_actions(self, obs_np, deterministic=False): 55 | return eval_np(self, obs_np, deterministic=deterministic)[0] 56 | 57 | def forward( 58 | self, 59 | obs, 60 | reparameterize=True, 61 | deterministic=False, 62 | return_log_prob=False, 63 | ): 64 | """ 65 | :param obs: Observation 66 | :param deterministic: If True, do not sample 67 | :param return_log_prob: If True, return a sample and its log probability 68 | """ 69 | h = obs 70 | for i, fc in enumerate(self.fcs): 71 | h = self.hidden_activation(fc(h)) 72 | mean = self.last_fc(h) 73 | if self.std is None: 74 | log_std = self.last_fc_log_std(h) 75 | log_std = torch.clamp(log_std, LOG_SIG_MIN, LOG_SIG_MAX) 76 | std = torch.exp(log_std) 77 | else: 78 | std = self.std 79 | log_std = self.log_std 80 | 81 | log_prob = None 82 | entropy = None 83 | mean_action_log_prob = None 84 | pre_tanh_value = None 85 | if deterministic: 86 | action = torch.tanh(mean) 87 | else: 88 | tanh_normal = TanhNormal(mean, std) 89 | if return_log_prob: 90 | if reparameterize is True: 91 | action, pre_tanh_value = tanh_normal.rsample(return_pretanh_value=True) 92 | else: 93 | action, pre_tanh_value = tanh_normal.sample(return_pretanh_value=True) 94 | log_prob = tanh_normal.log_prob(action, pre_tanh_value=pre_tanh_value) 95 | log_prob = log_prob.sum(dim=1, keepdim=True) 96 | else: 97 | if reparameterize is True: 98 | action = tanh_normal.rsample() 99 | else: 100 | action = tanh_normal.sample() 101 | 102 | return ( 103 | action, 104 | mean, 105 | log_std, 106 | log_prob, 107 | entropy, 108 | std, 109 | mean_action_log_prob, 110 | pre_tanh_value, 111 | ) 112 | 113 | 114 | class MakeDeterministic(Policy): 115 | 116 | def __init__(self, stochastic_policy): 117 | self.stochastic_policy = stochastic_policy 118 | 119 | def get_action(self, observation): 120 | return self.stochastic_policy.get_action(observation, deterministic=True) 121 | 122 | def get_actions(self, observations): 123 | return self.stochastic_policy.get_actions(observations, deterministic=True) 124 | -------------------------------------------------------------------------------- /rlkit/samplers/rollout_functions.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def multitask_rollout( 5 | env, 6 | agent, 7 | max_path_length=np.inf, 8 | render=False, 9 | render_kwargs=None, 10 | observation_key=None, 11 | desired_goal_key=None, 12 | get_action_kwargs=None, 13 | return_dict_obs=False, 14 | ): 15 | if render_kwargs is None: 16 | render_kwargs = {} 17 | if get_action_kwargs is None: 18 | get_action_kwargs = {} 19 | dict_obs = [] 20 | dict_next_obs = [] 21 | observations = [] 22 | actions = [] 23 | rewards = [] 24 | terminals = [] 25 | agent_infos = [] 26 | env_infos = [] 27 | next_observations = [] 28 | path_length = 0 29 | agent.reset() 30 | o = env.reset() 31 | if render: 32 | env.render(**render_kwargs) 33 | goal = o[desired_goal_key] 34 | while path_length < max_path_length: 35 | dict_obs.append(o) 36 | if observation_key: 37 | o = o[observation_key] 38 | new_obs = np.hstack((o, goal)) 39 | a, agent_info = agent.get_action(new_obs, **get_action_kwargs) 40 | next_o, r, d, env_info = env.step(a) 41 | if render: 42 | env.render(**render_kwargs) 43 | observations.append(o) 44 | rewards.append(r) 45 | terminals.append(d) 46 | actions.append(a) 47 | next_observations.append(next_o) 48 | dict_next_obs.append(next_o) 49 | agent_infos.append(agent_info) 50 | env_infos.append(env_info) 51 | path_length += 1 52 | if d: 53 | break 54 | o = next_o 55 | actions = np.array(actions) 56 | if len(actions.shape) == 1: 57 | actions = np.expand_dims(actions, 1) 58 | observations = np.array(observations) 59 | next_observations = np.array(next_observations) 60 | if return_dict_obs: 61 | observations = dict_obs 62 | next_observations = dict_next_obs 63 | return dict( 64 | observations=observations, 65 | actions=actions, 66 | rewards=np.array(rewards).reshape(-1, 1), 67 | next_observations=next_observations, 68 | terminals=np.array(terminals).reshape(-1, 1), 69 | agent_infos=agent_infos, 70 | env_infos=env_infos, 71 | goals=np.repeat(goal[None], path_length, 0), 72 | full_observations=dict_obs, 73 | ) 74 | 75 | 76 | def rollout( 77 | env, 78 | agent, 79 | max_path_length=np.inf, 80 | render=False, 81 | render_kwargs=None, 82 | ): 83 | """ 84 | The following value for the following keys will be a 2D array, with the 85 | first dimension corresponding to the time dimension. 86 | - observations 87 | - actions 88 | - rewards 89 | - next_observations 90 | - terminals 91 | 92 | The next two elements will be lists of dictionaries, with the index into 93 | the list being the index into the time 94 | - agent_infos 95 | - env_infos 96 | """ 97 | if render_kwargs is None: 98 | render_kwargs = {} 99 | observations = [] 100 | actions = [] 101 | rewards = [] 102 | terminals = [] 103 | agent_infos = [] 104 | env_infos = [] 105 | o = env.reset() 106 | agent.reset() 107 | next_o = None 108 | path_length = 0 109 | if render: 110 | env.render(**render_kwargs) 111 | while path_length < max_path_length: 112 | a, agent_info = agent.get_action(o) 113 | next_o, r, d, env_info = env.step(a) 114 | observations.append(o) 115 | rewards.append(r) 116 | terminals.append(d) 117 | actions.append(a) 118 | agent_infos.append(agent_info) 119 | env_infos.append(env_info) 120 | path_length += 1 121 | if d: 122 | break 123 | o = next_o 124 | if render: 125 | env.render(**render_kwargs) 126 | 127 | actions = np.array(actions) 128 | if len(actions.shape) == 1: 129 | actions = np.expand_dims(actions, 1) 130 | observations = np.array(observations) 131 | if len(observations.shape) == 1: 132 | observations = np.expand_dims(observations, 1) 133 | next_o = np.array([next_o]) 134 | next_observations = np.vstack((observations[1:, :], np.expand_dims(next_o, 0))) 135 | return dict( 136 | observations=observations, 137 | actions=actions, 138 | rewards=np.array(rewards).reshape(-1, 1), 139 | next_observations=next_observations, 140 | terminals=np.array(terminals).reshape(-1, 1), 141 | agent_infos=agent_infos, 142 | env_infos=env_infos, 143 | ) 144 | -------------------------------------------------------------------------------- /rlkit/launchers/conf.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copy this file to config.py and modify as needed. 3 | """ 4 | import os 5 | from os.path import join 6 | import rlkit 7 | 8 | """ 9 | `doodad.mount.MountLocal` by default ignores directories called "data" 10 | If you're going to rename this directory and use EC2, then change 11 | `doodad.mount.MountLocal.filter_dir` 12 | """ 13 | # The directory of the project, not source 14 | rlkit_project_dir = join(os.path.dirname(rlkit.__file__), os.pardir) 15 | LOCAL_LOG_DIR = join(rlkit_project_dir, 'data') 16 | 17 | """ 18 | ******************************************************************************** 19 | ******************************************************************************** 20 | ******************************************************************************** 21 | 22 | You probably don't need to set all of the configurations below this line, 23 | unless you use AWS, GCP, Slurm, and/or Slurm on a remote server. I recommend 24 | ignoring most of these things and only using them on an as-needed basis. 25 | 26 | ******************************************************************************** 27 | ******************************************************************************** 28 | ******************************************************************************** 29 | """ 30 | 31 | """ 32 | General doodad settings. 33 | """ 34 | CODE_DIRS_TO_MOUNT = [ 35 | rlkit_project_dir, 36 | # '/home/user/python/module/one', Add more paths as needed 37 | ] 38 | 39 | HOME = os.getenv('HOME') if os.getenv('HOME') is not None else os.getenv("USERPROFILE") 40 | 41 | DIR_AND_MOUNT_POINT_MAPPINGS = [ 42 | dict( 43 | local_dir=join(HOME, '.mujoco/'), 44 | mount_point='/root/.mujoco', 45 | ), 46 | ] 47 | RUN_DOODAD_EXPERIMENT_SCRIPT_PATH = ( 48 | join(rlkit_project_dir, 'scripts', 'run_experiment_from_doodad.py') 49 | # '/home/user/path/to/rlkit/scripts/run_experiment_from_doodad.py' 50 | ) 51 | """ 52 | AWS Settings 53 | """ 54 | # If not set, default will be chosen by doodad 55 | # AWS_S3_PATH = 's3://bucket/directory 56 | 57 | # The docker image is looked up on dockerhub.com. 58 | DOODAD_DOCKER_IMAGE = "TODO" 59 | INSTANCE_TYPE = 'c4.large' 60 | SPOT_PRICE = 0.03 61 | 62 | GPU_DOODAD_DOCKER_IMAGE = 'TODO' 63 | GPU_INSTANCE_TYPE = 'g2.2xlarge' 64 | GPU_SPOT_PRICE = 0.5 65 | 66 | # You can use AMI images with the docker images already installed. 67 | REGION_TO_GPU_AWS_IMAGE_ID = { 68 | 'us-west-1': "TODO", 69 | 'us-east-1': "TODO", 70 | } 71 | 72 | REGION_TO_GPU_AWS_AVAIL_ZONE = { 73 | 'us-east-1': "us-east-1b", 74 | } 75 | 76 | # This really shouldn't matter and in theory could be whatever 77 | OUTPUT_DIR_FOR_DOODAD_TARGET = '/tmp/doodad-output/' 78 | 79 | 80 | """ 81 | Slurm Settings 82 | """ 83 | SINGULARITY_IMAGE = '/home/PATH/TO/IMAGE.img' 84 | # This assumes you saved mujoco to $HOME/.mujoco 85 | SINGULARITY_PRE_CMDS = [ 86 | 'export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$HOME/.mujoco/mjpro150/bin' 87 | ] 88 | SLURM_CPU_CONFIG = dict( 89 | account_name='TODO', 90 | partition='savio', 91 | nodes=1, 92 | n_tasks=1, 93 | n_gpus=1, 94 | ) 95 | SLURM_GPU_CONFIG = dict( 96 | account_name='TODO', 97 | partition='savio2_1080ti', 98 | nodes=1, 99 | n_tasks=1, 100 | n_gpus=1, 101 | ) 102 | 103 | 104 | """ 105 | Slurm Script Settings 106 | 107 | These are basically the same settings as above, but for the remote machine 108 | where you will be running the generated script. 109 | """ 110 | SSS_CODE_DIRS_TO_MOUNT = [ 111 | ] 112 | SSS_DIR_AND_MOUNT_POINT_MAPPINGS = [ 113 | dict( 114 | local_dir='/global/home/users/USERNAME/.mujoco', 115 | mount_point='/root/.mujoco', 116 | ), 117 | ] 118 | SSS_LOG_DIR = '/global/scratch/USERNAME/doodad-log' 119 | 120 | SSS_IMAGE = '/global/scratch/USERNAME/TODO.img' 121 | SSS_RUN_DOODAD_EXPERIMENT_SCRIPT_PATH = ( 122 | '/global/home/users/USERNAME/path/to/rlkit/scripts' 123 | '/run_experiment_from_doodad.py' 124 | ) 125 | SSS_PRE_CMDS = [ 126 | 'export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/global/home/users/USERNAME' 127 | '/.mujoco/mjpro150/bin' 128 | ] 129 | 130 | """ 131 | GCP Settings 132 | """ 133 | GCP_IMAGE_NAME = 'TODO' 134 | GCP_GPU_IMAGE_NAME = 'TODO' 135 | GCP_BUCKET_NAME = 'TODO' 136 | 137 | GCP_DEFAULT_KWARGS = dict( 138 | zone='us-west2-c', 139 | instance_type='n1-standard-4', 140 | image_project='TODO', 141 | terminate=True, 142 | preemptible=True, 143 | gpu_kwargs=dict( 144 | gpu_model='nvidia-tesla-p4', 145 | num_gpu=1, 146 | ) 147 | ) 148 | 149 | try: 150 | from rlkit.launchers.conf_private import * 151 | except ImportError: 152 | print("No personal conf_private.py found.") 153 | -------------------------------------------------------------------------------- /rlkit/torch/pytorch_util.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | def soft_update_from_to(source, target, tau): 6 | for target_param, param in zip(target.parameters(), source.parameters()): 7 | target_param.data.copy_(target_param.data * (1.0 - tau) + param.data * tau) 8 | 9 | 10 | def copy_model_params_from_to(source, target): 11 | for target_param, param in zip(target.parameters(), source.parameters()): 12 | target_param.data.copy_(param.data) 13 | 14 | 15 | def fanin_init(tensor): 16 | size = tensor.size() 17 | if len(size) == 2: 18 | fan_in = size[0] 19 | elif len(size) > 2: 20 | fan_in = np.prod(size[1:]) 21 | else: 22 | raise Exception("Shape must be have dimension at least 2.") 23 | bound = 1. / np.sqrt(fan_in) 24 | return tensor.data.uniform_(-bound, bound) 25 | 26 | 27 | def fanin_init_weights_like(tensor): 28 | size = tensor.size() 29 | if len(size) == 2: 30 | fan_in = size[0] 31 | elif len(size) > 2: 32 | fan_in = np.prod(size[1:]) 33 | else: 34 | raise Exception("Shape must be have dimension at least 2.") 35 | bound = 1. / np.sqrt(fan_in) 36 | new_tensor = FloatTensor(tensor.size()) 37 | new_tensor.uniform_(-bound, bound) 38 | return new_tensor 39 | 40 | 41 | """ 42 | GPU wrappers 43 | """ 44 | 45 | _use_gpu = False 46 | device = None 47 | _gpu_id = 0 48 | 49 | 50 | def set_gpu_mode(mode, gpu_id=0): 51 | global _use_gpu 52 | global device 53 | global _gpu_id 54 | _gpu_id = gpu_id 55 | _use_gpu = mode 56 | device = torch.device("cuda:" + str(gpu_id) if _use_gpu else "cpu") 57 | torch.cuda.set_device(device) 58 | 59 | 60 | def gpu_enabled(): 61 | return _use_gpu 62 | 63 | 64 | def set_device(gpu_id): 65 | torch.cuda.set_device(gpu_id) 66 | 67 | 68 | # noinspection PyPep8Naming 69 | def FloatTensor(*args, torch_device=None, **kwargs): 70 | if torch_device is None: 71 | torch_device = device 72 | return torch.FloatTensor(*args, **kwargs, device=torch_device) 73 | 74 | 75 | def from_numpy(*args, **kwargs): 76 | return torch.from_numpy(*args, **kwargs).float().to(device) 77 | 78 | 79 | def get_numpy(tensor): 80 | return tensor.to('cpu').detach().numpy() 81 | 82 | 83 | def zeros(*sizes, torch_device=None, **kwargs): 84 | if torch_device is None: 85 | torch_device = device 86 | return torch.zeros(*sizes, **kwargs, device=torch_device) 87 | 88 | 89 | def ones(*sizes, torch_device=None, **kwargs): 90 | if torch_device is None: 91 | torch_device = device 92 | return torch.ones(*sizes, **kwargs, device=torch_device) 93 | 94 | 95 | def ones_like(*args, torch_device=None, **kwargs): 96 | if torch_device is None: 97 | torch_device = device 98 | return torch.ones_like(*args, **kwargs, device=torch_device) 99 | 100 | 101 | def randn(*args, torch_device=None, **kwargs): 102 | if torch_device is None: 103 | torch_device = device 104 | return torch.randn(*args, **kwargs, device=torch_device) 105 | 106 | 107 | def rand(*args, torch_device=None, **kwargs): 108 | if torch_device is None: 109 | torch_device = device 110 | return torch.rand(*args, **kwargs, device=torch_device) 111 | 112 | 113 | def zeros_like(*args, torch_device=None, **kwargs): 114 | if torch_device is None: 115 | torch_device = device 116 | return torch.zeros_like(*args, **kwargs, device=torch_device) 117 | 118 | 119 | def tensor(*args, torch_device=None, **kwargs): 120 | if torch_device is None: 121 | torch_device = device 122 | return torch.tensor(*args, **kwargs, device=torch_device) 123 | 124 | 125 | def normal(*args, **kwargs): 126 | return torch.normal(*args, **kwargs).to(device) 127 | 128 | 129 | def fast_clip_grad_norm(parameters, max_norm): 130 | r"""Clips gradient norm of an iterable of parameters. 131 | Only support norm_type = 2 132 | max_norm = 0, skip the total norm calculation and return 0 133 | https://pytorch.org/docs/stable/_modules/torch/nn/utils/clip_grad.html#clip_grad_norm_ 134 | Returns: 135 | Total norm of the parameters (viewed as a single vector). 136 | """ 137 | max_norm = float(max_norm) 138 | if abs(max_norm) < 1e-6: # max_norm = 0 139 | return 0 140 | else: 141 | if isinstance(parameters, torch.Tensor): 142 | parameters = [parameters] 143 | parameters = list(filter(lambda p: p.grad is not None, parameters)) 144 | total_norm = torch.stack([(p.grad.detach().pow(2)).sum() for p in parameters]).sum().sqrt().item() 145 | clip_coef = max_norm / (total_norm + 1e-6) 146 | if clip_coef < 1: 147 | for p in parameters: 148 | p.grad.detach().mul_(clip_coef) 149 | return total_norm 150 | -------------------------------------------------------------------------------- /rlkit/core/eval_util.py: -------------------------------------------------------------------------------- 1 | """ 2 | Common evaluation utilities. 3 | """ 4 | 5 | from collections import OrderedDict 6 | from numbers import Number 7 | 8 | import numpy as np 9 | 10 | import rlkit.pythonplusplus as ppp 11 | 12 | 13 | def get_generic_path_information(paths, stat_prefix=''): 14 | """ 15 | Get an OrderedDict with a bunch of statistic names and values. 16 | """ 17 | statistics = OrderedDict() 18 | returns = [sum(path["rewards"]) for path in paths] 19 | 20 | rewards = np.vstack([path["rewards"] for path in paths]) 21 | statistics.update(create_stats_ordered_dict('Rewards', rewards, 22 | stat_prefix=stat_prefix)) 23 | statistics.update(create_stats_ordered_dict('Returns', returns, 24 | stat_prefix=stat_prefix)) 25 | actions = [path["actions"] for path in paths] 26 | if len(actions[0].shape) == 1: 27 | actions = np.hstack([path["actions"] for path in paths]) 28 | else: 29 | actions = np.vstack([path["actions"] for path in paths]) 30 | statistics.update(create_stats_ordered_dict( 31 | 'Actions', actions, stat_prefix=stat_prefix 32 | )) 33 | statistics['Num Paths'] = len(paths) 34 | statistics[stat_prefix + 'Average Returns'] = get_average_returns(paths) 35 | 36 | # rlkit path info 37 | for info_key in ['env_infos', 'agent_infos']: 38 | if info_key in paths[0]: 39 | all_env_infos = [ 40 | ppp.list_of_dicts__to__dict_of_lists(p[info_key]) 41 | for p in paths 42 | ] 43 | for k in all_env_infos[0].keys(): 44 | # final_ks = np.array([info[k][-1] for info in all_env_infos]) 45 | # first_ks = np.array([info[k][0] for info in all_env_infos]) 46 | # all_ks = np.concatenate([info[k] for info in all_env_infos]) 47 | 48 | # statistics.update(create_stats_ordered_dict( 49 | # stat_prefix + k, 50 | # final_ks, 51 | # stat_prefix='{}/final/'.format(info_key), 52 | # )) 53 | # statistics.update(create_stats_ordered_dict( 54 | # stat_prefix + k, 55 | # first_ks, 56 | # stat_prefix='{}/initial/'.format(info_key), 57 | # )) 58 | # statistics.update(create_stats_ordered_dict( 59 | # stat_prefix + k, 60 | # all_ks, 61 | # stat_prefix='{}/'.format(info_key), 62 | # )) 63 | sum_ks = [np.sum(info[k]) for info in all_env_infos] 64 | average_ks = [np.mean(info[k]) for info in all_env_infos] 65 | statistics.update(create_stats_ordered_dict( 66 | stat_prefix + k, 67 | sum_ks, 68 | stat_prefix='{}/sum/'.format(info_key), 69 | )) 70 | statistics.update(create_stats_ordered_dict( 71 | stat_prefix + k, 72 | average_ks, 73 | stat_prefix='{}/average/'.format(info_key), 74 | )) 75 | 76 | return statistics 77 | 78 | 79 | def get_average_returns(paths): 80 | returns = [sum(path["rewards"]) for path in paths] 81 | return np.mean(returns) 82 | 83 | 84 | def create_stats_ordered_dict( 85 | name, 86 | data, 87 | stat_prefix=None, 88 | always_show_all_stats=True, 89 | exclude_max_min=False, 90 | ): 91 | if stat_prefix is not None: 92 | name = "{}{}".format(stat_prefix, name) 93 | if isinstance(data, Number): 94 | return OrderedDict({name: data}) 95 | 96 | if len(data) == 0: 97 | return OrderedDict() 98 | 99 | if isinstance(data, tuple): 100 | ordered_dict = OrderedDict() 101 | for number, d in enumerate(data): 102 | sub_dict = create_stats_ordered_dict( 103 | "{0}_{1}".format(name, number), 104 | d, 105 | ) 106 | ordered_dict.update(sub_dict) 107 | return ordered_dict 108 | 109 | if isinstance(data, list): 110 | try: 111 | iter(data[0]) 112 | except TypeError: 113 | pass 114 | else: 115 | data = np.concatenate(data) 116 | 117 | if (isinstance(data, np.ndarray) and data.size == 1 118 | and not always_show_all_stats): 119 | return OrderedDict({name: float(data)}) 120 | 121 | stats = OrderedDict([ 122 | (name + ' Mean', np.mean(data)), 123 | (name + ' Std', np.std(data)), 124 | ]) 125 | if not exclude_max_min: 126 | stats[name + ' Max'] = np.max(data) 127 | stats[name + ' Min'] = np.min(data) 128 | return stats 129 | -------------------------------------------------------------------------------- /rlkit/data_management/torch_replay_buffer.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | 3 | import numpy as np 4 | import torch 5 | 6 | import rlkit.torch.pytorch_util as ptu 7 | from rlkit.envs.env_utils import get_dim 8 | from rlkit.data_management.replay_buffer import ReplayBuffer 9 | 10 | 11 | class TorchReplayBuffer(ReplayBuffer): 12 | 13 | def __init__(self, max_replay_buffer_size, env, env_info_sizes=None): 14 | observation_dim = get_dim(env.observation_space) 15 | action_dim = get_dim(env.action_space) 16 | 17 | if env_info_sizes is None: 18 | if hasattr(env, 'info_sizes'): 19 | env_info_sizes = env.info_sizes 20 | else: 21 | env_info_sizes = dict() 22 | 23 | self._max_replay_buffer_size = max_replay_buffer_size 24 | self._observations = torch.zeros((max_replay_buffer_size, observation_dim), dtype=torch.float).pin_memory() 25 | # It's a bit memory inefficient to save the observations twice, 26 | # but it makes the code *much* easier since you no longer have to 27 | # worry about termination conditions. 28 | self._next_obs = torch.zeros((max_replay_buffer_size, observation_dim), dtype=torch.float).pin_memory() 29 | self._actions = torch.zeros((max_replay_buffer_size, action_dim), dtype=torch.float).pin_memory() 30 | # Make everything a 2D np array to make it easier for other code to 31 | # reason about the shape of the data 32 | self._rewards = torch.zeros((max_replay_buffer_size, 1), dtype=torch.float).pin_memory() 33 | # self._terminals[i] = a terminal was received at time i 34 | self._terminals = torch.zeros((max_replay_buffer_size, 1), dtype=torch.float).pin_memory() 35 | # Define self._env_infos[key][i] to be the return value of env_info[key] 36 | # at time i 37 | self._env_infos = {} 38 | for key, size in env_info_sizes.items(): 39 | self._env_infos[key] = torch.zeros((max_replay_buffer_size, size), dtype=torch.float).pin_memory() 40 | self._env_info_keys = env_info_sizes.keys() 41 | 42 | self._top = 0 43 | self._size = 0 44 | 45 | if ptu.gpu_enabled(): 46 | # self.stream = torch.cuda.Stream(ptu.device) 47 | self.batch = None 48 | 49 | def add_sample(self, observation, action, reward, next_observation, terminal, env_info, **kwargs): 50 | self._observations[self._top] = torch.from_numpy(observation) 51 | self._actions[self._top] = torch.from_numpy(action) 52 | self._rewards[self._top] = torch.from_numpy(reward) 53 | self._terminals[self._top] = torch.from_numpy(terminal) 54 | self._next_obs[self._top] = torch.from_numpy(next_observation) 55 | 56 | for key in self._env_info_keys: 57 | self._env_infos[key][self._top] = torch.from_numpy(env_info[key]) 58 | self._advance() 59 | 60 | def terminate_episode(self): 61 | pass 62 | 63 | def _advance(self): 64 | self._top = (self._top + 1) % self._max_replay_buffer_size 65 | if self._size < self._max_replay_buffer_size: 66 | self._size += 1 67 | 68 | def random_batch(self, batch_size): 69 | indices = np.random.randint(0, self._size, batch_size) 70 | batch = dict( 71 | observations=self._observations[indices], 72 | actions=self._actions[indices], 73 | rewards=self._rewards[indices], 74 | terminals=self._terminals[indices], 75 | next_observations=self._next_obs[indices], 76 | ) 77 | for key in self._env_info_keys: 78 | assert key not in batch.keys() 79 | batch[key] = self._env_infos[key][indices] 80 | return batch 81 | 82 | def preload(self, batch_size): 83 | try: 84 | self.batch = self.random_batch(batch_size) 85 | except StopIteration: 86 | self.batch = None 87 | return 88 | if ptu.gpu_enabled(): 89 | # with torch.cuda.stream(self.stream): 90 | for k in self.batch: 91 | self.batch[k] = self.batch[k].to(device=ptu.device, non_blocking=True) 92 | 93 | def next_batch(self, batch_size): 94 | # torch.cuda.current_stream(ptu.device).wait_stream(self.stream) 95 | if self.batch is None: 96 | self.preload(batch_size) 97 | batch = self.batch 98 | self.preload(batch_size) 99 | return batch 100 | 101 | def rebuild_env_info_dict(self, idx): 102 | return {key: self._env_infos[key][idx] for key in self._env_info_keys} 103 | 104 | def batch_env_info_dict(self, indices): 105 | return {key: self._env_infos[key][indices] for key in self._env_info_keys} 106 | 107 | def num_steps_can_sample(self): 108 | return self._size 109 | 110 | def get_diagnostics(self): 111 | return OrderedDict([('size', self._size)]) 112 | -------------------------------------------------------------------------------- /rlkit/samplers/data_collector/vec_path_collector.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict, deque 2 | 3 | import numpy as np 4 | 5 | from rlkit.core.eval_util import create_stats_ordered_dict 6 | from rlkit.data_management.path_builder import PathBuilder 7 | from rlkit.envs.vecenv import BaseVectorEnv 8 | from rlkit.samplers.data_collector.base import DataCollector 9 | 10 | 11 | class VecMdpPathCollector(DataCollector): 12 | 13 | def __init__( 14 | self, 15 | env: BaseVectorEnv, 16 | policy, 17 | max_num_epoch_paths_saved=None, 18 | render=False, 19 | render_kwargs=None, 20 | ): 21 | if render_kwargs is None: 22 | render_kwargs = {} 23 | self._env = env 24 | self._env_num = self._env.env_num 25 | self._policy = policy 26 | self._max_num_epoch_paths_saved = max_num_epoch_paths_saved 27 | self._epoch_paths = deque(maxlen=self._max_num_epoch_paths_saved) 28 | self._render = render 29 | self._render_kwargs = render_kwargs 30 | 31 | self._num_steps_total = 0 32 | self._num_paths_total = 0 33 | self._obs = None # cache variable 34 | 35 | def get_epoch_paths(self): 36 | return self._epoch_paths 37 | 38 | def end_epoch(self, epoch): 39 | self._epoch_paths = deque(maxlen=self._max_num_epoch_paths_saved) 40 | self._obs = None 41 | 42 | def get_diagnostics(self): 43 | path_lens = [len(path['actions']) for path in self._epoch_paths] 44 | stats = OrderedDict([ 45 | ('num steps total', self._num_steps_total), 46 | ('num paths total', self._num_paths_total), 47 | ]) 48 | stats.update(create_stats_ordered_dict( 49 | "path length", 50 | path_lens, 51 | always_show_all_stats=True, 52 | )) 53 | return stats 54 | 55 | def get_snapshot(self): 56 | return dict( 57 | env=self._env, 58 | policy=self._policy, 59 | ) 60 | 61 | def collect_new_paths( 62 | self, 63 | max_path_length, 64 | num_paths, 65 | discard_incomplete_paths, 66 | ): 67 | if self._obs is None: 68 | self._start_new_rollout() 69 | 70 | num_paths_collected = 0 71 | while num_paths_collected < num_paths: 72 | 73 | actions = self._policy.get_actions(self._obs) 74 | next_obs, rewards, terminals, env_infos = self._env.step(actions) 75 | 76 | if self._render: 77 | self._env.render(**self._render_kwargs) 78 | 79 | # unzip vectorized data 80 | for env_idx, ( 81 | path_builder, 82 | next_ob, 83 | action, 84 | reward, 85 | terminal, 86 | env_info, 87 | ) in enumerate(zip( 88 | self._current_path_builders, 89 | next_obs, 90 | actions, 91 | rewards, 92 | terminals, 93 | env_infos, 94 | )): 95 | obs = self._obs[env_idx].copy() 96 | terminal = np.array([terminal]) 97 | reward = np.array([reward]) 98 | # store path obs 99 | path_builder.add_all( 100 | observations=obs, 101 | actions=action, 102 | rewards=reward, 103 | next_observations=next_ob, 104 | terminals=terminal, 105 | agent_infos={}, # policy.get_actions doesn't return agent_info 106 | env_infos=env_info, 107 | ) 108 | self._obs[env_idx] = next_ob 109 | if terminal or len(path_builder) >= max_path_length: 110 | self._handle_rollout_ending(path_builder, max_path_length, discard_incomplete_paths) 111 | self._start_new_rollout(env_idx) 112 | num_paths_collected += 1 113 | 114 | def _start_new_rollout(self, env_idx=None): 115 | if env_idx is None: 116 | self._current_path_builders = [PathBuilder() for _ in range(self._env_num)] 117 | self._obs = self._env.reset() 118 | else: 119 | 120 | self._current_path_builders[env_idx] = PathBuilder() 121 | self._obs[env_idx] = self._env.reset(env_idx)[env_idx] 122 | 123 | def _handle_rollout_ending(self, path_builder, max_path_length, discard_incomplete_paths): 124 | if len(path_builder) > 0: 125 | path = path_builder.get_all_stacked() 126 | path_len = len(path['actions']) 127 | if (path_len != max_path_length and not path['terminals'][-1] and discard_incomplete_paths): 128 | return 129 | self._epoch_paths.append(path) 130 | self._num_paths_total += 1 131 | self._num_steps_total += path_len 132 | -------------------------------------------------------------------------------- /dsac.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import torch 4 | 5 | import rlkit.torch.pytorch_util as ptu 6 | import yaml 7 | from rlkit.data_management.torch_replay_buffer import TorchReplayBuffer 8 | from rlkit.envs import make_env 9 | from rlkit.envs.vecenv import SubprocVectorEnv, VectorEnv 10 | from rlkit.launchers.launcher_util import set_seed, setup_logger 11 | from rlkit.samplers.data_collector import (VecMdpPathCollector, VecMdpStepCollector) 12 | from rlkit.torch.dsac.dsac import DSACTrainer 13 | from rlkit.torch.dsac.networks import QuantileMlp, softmax 14 | from rlkit.torch.networks import FlattenMlp 15 | from rlkit.torch.sac.policies import MakeDeterministic, TanhGaussianPolicy 16 | from rlkit.torch.torch_rl_algorithm import TorchVecOnlineRLAlgorithm 17 | 18 | torch.set_num_threads(4) 19 | torch.set_num_interop_threads(4) 20 | 21 | 22 | def experiment(variant): 23 | dummy_env = make_env(variant['env']) 24 | obs_dim = dummy_env.observation_space.low.size 25 | action_dim = dummy_env.action_space.low.size 26 | expl_env = VectorEnv([lambda: make_env(variant['env']) for _ in range(variant['expl_env_num'])]) 27 | expl_env.seed(variant["seed"]) 28 | expl_env.action_space.seed(variant["seed"]) 29 | eval_env = SubprocVectorEnv([lambda: make_env(variant['env']) for _ in range(variant['eval_env_num'])]) 30 | eval_env.seed(variant["seed"]) 31 | 32 | M = variant['layer_size'] 33 | num_quantiles = variant['num_quantiles'] 34 | 35 | zf1 = QuantileMlp( 36 | input_size=obs_dim + action_dim, 37 | output_size=1, 38 | num_quantiles=num_quantiles, 39 | hidden_sizes=[M, M], 40 | ) 41 | zf2 = QuantileMlp( 42 | input_size=obs_dim + action_dim, 43 | output_size=1, 44 | num_quantiles=num_quantiles, 45 | hidden_sizes=[M, M], 46 | ) 47 | target_zf1 = QuantileMlp( 48 | input_size=obs_dim + action_dim, 49 | output_size=1, 50 | num_quantiles=num_quantiles, 51 | hidden_sizes=[M, M], 52 | ) 53 | target_zf2 = QuantileMlp( 54 | input_size=obs_dim + action_dim, 55 | output_size=1, 56 | num_quantiles=num_quantiles, 57 | hidden_sizes=[M, M], 58 | ) 59 | policy = TanhGaussianPolicy( 60 | obs_dim=obs_dim, 61 | action_dim=action_dim, 62 | hidden_sizes=[M, M], 63 | ) 64 | eval_policy = MakeDeterministic(policy) 65 | target_policy = TanhGaussianPolicy( 66 | obs_dim=obs_dim, 67 | action_dim=action_dim, 68 | hidden_sizes=[M, M], 69 | ) 70 | # fraction proposal network 71 | fp = target_fp = None 72 | if variant['trainer_kwargs'].get('tau_type') == 'fqf': 73 | fp = FlattenMlp( 74 | input_size=obs_dim + action_dim, 75 | output_size=num_quantiles, 76 | hidden_sizes=[M // 2, M // 2], 77 | output_activation=softmax, 78 | ) 79 | target_fp = FlattenMlp( 80 | input_size=obs_dim + action_dim, 81 | output_size=num_quantiles, 82 | hidden_sizes=[M // 2, M // 2], 83 | output_activation=softmax, 84 | ) 85 | eval_path_collector = VecMdpPathCollector( 86 | eval_env, 87 | eval_policy, 88 | ) 89 | expl_path_collector = VecMdpStepCollector( 90 | expl_env, 91 | policy, 92 | ) 93 | replay_buffer = TorchReplayBuffer( 94 | variant['replay_buffer_size'], 95 | dummy_env, 96 | ) 97 | trainer = DSACTrainer( 98 | env=dummy_env, 99 | policy=policy, 100 | target_policy=target_policy, 101 | zf1=zf1, 102 | zf2=zf2, 103 | target_zf1=target_zf1, 104 | target_zf2=target_zf2, 105 | fp=fp, 106 | target_fp=target_fp, 107 | num_quantiles=num_quantiles, 108 | **variant['trainer_kwargs'], 109 | ) 110 | algorithm = TorchVecOnlineRLAlgorithm( 111 | trainer=trainer, 112 | exploration_env=expl_env, 113 | evaluation_env=eval_env, 114 | exploration_data_collector=expl_path_collector, 115 | evaluation_data_collector=eval_path_collector, 116 | replay_buffer=replay_buffer, 117 | **variant['algorithm_kwargs'], 118 | ) 119 | algorithm.to(ptu.device) 120 | algorithm.train() 121 | 122 | 123 | if __name__ == "__main__": 124 | parser = argparse.ArgumentParser(description='Distributional Soft Actor Critic') 125 | parser.add_argument('--config', type=str, default="configs/lunarlander.yaml") 126 | parser.add_argument('--gpu', type=int, default=0, help="using cpu with -1") 127 | parser.add_argument('--seed', type=int, default=0) 128 | args = parser.parse_args() 129 | with open(args.config, 'r', encoding="utf-8") as f: 130 | variant = yaml.load(f, Loader=yaml.FullLoader) 131 | variant["seed"] = args.seed 132 | log_prefix = "_".join(["dsac", variant["env"][:-3].lower(), str(variant["version"])]) 133 | setup_logger(log_prefix, variant=variant, seed=args.seed) 134 | if args.gpu >= 0: 135 | ptu.set_gpu_mode(True, args.gpu) 136 | set_seed(args.seed) 137 | experiment(variant) 138 | -------------------------------------------------------------------------------- /rlkit/core/rl_algorithm.py: -------------------------------------------------------------------------------- 1 | import abc 2 | from collections import OrderedDict 3 | 4 | import gtimer as gt 5 | 6 | from rlkit.core import logger, eval_util 7 | from rlkit.data_management.replay_buffer import ReplayBuffer 8 | from rlkit.samplers.data_collector import DataCollector 9 | 10 | 11 | def _get_epoch_timings(): 12 | times_itrs = gt.get_times().stamps.itrs 13 | times = OrderedDict() 14 | epoch_time = 0 15 | for key in sorted(times_itrs): 16 | time = times_itrs[key][-1] 17 | epoch_time += time 18 | times['time/{} (s)'.format(key)] = time 19 | times['time/epoch (s)'] = epoch_time 20 | times['time/total (s)'] = gt.get_times().total 21 | return times 22 | 23 | 24 | class BaseRLAlgorithm(object, metaclass=abc.ABCMeta): 25 | def __init__( 26 | self, 27 | trainer, 28 | exploration_env, 29 | evaluation_env, 30 | exploration_data_collector: DataCollector, 31 | evaluation_data_collector: DataCollector, 32 | replay_buffer: ReplayBuffer, 33 | ): 34 | self.trainer = trainer 35 | self.expl_env = exploration_env 36 | self.eval_env = evaluation_env 37 | self.expl_data_collector = exploration_data_collector 38 | self.eval_data_collector = evaluation_data_collector 39 | self.replay_buffer = replay_buffer 40 | self._start_epoch = 0 41 | 42 | self.post_epoch_funcs = [] 43 | 44 | def train(self, start_epoch=0): 45 | self._start_epoch = start_epoch 46 | self._train() 47 | 48 | def _train(self): 49 | """ 50 | Train model. 51 | """ 52 | raise NotImplementedError('_train must implemented by inherited class') 53 | 54 | def _end_epoch(self, epoch): 55 | snapshot = self._get_snapshot() 56 | logger.save_itr_params(epoch, snapshot) 57 | gt.stamp('saving') 58 | self._log_stats(epoch) 59 | 60 | self.expl_data_collector.end_epoch(epoch) 61 | self.eval_data_collector.end_epoch(epoch) 62 | self.replay_buffer.end_epoch(epoch) 63 | self.trainer.end_epoch(epoch) 64 | 65 | for post_epoch_func in self.post_epoch_funcs: 66 | post_epoch_func(self, epoch) 67 | 68 | def _get_snapshot(self): 69 | snapshot = {} 70 | for k, v in self.trainer.get_snapshot().items(): 71 | snapshot['trainer/' + k] = v 72 | for k, v in self.expl_data_collector.get_snapshot().items(): 73 | snapshot['exploration/' + k] = v 74 | for k, v in self.eval_data_collector.get_snapshot().items(): 75 | snapshot['evaluation/' + k] = v 76 | for k, v in self.replay_buffer.get_snapshot().items(): 77 | snapshot['replay_buffer/' + k] = v 78 | return snapshot 79 | 80 | def _log_stats(self, epoch): 81 | logger.log("Epoch {} finished".format(epoch), with_timestamp=True) 82 | 83 | """ 84 | Replay Buffer 85 | """ 86 | logger.record_dict( 87 | self.replay_buffer.get_diagnostics(), 88 | prefix='replay_buffer/' 89 | ) 90 | 91 | """ 92 | Trainer 93 | """ 94 | logger.record_dict(self.trainer.get_diagnostics(), prefix='trainer/') 95 | 96 | """ 97 | Exploration 98 | """ 99 | logger.record_dict( 100 | self.expl_data_collector.get_diagnostics(), 101 | prefix='exploration/' 102 | ) 103 | expl_paths = self.expl_data_collector.get_epoch_paths() 104 | if hasattr(self.expl_env, 'get_diagnostics'): 105 | logger.record_dict( 106 | self.expl_env.get_diagnostics(expl_paths), 107 | prefix='exploration/', 108 | ) 109 | logger.record_dict( 110 | eval_util.get_generic_path_information(expl_paths), 111 | prefix="exploration/", 112 | ) 113 | """ 114 | Evaluation 115 | """ 116 | logger.record_dict( 117 | self.eval_data_collector.get_diagnostics(), 118 | prefix='evaluation/', 119 | ) 120 | eval_paths = self.eval_data_collector.get_epoch_paths() 121 | if hasattr(self.eval_env, 'get_diagnostics'): 122 | logger.record_dict( 123 | self.eval_env.get_diagnostics(eval_paths), 124 | prefix='evaluation/', 125 | ) 126 | logger.record_dict( 127 | eval_util.get_generic_path_information(eval_paths), 128 | prefix="evaluation/", 129 | ) 130 | 131 | """ 132 | Misc 133 | """ 134 | gt.stamp('logging') 135 | logger.record_dict(_get_epoch_timings()) 136 | logger.record_tabular('Epoch', epoch) 137 | logger.dump_tabular(with_prefix=False, with_timestamp=False) 138 | 139 | @abc.abstractmethod 140 | def training_mode(self, mode): 141 | """ 142 | Set training mode to `mode`. 143 | :param mode: If True, training will happen (e.g. set the dropout 144 | probabilities to not all ones). 145 | """ 146 | pass 147 | -------------------------------------------------------------------------------- /td4.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import torch 4 | 5 | import rlkit.torch.pytorch_util as ptu 6 | import yaml 7 | from rlkit.data_management.torch_replay_buffer import TorchReplayBuffer 8 | from rlkit.envs import make_env 9 | from rlkit.envs.vecenv import SubprocVectorEnv, VectorEnv 10 | from rlkit.exploration_strategies.base import \ 11 | PolicyWrappedWithExplorationStrategy 12 | from rlkit.exploration_strategies.gaussian_strategy import GaussianStrategy 13 | from rlkit.launchers.launcher_util import set_seed, setup_logger 14 | from rlkit.samplers.data_collector import (VecMdpPathCollector, VecMdpStepCollector) 15 | from rlkit.torch.dsac.networks import QuantileMlp, softmax 16 | from rlkit.torch.networks import FlattenMlp, TanhMlpPolicy 17 | from rlkit.torch.td4.td4 import TD4Trainer 18 | from rlkit.torch.torch_rl_algorithm import TorchVecOnlineRLAlgorithm 19 | 20 | torch.set_num_threads(4) 21 | torch.set_num_interop_threads(4) 22 | 23 | 24 | def experiment(variant): 25 | dummy_env = make_env(variant['env']) 26 | obs_dim = dummy_env.observation_space.low.size 27 | action_dim = dummy_env.action_space.low.size 28 | expl_env = VectorEnv([lambda: make_env(variant['env']) for _ in range(variant['expl_env_num'])]) 29 | expl_env.seed(variant["seed"]) 30 | expl_env.action_space.seed(variant["seed"]) 31 | eval_env = SubprocVectorEnv([lambda: make_env(variant['env']) for _ in range(variant['eval_env_num'])]) 32 | eval_env.seed(variant["seed"]) 33 | 34 | M = variant['layer_size'] 35 | num_quantiles = variant['num_quantiles'] 36 | 37 | zf1 = QuantileMlp( 38 | input_size=obs_dim + action_dim, 39 | output_size=1, 40 | num_quantiles=num_quantiles, 41 | hidden_sizes=[M, M], 42 | ) 43 | zf2 = QuantileMlp( 44 | input_size=obs_dim + action_dim, 45 | output_size=1, 46 | num_quantiles=num_quantiles, 47 | hidden_sizes=[M, M], 48 | ) 49 | target_zf1 = QuantileMlp( 50 | input_size=obs_dim + action_dim, 51 | output_size=1, 52 | num_quantiles=num_quantiles, 53 | hidden_sizes=[M, M], 54 | ) 55 | target_zf2 = QuantileMlp( 56 | input_size=obs_dim + action_dim, 57 | output_size=1, 58 | num_quantiles=num_quantiles, 59 | hidden_sizes=[M, M], 60 | ) 61 | policy = TanhMlpPolicy( 62 | input_size=obs_dim, 63 | output_size=action_dim, 64 | hidden_sizes=[M, M], 65 | ) 66 | target_policy = TanhMlpPolicy( 67 | input_size=obs_dim, 68 | output_size=action_dim, 69 | hidden_sizes=[M, M], 70 | ) 71 | es = GaussianStrategy( 72 | action_space=dummy_env.action_space, 73 | max_sigma=0.1, 74 | min_sigma=0.1, # Constant sigma 75 | ) 76 | exploration_policy = PolicyWrappedWithExplorationStrategy( 77 | exploration_strategy=es, 78 | policy=policy, 79 | ) 80 | # fraction proposal network 81 | fp = target_fp = None 82 | if variant['trainer_kwargs'].get('risk_type') == 'fqf': 83 | fp = FlattenMlp( 84 | input_size=obs_dim + action_dim, 85 | output_size=num_quantiles, 86 | hidden_sizes=[M // 2, M // 2], 87 | output_activation=softmax, 88 | ) 89 | target_fp = FlattenMlp( 90 | input_size=obs_dim + action_dim, 91 | output_size=num_quantiles, 92 | hidden_sizes=[M // 2, M // 2], 93 | output_activation=softmax, 94 | ) 95 | eval_path_collector = VecMdpPathCollector( 96 | eval_env, 97 | policy, 98 | ) 99 | expl_path_collector = VecMdpStepCollector( 100 | expl_env, 101 | exploration_policy, 102 | ) 103 | replay_buffer = TorchReplayBuffer( 104 | variant['replay_buffer_size'], 105 | dummy_env, 106 | ) 107 | trainer = TD4Trainer( 108 | policy=policy, 109 | target_policy=target_policy, 110 | zf1=zf1, 111 | zf2=zf2, 112 | target_zf1=target_zf1, 113 | target_zf2=target_zf2, 114 | fp=fp, 115 | target_fp=target_fp, 116 | num_quantiles=num_quantiles, 117 | **variant['trainer_kwargs'], 118 | ) 119 | algorithm = TorchVecOnlineRLAlgorithm( 120 | trainer=trainer, 121 | exploration_env=expl_env, 122 | evaluation_env=eval_env, 123 | exploration_data_collector=expl_path_collector, 124 | evaluation_data_collector=eval_path_collector, 125 | replay_buffer=replay_buffer, 126 | **variant['algorithm_kwargs'], 127 | ) 128 | algorithm.to(ptu.device) 129 | algorithm.train() 130 | 131 | 132 | if __name__ == "__main__": 133 | parser = argparse.ArgumentParser(description='TD4') 134 | parser.add_argument('--config', type=str, default="configs/lunarlander.yaml") 135 | parser.add_argument('--gpu', type=int, default=0, help="using cpu with -1") 136 | parser.add_argument('--seed', type=int, default=0) 137 | args = parser.parse_args() 138 | with open(args.config, 'r', encoding="utf-8") as f: 139 | variant = yaml.load(f, Loader=yaml.FullLoader) 140 | variant["seed"] = args.seed 141 | log_prefix = "_".join(["td4", variant["env"][:-3].lower(), str(variant["version"])]) 142 | setup_logger(log_prefix, variant=variant, seed=args.seed) 143 | if args.gpu >= 0: 144 | ptu.set_gpu_mode(True, args.gpu) 145 | set_seed(args.seed) 146 | experiment(variant) 147 | -------------------------------------------------------------------------------- /rlkit/samplers/data_collector/vec_step_collector.py: -------------------------------------------------------------------------------- 1 | from collections import deque, OrderedDict 2 | 3 | import numpy as np 4 | 5 | # from rlkit.core.eval_util import create_stats_ordered_dict 6 | from rlkit.data_management.path_builder import PathBuilder 7 | from rlkit.samplers.data_collector.base import DataCollector 8 | from rlkit.envs.vecenv import BaseVectorEnv 9 | 10 | 11 | class VecMdpStepCollector(DataCollector): 12 | 13 | def __init__( 14 | self, 15 | env: BaseVectorEnv, 16 | policy, 17 | max_num_epoch_paths_saved=None, 18 | render=False, 19 | render_kwargs=None, 20 | ): 21 | if render_kwargs is None: 22 | render_kwargs = {} 23 | self._env = env 24 | self._env_num = self._env.env_num 25 | self._policy = policy 26 | self._max_num_epoch_paths_saved = max_num_epoch_paths_saved 27 | self._epoch_paths = deque(maxlen=self._max_num_epoch_paths_saved) 28 | self._render = render 29 | self._render_kwargs = render_kwargs 30 | 31 | self._num_steps_total = 0 32 | self._num_paths_total = 0 33 | self._obs = None # cache variable 34 | 35 | def get_epoch_paths(self): 36 | return self._epoch_paths 37 | 38 | def end_epoch(self, epoch): 39 | self._epoch_paths = deque(maxlen=self._max_num_epoch_paths_saved) 40 | 41 | def reset(self): 42 | self._epoch_paths = deque(maxlen=self._max_num_epoch_paths_saved) 43 | self._obs = None 44 | 45 | def get_diagnostics(self): 46 | stats = OrderedDict([ 47 | ('num steps total', self._num_steps_total), 48 | ('num paths total', self._num_paths_total), 49 | ]) 50 | # path_lens = [len(path['actions']) for path in self._epoch_paths] 51 | # stats.update(create_stats_ordered_dict( 52 | # "path length", 53 | # path_lens, 54 | # always_show_all_stats=True, 55 | # )) 56 | return stats 57 | 58 | def get_snapshot(self): 59 | return dict( 60 | env=self._env, 61 | policy=self._policy, 62 | ) 63 | 64 | def collect_new_steps( 65 | self, 66 | max_path_length, 67 | num_steps, 68 | discard_incomplete_paths, 69 | random=False, 70 | ): 71 | steps_collector = PathBuilder() 72 | for _ in range(num_steps): 73 | self.collect_one_step( 74 | max_path_length, 75 | discard_incomplete_paths, 76 | steps_collector, 77 | random, 78 | ) 79 | return [steps_collector.get_all_stacked()] 80 | 81 | def collect_one_step( 82 | self, 83 | max_path_length, 84 | discard_incomplete_paths, 85 | steps_collector: PathBuilder = None, 86 | random=False, 87 | ): 88 | if self._obs is None: 89 | self._start_new_rollout() 90 | if random: 91 | actions = [self._env.action_space.sample() for _ in range(self._env_num)] 92 | else: 93 | actions = self._policy.get_actions(self._obs) 94 | next_obs, rewards, terminals, env_infos = self._env.step(actions) 95 | 96 | if self._render: 97 | self._env.render(**self._render_kwargs) 98 | 99 | # unzip vectorized data 100 | for env_idx, ( 101 | path_builder, 102 | next_ob, 103 | action, 104 | reward, 105 | terminal, 106 | env_info, 107 | ) in enumerate(zip( 108 | self._current_path_builders, 109 | next_obs, 110 | actions, 111 | rewards, 112 | terminals, 113 | env_infos, 114 | )): 115 | obs = self._obs[env_idx].copy() 116 | terminal = np.array([terminal]) 117 | reward = np.array([reward]) 118 | # store path obs 119 | path_builder.add_all( 120 | observations=obs, 121 | actions=action, 122 | rewards=reward, 123 | next_observations=next_ob, 124 | terminals=terminal, 125 | agent_infos={}, # policy.get_actions doesn't return agent_info 126 | env_infos=env_info, 127 | ) 128 | if steps_collector is not None: 129 | steps_collector.add_all( 130 | observations=obs, 131 | actions=action, 132 | rewards=reward, 133 | next_observations=next_ob, 134 | terminals=terminal, 135 | agent_infos={}, # policy.get_actions doesn't return agent_info 136 | env_infos=env_info, 137 | ) 138 | self._obs[env_idx] = next_ob 139 | if terminal or len(path_builder) >= max_path_length: 140 | self._handle_rollout_ending(path_builder, max_path_length, discard_incomplete_paths) 141 | self._start_new_rollout(env_idx) 142 | 143 | def _start_new_rollout(self, env_idx=None): 144 | if env_idx is None: 145 | self._current_path_builders = [PathBuilder() for _ in range(self._env_num)] 146 | self._obs = self._env.reset() 147 | else: 148 | self._current_path_builders[env_idx] = PathBuilder() 149 | self._obs[env_idx] = self._env.reset(env_idx)[env_idx] 150 | 151 | def _handle_rollout_ending(self, path_builder, max_path_length, discard_incomplete_paths): 152 | if len(path_builder) > 0: 153 | path = path_builder.get_all_stacked() 154 | path_len = len(path['actions']) 155 | if (path_len != max_path_length and not path['terminals'][-1] and discard_incomplete_paths): 156 | return 157 | self._epoch_paths.append(path) 158 | self._num_paths_total += 1 159 | self._num_steps_total += path_len 160 | -------------------------------------------------------------------------------- /rlkit/core/vec_online_rl_algorithm.py: -------------------------------------------------------------------------------- 1 | import abc 2 | 3 | import gtimer as gt 4 | from rlkit.core import eval_util, logger 5 | from rlkit.core.rl_algorithm import BaseRLAlgorithm, _get_epoch_timings 6 | from rlkit.data_management.torch_replay_buffer import TorchReplayBuffer 7 | from rlkit.samplers.data_collector import (VecMdpPathCollector, 8 | VecMdpStepCollector) 9 | 10 | 11 | class VecOnlineRLAlgorithm(BaseRLAlgorithm, metaclass=abc.ABCMeta): 12 | def __init__( 13 | self, 14 | trainer, 15 | exploration_env, 16 | evaluation_env, 17 | exploration_data_collector: VecMdpStepCollector, 18 | evaluation_data_collector: VecMdpPathCollector, 19 | replay_buffer: TorchReplayBuffer, 20 | batch_size, 21 | max_path_length, 22 | num_epochs, 23 | num_eval_paths_per_epoch, 24 | num_expl_steps_per_train_loop, 25 | num_trains_per_train_loop, 26 | num_train_loops_per_epoch=1, 27 | min_num_steps_before_training=0, 28 | ): 29 | super().__init__( 30 | trainer, 31 | exploration_env, 32 | evaluation_env, 33 | exploration_data_collector, 34 | evaluation_data_collector, 35 | replay_buffer, 36 | ) 37 | self.batch_size = batch_size 38 | self.max_path_length = max_path_length 39 | self.num_epochs = num_epochs 40 | self.num_eval_paths_per_epoch = num_eval_paths_per_epoch 41 | self.num_trains_per_train_loop = num_trains_per_train_loop 42 | self.num_train_loops_per_epoch = num_train_loops_per_epoch 43 | self.num_expl_steps_per_train_loop = num_expl_steps_per_train_loop 44 | self.min_num_steps_before_training = min_num_steps_before_training 45 | 46 | assert self.num_trains_per_train_loop >= self.num_expl_steps_per_train_loop, \ 47 | 'Online training presumes num_trains_per_train_loop >= num_expl_steps_per_train_loop' 48 | 49 | def _train(self): 50 | self.training_mode(False) 51 | if self.min_num_steps_before_training > 0: 52 | init_expl_paths = self.expl_data_collector.collect_new_steps( 53 | self.max_path_length, 54 | self.min_num_steps_before_training // self.expl_env.env_num, 55 | discard_incomplete_paths=False, 56 | random=True, # whether random sample from action_space 57 | ) 58 | self.replay_buffer.add_paths(init_expl_paths) 59 | self.expl_data_collector.end_epoch(-1) 60 | gt.stamp('initial exploration', unique=True) 61 | 62 | num_trains_per_expl_step = self.num_trains_per_train_loop // self.num_expl_steps_per_train_loop 63 | num_trains_per_expl_step *= self.expl_env.env_num 64 | 65 | train_data = self.replay_buffer.next_batch(self.batch_size) 66 | for epoch in gt.timed_for( 67 | range(self._start_epoch, self.num_epochs), 68 | save_itrs=True, 69 | ): 70 | for _ in range(self.num_train_loops_per_epoch): 71 | for _ in range(self.num_expl_steps_per_train_loop // self.expl_env.env_num): 72 | new_expl_steps = self.expl_data_collector.collect_new_steps( 73 | self.max_path_length, 74 | 1, # num steps 75 | discard_incomplete_paths=False, 76 | ) 77 | gt.stamp('exploration sampling', unique=False) 78 | self.replay_buffer.add_paths(new_expl_steps) 79 | gt.stamp('data storing', unique=False) 80 | 81 | self.training_mode(True) 82 | for _ in range(num_trains_per_expl_step): 83 | self.trainer.train(train_data) 84 | gt.stamp('training', unique=False) 85 | train_data = self.replay_buffer.next_batch(self.batch_size) 86 | gt.stamp('data sampling', unique=False) 87 | self.training_mode(False) 88 | 89 | self.eval_data_collector.collect_new_paths( 90 | self.max_path_length, 91 | self.num_eval_paths_per_epoch, 92 | discard_incomplete_paths=True, 93 | ) 94 | gt.stamp('evaluation sampling') 95 | 96 | self._end_epoch(epoch) 97 | 98 | def _get_snapshot(self): 99 | snapshot = {} 100 | for k, v in self.trainer.get_snapshot().items(): 101 | snapshot['trainer/' + k] = v 102 | for k, v in self.replay_buffer.get_snapshot().items(): 103 | snapshot['replay_buffer/' + k] = v 104 | return snapshot 105 | 106 | def _log_stats(self, epoch): 107 | logger.log("Epoch {} finished".format(epoch), with_timestamp=True) 108 | 109 | """ 110 | Replay Buffer 111 | """ 112 | logger.record_dict( 113 | self.replay_buffer.get_diagnostics(), 114 | prefix='replay_buffer/' 115 | ) 116 | 117 | """ 118 | Trainer 119 | """ 120 | logger.record_dict(self.trainer.get_diagnostics(), prefix='trainer/') 121 | 122 | """ 123 | Exploration 124 | """ 125 | logger.record_dict( 126 | self.expl_data_collector.get_diagnostics(), 127 | prefix='exploration/' 128 | ) 129 | """ 130 | Evaluation 131 | """ 132 | logger.record_dict( 133 | self.eval_data_collector.get_diagnostics(), 134 | prefix='evaluation/', 135 | ) 136 | eval_paths = self.eval_data_collector.get_epoch_paths() 137 | logger.record_dict( 138 | eval_util.get_generic_path_information(eval_paths), 139 | prefix="evaluation/", 140 | ) 141 | 142 | """ 143 | Misc 144 | """ 145 | gt.stamp('logging') 146 | logger.record_dict(_get_epoch_timings()) 147 | logger.record_tabular('Epoch', epoch) 148 | logger.dump_tabular(with_prefix=False, with_timestamp=False) 149 | -------------------------------------------------------------------------------- /rlkit/samplers/data_collector/path_collector.py: -------------------------------------------------------------------------------- 1 | from collections import deque, OrderedDict 2 | 3 | from rlkit.core.eval_util import create_stats_ordered_dict 4 | from rlkit.samplers.rollout_functions import rollout, multitask_rollout 5 | from rlkit.samplers.data_collector.base import PathCollector 6 | 7 | 8 | class MdpPathCollector(PathCollector): 9 | 10 | def __init__( 11 | self, 12 | env, 13 | policy, 14 | max_num_epoch_paths_saved=None, 15 | render=False, 16 | render_kwargs=None, 17 | ): 18 | if render_kwargs is None: 19 | render_kwargs = {} 20 | self._env = env 21 | self._policy = policy 22 | self._max_num_epoch_paths_saved = max_num_epoch_paths_saved 23 | self._epoch_paths = deque(maxlen=self._max_num_epoch_paths_saved) 24 | self._render = render 25 | self._render_kwargs = render_kwargs 26 | 27 | self._num_steps_total = 0 28 | self._num_paths_total = 0 29 | 30 | def collect_new_paths( 31 | self, 32 | max_path_length, 33 | num_steps, 34 | discard_incomplete_paths, 35 | ): 36 | paths = [] 37 | num_steps_collected = 0 38 | while num_steps_collected < num_steps: 39 | max_path_length_this_loop = min( # Do not go over num_steps 40 | max_path_length, 41 | num_steps - num_steps_collected, 42 | ) 43 | path = rollout( 44 | self._env, 45 | self._policy, 46 | max_path_length=max_path_length_this_loop, 47 | ) 48 | path_len = len(path['actions']) 49 | if (path_len != max_path_length and not path['terminals'][-1] and discard_incomplete_paths): 50 | break 51 | num_steps_collected += path_len 52 | paths.append(path) 53 | self._num_paths_total += len(paths) 54 | self._num_steps_total += num_steps_collected 55 | self._epoch_paths.extend(paths) 56 | return paths 57 | 58 | def get_epoch_paths(self): 59 | return self._epoch_paths 60 | 61 | def end_epoch(self, epoch): 62 | self._epoch_paths = deque(maxlen=self._max_num_epoch_paths_saved) 63 | 64 | def get_diagnostics(self): 65 | path_lens = [len(path['actions']) for path in self._epoch_paths] 66 | stats = OrderedDict([ 67 | ('num steps total', self._num_steps_total), 68 | ('num paths total', self._num_paths_total), 69 | ]) 70 | stats.update(create_stats_ordered_dict( 71 | "path length", 72 | path_lens, 73 | always_show_all_stats=True, 74 | )) 75 | return stats 76 | 77 | def get_snapshot(self): 78 | return dict( 79 | env=self._env, 80 | policy=self._policy, 81 | ) 82 | 83 | 84 | class EvalPathCollector(MdpPathCollector): 85 | 86 | def collect_new_paths( 87 | self, 88 | max_path_length, 89 | num_paths, 90 | ): 91 | paths = [] 92 | num_steps_collected = 0 93 | for _ in range(num_paths): 94 | path = rollout( 95 | self._env, 96 | self._policy, 97 | max_path_length=max_path_length, 98 | ) 99 | path_len = len(path['actions']) 100 | num_steps_collected += path_len 101 | paths.append(path) 102 | self._num_paths_total += len(paths) 103 | self._num_steps_total += num_steps_collected 104 | self._epoch_paths.extend(paths) 105 | return paths 106 | 107 | 108 | class GoalConditionedPathCollector(PathCollector): 109 | 110 | def __init__( 111 | self, 112 | env, 113 | policy, 114 | max_num_epoch_paths_saved=None, 115 | render=False, 116 | render_kwargs=None, 117 | observation_key='observation', 118 | desired_goal_key='desired_goal', 119 | ): 120 | if render_kwargs is None: 121 | render_kwargs = {} 122 | self._env = env 123 | self._policy = policy 124 | self._max_num_epoch_paths_saved = max_num_epoch_paths_saved 125 | self._render = render 126 | self._render_kwargs = render_kwargs 127 | self._epoch_paths = deque(maxlen=self._max_num_epoch_paths_saved) 128 | self._observation_key = observation_key 129 | self._desired_goal_key = desired_goal_key 130 | 131 | self._num_steps_total = 0 132 | self._num_paths_total = 0 133 | 134 | def collect_new_paths( 135 | self, 136 | max_path_length, 137 | num_steps, 138 | discard_incomplete_paths, 139 | ): 140 | paths = [] 141 | num_steps_collected = 0 142 | while num_steps_collected < num_steps: 143 | max_path_length_this_loop = min( # Do not go over num_steps 144 | max_path_length, 145 | num_steps - num_steps_collected, 146 | ) 147 | path = multitask_rollout( 148 | self._env, 149 | self._policy, 150 | max_path_length=max_path_length_this_loop, 151 | render=self._render, 152 | render_kwargs=self._render_kwargs, 153 | observation_key=self._observation_key, 154 | desired_goal_key=self._desired_goal_key, 155 | return_dict_obs=True, 156 | ) 157 | path_len = len(path['actions']) 158 | if (path_len != max_path_length and not path['terminals'][-1] and discard_incomplete_paths): 159 | break 160 | num_steps_collected += path_len 161 | paths.append(path) 162 | self._num_paths_total += len(paths) 163 | self._num_steps_total += num_steps_collected 164 | self._epoch_paths.extend(paths) 165 | return paths 166 | 167 | def get_epoch_paths(self): 168 | return self._epoch_paths 169 | 170 | def end_epoch(self, epoch): 171 | self._epoch_paths = deque(maxlen=self._max_num_epoch_paths_saved) 172 | 173 | def get_diagnostics(self): 174 | path_lens = [len(path['actions']) for path in self._epoch_paths] 175 | stats = OrderedDict([ 176 | ('num steps total', self._num_steps_total), 177 | ('num paths total', self._num_paths_total), 178 | ]) 179 | stats.update(create_stats_ordered_dict( 180 | "path length", 181 | path_lens, 182 | always_show_all_stats=True, 183 | )) 184 | return stats 185 | 186 | def get_snapshot(self): 187 | return dict( 188 | env=self._env, 189 | policy=self._policy, 190 | observation_key=self._observation_key, 191 | desired_goal_key=self._desired_goal_key, 192 | ) 193 | -------------------------------------------------------------------------------- /rlkit/torch/td3/td3.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | 3 | import torch 4 | import torch.optim as optim 5 | from torch import nn as nn 6 | 7 | import rlkit.torch.pytorch_util as ptu 8 | from rlkit.core.eval_util import create_stats_ordered_dict 9 | from rlkit.torch.torch_rl_algorithm import TorchTrainer 10 | 11 | 12 | class TD3Trainer(TorchTrainer): 13 | """ 14 | Twin Delayed Deep Deterministic policy gradients 15 | """ 16 | 17 | def __init__( 18 | self, 19 | policy, 20 | target_policy, 21 | qf1, 22 | qf2, 23 | target_qf1, 24 | target_qf2, 25 | target_policy_noise=0.2, 26 | target_policy_noise_clip=0.5, 27 | discount=0.99, 28 | reward_scale=1.0, 29 | policy_lr=3e-4, 30 | qf_lr=3e-4, 31 | policy_and_target_update_period=2, 32 | tau=0.005, 33 | qf_criterion=None, 34 | optimizer_class=optim.Adam, 35 | max_action=1., 36 | clip_norm=0., 37 | ): 38 | super().__init__() 39 | if qf_criterion is None: 40 | qf_criterion = nn.MSELoss() 41 | self.qf1 = qf1 42 | self.qf2 = qf2 43 | self.policy = policy 44 | self.target_policy = target_policy 45 | self.target_qf1 = target_qf1 46 | self.target_qf2 = target_qf2 47 | self.target_policy_noise = target_policy_noise 48 | self.target_policy_noise_clip = target_policy_noise_clip 49 | 50 | self.discount = discount 51 | self.reward_scale = reward_scale 52 | self.max_action = max_action 53 | self.clip_norm = clip_norm 54 | 55 | self.policy_and_target_update_period = policy_and_target_update_period 56 | self.tau = tau 57 | self.qf_criterion = qf_criterion 58 | 59 | self.qf1_optimizer = optimizer_class( 60 | self.qf1.parameters(), 61 | lr=qf_lr, 62 | ) 63 | self.qf2_optimizer = optimizer_class( 64 | self.qf2.parameters(), 65 | lr=qf_lr, 66 | ) 67 | self.policy_optimizer = optimizer_class( 68 | self.policy.parameters(), 69 | lr=policy_lr, 70 | ) 71 | 72 | self.eval_statistics = OrderedDict() 73 | self._n_train_steps_total = 0 74 | self._need_to_update_eval_statistics = True 75 | 76 | def train_from_torch(self, batch): 77 | rewards = batch['rewards'] 78 | terminals = batch['terminals'] 79 | obs = batch['observations'] 80 | actions = batch['actions'] 81 | next_obs = batch['next_observations'] 82 | """ 83 | Update QF 84 | """ 85 | with torch.no_grad(): 86 | next_actions = self.target_policy(next_obs) 87 | noise = ptu.randn(next_actions.shape) * self.target_policy_noise 88 | noise = torch.clamp(noise, -self.target_policy_noise_clip, self.target_policy_noise_clip) 89 | noisy_next_actions = torch.clamp(next_actions + noise, -self.max_action, self.max_action) 90 | 91 | target_q1_values = self.target_qf1(next_obs, noisy_next_actions) 92 | target_q2_values = self.target_qf2(next_obs, noisy_next_actions) 93 | target_q_values = torch.min(target_q1_values, target_q2_values) 94 | q_target = self.reward_scale * rewards + (1. - terminals) * self.discount * target_q_values 95 | 96 | q1_pred = self.qf1(obs, actions) 97 | bellman_errors_1 = (q1_pred - q_target)**2 98 | qf1_loss = bellman_errors_1.mean() 99 | 100 | q2_pred = self.qf2(obs, actions) 101 | bellman_errors_2 = (q2_pred - q_target)**2 102 | qf2_loss = bellman_errors_2.mean() 103 | 104 | self.qf1_optimizer.zero_grad() 105 | qf1_loss.backward() 106 | self.qf1_optimizer.step() 107 | 108 | self.qf2_optimizer.zero_grad() 109 | qf2_loss.backward() 110 | self.qf2_optimizer.step() 111 | """ 112 | Update Policy 113 | """ 114 | 115 | policy_actions = policy_loss = None 116 | if self._n_train_steps_total % self.policy_and_target_update_period == 0: 117 | policy_actions = self.policy(obs) 118 | q_output = self.qf1(obs, policy_actions) 119 | policy_loss = -q_output.mean() 120 | 121 | self.policy_optimizer.zero_grad() 122 | policy_loss.backward() 123 | policy_grad = ptu.fast_clip_grad_norm(self.policy.parameters(), self.clip_norm) 124 | self.policy_optimizer.step() 125 | 126 | ptu.soft_update_from_to(self.policy, self.target_policy, self.tau) 127 | ptu.soft_update_from_to(self.qf1, self.target_qf1, self.tau) 128 | ptu.soft_update_from_to(self.qf2, self.target_qf2, self.tau) 129 | 130 | if self._need_to_update_eval_statistics: 131 | self._need_to_update_eval_statistics = False 132 | if policy_loss is None: 133 | policy_actions = self.policy(obs) 134 | q_output = self.qf1(obs, policy_actions) 135 | policy_loss = -q_output.mean() 136 | 137 | self.eval_statistics['QF1 Loss'] = qf1_loss.item() 138 | self.eval_statistics['QF2 Loss'] = qf2_loss.item() 139 | self.eval_statistics['Policy Loss'] = policy_loss.item() 140 | self.eval_statistics['Policy Grad'] = policy_grad 141 | self.eval_statistics.update(create_stats_ordered_dict( 142 | 'Q1 Predictions', 143 | ptu.get_numpy(q1_pred), 144 | )) 145 | self.eval_statistics.update(create_stats_ordered_dict( 146 | 'Q2 Predictions', 147 | ptu.get_numpy(q2_pred), 148 | )) 149 | self.eval_statistics.update(create_stats_ordered_dict( 150 | 'Q Targets', 151 | ptu.get_numpy(q_target), 152 | )) 153 | self.eval_statistics.update(create_stats_ordered_dict( 154 | 'Bellman Errors 1', 155 | ptu.get_numpy(bellman_errors_1), 156 | )) 157 | self.eval_statistics.update(create_stats_ordered_dict( 158 | 'Bellman Errors 2', 159 | ptu.get_numpy(bellman_errors_2), 160 | )) 161 | self.eval_statistics.update(create_stats_ordered_dict( 162 | 'Policy Action', 163 | ptu.get_numpy(policy_actions), 164 | )) 165 | self._n_train_steps_total += 1 166 | 167 | def get_diagnostics(self): 168 | return self.eval_statistics 169 | 170 | def end_epoch(self, epoch): 171 | self._need_to_update_eval_statistics = True 172 | 173 | @property 174 | def networks(self): 175 | return [ 176 | self.policy, 177 | self.qf1, 178 | self.qf2, 179 | self.target_policy, 180 | self.target_qf1, 181 | self.target_qf2, 182 | ] 183 | 184 | def get_snapshot(self): 185 | return dict( 186 | qf1=self.qf1.state_dict(), 187 | qf2=self.qf2.state_dict(), 188 | policy=self.policy.state_dict(), 189 | target_policy=self.target_policy.state_dict(), 190 | ) 191 | -------------------------------------------------------------------------------- /rlkit/envs/wrappers.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import itertools 3 | from gym import Env, Wrapper 4 | from gym.spaces import Box 5 | from gym.spaces import Discrete 6 | from collections import deque 7 | 8 | 9 | class ProxyEnv(Env): 10 | 11 | def __init__(self, wrapped_env): 12 | self._wrapped_env = wrapped_env 13 | self.action_space = self._wrapped_env.action_space 14 | self.observation_space = self._wrapped_env.observation_space 15 | 16 | @property 17 | def wrapped_env(self): 18 | return self._wrapped_env 19 | 20 | def reset(self, **kwargs): 21 | return self._wrapped_env.reset(**kwargs) 22 | 23 | def step(self, action): 24 | return self._wrapped_env.step(action) 25 | 26 | def render(self, *args, **kwargs): 27 | return self._wrapped_env.render(*args, **kwargs) 28 | 29 | @property 30 | def horizon(self): 31 | return self._wrapped_env.horizon 32 | 33 | def terminate(self): 34 | if hasattr(self.wrapped_env, "terminate"): 35 | self.wrapped_env.terminate() 36 | 37 | def __getattr__(self, attr): 38 | if attr == '_wrapped_env': 39 | raise AttributeError() 40 | return getattr(self._wrapped_env, attr) 41 | 42 | def __getstate__(self): 43 | """ 44 | This is useful to override in case the wrapped env has some funky 45 | __getstate__ that doesn't play well with overriding __getattr__. 46 | 47 | The main problematic case is/was gym's EzPickle serialization scheme. 48 | :return: 49 | """ 50 | return self.__dict__ 51 | 52 | def __setstate__(self, state): 53 | self.__dict__.update(state) 54 | 55 | def __str__(self): 56 | return '{}({})'.format(type(self).__name__, self.wrapped_env) 57 | 58 | 59 | class HistoryEnv(ProxyEnv, Env): 60 | 61 | def __init__(self, wrapped_env, history_len): 62 | super().__init__(wrapped_env) 63 | self.history_len = history_len 64 | 65 | high = np.inf * np.ones(self.history_len * self.observation_space.low.size) 66 | low = -high 67 | self.observation_space = Box( 68 | low=low, 69 | high=high, 70 | ) 71 | self.history = deque(maxlen=self.history_len) 72 | 73 | def step(self, action): 74 | state, reward, done, info = super().step(action) 75 | self.history.append(state) 76 | flattened_history = self._get_history().flatten() 77 | return flattened_history, reward, done, info 78 | 79 | def reset(self, **kwargs): 80 | state = super().reset() 81 | self.history = deque(maxlen=self.history_len) 82 | self.history.append(state) 83 | flattened_history = self._get_history().flatten() 84 | return flattened_history 85 | 86 | def _get_history(self): 87 | observations = list(self.history) 88 | 89 | obs_count = len(observations) 90 | for _ in range(self.history_len - obs_count): 91 | dummy = np.zeros(self._wrapped_env.observation_space.low.size) 92 | observations.append(dummy) 93 | return np.c_[observations] 94 | 95 | 96 | class DiscretizeEnv(ProxyEnv, Env): 97 | 98 | def __init__(self, wrapped_env, num_bins): 99 | super().__init__(wrapped_env) 100 | low = self.wrapped_env.action_space.low 101 | high = self.wrapped_env.action_space.high 102 | action_ranges = [np.linspace(low[i], high[i], num_bins) for i in range(len(low))] 103 | self.idx_to_continuous_action = [np.array(x) for x in itertools.product(*action_ranges)] 104 | self.action_space = Discrete(len(self.idx_to_continuous_action)) 105 | 106 | def step(self, action): 107 | continuous_action = self.idx_to_continuous_action[action] 108 | return super().step(continuous_action) 109 | 110 | 111 | class NormalizedBoxEnv(Wrapper): 112 | """ 113 | Normalize action to in [-1, 1]. 114 | 115 | Optionally normalize observations and scale reward. 116 | """ 117 | 118 | def __init__( 119 | self, 120 | env, 121 | reward_scale=1., 122 | obs_mean=None, 123 | obs_std=None, 124 | ): 125 | super().__init__(env) 126 | self._should_normalize = not (obs_mean is None and obs_std is None) 127 | if self._should_normalize: 128 | if obs_mean is None: 129 | obs_mean = np.zeros_like(env.observation_space.low) 130 | else: 131 | obs_mean = np.array(obs_mean) 132 | if obs_std is None: 133 | obs_std = np.ones_like(env.observation_space.low) 134 | else: 135 | obs_std = np.array(obs_std) 136 | self._reward_scale = reward_scale 137 | self._obs_mean = obs_mean 138 | self._obs_std = obs_std 139 | ub = np.ones(self.env.action_space.shape) 140 | self.action_space = Box(-1 * ub, ub) 141 | 142 | def estimate_obs_stats(self, obs_batch, override_values=False): 143 | if self._obs_mean is not None and not override_values: 144 | raise Exception("Observation mean and std already set. To " "override, set override_values to True.") 145 | self._obs_mean = np.mean(obs_batch, axis=0) 146 | self._obs_std = np.std(obs_batch, axis=0) 147 | self._should_normalize = True 148 | 149 | def _apply_normalize_obs(self, obs): 150 | return (obs - self._obs_mean) / (self._obs_std + 1e-8) 151 | 152 | def step(self, action): 153 | lb = self.env.action_space.low 154 | ub = self.env.action_space.high 155 | scaled_action = lb + (action + 1.) * 0.5 * (ub - lb) 156 | scaled_action = np.clip(scaled_action, lb, ub) 157 | 158 | wrapped_step = self.env.step(scaled_action) 159 | next_obs, reward, done, info = wrapped_step 160 | if self._should_normalize: 161 | next_obs = self._apply_normalize_obs(next_obs) 162 | return next_obs, reward * self._reward_scale, done, info 163 | 164 | def __str__(self): 165 | return "Normalized: %s" % self.env 166 | 167 | 168 | class CustomInfoEnv(Wrapper): 169 | 170 | def __init__(self, wrapped_env): 171 | 172 | env_id = wrapped_env.spec.id 173 | if env_id in [ 174 | "Walker2d-v2", # mujoco 175 | "Hopper-v2", 176 | "Ant-v2", 177 | "HalfCheetah-v2", 178 | "Humanoid-v2", 179 | "HumanoidStandup-v2", 180 | "Walker2d-v3", # mujoco 181 | "Hopper-v3", 182 | "Ant-v3", 183 | "HalfCheetah-v3", 184 | "Humanoid-v3", 185 | ]: 186 | self.env_type = "mujoco" 187 | elif env_id in [ 188 | "LunarLanderContinuous-v2", 189 | "BipedalWalker-v3", 190 | "BipedalWalkerHardcore-v3", 191 | ]: 192 | self.env_type = "box2d" 193 | 194 | super().__init__(wrapped_env) 195 | 196 | def step(self, action): 197 | state, reward, done, info = self.env.step(action) 198 | if self.env_type == "mujoco": 199 | custom_info = {'failed': done} 200 | if self.env_type == "box2d": 201 | custom_info = {'failed': reward <= -100} 202 | return state, reward, done, custom_info 203 | -------------------------------------------------------------------------------- /rlkit/util/hyperparameter.py: -------------------------------------------------------------------------------- 1 | """ 2 | Custom hyperparameter functions. 3 | """ 4 | import abc 5 | import copy 6 | import math 7 | import random 8 | import itertools 9 | from typing import List 10 | 11 | import rlkit.pythonplusplus as ppp 12 | 13 | 14 | class Hyperparameter(metaclass=abc.ABCMeta): 15 | def __init__(self, name): 16 | self._name = name 17 | 18 | @property 19 | def name(self): 20 | return self._name 21 | 22 | 23 | class RandomHyperparameter(Hyperparameter): 24 | def __init__(self, name): 25 | super().__init__(name) 26 | self._last_value = None 27 | 28 | @abc.abstractmethod 29 | def generate_next_value(self): 30 | """Return a value for the hyperparameter""" 31 | return 32 | 33 | def generate(self): 34 | self._last_value = self.generate_next_value() 35 | return self._last_value 36 | 37 | 38 | class EnumParam(RandomHyperparameter): 39 | def __init__(self, name, possible_values): 40 | super().__init__(name) 41 | self.possible_values = possible_values 42 | 43 | def generate_next_value(self): 44 | return random.choice(self.possible_values) 45 | 46 | 47 | class LogFloatParam(RandomHyperparameter): 48 | """ 49 | Return something ranging from [min_value + offset, max_value + offset], 50 | distributed with a log. 51 | """ 52 | def __init__(self, name, min_value, max_value, *, offset=0): 53 | super(LogFloatParam, self).__init__(name) 54 | self._linear_float_param = LinearFloatParam("log_" + name, 55 | math.log(min_value), 56 | math.log(max_value)) 57 | self.offset = offset 58 | 59 | def generate_next_value(self): 60 | return math.e ** (self._linear_float_param.generate()) + self.offset 61 | 62 | 63 | class LinearFloatParam(RandomHyperparameter): 64 | def __init__(self, name, min_value, max_value): 65 | super(LinearFloatParam, self).__init__(name) 66 | self._min = min_value 67 | self._delta = max_value - min_value 68 | 69 | def generate_next_value(self): 70 | return random.random() * self._delta + self._min 71 | 72 | 73 | class LogIntParam(RandomHyperparameter): 74 | def __init__(self, name, min_value, max_value, *, offset=0): 75 | super().__init__(name) 76 | self._linear_float_param = LinearFloatParam("log_" + name, 77 | math.log(min_value), 78 | math.log(max_value)) 79 | self.offset = offset 80 | 81 | def generate_next_value(self): 82 | return int( 83 | math.e ** (self._linear_float_param.generate()) + self.offset 84 | ) 85 | 86 | 87 | class LinearIntParam(RandomHyperparameter): 88 | def __init__(self, name, min_value, max_value): 89 | super(LinearIntParam, self).__init__(name) 90 | self._min = min_value 91 | self._max = max_value 92 | 93 | def generate_next_value(self): 94 | return random.randint(self._min, self._max) 95 | 96 | 97 | class FixedParam(RandomHyperparameter): 98 | def __init__(self, name, value): 99 | super().__init__(name) 100 | self._value = value 101 | 102 | def generate_next_value(self): 103 | return self._value 104 | 105 | 106 | class Sweeper(object): 107 | pass 108 | 109 | 110 | class RandomHyperparameterSweeper(Sweeper): 111 | def __init__(self, hyperparameters=None, default_kwargs=None): 112 | if default_kwargs is None: 113 | default_kwargs = {} 114 | self._hyperparameters = hyperparameters or [] 115 | self._validate_hyperparameters() 116 | self._default_kwargs = default_kwargs 117 | 118 | def _validate_hyperparameters(self): 119 | names = set() 120 | for hp in self._hyperparameters: 121 | name = hp.name 122 | if name in names: 123 | raise Exception("Hyperparameter '{0}' already added.".format( 124 | name)) 125 | names.add(name) 126 | 127 | def set_default_parameters(self, default_kwargs): 128 | self._default_kwargs = default_kwargs 129 | 130 | def generate_random_hyperparameters(self): 131 | hyperparameters = {} 132 | for hp in self._hyperparameters: 133 | hyperparameters[hp.name] = hp.generate() 134 | hyperparameters = ppp.dot_map_dict_to_nested_dict(hyperparameters) 135 | return ppp.merge_recursive_dicts( 136 | hyperparameters, 137 | copy.deepcopy(self._default_kwargs), 138 | ignore_duplicate_keys_in_second_dict=True, 139 | ) 140 | 141 | def sweep_hyperparameters(self, function, num_configs): 142 | returned_value_and_params = [] 143 | for _ in range(num_configs): 144 | kwargs = self.generate_random_hyperparameters() 145 | score = function(**kwargs) 146 | returned_value_and_params.append((score, kwargs)) 147 | 148 | return returned_value_and_params 149 | 150 | 151 | class DeterministicHyperparameterSweeper(Sweeper): 152 | """ 153 | Do a grid search over hyperparameters based on a predefined set of 154 | hyperparameters. 155 | """ 156 | def __init__(self, hyperparameters, default_parameters=None): 157 | """ 158 | 159 | :param hyperparameters: A dictionary of the form 160 | ``` 161 | { 162 | 'hp_1': [value1, value2, value3], 163 | 'hp_2': [value1, value2, value3], 164 | ... 165 | } 166 | ``` 167 | This format is like the param_grid in SciKit-Learn: 168 | http://scikit-learn.org/stable/modules/grid_search.html#exhaustive-grid-search 169 | :param default_parameters: Default key-value pairs to add to the 170 | dictionary. 171 | """ 172 | self._hyperparameters = hyperparameters 173 | self._default_kwargs = default_parameters or {} 174 | named_hyperparameters = [] 175 | for name, values in self._hyperparameters.items(): 176 | named_hyperparameters.append( 177 | [(name, v) for v in values] 178 | ) 179 | self._hyperparameters_dicts = [ 180 | ppp.dot_map_dict_to_nested_dict(dict(tuple_list)) 181 | for tuple_list in itertools.product(*named_hyperparameters) 182 | ] 183 | 184 | def iterate_hyperparameters(self): 185 | """ 186 | Iterate over the hyperparameters in a grid-manner. 187 | 188 | :return: List of dictionaries. Each dictionary is a map from name to 189 | hyperpameter. 190 | """ 191 | return [ 192 | ppp.merge_recursive_dicts( 193 | hyperparameters, 194 | copy.deepcopy(self._default_kwargs), 195 | ignore_duplicate_keys_in_second_dict=True, 196 | ) 197 | for hyperparameters in self._hyperparameters_dicts 198 | ] 199 | 200 | 201 | # TODO(vpong): Test this 202 | class DeterministicSweeperCombiner(object): 203 | """ 204 | A simple wrapper to combiner multiple DeterministicHyperParameterSweeper's 205 | """ 206 | def __init__(self, sweepers: List[DeterministicHyperparameterSweeper]): 207 | self._sweepers = sweepers 208 | 209 | def iterate_list_of_hyperparameters(self): 210 | """ 211 | Usage: 212 | 213 | ``` 214 | sweeper1 = DeterministicHyperparameterSweeper(...) 215 | sweeper2 = DeterministicHyperparameterSweeper(...) 216 | combiner = DeterministicSweeperCombiner([sweeper1, sweeper2]) 217 | 218 | for params_1, params_2 in combiner.iterate_list_of_hyperparameters(): 219 | # param_1 = {...} 220 | # param_2 = {...} 221 | ``` 222 | :return: Generator of hyperparameters, in the same order as provided 223 | sweepers. 224 | """ 225 | return itertools.product( 226 | sweeper.iterate_hyperparameters() 227 | for sweeper in self._sweepers 228 | ) -------------------------------------------------------------------------------- /rlkit/torch/sac/sac.py: -------------------------------------------------------------------------------- 1 | import gtimer as gt 2 | from collections import OrderedDict 3 | 4 | import numpy as np 5 | import torch 6 | import torch.optim as optim 7 | from torch import nn as nn 8 | 9 | import rlkit.torch.pytorch_util as ptu 10 | from rlkit.core.eval_util import create_stats_ordered_dict 11 | from rlkit.torch.torch_rl_algorithm import TorchTrainer 12 | 13 | 14 | class SACTrainer(TorchTrainer): 15 | 16 | def __init__( 17 | self, 18 | env, 19 | policy, 20 | qf1, 21 | qf2, 22 | target_qf1, 23 | target_qf2, 24 | discount=0.99, 25 | reward_scale=1.0, 26 | alpha=1., 27 | policy_lr=1e-3, 28 | qf_lr=1e-3, 29 | optimizer_class=optim.Adam, 30 | soft_target_tau=1e-2, 31 | target_update_period=1, 32 | clip_norm=0., 33 | use_automatic_entropy_tuning=True, 34 | target_entropy=None, 35 | ): 36 | super().__init__() 37 | self.env = env 38 | self.policy = policy 39 | self.qf1 = qf1 40 | self.qf2 = qf2 41 | self.target_qf1 = target_qf1 42 | self.target_qf2 = target_qf2 43 | self.soft_target_tau = soft_target_tau 44 | self.target_update_period = target_update_period 45 | 46 | self.use_automatic_entropy_tuning = use_automatic_entropy_tuning 47 | if self.use_automatic_entropy_tuning: 48 | if target_entropy: 49 | self.target_entropy = target_entropy 50 | else: 51 | self.target_entropy = -np.prod(self.env.action_space.shape).item() # heuristic value from Tuomas 52 | self.log_alpha = ptu.zeros(1, requires_grad=True) 53 | self.alpha_optimizer = optimizer_class( 54 | [self.log_alpha], 55 | lr=policy_lr, 56 | ) 57 | else: 58 | self.alpha = alpha 59 | 60 | self.qf_criterion = nn.MSELoss() 61 | 62 | self.policy_optimizer = optimizer_class( 63 | self.policy.parameters(), 64 | lr=policy_lr, 65 | ) 66 | self.qf1_optimizer = optimizer_class( 67 | self.qf1.parameters(), 68 | lr=qf_lr, 69 | ) 70 | self.qf2_optimizer = optimizer_class( 71 | self.qf2.parameters(), 72 | lr=qf_lr, 73 | ) 74 | 75 | self.discount = discount 76 | self.reward_scale = reward_scale 77 | self.clip_norm = clip_norm 78 | self.eval_statistics = OrderedDict() 79 | self._n_train_steps_total = 0 80 | self._need_to_update_eval_statistics = True 81 | 82 | def train_from_torch(self, batch): 83 | rewards = batch['rewards'] 84 | terminals = batch['terminals'] 85 | obs = batch['observations'] 86 | actions = batch['actions'] 87 | next_obs = batch['next_observations'] 88 | gt.stamp('preback_start', unique=False) 89 | """ 90 | Update Alpha 91 | """ 92 | new_actions, policy_mean, policy_log_std, log_pi, *_ = self.policy( 93 | obs, 94 | reparameterize=True, 95 | return_log_prob=True, 96 | ) 97 | if self.use_automatic_entropy_tuning: 98 | alpha_loss = -(self.log_alpha * (log_pi + self.target_entropy).detach()).mean() 99 | self.alpha_optimizer.zero_grad() 100 | alpha_loss.backward() 101 | self.alpha_optimizer.step() 102 | alpha = self.log_alpha.exp() 103 | else: 104 | alpha_loss = 0 105 | alpha = self.alpha 106 | """ 107 | Update QF 108 | """ 109 | with torch.no_grad(): 110 | new_next_actions, _, _, new_log_pi, *_ = self.policy( 111 | next_obs, 112 | reparameterize=True, 113 | return_log_prob=True, 114 | ) 115 | target_q_values = torch.min( 116 | self.target_qf1(next_obs, new_next_actions), 117 | self.target_qf2(next_obs, new_next_actions), 118 | ) - alpha * new_log_pi 119 | 120 | q_target = self.reward_scale * rewards + (1. - terminals) * self.discount * target_q_values 121 | q1_pred = self.qf1(obs, actions) 122 | q2_pred = self.qf2(obs, actions) 123 | qf1_loss = self.qf_criterion(q1_pred, q_target) 124 | qf2_loss = self.qf_criterion(q2_pred, q_target) 125 | gt.stamp('preback_qf', unique=False) 126 | 127 | self.qf1_optimizer.zero_grad() 128 | qf1_loss.backward() 129 | self.qf1_optimizer.step() 130 | gt.stamp('backward_qf1', unique=False) 131 | 132 | self.qf2_optimizer.zero_grad() 133 | qf2_loss.backward() 134 | self.qf2_optimizer.step() 135 | gt.stamp('backward_qf2', unique=False) 136 | """ 137 | Update Policy 138 | """ 139 | q_new_actions = torch.min( 140 | self.qf1(obs, new_actions), 141 | self.qf2(obs, new_actions), 142 | ) 143 | policy_loss = (alpha * log_pi - q_new_actions).mean() 144 | gt.stamp('preback_policy', unique=False) 145 | 146 | self.policy_optimizer.zero_grad() 147 | policy_loss.backward() 148 | policy_grad = ptu.fast_clip_grad_norm(self.policy.parameters(), self.clip_norm) 149 | self.policy_optimizer.step() 150 | gt.stamp('backward_policy', unique=False) 151 | """ 152 | Soft Updates 153 | """ 154 | if self._n_train_steps_total % self.target_update_period == 0: 155 | ptu.soft_update_from_to(self.qf1, self.target_qf1, self.soft_target_tau) 156 | ptu.soft_update_from_to(self.qf2, self.target_qf2, self.soft_target_tau) 157 | """ 158 | Save some statistics for eval 159 | """ 160 | if self._need_to_update_eval_statistics: 161 | self._need_to_update_eval_statistics = False 162 | 163 | self.eval_statistics['QF1 Loss'] = qf1_loss.item() 164 | self.eval_statistics['QF2 Loss'] = qf2_loss.item() 165 | self.eval_statistics['Policy Loss'] = policy_loss.item() 166 | self.eval_statistics['Policy Grad'] = policy_grad 167 | self.eval_statistics.update(create_stats_ordered_dict( 168 | 'Q1 Predictions', 169 | ptu.get_numpy(q1_pred), 170 | )) 171 | self.eval_statistics.update(create_stats_ordered_dict( 172 | 'Q2 Predictions', 173 | ptu.get_numpy(q2_pred), 174 | )) 175 | self.eval_statistics.update(create_stats_ordered_dict( 176 | 'Q Targets', 177 | ptu.get_numpy(q_target), 178 | )) 179 | self.eval_statistics.update(create_stats_ordered_dict( 180 | 'Log Pis', 181 | ptu.get_numpy(log_pi), 182 | )) 183 | self.eval_statistics.update(create_stats_ordered_dict( 184 | 'Policy mu', 185 | ptu.get_numpy(policy_mean), 186 | )) 187 | self.eval_statistics.update(create_stats_ordered_dict( 188 | 'Policy log std', 189 | ptu.get_numpy(policy_log_std), 190 | )) 191 | if self.use_automatic_entropy_tuning: 192 | self.eval_statistics['Alpha'] = alpha.item() 193 | self.eval_statistics['Alpha Loss'] = alpha_loss.item() 194 | self._n_train_steps_total += 1 195 | 196 | def get_diagnostics(self): 197 | return self.eval_statistics 198 | 199 | def end_epoch(self, epoch): 200 | self._need_to_update_eval_statistics = True 201 | 202 | @property 203 | def networks(self): 204 | return [ 205 | self.policy, 206 | self.qf1, 207 | self.qf2, 208 | self.target_qf1, 209 | self.target_qf2, 210 | ] 211 | 212 | def get_snapshot(self): 213 | return dict( 214 | policy=self.policy.state_dict(), 215 | qf1=self.qf1.state_dict(), 216 | qf2=self.qf2.state_dict(), 217 | target_qf1=self.qf1.state_dict(), 218 | target_qf2=self.qf2.state_dict(), 219 | ) 220 | -------------------------------------------------------------------------------- /rlkit/samplers/data_collector/step_collector.py: -------------------------------------------------------------------------------- 1 | from collections import deque, OrderedDict 2 | 3 | import numpy as np 4 | 5 | from rlkit.core.eval_util import create_stats_ordered_dict 6 | from rlkit.data_management.path_builder import PathBuilder 7 | from rlkit.samplers.data_collector.base import StepCollector 8 | 9 | 10 | class MdpStepCollector(StepCollector): 11 | 12 | def __init__( 13 | self, 14 | env, 15 | policy, 16 | max_num_epoch_paths_saved=None, 17 | render=False, 18 | render_kwargs=None, 19 | ): 20 | if render_kwargs is None: 21 | render_kwargs = {} 22 | self._env = env 23 | self._policy = policy 24 | self._max_num_epoch_paths_saved = max_num_epoch_paths_saved 25 | self._epoch_paths = deque(maxlen=self._max_num_epoch_paths_saved) 26 | self._render = render 27 | self._render_kwargs = render_kwargs 28 | 29 | self._num_steps_total = 0 30 | self._num_paths_total = 0 31 | self._obs = None # cache variable 32 | 33 | def get_epoch_paths(self): 34 | return self._epoch_paths 35 | 36 | def end_epoch(self, epoch): 37 | self._epoch_paths = deque(maxlen=self._max_num_epoch_paths_saved) 38 | self._obs = None 39 | 40 | def get_diagnostics(self): 41 | path_lens = [len(path['actions']) for path in self._epoch_paths] 42 | stats = OrderedDict([ 43 | ('num steps total', self._num_steps_total), 44 | ('num paths total', self._num_paths_total), 45 | ]) 46 | stats.update(create_stats_ordered_dict( 47 | "path length", 48 | path_lens, 49 | always_show_all_stats=True, 50 | )) 51 | return stats 52 | 53 | def get_snapshot(self): 54 | return dict( 55 | env=self._env, 56 | policy=self._policy, 57 | ) 58 | 59 | def collect_new_steps( 60 | self, 61 | max_path_length, 62 | num_steps, 63 | discard_incomplete_paths, 64 | ): 65 | for _ in range(num_steps): 66 | self.collect_one_step(max_path_length, discard_incomplete_paths) 67 | 68 | def collect_one_step( 69 | self, 70 | max_path_length, 71 | discard_incomplete_paths, 72 | ): 73 | if self._obs is None: 74 | self._start_new_rollout() 75 | 76 | action, agent_info = self._policy.get_action(self._obs) 77 | next_ob, reward, terminal, env_info = (self._env.step(action)) 78 | if self._render: 79 | self._env.render(**self._render_kwargs) 80 | terminal = np.array([terminal]) 81 | reward = np.array([reward]) 82 | # store path obs 83 | self._current_path_builder.add_all( 84 | observations=self._obs, 85 | actions=action, 86 | rewards=reward, 87 | next_observations=next_ob, 88 | terminals=terminal, 89 | agent_infos=agent_info, 90 | env_infos=env_info, 91 | ) 92 | if terminal or len(self._current_path_builder) >= max_path_length: 93 | self._handle_rollout_ending(max_path_length, discard_incomplete_paths) 94 | self._start_new_rollout() 95 | else: 96 | self._obs = next_ob 97 | 98 | def _start_new_rollout(self): 99 | self._current_path_builder = PathBuilder() 100 | self._obs = self._env.reset() 101 | 102 | def _handle_rollout_ending(self, max_path_length, discard_incomplete_paths): 103 | if len(self._current_path_builder) > 0: 104 | path = self._current_path_builder.get_all_stacked() 105 | path_len = len(path['actions']) 106 | if (path_len != max_path_length and not path['terminals'][-1] and discard_incomplete_paths): 107 | return 108 | self._epoch_paths.append(path) 109 | self._num_paths_total += 1 110 | self._num_steps_total += path_len 111 | 112 | 113 | class GoalConditionedStepCollector(StepCollector): 114 | 115 | def __init__( 116 | self, 117 | env, 118 | policy, 119 | max_num_epoch_paths_saved=None, 120 | render=False, 121 | render_kwargs=None, 122 | observation_key='observation', 123 | desired_goal_key='desired_goal', 124 | ): 125 | if render_kwargs is None: 126 | render_kwargs = {} 127 | self._env = env 128 | self._policy = policy 129 | self._max_num_epoch_paths_saved = max_num_epoch_paths_saved 130 | self._epoch_paths = deque(maxlen=self._max_num_epoch_paths_saved) 131 | self._render = render 132 | self._render_kwargs = render_kwargs 133 | self._observation_key = observation_key 134 | self._desired_goal_key = desired_goal_key 135 | 136 | self._num_steps_total = 0 137 | self._num_paths_total = 0 138 | self._obs = None # cache variable 139 | 140 | def get_epoch_paths(self): 141 | return self._epoch_paths 142 | 143 | def end_epoch(self, epoch): 144 | self._epoch_paths = deque(maxlen=self._max_num_epoch_paths_saved) 145 | self._obs = None 146 | 147 | def get_diagnostics(self): 148 | path_lens = [len(path['actions']) for path in self._epoch_paths] 149 | stats = OrderedDict([ 150 | ('num steps total', self._num_steps_total), 151 | ('num paths total', self._num_paths_total), 152 | ]) 153 | stats.update(create_stats_ordered_dict( 154 | "path length", 155 | path_lens, 156 | always_show_all_stats=True, 157 | )) 158 | return stats 159 | 160 | def get_snapshot(self): 161 | return dict( 162 | env=self._env, 163 | policy=self._policy, 164 | observation_key=self._observation_key, 165 | desired_goal_key=self._desired_goal_key, 166 | ) 167 | 168 | def start_collection(self): 169 | self._start_new_rollout() 170 | 171 | def end_collection(self): 172 | epoch_paths = self.get_epoch_paths() 173 | return epoch_paths 174 | 175 | def collect_new_steps( 176 | self, 177 | max_path_length, 178 | num_steps, 179 | discard_incomplete_paths, 180 | ): 181 | for _ in range(num_steps): 182 | self.collect_one_step(max_path_length, discard_incomplete_paths) 183 | 184 | def collect_one_step( 185 | self, 186 | max_path_length, 187 | discard_incomplete_paths, 188 | ): 189 | if self._obs is None: 190 | self._start_new_rollout() 191 | 192 | new_obs = np.hstack(( 193 | self._obs[self._observation_key], 194 | self._obs[self._desired_goal_key], 195 | )) 196 | action, agent_info = self._policy.get_action(new_obs) 197 | next_ob, reward, terminal, env_info = (self._env.step(action)) 198 | if self._render: 199 | self._env.render(**self._render_kwargs) 200 | terminal = np.array([terminal]) 201 | reward = np.array([reward]) 202 | # store path obs 203 | self._current_path_builder.add_all( 204 | observations=self._obs, 205 | actions=action, 206 | rewards=reward, 207 | next_observations=next_ob, 208 | terminals=terminal, 209 | agent_infos=agent_info, 210 | env_infos=env_info, 211 | ) 212 | if terminal or len(self._current_path_builder) >= max_path_length: 213 | self._handle_rollout_ending(max_path_length, discard_incomplete_paths) 214 | self._start_new_rollout() 215 | else: 216 | self._obs = next_ob 217 | 218 | def _start_new_rollout(self): 219 | self._current_path_builder = PathBuilder() 220 | self._obs = self._env.reset() 221 | 222 | def _handle_rollout_ending(self, max_path_length, discard_incomplete_paths): 223 | if len(self._current_path_builder) > 0: 224 | path = self._current_path_builder.get_all_stacked() 225 | path_len = len(path['actions']) 226 | if (path_len != max_path_length and not path['terminals'][-1] and discard_incomplete_paths): 227 | return 228 | self._epoch_paths.append(path) 229 | self._num_paths_total += 1 230 | self._num_steps_total += path_len 231 | -------------------------------------------------------------------------------- /rlkit/envs/vecenv.py: -------------------------------------------------------------------------------- 1 | import gym 2 | import numpy as np 3 | from abc import ABC, abstractmethod 4 | from multiprocessing import Process, Pipe 5 | 6 | from rlkit.envs.env_utils import CloudpickleWrapper 7 | 8 | 9 | class BaseVectorEnv(ABC, gym.Wrapper): 10 | """Base class for vectorized environments wrapper. Usage: 11 | :: 12 | 13 | env_num = 8 14 | envs = VectorEnv([lambda: gym.make(task) for _ in range(env_num)]) 15 | assert len(envs) == env_num 16 | 17 | It accepts a list of environment generators. In other words, an environment 18 | generator ``efn`` of a specific task means that ``efn()`` returns the 19 | environment of the given task, for example, ``gym.make(task)``. 20 | 21 | All of the VectorEnv must inherit :class:`~tianshou.env.BaseVectorEnv`. 22 | Here are some other usages: 23 | :: 24 | 25 | envs.seed(2) # which is equal to the next line 26 | envs.seed([2, 3, 4, 5, 6, 7, 8, 9]) # set specific seed for each env 27 | obs = envs.reset() # reset all environments 28 | obs = envs.reset([0, 5, 7]) # reset 3 specific environments 29 | obs, rew, done, info = envs.step([1] * 8) # step synchronously 30 | envs.render() # render all environments 31 | envs.close() # close all environments 32 | """ 33 | 34 | def __init__(self, env_fns): 35 | self._env_fns = env_fns 36 | self.env_num = len(env_fns) 37 | 38 | def __len__(self): 39 | """Return len(self), which is the number of environments.""" 40 | return self.env_num 41 | 42 | @abstractmethod 43 | def reset(self, id=None): 44 | """Reset the state of all the environments and return initial 45 | observations if id is ``None``, otherwise reset the specific 46 | environments with given id, either an int or a list. 47 | """ 48 | pass 49 | 50 | @abstractmethod 51 | def step(self, action): 52 | """Run one timestep of all the environments’ dynamics. When the end of 53 | episode is reached, you are responsible for calling reset(id) to reset 54 | this environment’s state. 55 | 56 | Accept a batch of action and return a tuple (obs, rew, done, info). 57 | 58 | :param numpy.ndarray action: a batch of action provided by the agent. 59 | 60 | :return: A tuple including four items: 61 | 62 | * ``obs`` a numpy.ndarray, the agent's observation of current \ 63 | environments 64 | * ``rew`` a numpy.ndarray, the amount of rewards returned after \ 65 | previous actions 66 | * ``done`` a numpy.ndarray, whether these episodes have ended, in \ 67 | which case further step() calls will return undefined results 68 | * ``info`` a numpy.ndarray, contains auxiliary diagnostic \ 69 | information (helpful for debugging, and sometimes learning) 70 | """ 71 | pass 72 | 73 | @abstractmethod 74 | def seed(self, seed=None): 75 | """Set the seed for all environments. Accept ``None``, an int (which 76 | will extend ``i`` to ``[i, i + 1, i + 2, ...]``) or a list. 77 | """ 78 | pass 79 | 80 | @abstractmethod 81 | def render(self, **kwargs): 82 | """Render all of the environments.""" 83 | pass 84 | 85 | @abstractmethod 86 | def close(self): 87 | """Close all of the environments.""" 88 | pass 89 | 90 | 91 | class VectorEnv(BaseVectorEnv): 92 | """Dummy vectorized environment wrapper, implemented in for-loop. 93 | 94 | .. seealso:: 95 | 96 | Please refer to :class:`~tianshou.env.BaseVectorEnv` for more detailed 97 | explanation. 98 | """ 99 | 100 | def __init__(self, env_fns): 101 | super().__init__(env_fns) 102 | self.envs = [_() for _ in env_fns] 103 | self.observation_space = self.envs[0].observation_space 104 | self.action_space = self.envs[0].action_space 105 | 106 | def reset(self, id=None): 107 | if id is None: 108 | self._obs = np.stack([e.reset() for e in self.envs]) 109 | else: 110 | if np.isscalar(id): 111 | id = [id] 112 | for i in id: 113 | self._obs[i] = self.envs[i].reset() 114 | return self._obs 115 | 116 | def step(self, action): 117 | assert len(action) == self.env_num 118 | result = [e.step(a) for e, a in zip(self.envs, action)] 119 | self._obs, self._rew, self._done, self._info = zip(*result) 120 | self._obs = np.stack(self._obs) 121 | self._rew = np.stack(self._rew) 122 | self._done = np.stack(self._done) 123 | self._info = np.stack(self._info) 124 | return self._obs, self._rew, self._done, self._info 125 | 126 | def seed(self, seed=None): 127 | if np.isscalar(seed): 128 | seed = [seed + _ for _ in range(self.env_num)] 129 | elif seed is None: 130 | seed = [seed] * self.env_num 131 | result = [] 132 | for e, s in zip(self.envs, seed): 133 | if hasattr(e, 'seed'): 134 | result.append(e.seed(s)) 135 | return result 136 | 137 | def render(self, **kwargs): 138 | result = [] 139 | for e in self.envs: 140 | if hasattr(e, 'render'): 141 | result.append(e.render(**kwargs)) 142 | return result 143 | 144 | def close(self): 145 | return [e.close() for e in self.envs] 146 | 147 | 148 | def worker(parent, p, env_fn_wrapper): 149 | parent.close() 150 | env = env_fn_wrapper.data() 151 | try: 152 | while True: 153 | cmd, data = p.recv() 154 | if cmd == 'step': 155 | p.send(env.step(data)) 156 | elif cmd == 'reset': 157 | p.send(env.reset()) 158 | elif cmd == 'close': 159 | p.send(env.close()) 160 | p.close() 161 | break 162 | elif cmd == 'render': 163 | p.send(env.render(**data) if hasattr(env, 'render') else None) 164 | elif cmd == 'seed': 165 | p.send(env.seed(data) if hasattr(env, 'seed') else None) 166 | else: 167 | p.close() 168 | raise NotImplementedError 169 | except KeyboardInterrupt: 170 | p.close() 171 | 172 | 173 | class SubprocVectorEnv(BaseVectorEnv): 174 | """Vectorized environment wrapper based on subprocess. 175 | 176 | .. seealso:: 177 | 178 | Please refer to :class:`~tianshou.env.BaseVectorEnv` for more detailed 179 | explanation. 180 | """ 181 | 182 | def __init__(self, env_fns): 183 | dummy_env = env_fns[0]() 184 | self.observation_space = dummy_env.observation_space 185 | self.action_space = dummy_env.action_space 186 | super().__init__(env_fns) 187 | self.closed = False 188 | self.parent_remote, self.child_remote = \ 189 | zip(*[Pipe() for _ in range(self.env_num)]) 190 | self.processes = [ 191 | Process(target=worker, args=(parent, child, CloudpickleWrapper(env_fn)), daemon=True) 192 | for (parent, child, env_fn) in zip(self.parent_remote, self.child_remote, env_fns) 193 | ] 194 | for p in self.processes: 195 | p.start() 196 | for c in self.child_remote: 197 | c.close() 198 | 199 | def step(self, action): 200 | assert len(action) == self.env_num 201 | for p, a in zip(self.parent_remote, action): 202 | p.send(['step', a]) 203 | result = [p.recv() for p in self.parent_remote] 204 | self._obs, self._rew, self._done, self._info = zip(*result) 205 | self._obs = np.stack(self._obs) 206 | self._rew = np.stack(self._rew) 207 | self._done = np.stack(self._done) 208 | self._info = np.stack(self._info) 209 | return self._obs, self._rew, self._done, self._info 210 | 211 | def reset(self, id=None): 212 | if id is None: 213 | for p in self.parent_remote: 214 | p.send(['reset', None]) 215 | self._obs = np.stack([p.recv() for p in self.parent_remote]) 216 | return self._obs 217 | else: 218 | if np.isscalar(id): 219 | id = [id] 220 | for i in id: 221 | self.parent_remote[i].send(['reset', None]) 222 | for i in id: 223 | self._obs[i] = self.parent_remote[i].recv() 224 | return self._obs 225 | 226 | def seed(self, seed=None): 227 | if np.isscalar(seed): 228 | seed = [seed + _ for _ in range(self.env_num)] 229 | elif seed is None: 230 | seed = [seed] * self.env_num 231 | for p, s in zip(self.parent_remote, seed): 232 | p.send(['seed', s]) 233 | return [p.recv() for p in self.parent_remote] 234 | 235 | def render(self, **kwargs): 236 | for p in self.parent_remote: 237 | p.send(['render', kwargs]) 238 | return [p.recv() for p in self.parent_remote] 239 | 240 | def close(self): 241 | if self.closed: 242 | return 243 | for p in self.parent_remote: 244 | p.send(['close', None]) 245 | result = [p.recv() for p in self.parent_remote] 246 | self.closed = True 247 | for p in self.processes: 248 | p.join() 249 | return result 250 | --------------------------------------------------------------------------------