├── mbpo ├── __init__.py ├── models │ ├── __init__.py │ ├── constructor.py │ ├── utils.py │ └── fake_env.py ├── algorithms │ └── __init__.py ├── scripts │ ├── __init__.py │ └── console_scripts.py ├── utils │ ├── __init__.py │ ├── filesystem.py │ ├── writer.py │ ├── visualization.py │ └── logging.py ├── static │ ├── halfcheetah.py │ ├── humanoid_truncated_obs.py │ ├── inverted_pendulum.py │ ├── ant_truncated_obs.py │ ├── walker2d.py │ ├── hopper.py │ ├── inverted_double_pendulum.py │ └── __init__.py └── env │ ├── __init__.py │ ├── ant.py │ └── humanoid.py ├── examples ├── __init__.py ├── config │ ├── walker2d │ │ └── 0.py │ ├── custom │ │ └── 0.py │ ├── halfcheetah │ │ └── 0.py │ ├── hopper │ │ └── 0.py │ ├── ant │ │ └── 0.py │ ├── humanoid │ │ └── 0.py │ ├── inverted_pendulum │ │ └── 0.py │ ├── inverted_double_pendulum │ │ └── 0.py │ └── __init__.py └── development │ ├── __init__.py │ ├── simulate_policy.py │ ├── base.py │ └── main.py ├── softlearning ├── __init__.py ├── misc │ ├── __init__.py │ ├── plotter.py │ ├── kernel.py │ └── utils.py ├── models │ ├── __init__.py │ ├── utils.py │ └── feedforward.py ├── policies │ ├── __init__.py │ ├── utils.py │ ├── uniform_policy.py │ ├── base_policy.py │ └── gaussian_policy.py ├── scripts │ ├── __init__.py │ └── console_scripts.py ├── environments │ ├── __init__.py │ ├── gym │ │ ├── mujoco │ │ │ ├── __init__.py │ │ │ └── image_pusher_2d.py │ │ ├── robotics │ │ │ └── __init__.py │ │ ├── wrappers │ │ │ ├── __init__.py │ │ │ └── normalize_action.py │ │ ├── __init__.py │ │ └── multi_goal.py │ ├── adapters │ │ ├── __init__.py │ │ ├── gym_adapter.py │ │ └── softlearning_env.py │ ├── dm_control │ │ └── __init__.py │ ├── helpers.py │ └── utils.py ├── preprocessors │ ├── __init__.py │ ├── utils.py │ └── convnet.py ├── value_functions │ ├── __init__.py │ ├── vanilla.py │ ├── utils.py │ └── value_function.py ├── algorithms │ ├── __init__.py │ └── utils.py ├── distributions │ ├── __init__.py │ └── squash_bijector.py ├── utils │ ├── numpy.py │ └── keras.py ├── replay_pools │ ├── __init__.py │ ├── extra_policy_info_replay_pool.py │ ├── utils.py │ ├── replay_pool.py │ ├── union_pool.py │ ├── flexible_replay_pool.py │ ├── trajectory_replay_pool.py │ └── simple_replay_pool.py ├── samplers │ ├── __init__.py │ ├── dummy_sampler.py │ ├── base_sampler.py │ ├── extra_policy_info_sampler.py │ ├── utils.py │ ├── simple_sampler.py │ ├── explore_sampler.py │ └── remote_sampler.py └── softlearning.md ├── .gitmodules ├── environment ├── gpu-env.yml └── requirements.txt ├── setup.py ├── LICENSE ├── .gitignore └── README.md /mbpo/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /examples/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /mbpo/models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /mbpo/algorithms/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /mbpo/scripts/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /softlearning/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /softlearning/misc/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /softlearning/models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /softlearning/policies/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /softlearning/scripts/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /softlearning/environments/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /softlearning/preprocessors/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /softlearning/value_functions/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /softlearning/environments/gym/mujoco/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /softlearning/environments/gym/robotics/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /mbpo/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # from .filesystem import * 2 | # from .launcher import * -------------------------------------------------------------------------------- /softlearning/algorithms/__init__.py: -------------------------------------------------------------------------------- 1 | from .sql import SQL 2 | from .sac import SAC 3 | -------------------------------------------------------------------------------- /softlearning/distributions/__init__.py: -------------------------------------------------------------------------------- 1 | from .real_nvp_flow import ConditionalRealNVPFlow 2 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "viskit"] 2 | path = viskit 3 | url = https://github.com/vitchyr/viskit.git 4 | -------------------------------------------------------------------------------- /softlearning/environments/gym/wrappers/__init__.py: -------------------------------------------------------------------------------- 1 | from .normalize_action import NormalizeActionWrapper 2 | -------------------------------------------------------------------------------- /mbpo/utils/filesystem.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | def mkdir(path): 4 | if not os.path.exists(path): 5 | os.mkdir(path) -------------------------------------------------------------------------------- /softlearning/environments/adapters/__init__.py: -------------------------------------------------------------------------------- 1 | """Module that provides adapters between SoftlearningEnv and other universes""" 2 | -------------------------------------------------------------------------------- /softlearning/utils/numpy.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def softmax(x): 5 | max_x = np.max(x) 6 | exp_x = np.exp(x - max_x) 7 | return exp_x / np.sum(exp_x) 8 | -------------------------------------------------------------------------------- /softlearning/replay_pools/__init__.py: -------------------------------------------------------------------------------- 1 | from .simple_replay_pool import SimpleReplayPool 2 | from .extra_policy_info_replay_pool import ExtraPolicyInfoReplayPool 3 | from .union_pool import UnionPool 4 | from .trajectory_replay_pool import TrajectoryReplayPool 5 | -------------------------------------------------------------------------------- /softlearning/environments/dm_control/__init__.py: -------------------------------------------------------------------------------- 1 | """Custom DeepMind Control Suite environments. 2 | 3 | Every class inside this module should extend a dm_control.suite.Task class. The 4 | # file structure should be similar to dm_control's file structure. 5 | """ 6 | -------------------------------------------------------------------------------- /softlearning/samplers/__init__.py: -------------------------------------------------------------------------------- 1 | from .base_sampler import BaseSampler 2 | from .dummy_sampler import DummySampler 3 | from .simple_sampler import SimpleSampler 4 | from .remote_sampler import RemoteSampler 5 | from .extra_policy_info_sampler import ExtraPolicyInfoSampler 6 | from .utils import rollout, rollouts 7 | -------------------------------------------------------------------------------- /environment/gpu-env.yml: -------------------------------------------------------------------------------- 1 | ## copy of https://github.com/rail-berkeley/softlearning/blob/master/environment.yml 2 | 3 | name: mbpo 4 | channels: 5 | - defaults 6 | - conda-forge 7 | dependencies: 8 | - python=3.6.5 9 | - pip>=18.0 10 | - conda>=4.5.9 11 | - patchelf=0.9 12 | - pip: 13 | - -r requirements.txt 14 | -------------------------------------------------------------------------------- /mbpo/static/halfcheetah.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | class StaticFns: 4 | 5 | @staticmethod 6 | def termination_fn(obs, act, next_obs): 7 | assert len(obs.shape) == len(next_obs.shape) == len(act.shape) == 2 8 | 9 | done = np.array([False]).repeat(len(obs)) 10 | done = done[:,None] 11 | return done 12 | -------------------------------------------------------------------------------- /mbpo/scripts/console_scripts.py: -------------------------------------------------------------------------------- 1 | import click 2 | 3 | @click.group(invoke_without_command=True) 4 | @click.argument("data_path", required=True, nargs=1) 5 | @click.argument('kwargs', nargs=-1) 6 | def cli(data_path, kwargs): 7 | from viskit.frontend import main 8 | main() 9 | 10 | 11 | def main(): 12 | return cli() 13 | 14 | if __name__ == "__main__": 15 | main() 16 | -------------------------------------------------------------------------------- /softlearning/environments/helpers.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def random_point_in_circle(angle_range=(0, 2*np.pi), radius=(0, 25)): 5 | angle = np.random.uniform(*angle_range) 6 | radius = radius if np.isscalar(radius) else np.random.uniform(*radius) 7 | x, y = np.cos(angle) * radius, np.sin(angle) * radius 8 | point = np.array([x, y]) 9 | return point 10 | -------------------------------------------------------------------------------- /mbpo/static/humanoid_truncated_obs.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import numpy as np 3 | import pdb 4 | 5 | class StaticFns: 6 | 7 | @staticmethod 8 | def termination_fn(obs, act, next_obs): 9 | assert len(obs.shape) == len(next_obs.shape) == len(act.shape) == 2 10 | 11 | z = next_obs[:,0] 12 | done = (z < 1.0) + (z > 2.0) 13 | 14 | done = done[:,None] 15 | return done -------------------------------------------------------------------------------- /softlearning/samplers/dummy_sampler.py: -------------------------------------------------------------------------------- 1 | from .base_sampler import BaseSampler 2 | 3 | 4 | class DummySampler(BaseSampler): 5 | def __init__(self, batch_size, max_path_length): 6 | super(DummySampler, self).__init__( 7 | max_path_length=max_path_length, 8 | min_pool_size=0, 9 | batch_size=batch_size) 10 | 11 | def sample(self): 12 | pass 13 | -------------------------------------------------------------------------------- /mbpo/static/inverted_pendulum.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import numpy as np 3 | import pdb 4 | 5 | class StaticFns: 6 | 7 | @staticmethod 8 | def termination_fn(obs, act, next_obs): 9 | assert len(obs.shape) == len(next_obs.shape) == len(act.shape) == 2 10 | 11 | notdone = np.isfinite(next_obs).all(axis=-1) \ 12 | * (np.abs(next_obs[:,1]) <= .2) 13 | done = ~notdone 14 | 15 | done = done[:,None] 16 | 17 | return done -------------------------------------------------------------------------------- /mbpo/static/ant_truncated_obs.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | class StaticFns: 4 | 5 | @staticmethod 6 | def termination_fn(obs, act, next_obs): 7 | assert len(obs.shape) == len(next_obs.shape) == len(act.shape) == 2 8 | 9 | x = next_obs[:, 0] 10 | not_done = np.isfinite(next_obs).all(axis=-1) \ 11 | * (x >= 0.2) \ 12 | * (x <= 1.0) 13 | 14 | done = ~not_done 15 | done = done[:,None] 16 | return done 17 | -------------------------------------------------------------------------------- /mbpo/static/walker2d.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | class StaticFns: 4 | 5 | @staticmethod 6 | def termination_fn(obs, act, next_obs): 7 | assert len(obs.shape) == len(next_obs.shape) == len(act.shape) == 2 8 | 9 | height = next_obs[:, 0] 10 | angle = next_obs[:, 1] 11 | not_done = (height > 0.8) \ 12 | * (height < 2.0) \ 13 | * (angle > -1.0) \ 14 | * (angle < 1.0) 15 | done = ~not_done 16 | done = done[:,None] 17 | return done 18 | -------------------------------------------------------------------------------- /mbpo/static/hopper.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | class StaticFns: 4 | 5 | @staticmethod 6 | def termination_fn(obs, act, next_obs): 7 | assert len(obs.shape) == len(next_obs.shape) == len(act.shape) == 2 8 | 9 | height = next_obs[:, 0] 10 | angle = next_obs[:, 1] 11 | not_done = np.isfinite(next_obs).all(axis=-1) \ 12 | * np.abs(next_obs[:,1:] < 100).all(axis=-1) \ 13 | * (height > .7) \ 14 | * (np.abs(angle) < .2) 15 | 16 | done = ~not_done 17 | done = done[:,None] 18 | return done 19 | -------------------------------------------------------------------------------- /mbpo/static/inverted_double_pendulum.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import numpy as np 3 | import pdb 4 | 5 | class StaticFns: 6 | 7 | @staticmethod 8 | def termination_fn(obs, act, next_obs): 9 | assert len(obs.shape) == len(next_obs.shape) == len(act.shape) == 2 10 | 11 | sin1, cos1 = next_obs[:,1], next_obs[:,3] 12 | sin2, cos2 = next_obs[:,2], next_obs[:,4] 13 | theta_1 = np.arctan2(sin1, cos1) 14 | theta_2 = np.arctan2(sin2, cos2) 15 | y = 0.6 * (cos1 + np.cos(theta_1 + theta_2)) 16 | 17 | done = y <= 1 18 | 19 | done = done[:,None] 20 | return done -------------------------------------------------------------------------------- /mbpo/env/__init__.py: -------------------------------------------------------------------------------- 1 | import gym 2 | 3 | MBPO_ENVIRONMENT_SPECS = ( 4 | { 5 | 'id': 'AntTruncatedObs-v2', 6 | 'entry_point': (f'mbpo.env.ant:AntTruncatedObsEnv'), 7 | }, 8 | { 9 | 'id': 'HumanoidTruncatedObs-v2', 10 | 'entry_point': (f'mbpo.env.humanoid:HumanoidTruncatedObsEnv'), 11 | }, 12 | ) 13 | 14 | def register_mbpo_environments(): 15 | for mbpo_environment in MBPO_ENVIRONMENT_SPECS: 16 | gym.register(**mbpo_environment) 17 | 18 | gym_ids = tuple( 19 | environment_spec['id'] 20 | for environment_spec in MBPO_ENVIRONMENT_SPECS) 21 | 22 | return gym_ids -------------------------------------------------------------------------------- /softlearning/models/utils.py: -------------------------------------------------------------------------------- 1 | def build_metric_learner_from_variant(variant, env, evaluation_data): 2 | sampler_params = variant['sampler_params'] 3 | metric_learner_params = variant['metric_learner_params'] 4 | metric_learner_params.update({ 5 | 'observation_shape': env.observation_space.shape, 6 | 'max_distance': sampler_params['kwargs']['max_path_length'], 7 | 'evaluation_data': evaluation_data 8 | }) 9 | 10 | metric_learner = MetricLearner(**metric_learner_params) 11 | return metric_learner 12 | 13 | 14 | def get_model_from_variant(variant, env, *args, **kwargs): 15 | pass 16 | -------------------------------------------------------------------------------- /softlearning/replay_pools/extra_policy_info_replay_pool.py: -------------------------------------------------------------------------------- 1 | from .simple_replay_pool import SimpleReplayPool 2 | 3 | 4 | class ExtraPolicyInfoReplayPool(SimpleReplayPool): 5 | def __init__(self, *args, **kwargs): 6 | super(ExtraPolicyInfoReplayPool, self).__init__(*args, **kwargs) 7 | 8 | fields = { 9 | 'raw_actions': { 10 | 'shape': self._action_space.shape, 11 | 'dtype': 'float32' 12 | }, 13 | 'log_pis': { 14 | 'shape': (1, ), 15 | 'dtype': 'float32' 16 | } 17 | } 18 | 19 | self.add_fields(fields) 20 | -------------------------------------------------------------------------------- /softlearning/distributions/squash_bijector.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | import tensorflow_probability as tfp 4 | 5 | 6 | class SquashBijector(tfp.bijectors.Bijector): 7 | def __init__(self, validate_args=False, name="tanh"): 8 | super(SquashBijector, self).__init__( 9 | forward_min_event_ndims=0, 10 | validate_args=validate_args, 11 | name=name) 12 | 13 | def _forward(self, x): 14 | return tf.nn.tanh(x) 15 | 16 | def _inverse(self, y): 17 | return tf.atanh(y) 18 | 19 | def _forward_log_det_jacobian(self, x): 20 | return 2. * (np.log(2.) - x - tf.nn.softplus(-2. * x)) 21 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from distutils.core import setup 2 | from setuptools import find_packages 3 | 4 | setup( 5 | name='mbpo', 6 | packages=find_packages(), 7 | version='0.0.1', 8 | description='Model-based policy optimization', 9 | long_description=open('./README.md').read(), 10 | author='Michael Janner', 11 | author_email='janner@berkeley.edu', 12 | url='https://people.eecs.berkeley.edu/~janner/mbpo/', 13 | entry_points={ 14 | 'console_scripts': ( 15 | 'mbpo=softlearning.scripts.console_scripts:main', 16 | 'viskit=mbpo.scripts.console_scripts:main' 17 | ) 18 | }, 19 | requires=(), 20 | zip_safe=True, 21 | license='MIT' 22 | ) 23 | -------------------------------------------------------------------------------- /mbpo/static/__init__.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | import importlib 4 | import pdb 5 | 6 | 7 | def import_fns(path, file, fns_name='StaticFns'): 8 | full_path = os.path.join(path, file) 9 | import_path = full_path.replace('/', '.') 10 | module = importlib.import_module(import_path) 11 | fns = getattr(module, fns_name) 12 | return fns 13 | 14 | cwd = 'mbpo/static' 15 | files = os.listdir(cwd) 16 | ## remove __init__.py 17 | files = filter(lambda x: '__' not in x, files) 18 | ## env.py --> env 19 | files = map(lambda x: x.replace('.py', ''), files) 20 | 21 | ## {env: StaticFns, ... } 22 | static_fns = {file.replace('_', ''): import_fns(cwd, file) for file in files} 23 | 24 | sys.modules[__name__] = static_fns 25 | 26 | -------------------------------------------------------------------------------- /softlearning/environments/utils.py: -------------------------------------------------------------------------------- 1 | from .adapters.gym_adapter import ( 2 | GYM_ENVIRONMENTS, 3 | GymAdapter, 4 | ) 5 | 6 | import pdb 7 | 8 | ENVIRONMENTS = { 9 | 'gym': GYM_ENVIRONMENTS, 10 | } 11 | 12 | ADAPTERS = { 13 | 'gym': GymAdapter, 14 | } 15 | 16 | 17 | def get_environment(universe, domain, task, environment_params): 18 | env = ADAPTERS[universe](domain, task, **environment_params) 19 | return env 20 | 21 | 22 | def get_environment_from_params(environment_params): 23 | universe = environment_params['universe'] 24 | task = environment_params['task'] 25 | domain = environment_params['domain'] 26 | environment_kwargs = environment_params.get('kwargs', {}).copy() 27 | 28 | return get_environment(universe, domain, task, environment_kwargs) 29 | -------------------------------------------------------------------------------- /softlearning/environments/gym/wrappers/normalize_action.py: -------------------------------------------------------------------------------- 1 | import gym 2 | from gym import spaces 3 | import numpy as np 4 | 5 | 6 | __all__ = ['NormalizeActionWrapper'] 7 | 8 | 9 | class NormalizeActionWrapper(gym.ActionWrapper): 10 | """Rescale the action space of the environment.""" 11 | 12 | def action(self, action): 13 | if not isinstance(self.env.action_space, spaces.Box): 14 | return action 15 | 16 | # rescale the action 17 | low, high = self.env.action_space.low, self.env.action_space.high 18 | scaled_action = low + (action + 1.0) * (high - low) / 2.0 19 | scaled_action = np.clip(scaled_action, low, high) 20 | 21 | return scaled_action 22 | 23 | def reverse_action(self, action): 24 | raise NotImplementedError 25 | 26 | 27 | normalize = NormalizeActionWrapper 28 | -------------------------------------------------------------------------------- /examples/config/walker2d/0.py: -------------------------------------------------------------------------------- 1 | params = { 2 | 'type': 'MBPO', 3 | 'universe': 'gym', 4 | 'domain': 'Walker2d', 5 | 'task': 'v2', 6 | 7 | 'log_dir': '~/ray_mbpo/', 8 | 'exp_name': 'defaults', 9 | 10 | 'kwargs': { 11 | 'epoch_length': 1000, 12 | 'train_every_n_steps': 1, 13 | 'n_train_repeat': 20, 14 | 'eval_render_mode': None, 15 | 'eval_n_episodes': 1, 16 | 'eval_deterministic': True, 17 | 18 | 'discount': 0.99, 19 | 'tau': 5e-3, 20 | 'reward_scale': 1.0, 21 | 22 | 'model_train_freq': 250, 23 | 'model_retain_epochs': 1, 24 | 'rollout_batch_size': 100e3, 25 | 'deterministic': False, 26 | 'num_networks': 7, 27 | 'num_elites': 5, 28 | 'real_ratio': 0.05, 29 | 'target_entropy': -3, 30 | 'max_model_t': None, 31 | 'rollout_schedule': [20, 150, 1, 1], 32 | } 33 | } -------------------------------------------------------------------------------- /examples/config/custom/0.py: -------------------------------------------------------------------------------- 1 | params = { 2 | 'type': 'MBPO', 3 | 'universe': TODO, 4 | 'domain': TODO, 5 | 'task': TODO, 6 | 7 | 'log_dir': '~/ray_mbpo/', 8 | 'exp_name': 'defaults', 9 | 10 | 'kwargs': { 11 | 'epoch_length': 1000, 12 | 'train_every_n_steps': 1, 13 | 'n_train_repeat': 20, 14 | 'eval_render_mode': None, 15 | 'eval_n_episodes': 1, 16 | 'eval_deterministic': True, 17 | 18 | 'discount': 0.99, 19 | 'tau': 5e-3, 20 | 'reward_scale': 1.0, 21 | 22 | 'model_train_freq': 250, 23 | 'model_retain_epochs': 1, 24 | 'rollout_batch_size': 100e3, 25 | 'deterministic': False, 26 | 'num_networks': 7, 27 | 'num_elites': 5, 28 | 'real_ratio': 0.05, 29 | 'target_entropy': -1, 30 | 'max_model_t': None, 31 | 'rollout_schedule': [20, 150, 1, 15], 32 | } 33 | } 34 | 35 | -------------------------------------------------------------------------------- /examples/config/halfcheetah/0.py: -------------------------------------------------------------------------------- 1 | params = { 2 | 'type': 'MBPO', 3 | 'universe': 'gym', 4 | 'domain': 'HalfCheetah', 5 | 'task': 'v2', 6 | 7 | 'log_dir': '~/ray_mbpo/', 8 | 'exp_name': 'defaults', 9 | 10 | 'kwargs': { 11 | 'epoch_length': 1000, 12 | 'train_every_n_steps': 1, 13 | 'n_train_repeat': 40, 14 | 'eval_render_mode': None, 15 | 'eval_n_episodes': 1, 16 | 'eval_deterministic': True, 17 | 18 | 'discount': 0.99, 19 | 'tau': 5e-3, 20 | 'reward_scale': 1.0, 21 | 22 | 'model_train_freq': 250, 23 | 'model_retain_epochs': 1, 24 | 'rollout_batch_size': 100e3, 25 | 'deterministic': False, 26 | 'num_networks': 7, 27 | 'num_elites': 5, 28 | 'real_ratio': 0.05, 29 | 'target_entropy': -3, 30 | 'max_model_t': None, 31 | 'rollout_schedule': [20, 150, 1, 1], 32 | } 33 | } -------------------------------------------------------------------------------- /examples/config/hopper/0.py: -------------------------------------------------------------------------------- 1 | params = { 2 | 'type': 'MBPO', 3 | 'universe': 'gym', 4 | 'domain': 'Hopper', 5 | 'task': 'v2', 6 | 7 | 'log_dir': '~/ray_mbpo/', 8 | 'exp_name': 'defaults', 9 | 10 | 'kwargs': { 11 | 'epoch_length': 1000, 12 | 'train_every_n_steps': 1, 13 | 'n_train_repeat': 20, 14 | 'eval_render_mode': None, 15 | 'eval_n_episodes': 1, 16 | 'eval_deterministic': True, 17 | 18 | 'discount': 0.99, 19 | 'tau': 5e-3, 20 | 'reward_scale': 1.0, 21 | 22 | 'model_train_freq': 250, 23 | 'model_retain_epochs': 1, 24 | 'rollout_batch_size': 100e3, 25 | 'deterministic': False, 26 | 'num_networks': 7, 27 | 'num_elites': 5, 28 | 'real_ratio': 0.05, 29 | 'target_entropy': -1, 30 | 'max_model_t': None, 31 | 'rollout_schedule': [20, 150, 1, 15], 32 | } 33 | } 34 | 35 | -------------------------------------------------------------------------------- /examples/config/ant/0.py: -------------------------------------------------------------------------------- 1 | params = { 2 | 'type': 'MBPO', 3 | 'universe': 'gym', 4 | 'domain': 'AntTruncatedObs', ## mbpo/env/ant.py 5 | 'task': 'v2', 6 | 7 | 'log_dir': '~/ray_mbpo/', 8 | 'exp_name': 'defaults', 9 | 10 | 'kwargs': { 11 | 'epoch_length': 1000, 12 | 'train_every_n_steps': 1, 13 | 'n_train_repeat': 20, 14 | 'eval_render_mode': None, 15 | 'eval_n_episodes': 1, 16 | 'eval_deterministic': True, 17 | 18 | 'discount': 0.99, 19 | 'tau': 5e-3, 20 | 'reward_scale': 1.0, 21 | 22 | 'model_train_freq': 250, 23 | 'model_retain_epochs': 1, 24 | 'rollout_batch_size': 100e3, 25 | 'deterministic': False, 26 | 'num_networks': 7, 27 | 'num_elites': 5, 28 | 'real_ratio': 0.05, 29 | 'target_entropy': -4, 30 | 'max_model_t': None, 31 | 'rollout_schedule': [20, 100, 1, 25], 32 | } 33 | } -------------------------------------------------------------------------------- /examples/config/humanoid/0.py: -------------------------------------------------------------------------------- 1 | params = { 2 | 'type': 'MBPO', 3 | 'universe': 'gym', 4 | 'domain': 'HumanoidTruncatedObs', ## mbpo/env/humanoid.py 5 | 'task': 'v2', 6 | 7 | 'log_dir': '~/ray_mbpo/', 8 | 'exp_name': 'defaults', 9 | 10 | 'kwargs': { 11 | 'epoch_length': 1000, 12 | 'train_every_n_steps': 1, 13 | 'n_train_repeat': 20, 14 | 'eval_render_mode': None, 15 | 'eval_n_episodes': 1, 16 | 'eval_deterministic': True, 17 | 18 | 'discount': 0.99, 19 | 'tau': 5e-3, 20 | 'reward_scale': 1.0, 21 | 22 | 'model_train_freq': 1000, 23 | 'model_retain_epochs': 5, 24 | 'rollout_batch_size': 100e3, 25 | 'deterministic': False, 26 | 'num_networks': 7, 27 | 'num_elites': 5, 28 | 'real_ratio': 0.05, 29 | 'target_entropy': -2, 30 | 'max_model_t': None, 31 | 'rollout_schedule': [20, 300, 1, 25], 32 | 'hidden_dim': 400, 33 | } 34 | } 35 | -------------------------------------------------------------------------------- /examples/config/inverted_pendulum/0.py: -------------------------------------------------------------------------------- 1 | params = { 2 | 'type': 'MBPO', 3 | 'universe': 'gym', 4 | 'domain': 'InvertedPendulum', 5 | 'task': 'v2', 6 | 7 | 'log_dir': '~/ray_mbpo/', 8 | 'exp_name': 'defaults', 9 | 10 | 'kwargs': { 11 | 'n_epochs': 80, ## 20k steps 12 | 'epoch_length': 250, 13 | 'train_every_n_steps': 1, 14 | 'n_train_repeat': 10, 15 | 'eval_render_mode': None, 16 | 'eval_n_episodes': 1, 17 | 'eval_deterministic': True, 18 | 19 | 'discount': 0.99, 20 | 'tau': 5e-3, 21 | 'reward_scale': 1.0, 22 | 23 | 'model_train_freq': 250, 24 | 'model_retain_epochs': 1, 25 | 'rollout_batch_size': 100e3, 26 | 'deterministic': False, 27 | 'num_networks': 7, 28 | 'num_elites': 5, 29 | 'real_ratio': 0.05, 30 | 'target_entropy': -0.05, 31 | 'max_model_t': None, 32 | 'rollout_schedule': [1, 15, 1, 1], 33 | 'hidden_dim': 200, 34 | 'n_initial_exploration_steps': 500, 35 | } 36 | } -------------------------------------------------------------------------------- /examples/config/inverted_double_pendulum/0.py: -------------------------------------------------------------------------------- 1 | params = { 2 | 'type': 'MBPO', 3 | 'universe': 'gym', 4 | 'domain': 'InvertedDoublePendulum', 5 | 'task': 'v2', 6 | 7 | 'log_dir': '~/ray_mbpo/', 8 | 'exp_name': 'defaults', 9 | 10 | 'kwargs': { 11 | 'n_epochs': 80, ## 20k steps 12 | 'epoch_length': 250, 13 | 'train_every_n_steps': 1, 14 | 'n_train_repeat': 20, 15 | 'eval_render_mode': None, 16 | 'eval_n_episodes': 1, 17 | 'eval_deterministic': True, 18 | 19 | 'discount': 0.99, 20 | 'tau': 5e-3, 21 | 'reward_scale': 1.0, 22 | 23 | 'model_train_freq': 250, 24 | 'model_retain_epochs': 1, 25 | 'rollout_batch_size': 100e3, 26 | 'deterministic': False, 27 | 'num_networks': 7, 28 | 'num_elites': 5, 29 | 'real_ratio': 0.05, 30 | 'target_entropy': -0.5, 31 | 'max_model_t': None, 32 | 'rollout_schedule': [1, 15, 1, 1], 33 | 'hidden_dim': 200, 34 | 'n_initial_exploration_steps': 500, 35 | } 36 | } 37 | 38 | -------------------------------------------------------------------------------- /softlearning/replay_pools/utils.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | 3 | from . import ( 4 | simple_replay_pool, 5 | extra_policy_info_replay_pool, 6 | union_pool, 7 | trajectory_replay_pool) 8 | 9 | 10 | POOL_CLASSES = { 11 | 'SimpleReplayPool': simple_replay_pool.SimpleReplayPool, 12 | 'TrajectoryReplayPool': trajectory_replay_pool.TrajectoryReplayPool, 13 | 'ExtraPolicyInfoReplayPool': ( 14 | extra_policy_info_replay_pool.ExtraPolicyInfoReplayPool), 15 | 'UnionPool': union_pool.UnionPool, 16 | } 17 | 18 | DEFAULT_REPLAY_POOL = 'SimpleReplayPool' 19 | 20 | 21 | def get_replay_pool_from_variant(variant, env, *args, **kwargs): 22 | replay_pool_params = variant['replay_pool_params'] 23 | replay_pool_type = replay_pool_params['type'] 24 | replay_pool_kwargs = deepcopy(replay_pool_params['kwargs']) 25 | 26 | replay_pool = POOL_CLASSES[replay_pool_type]( 27 | *args, 28 | observation_space=env.observation_space, 29 | action_space=env.action_space, 30 | **replay_pool_kwargs, 31 | **kwargs) 32 | 33 | return replay_pool 34 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Michael Janner 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /softlearning/utils/keras.py: -------------------------------------------------------------------------------- 1 | import tempfile 2 | 3 | import tensorflow as tf 4 | 5 | 6 | class PicklableKerasModel(tf.keras.Model): 7 | def __getstate__(self): 8 | with tempfile.NamedTemporaryFile(suffix='.hdf5', delete=True) as fd: 9 | tf.keras.models.save_model(self, fd.name, overwrite=True) 10 | model_str = fd.read() 11 | d = {'model_str': model_str} 12 | 13 | return d 14 | 15 | def __setstate__(self, state): 16 | with tempfile.NamedTemporaryFile(suffix='.hdf5', delete=True) as fd: 17 | fd.write(state['model_str']) 18 | fd.flush() 19 | 20 | loaded_model = tf.keras.models.load_model( 21 | fd.name, custom_objects={ 22 | self.__class__.__name__: self.__class__}) 23 | 24 | self.__dict__.update(loaded_model.__dict__.copy()) 25 | 26 | @classmethod 27 | def from_config(cls, *args, custom_objects=None, **kwargs): 28 | custom_objects = custom_objects or {} 29 | custom_objects[cls.__name__] = cls 30 | custom_objects['tf'] = tf 31 | return super(PicklableKerasModel, cls).from_config( 32 | *args, custom_objects=custom_objects, **kwargs) 33 | -------------------------------------------------------------------------------- /softlearning/replay_pools/replay_pool.py: -------------------------------------------------------------------------------- 1 | import abc 2 | 3 | 4 | class ReplayPool(object): 5 | """A class used to save and replay data.""" 6 | 7 | @abc.abstractmethod 8 | def add_sample(self, sample): 9 | """Add a transition tuple.""" 10 | pass 11 | 12 | @abc.abstractmethod 13 | def terminate_episode(self): 14 | """Clean up pool after episode termination.""" 15 | pass 16 | 17 | @property 18 | @abc.abstractmethod 19 | def size(self, **kwargs): 20 | pass 21 | 22 | def add_path(self, path): 23 | """Add a rollout to the replay pool. 24 | 25 | This default implementation naively goes through every step, but you 26 | may want to optimize this. 27 | 28 | NOTE: You should NOT call "terminate_episode" after calling add_path. 29 | It's assumed that this function handles the episode termination. 30 | 31 | :param path: Dict like one outputted by railrl.samplers.util.rollout 32 | """ 33 | self.add_samples(path) 34 | self.terminate_episode() 35 | 36 | @abc.abstractmethod 37 | def random_batch(self, batch_size): 38 | """Return a random batch of size `batch_size`.""" 39 | pass 40 | -------------------------------------------------------------------------------- /examples/config/__init__.py: -------------------------------------------------------------------------------- 1 | params = { 2 | 'type': 'MBPO', 3 | 'universe': 'gym', 4 | 'domain': 'Hopper', 5 | 'task': 'v2', 6 | 7 | 'log_dir': '~/ray_mbpo/', 8 | 'exp_name': 'defaults', 9 | 10 | 'kwargs': { 11 | 'epoch_length': 1000, 12 | 'train_every_n_steps': 1, 13 | 'n_train_repeat': 2, #20, 14 | 'eval_render_mode': None, 15 | 'eval_n_episodes': 1, 16 | 'eval_deterministic': True, 17 | 18 | 'discount': 0.99, 19 | 'tau': 5e-3, 20 | 'reward_scale': 1.0, 21 | #### 22 | 'model_reset_freq': 1000, 23 | 'model_train_freq': 250, # 250 24 | # 'retain_model_epochs': 2, 25 | 'model_pool_size': 2e6, 26 | 'rollout_batch': 100e3, # 40e3 27 | 'rollout_length': 1, 28 | 'deterministic': False, 29 | 'num_networks': 7, 30 | 'num_elites': 5, 31 | 'real_ratio': 0.05, 32 | 'entropy_mult': 0.5, 33 | # 'target_entropy': -1.5, 34 | 'max_model_t': 1e10, 35 | # 'max_dev': 0.25, 36 | # 'marker': 'early-stop_10rep_stochastic', 37 | 'rollout_length_params': [20, 150, 1, 1], ## epoch, loss, length 38 | # 'marker': 'dump', 39 | } 40 | } -------------------------------------------------------------------------------- /softlearning/value_functions/vanilla.py: -------------------------------------------------------------------------------- 1 | from softlearning.models.feedforward import feedforward_model 2 | 3 | 4 | def create_feedforward_Q_function(observation_shape, 5 | action_shape, 6 | *args, 7 | observation_preprocessor=None, 8 | name='feedforward_Q', 9 | **kwargs): 10 | input_shapes = (observation_shape, action_shape) 11 | preprocessors = (observation_preprocessor, None) 12 | return feedforward_model( 13 | input_shapes, 14 | *args, 15 | output_size=1, 16 | preprocessors=preprocessors, 17 | name=name, 18 | **kwargs) 19 | 20 | 21 | def create_feedforward_V_function(observation_shape, 22 | *args, 23 | observation_preprocessor=None, 24 | name='feedforward_V', 25 | **kwargs): 26 | input_shapes = (observation_shape, ) 27 | preprocessors = (observation_preprocessor, None) 28 | return feedforward_model( 29 | input_shapes, 30 | *args, 31 | output_size=1, 32 | preprocessors=preprocessors, 33 | **kwargs) 34 | -------------------------------------------------------------------------------- /softlearning/algorithms/utils.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | 3 | 4 | def create_SAC_algorithm(variant, *args, **kwargs): 5 | from .sac import SAC 6 | 7 | algorithm = SAC(*args, **kwargs) 8 | 9 | return algorithm 10 | 11 | 12 | def create_SQL_algorithm(variant, *args, **kwargs): 13 | from .sql import SQL 14 | 15 | algorithm = SQL(*args, **kwargs) 16 | 17 | return algorithm 18 | 19 | def create_MVE_algorithm(variant, *args, **kwargs): 20 | from .mve_sac import MVESAC 21 | 22 | algorithm = MVESAC(*args, **kwargs) 23 | 24 | return algorithm 25 | 26 | def create_MBPO_algorithm(variant, *args, **kwargs): 27 | from mbpo.algorithms.mbpo import MBPO 28 | 29 | algorithm = MBPO(*args, **kwargs) 30 | 31 | return algorithm 32 | 33 | 34 | ALGORITHM_CLASSES = { 35 | 'SAC': create_SAC_algorithm, 36 | 'SQL': create_SQL_algorithm, 37 | 'MBPO': create_MBPO_algorithm, 38 | } 39 | 40 | 41 | def get_algorithm_from_variant(variant, 42 | *args, 43 | **kwargs): 44 | algorithm_params = variant['algorithm_params'] 45 | algorithm_type = algorithm_params['type'] 46 | algorithm_kwargs = deepcopy(algorithm_params['kwargs']) 47 | algorithm = ALGORITHM_CLASSES[algorithm_type]( 48 | variant, *args, **algorithm_kwargs, **kwargs) 49 | 50 | return algorithm 51 | -------------------------------------------------------------------------------- /examples/development/__init__.py: -------------------------------------------------------------------------------- 1 | """Provides functions that are utilized by the command line interface. 2 | 3 | In particular, the examples are exposed to the command line interface 4 | (defined in `softlearning.scripts.console_scripts`) through the 5 | `get_trainable_class`, `get_variant_spec`, and `get_parser` functions. 6 | """ 7 | 8 | 9 | def get_trainable_class(*args, **kwargs): 10 | from .main import ExperimentRunner 11 | return ExperimentRunner 12 | 13 | 14 | # def get_variant_spec(command_line_args, *args, **kwargs): 15 | # from .variants import get_variant_spec 16 | # variant_spec = get_variant_spec(command_line_args, *args, **kwargs) 17 | # return variant_spec 18 | 19 | def get_params_from_file(filepath, params_name='params'): 20 | import importlib 21 | from dotmap import DotMap 22 | module = importlib.import_module(filepath) 23 | params = getattr(module, params_name) 24 | params = DotMap(params) 25 | return params 26 | 27 | def get_variant_spec(command_line_args, *args, **kwargs): 28 | from .base import get_variant_spec 29 | import importlib 30 | params = get_params_from_file(command_line_args.config) 31 | variant_spec = get_variant_spec(command_line_args, *args, params, **kwargs) 32 | return variant_spec 33 | 34 | def get_parser(): 35 | from examples.utils import get_parser 36 | parser = get_parser() 37 | return parser 38 | -------------------------------------------------------------------------------- /softlearning/replay_pools/union_pool.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from .replay_pool import ReplayPool 4 | 5 | 6 | class UnionPool(ReplayPool): 7 | def __init__(self, pools): 8 | pool_sizes = np.array([b.size for b in pools]) 9 | self._total_size = sum(pool_sizes) 10 | self._normalized_pool_sizes = pool_sizes / self._total_size 11 | 12 | self.pools = pools 13 | 14 | def add_sample(self, *args, **kwargs): 15 | raise NotImplementedError 16 | 17 | def terminate_episode(self): 18 | raise NotImplementedError 19 | 20 | @property 21 | def size(self): 22 | return self._total_size 23 | 24 | def add_path(self, **kwargs): 25 | raise NotImplementedError 26 | 27 | def random_batch(self, batch_size): 28 | 29 | # TODO: Hack 30 | partial_batch_sizes = self._normalized_pool_sizes * batch_size 31 | partial_batch_sizes = partial_batch_sizes.astype(int) 32 | partial_batch_sizes[0] = batch_size - sum(partial_batch_sizes[1:]) 33 | 34 | partial_batches = [ 35 | pool.random_batch(partial_batch_size) for pool, 36 | partial_batch_size in zip(self.pools, partial_batch_sizes) 37 | ] 38 | 39 | def all_values(key): 40 | return [partial_batch[key] for partial_batch in partial_batches] 41 | 42 | keys = partial_batches[0].keys() 43 | 44 | return {key: np.concatenate(all_values(key), axis=0) for key in keys} 45 | -------------------------------------------------------------------------------- /softlearning/models/feedforward.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | 4 | from softlearning.utils.keras import PicklableKerasModel 5 | 6 | 7 | def feedforward_model(input_shapes, 8 | output_size, 9 | hidden_layer_sizes, 10 | activation='relu', 11 | output_activation='linear', 12 | preprocessors=None, 13 | name='feedforward_model', 14 | *args, 15 | **kwargs): 16 | inputs = [ 17 | tf.keras.layers.Input(shape=input_shape) 18 | for input_shape in input_shapes 19 | ] 20 | 21 | if preprocessors is None: 22 | preprocessors = (None, ) * len(inputs) 23 | 24 | preprocessed_inputs = [ 25 | preprocessor(input_) if preprocessor is not None else input_ 26 | for preprocessor, input_ in zip(preprocessors, inputs) 27 | ] 28 | 29 | concatenated = tf.keras.layers.Lambda( 30 | lambda x: tf.concat(x, axis=-1) 31 | )(preprocessed_inputs) 32 | 33 | out = concatenated 34 | for units in hidden_layer_sizes: 35 | out = tf.keras.layers.Dense( 36 | units, *args, activation=activation, **kwargs 37 | )(out) 38 | 39 | out = tf.keras.layers.Dense( 40 | output_size, *args, activation=output_activation, **kwargs 41 | )(out) 42 | 43 | model = PicklableKerasModel(inputs, out, name=name) 44 | 45 | return model 46 | -------------------------------------------------------------------------------- /mbpo/models/constructor.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | 4 | from mbpo.models.fc import FC 5 | from mbpo.models.bnn import BNN 6 | 7 | def construct_model(obs_dim=11, act_dim=3, rew_dim=1, hidden_dim=200, num_networks=7, num_elites=5, session=None): 8 | print('[ BNN ] Observation dim {} | Action dim: {} | Hidden dim: {}'.format(obs_dim, act_dim, hidden_dim)) 9 | params = {'name': 'BNN', 'num_networks': num_networks, 'num_elites': num_elites, 'sess': session} 10 | model = BNN(params) 11 | 12 | model.add(FC(hidden_dim, input_dim=obs_dim+act_dim, activation="swish", weight_decay=0.000025)) 13 | model.add(FC(hidden_dim, activation="swish", weight_decay=0.00005)) 14 | model.add(FC(hidden_dim, activation="swish", weight_decay=0.000075)) 15 | model.add(FC(hidden_dim, activation="swish", weight_decay=0.000075)) 16 | model.add(FC(obs_dim+rew_dim, weight_decay=0.0001)) 17 | model.finalize(tf.train.AdamOptimizer, {"learning_rate": 0.001}) 18 | return model 19 | 20 | def format_samples_for_training(samples): 21 | obs = samples['observations'] 22 | act = samples['actions'] 23 | next_obs = samples['next_observations'] 24 | rew = samples['rewards'] 25 | delta_obs = next_obs - obs 26 | inputs = np.concatenate((obs, act), axis=-1) 27 | outputs = np.concatenate((rew, delta_obs), axis=-1) 28 | return inputs, outputs 29 | 30 | def reset_model(model): 31 | model_vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope=model.name) 32 | model.sess.run(tf.initialize_vars(model_vars)) 33 | 34 | if __name__ == '__main__': 35 | model = construct_model() 36 | -------------------------------------------------------------------------------- /softlearning/policies/utils.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | 3 | from softlearning.preprocessors.utils import get_preprocessor_from_params 4 | 5 | 6 | def get_gaussian_policy(env, Q, **kwargs): 7 | from .gaussian_policy import FeedforwardGaussianPolicy 8 | policy = FeedforwardGaussianPolicy( 9 | input_shapes=(env.active_observation_shape, ), 10 | output_shape=env.action_space.shape, 11 | **kwargs) 12 | 13 | return policy 14 | 15 | 16 | def get_uniform_policy(env, *args, **kwargs): 17 | from .uniform_policy import UniformPolicy 18 | policy = UniformPolicy( 19 | input_shapes=(env.active_observation_shape, ), 20 | output_shape=env.action_space.shape) 21 | 22 | return policy 23 | 24 | 25 | POLICY_FUNCTIONS = { 26 | 'GaussianPolicy': get_gaussian_policy, 27 | 'UniformPolicy': get_uniform_policy, 28 | } 29 | 30 | 31 | def get_policy(policy_type, *args, **kwargs): 32 | return POLICY_FUNCTIONS[policy_type](*args, **kwargs) 33 | 34 | 35 | def get_policy_from_variant(variant, env, Qs, *args, **kwargs): 36 | policy_params = variant['policy_params'] 37 | policy_type = policy_params['type'] 38 | policy_kwargs = deepcopy(policy_params['kwargs']) 39 | 40 | preprocessor_params = policy_kwargs.pop('preprocessor_params', None) 41 | preprocessor = get_preprocessor_from_params(env, preprocessor_params) 42 | 43 | policy = POLICY_FUNCTIONS[policy_type]( 44 | env, 45 | *args, 46 | Q=Qs[0], 47 | preprocessor=preprocessor, 48 | **policy_kwargs, 49 | **kwargs) 50 | 51 | return policy 52 | -------------------------------------------------------------------------------- /softlearning/preprocessors/utils.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | 3 | 4 | def get_convnet_preprocessor(observation_shape, 5 | name='convnet_preprocessor', 6 | **kwargs): 7 | from .convnet import convnet_preprocessor 8 | preprocessor = convnet_preprocessor( 9 | input_shapes=(observation_shape, ), name=name, **kwargs) 10 | 11 | return preprocessor 12 | 13 | 14 | def get_feedforward_preprocessor(observation_shape, 15 | name='feedforward_preprocessor', 16 | **kwargs): 17 | from softlearning.models.feedforward import feedforward_model 18 | preprocessor = feedforward_model( 19 | input_shapes=(observation_shape, ), name=name, **kwargs) 20 | 21 | return preprocessor 22 | 23 | 24 | PREPROCESSOR_FUNCTIONS = { 25 | 'convnet_preprocessor': get_convnet_preprocessor, 26 | 'feedforward_preprocessor': get_feedforward_preprocessor, 27 | None: lambda *args, **kwargs: None 28 | } 29 | 30 | 31 | def get_preprocessor_from_params(env, preprocessor_params, *args, **kwargs): 32 | if preprocessor_params is None: 33 | return None 34 | 35 | preprocessor_type = preprocessor_params.get('type', None) 36 | preprocessor_kwargs = deepcopy(preprocessor_params.get('kwargs', {})) 37 | 38 | if preprocessor_type is None: 39 | return None 40 | 41 | preprocessor = PREPROCESSOR_FUNCTIONS[ 42 | preprocessor_type]( 43 | env.active_observation_shape, 44 | *args, 45 | **preprocessor_kwargs, 46 | **kwargs) 47 | 48 | return preprocessor 49 | 50 | 51 | def get_preprocessor_from_variant(variant, env, *args, **kwargs): 52 | preprocessor_params = variant['preprocessor_params'] 53 | return get_preprocessor_from_params( 54 | env, preprocessor_params, *args, **kwargs) 55 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.pkl 2 | *.stl 3 | 4 | # Byte-compiled / optimized / DLL files 5 | __pycache__/ 6 | *.py[cod] 7 | *$py.class 8 | 9 | # C extensions 10 | *.so 11 | 12 | # Distribution / packaging 13 | .Python 14 | build/ 15 | develop-eggs/ 16 | dist/ 17 | downloads/ 18 | eggs/ 19 | .eggs/ 20 | lib/ 21 | lib64/ 22 | parts/ 23 | sdist/ 24 | var/ 25 | wheels/ 26 | *.egg-info/ 27 | .installed.cfg 28 | *.egg 29 | MANIFEST 30 | 31 | # PyInstaller 32 | # Usually these files are written by a python script from a template 33 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 34 | *.manifest 35 | *.spec 36 | 37 | # Installer logs 38 | pip-log.txt 39 | pip-delete-this-directory.txt 40 | 41 | # Unit test / coverage reports 42 | htmlcov/ 43 | .tox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | 53 | # Translations 54 | *.mo 55 | *.pot 56 | 57 | # Django stuff: 58 | *.log 59 | local_settings.py 60 | db.sqlite3 61 | 62 | # Flask stuff: 63 | instance/ 64 | .webassets-cache 65 | 66 | # Scrapy stuff: 67 | .scrapy 68 | 69 | # Sphinx documentation 70 | docs/_build/ 71 | 72 | # PyBuilder 73 | target/ 74 | 75 | # Jupyter Notebook 76 | .ipynb_checkpoints 77 | 78 | # pyenv 79 | .python-version 80 | 81 | # celery beat schedule file 82 | celerybeat-schedule 83 | 84 | # SageMath parsed files 85 | *.sage.py 86 | 87 | # Environments 88 | .venv 89 | env/ 90 | venv/ 91 | ENV/ 92 | env.bak/ 93 | venv.bak/ 94 | 95 | # Spyder project settings 96 | .spyderproject 97 | .spyproject 98 | 99 | # Rope project settings 100 | .ropeproject 101 | 102 | # mkdocs documentation 103 | /site 104 | 105 | # mypy 106 | .mypy_cache/ 107 | 108 | # soft learning specific things 109 | *.swp 110 | .idea 111 | *.mp4 112 | data/ 113 | vis/ 114 | tmp/ 115 | vendor/* 116 | .pkl 117 | 118 | 119 | .mujoco/ 120 | .vscode/ 121 | -------------------------------------------------------------------------------- /softlearning/value_functions/utils.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | 3 | from softlearning.preprocessors.utils import get_preprocessor_from_params 4 | from . import vanilla 5 | 6 | 7 | def create_double_value_function(value_fn, *args, **kwargs): 8 | # TODO(hartikainen): The double Q-function should support the same 9 | # interface as the regular ones. Implement the double min-thing 10 | # as a Keras layer. 11 | value_fns = tuple(value_fn(*args, **kwargs) for i in range(2)) 12 | return value_fns 13 | 14 | 15 | VALUE_FUNCTIONS = { 16 | 'feedforward_V_function': ( 17 | vanilla.create_feedforward_V_function), 18 | 'double_feedforward_Q_function': lambda *args, **kwargs: ( 19 | create_double_value_function( 20 | vanilla.create_feedforward_Q_function, *args, **kwargs)), 21 | } 22 | 23 | 24 | def get_Q_function_from_variant(variant, env, *args, **kwargs): 25 | Q_params = variant['Q_params'] 26 | Q_type = Q_params['type'] 27 | Q_kwargs = deepcopy(Q_params['kwargs']) 28 | 29 | preprocessor_params = Q_kwargs.pop('preprocessor_params', None) 30 | preprocessor = get_preprocessor_from_params(env, preprocessor_params) 31 | 32 | return VALUE_FUNCTIONS[Q_type]( 33 | observation_shape=env.active_observation_shape, 34 | action_shape=env.action_space.shape, 35 | *args, 36 | observation_preprocessor=preprocessor, 37 | **Q_kwargs, 38 | **kwargs) 39 | 40 | 41 | def get_V_function_from_variant(variant, env, *args, **kwargs): 42 | V_params = variant['V_params'] 43 | V_type = V_params['type'] 44 | V_kwargs = deepcopy(V_params['kwargs']) 45 | 46 | preprocessor_params = V_kwargs.pop('preprocessor_params', None) 47 | preprocessor = get_preprocessor_from_params(env, preprocessor_params) 48 | 49 | return VALUE_FUNCTIONS[V_type]( 50 | observation_shape=env.active_observation_shape, 51 | *args, 52 | observation_preprocessor=preprocessor, 53 | **V_kwargs, 54 | **kwargs) 55 | -------------------------------------------------------------------------------- /softlearning/value_functions/value_function.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | 4 | from serializable import Serializable 5 | 6 | 7 | class SumQFunction(Serializable): 8 | def __init__(self, 9 | observation_shape, 10 | action_shape, 11 | q_functions): 12 | self._Serializable__initialize(locals()) 13 | 14 | self.q_functions = q_functions 15 | 16 | assert len(observation_shape) == 1, observation_shape 17 | self._Do = observation_shape[0] 18 | assert len(action_shape) == 1, action_shape 19 | self._Da = action_shape[0] 20 | 21 | self._observations_ph = tf.placeholder( 22 | tf.float32, shape=(None, self._Do), name='observations') 23 | self._actions_ph = tf.placeholder( 24 | tf.float32, shape=(None, self._Da), name='actions') 25 | 26 | self._output = self.output_for( 27 | self._observations_ph, self._actions_ph, reuse=True) 28 | 29 | def output_for(self, observations, actions, reuse=False): 30 | outputs = [ 31 | qf.output_for(observations, actions, reuse=reuse) 32 | for qf in self.q_functions 33 | ] 34 | output = tf.add_n(outputs) 35 | return output 36 | 37 | def _eval(self, observations, actions): 38 | feeds = { 39 | self._observations_ph: observations, 40 | self._actions_ph: actions 41 | } 42 | 43 | return tf.keras.backend.get_session().run(self._output, feeds) 44 | 45 | def get_param_values(self): 46 | all_values_list = [qf.get_param_values() for qf in self.q_functions] 47 | 48 | return np.concatenate(all_values_list) 49 | 50 | def set_param_values(self, all_values): 51 | param_sizes = [qf.get_param_values().size for qf in self.q_functions] 52 | split_points = np.cumsum(param_sizes)[:-1] 53 | 54 | all_values_list = np.split(all_values, split_points) 55 | 56 | for values, qf in zip(all_values_list, self.q_functions): 57 | qf.set_param_values(values) 58 | -------------------------------------------------------------------------------- /mbpo/env/ant.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from gym import utils 3 | from gym.envs.mujoco import mujoco_env 4 | 5 | class AntTruncatedObsEnv(mujoco_env.MujocoEnv, utils.EzPickle): 6 | """ 7 | External forces (sim.data.cfrc_ext) are removed from the observation. 8 | Otherwise identical to Ant-v2 from 9 | https://github.com/openai/gym/blob/master/gym/envs/mujoco/ant.py 10 | """ 11 | def __init__(self): 12 | mujoco_env.MujocoEnv.__init__(self, 'ant.xml', 5) 13 | utils.EzPickle.__init__(self) 14 | 15 | def step(self, a): 16 | xposbefore = self.get_body_com("torso")[0] 17 | self.do_simulation(a, self.frame_skip) 18 | xposafter = self.get_body_com("torso")[0] 19 | forward_reward = (xposafter - xposbefore)/self.dt 20 | ctrl_cost = .5 * np.square(a).sum() 21 | contact_cost = 0.5 * 1e-3 * np.sum( 22 | np.square(np.clip(self.sim.data.cfrc_ext, -1, 1))) 23 | survive_reward = 1.0 24 | reward = forward_reward - ctrl_cost - contact_cost + survive_reward 25 | state = self.state_vector() 26 | notdone = np.isfinite(state).all() \ 27 | and state[2] >= 0.2 and state[2] <= 1.0 28 | done = not notdone 29 | ob = self._get_obs() 30 | return ob, reward, done, dict( 31 | reward_forward=forward_reward, 32 | reward_ctrl=-ctrl_cost, 33 | reward_contact=-contact_cost, 34 | reward_survive=survive_reward) 35 | 36 | def _get_obs(self): 37 | return np.concatenate([ 38 | self.sim.data.qpos.flat[2:], 39 | self.sim.data.qvel.flat, 40 | # np.clip(self.sim.data.cfrc_ext, -1, 1).flat, 41 | ]) 42 | 43 | def reset_model(self): 44 | qpos = self.init_qpos + self.np_random.uniform(size=self.model.nq, low=-.1, high=.1) 45 | qvel = self.init_qvel + self.np_random.randn(self.model.nv) * .1 46 | self.set_state(qpos, qvel) 47 | return self._get_obs() 48 | 49 | def viewer_setup(self): 50 | self.viewer.cam.distance = self.model.stat.extent * 0.5 -------------------------------------------------------------------------------- /softlearning/policies/uniform_policy.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | 3 | import tensorflow as tf 4 | 5 | from .base_policy import BasePolicy 6 | 7 | 8 | class UniformPolicy(BasePolicy): 9 | def __init__(self, input_shapes, output_shape, action_range=(-1.0, 1.0)): 10 | super(UniformPolicy, self).__init__() 11 | self._Serializable__initialize(locals()) 12 | 13 | self.inputs = [ 14 | tf.keras.layers.Input(shape=input_shape) 15 | for input_shape in input_shapes 16 | ] 17 | self._action_range = action_range 18 | 19 | x = tf.keras.layers.Lambda( 20 | lambda x: tf.concat(x, axis=-1) 21 | )(self.inputs) 22 | 23 | actions = tf.keras.layers.Lambda( 24 | lambda x: tf.random.uniform( 25 | (tf.shape(x)[0], output_shape[0]), 26 | *action_range) 27 | )(x) 28 | 29 | self.actions_model = tf.keras.Model(self.inputs, actions) 30 | 31 | self.actions_input = tf.keras.Input(shape=output_shape) 32 | 33 | log_pis = tf.keras.layers.Lambda( 34 | lambda x: tf.tile(tf.log([ 35 | (action_range[1] - action_range[0]) / 2.0 36 | ])[None], (tf.shape(x)[0], 1)) 37 | )(self.actions_input) 38 | 39 | self.log_pis_model = tf.keras.Model( 40 | (*self.inputs, self.actions_input), log_pis) 41 | 42 | def get_weights(self): 43 | return [] 44 | 45 | def set_weights(self, *args, **kwargs): 46 | return 47 | 48 | @property 49 | def trainable_variables(self): 50 | return [] 51 | 52 | def reset(self): 53 | pass 54 | 55 | def actions(self, conditions): 56 | return self.actions_model(conditions) 57 | 58 | def log_pis(self, conditions, actions): 59 | return self.log_pis_model([*conditions, actions]) 60 | 61 | def actions_np(self, conditions): 62 | return self.actions_model.predict(conditions) 63 | 64 | def log_pis_np(self, conditions, actions): 65 | return self.log_pis_model.predict([*conditions, actions]) 66 | 67 | def get_diagnostics(self, conditions): 68 | return OrderedDict({}) 69 | -------------------------------------------------------------------------------- /environment/requirements.txt: -------------------------------------------------------------------------------- 1 | tqdm==4.32.1 2 | gpflow==1.4.1 3 | flask==1.0.2 4 | tensorboardX==1.8 5 | absl-py==0.6.1 6 | asn1crypto==0.24.0 7 | astor==0.7.1 8 | atomicwrites==1.2.1 9 | attrs==18.2.0 10 | boto3==1.9.57 11 | botocore==1.12.57 12 | cachetools==3.0.0 13 | cffi==1.11.5 14 | chardet==3.0.4 15 | Click==7.0 16 | cloudpickle==0.6.1 17 | colorama==0.3.9 18 | cryptography==2.3.1 19 | cycler==0.10.0 20 | Cython==0.29.1 21 | dask==1.0.0 22 | decorator==4.3.0 23 | docutils==0.14 24 | dotmap==1.3.8 25 | deepdiff==3.3.0 26 | flatbuffers==1.10 27 | funcsigs==1.0.2 28 | future==0.17.1 29 | gast==0.2.0 30 | gitdb2==2.0.5 31 | GitPython==2.1.11 32 | glfw==1.7.0 33 | google-api-python-client==1.7.5 34 | google-auth==1.6.1 35 | google-auth-httplib2==0.0.3 36 | grpcio==1.16.1 37 | gtimer==1.0.0b5 38 | gym==0.12.0 39 | h5py==2.8.0 40 | httplib2==0.12.0 41 | idna==2.7 42 | imageio==2.4.1 43 | jmespath==0.9.3 44 | Keras-Applications==1.0.6 45 | Keras-Preprocessing==1.0.5 46 | kiwisolver==1.0.1 47 | lockfile==0.12.2 48 | Markdown==3.0.1 49 | matplotlib==3.0.2 50 | more-itertools==4.3.0 51 | git+https://github.com/jannerm/mujoco-py.git@v1.50.1.68 52 | git+https://github.com/vitchyr/multiworld.git@d76b3dae2e8cbca02924f93d6cc0239c552f6408 53 | networkx==2.2 54 | numpy==1.15.4 55 | pandas==0.23.4 56 | Pillow==6.2.0 57 | plotly==1.9.6 58 | pluggy==0.8.0 59 | protobuf==3.6.1 60 | py==1.7.0 61 | pyasn1==0.4.4 62 | pyasn1-modules==0.2.2 63 | pycosat==0.6.3 64 | pycparser==2.19 65 | pygame==1.9.4 66 | pyglet==1.3.2 67 | pyOpenSSL==18.0.0 68 | pyparsing==2.3.0 69 | PySocks==1.6.8 70 | pytest==4.0.1 71 | python-dateutil==2.7.5 72 | pytz==2018.7 73 | PyWavelets==1.0.1 74 | PyYAML==4.2b4 75 | psutil==5.4.8 76 | ray[rllib,debug]==0.6.4 77 | redis==3.0.1 78 | requests==2.20.1 79 | rsa==3.4.2 80 | s3transfer==0.1.13 81 | scikit-image==0.14.2 82 | scikit-learn==0.20.1 83 | scipy==1.1.0 84 | git+https://github.com/hartikainen/serializable.git@76516385a3a716ed4a2a9ad877e2d5cbcf18d4e6 85 | setproctitle==1.1.10 86 | six==1.11.0 87 | smmap2==2.0.5 88 | tensorboard==1.13.1 89 | tensorflow-gpu==1.13.1 90 | tensorflow-estimator==1.13.0 91 | tensorflow-probability==0.6.0 92 | termcolor==1.1.0 93 | toolz==0.9.0 94 | uritemplate==3.0.0 95 | urllib3==1.24.2 96 | Werkzeug==0.15.3 -------------------------------------------------------------------------------- /mbpo/utils/writer.py: -------------------------------------------------------------------------------- 1 | import io 2 | import numpy as np 3 | import cv2 4 | import pdb 5 | 6 | import matplotlib 7 | matplotlib.use('Agg') 8 | from mpl_toolkits.mplot3d import Axes3D 9 | import matplotlib.pyplot as plt 10 | import matplotlib.cm as cm 11 | 12 | import tensorboardX as tbx 13 | 14 | class Writer(): 15 | 16 | def __init__(self, log_dir): 17 | self.log_dir = log_dir 18 | self._writer = tbx.SummaryWriter(self.log_dir) 19 | self._data = {} 20 | self._data_3d = {} 21 | print('[ Writer ] Log dir: {}'.format(log_dir)) 22 | 23 | def __getitem__(self, key): 24 | if key in self._data: 25 | return self._data[key] 26 | else: 27 | return self._data_3d[key] 28 | 29 | def __getattr__(self, attr): 30 | return getattr(self._writer, attr) 31 | 32 | def _add_label(self, data, label): 33 | if label not in data: 34 | data[label] = 0 35 | 36 | def add_scalar(self, label, val, epoch): 37 | self._add_label(self._data, label) 38 | if epoch > self._data[label]: 39 | self._data[label] = epoch 40 | self._writer.add_scalar(label, val, epoch) 41 | 42 | def plot_cdfs(self, label, epoch, env_mean, model_mean, env_paths, model_paths): 43 | plt.clf() 44 | plt.plot(env_mean, linewidth=2, label='env', c='k') 45 | plt.plot(model_mean, linewidth=2, label='model', c='b') 46 | 47 | for path in env_paths: 48 | plt.plot(path['rewards'].cumsum(), alpha=0.5, c='k') 49 | for path in model_paths: 50 | plt.plot(path['rewards'].cumsum(), alpha=0.5, c='b') 51 | 52 | plt.ylabel('cumulative return') 53 | plt.xlabel('step') 54 | plt.legend() 55 | self._savefig(label, epoch) 56 | 57 | def _savefig(self, label, epoch): 58 | buf = io.BytesIO() 59 | plt.savefig(buf, format='png', layout = 'tight') 60 | buf.seek(0) 61 | img = cv2.imdecode(np.fromstring(buf.getvalue(), dtype=np.uint8), -1) 62 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 63 | img = img.transpose(2,0,1) / 255. 64 | self._writer.add_image(label, img, epoch) 65 | return img 66 | 67 | 68 | -------------------------------------------------------------------------------- /softlearning/samplers/base_sampler.py: -------------------------------------------------------------------------------- 1 | from collections import deque, OrderedDict 2 | from itertools import islice 3 | 4 | 5 | class BaseSampler(object): 6 | def __init__(self, 7 | max_path_length, 8 | min_pool_size, 9 | batch_size, 10 | store_last_n_paths=10): 11 | self._max_path_length = max_path_length 12 | self._min_pool_size = min_pool_size 13 | self._batch_size = batch_size 14 | self._store_last_n_paths = store_last_n_paths 15 | self._last_n_paths = deque(maxlen=store_last_n_paths) 16 | 17 | self.env = None 18 | self.policy = None 19 | self.pool = None 20 | 21 | def initialize(self, env, policy, pool): 22 | self.env = env 23 | self.policy = policy 24 | self.pool = pool 25 | 26 | def set_policy(self, policy): 27 | self.policy = policy 28 | 29 | def clear_last_n_paths(self): 30 | self._last_n_paths.clear() 31 | 32 | def get_last_n_paths(self, n=None): 33 | if n is None: 34 | n = self._store_last_n_paths 35 | 36 | last_n_paths = tuple(islice(self._last_n_paths, None, n)) 37 | 38 | return last_n_paths 39 | 40 | def sample(self): 41 | raise NotImplementedError 42 | 43 | def batch_ready(self): 44 | enough_samples = self.pool.size >= self._min_pool_size 45 | return enough_samples 46 | 47 | def random_batch(self, batch_size=None, **kwargs): 48 | batch_size = batch_size or self._batch_size 49 | return self.pool.random_batch(batch_size, **kwargs) 50 | 51 | def terminate(self): 52 | self.env.close() 53 | 54 | def get_diagnostics(self): 55 | diagnostics = OrderedDict({'pool-size': self.pool.size}) 56 | return diagnostics 57 | 58 | def __getstate__(self): 59 | state = { 60 | key: value for key, value in self.__dict__.items() 61 | if key not in ('env', 'policy', 'pool') 62 | } 63 | 64 | return state 65 | 66 | def __setstate__(self, state): 67 | self.__dict__.update(state) 68 | 69 | self.env = None 70 | self.policy = None 71 | self.pool = None 72 | -------------------------------------------------------------------------------- /softlearning/samplers/extra_policy_info_sampler.py: -------------------------------------------------------------------------------- 1 | """Sampler that stores raw actions and log pis from policy.""" 2 | 3 | 4 | from collections import defaultdict 5 | 6 | import numpy as np 7 | 8 | from .simple_sampler import SimpleSampler 9 | 10 | 11 | class ExtraPolicyInfoSampler(SimpleSampler): 12 | def sample(self): 13 | if self._current_observation is None: 14 | self._current_observation = self.env.reset() 15 | 16 | observations = self.env.convert_to_active_observation( 17 | self._current_observation)[None] 18 | actions = self.policy.actions_np([observations]) 19 | log_pis = self.policy.log_pis_np([observations], actions) 20 | 21 | action = actions[0] 22 | log_pi = log_pis[0] 23 | 24 | next_observation, reward, terminal, info = self.env.step(action) 25 | self._path_length += 1 26 | self._path_return += reward 27 | self._total_samples += 1 28 | 29 | self._current_path['observations'].append(self._current_observation) 30 | self._current_path['actions'].append(action) 31 | self._current_path['rewards'].append([reward]) 32 | self._current_path['terminals'].append([terminal]) 33 | self._current_path['next_observations'].append(next_observation) 34 | self._current_path['infos'].append(info) 35 | # self._current_path['raw_actions'].append(raw_action) 36 | self._current_path['log_pis'].append(log_pi) 37 | 38 | if terminal or self._path_length >= self._max_path_length: 39 | last_path = { 40 | field_name: np.array(values) 41 | for field_name, values in self._current_path.items() 42 | } 43 | self.pool.add_path(last_path) 44 | self._last_n_paths.appendleft(last_path) 45 | 46 | self.policy.reset() 47 | self._current_observation = self.env.reset() 48 | 49 | self._max_path_return = max(self._max_path_return, 50 | self._path_return) 51 | self._last_path_return = self._path_return 52 | 53 | self._path_length = 0 54 | self._path_return = 0 55 | self._current_path = defaultdict(list) 56 | 57 | self._n_episodes += 1 58 | else: 59 | self._current_observation = next_observation 60 | 61 | return self._current_observation, reward, terminal, info 62 | -------------------------------------------------------------------------------- /softlearning/preprocessors/convnet.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | from softlearning.models.feedforward import feedforward_model 4 | from softlearning.utils.keras import PicklableKerasModel 5 | 6 | 7 | def convnet_preprocessor( 8 | input_shapes, 9 | image_shape, 10 | output_size, 11 | conv_filters=(32, 32), 12 | conv_kernel_sizes=((5, 5), (5, 5)), 13 | pool_type='MaxPool2D', 14 | pool_sizes=((2, 2), (2, 2)), 15 | pool_strides=(2, 2), 16 | dense_hidden_layer_sizes=(64, 64), 17 | data_format='channels_last', 18 | name="convnet_preprocessor", 19 | make_picklable=True, 20 | *args, 21 | **kwargs): 22 | if data_format == 'channels_last': 23 | H, W, C = image_shape 24 | elif data_format == 'channels_first': 25 | C, H, W = image_shape 26 | 27 | inputs = [ 28 | tf.keras.layers.Input(shape=input_shape) 29 | for input_shape in input_shapes 30 | ] 31 | 32 | concatenated_input = tf.keras.layers.Lambda( 33 | lambda x: tf.concat(x, axis=-1) 34 | )(inputs) 35 | 36 | images_flat, input_raw = tf.keras.layers.Lambda( 37 | lambda x: [x[..., :H * W * C], x[..., H * W * C:]] 38 | )(concatenated_input) 39 | 40 | images = tf.keras.layers.Reshape(image_shape)(images_flat) 41 | 42 | conv_out = images 43 | for filters, kernel_size, pool_size, strides in zip( 44 | conv_filters, conv_kernel_sizes, pool_sizes, pool_strides): 45 | conv_out = tf.keras.layers.Conv2D( 46 | filters=filters, 47 | kernel_size=kernel_size, 48 | padding="SAME", 49 | activation=tf.nn.relu, 50 | *args, 51 | **kwargs 52 | )(conv_out) 53 | conv_out = getattr(tf.keras.layers, pool_type)( 54 | pool_size=pool_size, strides=strides 55 | )(conv_out) 56 | 57 | flattened = tf.keras.layers.Flatten()(conv_out) 58 | concatenated_output = tf.keras.layers.Lambda( 59 | lambda x: tf.concat(x, axis=-1) 60 | )([flattened, input_raw]) 61 | 62 | output = ( 63 | feedforward_model( 64 | input_shapes=(concatenated_output.shape[1:].as_list(), ), 65 | output_size=output_size, 66 | hidden_layer_sizes=dense_hidden_layer_sizes, 67 | activation='relu', 68 | output_activation='linear', 69 | *args, 70 | **kwargs 71 | )([concatenated_output]) 72 | if dense_hidden_layer_sizes 73 | else concatenated_output) 74 | 75 | model = PicklableKerasModel(inputs, output, name=name) 76 | 77 | return model 78 | -------------------------------------------------------------------------------- /mbpo/env/humanoid.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from gym.envs.mujoco import mujoco_env 3 | from gym import utils 4 | 5 | def mass_center(model, sim): 6 | mass = np.expand_dims(model.body_mass, 1) 7 | xpos = sim.data.xipos 8 | return (np.sum(mass * xpos, 0) / np.sum(mass))[0] 9 | 10 | class HumanoidTruncatedObsEnv(mujoco_env.MujocoEnv, utils.EzPickle): 11 | """ 12 | COM inertia (cinert), COM velocity (cvel), actuator forces (qfrc_actuator), 13 | and external forces (cfrc_ext) are removed from the observation. 14 | Otherwise identical to Humanoid-v2 from 15 | https://github.com/openai/gym/blob/master/gym/envs/mujoco/humanoid.py 16 | """ 17 | def __init__(self): 18 | mujoco_env.MujocoEnv.__init__(self, 'humanoid.xml', 5) 19 | utils.EzPickle.__init__(self) 20 | 21 | def _get_obs(self): 22 | data = self.sim.data 23 | return np.concatenate([data.qpos.flat[2:], 24 | data.qvel.flat, 25 | # data.cinert.flat, 26 | # data.cvel.flat, 27 | # data.qfrc_actuator.flat, 28 | # data.cfrc_ext.flat 29 | ]) 30 | 31 | def step(self, a): 32 | pos_before = mass_center(self.model, self.sim) 33 | self.do_simulation(a, self.frame_skip) 34 | pos_after = mass_center(self.model, self.sim) 35 | alive_bonus = 5.0 36 | data = self.sim.data 37 | lin_vel_cost = 0.25 * (pos_after - pos_before) / self.model.opt.timestep 38 | quad_ctrl_cost = 0.1 * np.square(data.ctrl).sum() 39 | quad_impact_cost = .5e-6 * np.square(data.cfrc_ext).sum() 40 | quad_impact_cost = min(quad_impact_cost, 10) 41 | reward = lin_vel_cost - quad_ctrl_cost - quad_impact_cost + alive_bonus 42 | qpos = self.sim.data.qpos 43 | done = bool((qpos[2] < 1.0) or (qpos[2] > 2.0)) 44 | return self._get_obs(), reward, done, dict(reward_linvel=lin_vel_cost, reward_quadctrl=-quad_ctrl_cost, reward_alive=alive_bonus, reward_impact=-quad_impact_cost) 45 | 46 | def reset_model(self): 47 | c = 0.01 48 | self.set_state( 49 | self.init_qpos + self.np_random.uniform(low=-c, high=c, size=self.model.nq), 50 | self.init_qvel + self.np_random.uniform(low=-c, high=c, size=self.model.nv,) 51 | ) 52 | return self._get_obs() 53 | 54 | def viewer_setup(self): 55 | self.viewer.cam.trackbodyid = 1 56 | self.viewer.cam.distance = self.model.stat.extent * 1.0 57 | self.viewer.cam.lookat[2] = 2.0 58 | self.viewer.cam.elevation = -20 59 | 60 | -------------------------------------------------------------------------------- /softlearning/misc/plotter.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | import matplotlib.pyplot as plt 4 | 5 | 6 | class QFPolicyPlotter: 7 | def __init__(self, Q, policy, obs_lst, default_action, n_samples): 8 | self._Q = Q 9 | self._policy = policy 10 | self._obs_lst = obs_lst 11 | self._default_action = np.array(default_action) 12 | self._n_samples = n_samples 13 | 14 | self._var_inds = np.where(np.isnan(default_action))[0] 15 | assert len(self._var_inds) == 2 16 | 17 | n_plots = len(obs_lst) 18 | 19 | x_size = 5 * n_plots 20 | y_size = 5 21 | 22 | fig = plt.figure(figsize=(x_size, y_size)) 23 | self._ax_lst = [] 24 | for i in range(n_plots): 25 | ax = fig.add_subplot(100 + n_plots * 10 + i + 1) 26 | ax.set_xlim((-1, 1)) 27 | ax.set_ylim((-1, 1)) 28 | ax.grid(True) 29 | self._ax_lst.append(ax) 30 | 31 | self._line_objects = list() 32 | 33 | def draw(self): 34 | # noinspection PyArgumentList 35 | [h.remove() for h in self._line_objects] 36 | self._line_objects = list() 37 | 38 | self._plot_level_curves() 39 | self._plot_action_samples() 40 | 41 | plt.draw() 42 | plt.pause(0.001) 43 | 44 | def _plot_level_curves(self): 45 | # Create mesh grid. 46 | xs = np.linspace(-1, 1, 50) 47 | ys = np.linspace(-1, 1, 50) 48 | xgrid, ygrid = np.meshgrid(xs, ys) 49 | N = len(xs)*len(ys) 50 | 51 | # Copy default values along the first axis and replace nans with 52 | # the mesh grid points. 53 | actions = np.tile(self._default_action.astype(np.float32), (N, 1)) 54 | actions[:, self._var_inds[0]] = xgrid.ravel() 55 | actions[:, self._var_inds[1]] = ygrid.ravel() 56 | 57 | for ax, obs in zip(self._ax_lst, self._obs_lst): 58 | observations = np.tile( 59 | obs[None].astype(np.float32), (actions.shape[0], 1)) 60 | 61 | Q_np = self._Q.predict((observations, actions)) 62 | Q_np = np.reshape(Q_np, xgrid.shape) 63 | 64 | cs = ax.contour(xgrid, ygrid, Q_np, 20) 65 | self._line_objects += cs.collections 66 | self._line_objects += ax.clabel( 67 | cs, inline=1, fontsize=10, fmt='%.2f') 68 | 69 | def _plot_action_samples(self): 70 | for ax, obs in zip(self._ax_lst, self._obs_lst): 71 | observations = np.ones((self._n_samples, 1)) * obs[None, :] 72 | actions = self._policy.actions_np([observations]) 73 | 74 | x, y = actions[:, 0], actions[:, 1] 75 | self._line_objects += ax.plot(x, y, 'b*') 76 | -------------------------------------------------------------------------------- /softlearning/samplers/utils.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | 3 | import numpy as np 4 | 5 | from softlearning import replay_pools 6 | from . import ( 7 | dummy_sampler, 8 | extra_policy_info_sampler, 9 | remote_sampler, 10 | base_sampler, 11 | simple_sampler) 12 | 13 | 14 | def get_sampler_from_variant(variant, *args, **kwargs): 15 | SAMPLERS = { 16 | 'DummySampler': dummy_sampler.DummySampler, 17 | 'ExtraPolicyInfoSampler': ( 18 | extra_policy_info_sampler.ExtraPolicyInfoSampler), 19 | 'RemoteSampler': remote_sampler.RemoteSampler, 20 | 'Sampler': base_sampler.BaseSampler, 21 | 'SimpleSampler': simple_sampler.SimpleSampler, 22 | } 23 | 24 | sampler_params = variant['sampler_params'] 25 | sampler_type = sampler_params['type'] 26 | 27 | sampler_args = deepcopy(sampler_params.get('args', ())) 28 | sampler_kwargs = deepcopy(sampler_params.get('kwargs', {})) 29 | 30 | sampler = SAMPLERS[sampler_type]( 31 | *sampler_args, *args, **sampler_kwargs, **kwargs) 32 | 33 | return sampler 34 | 35 | 36 | def rollout(env, 37 | policy, 38 | path_length, 39 | callback=None, 40 | render_mode=None, 41 | break_on_terminal=True): 42 | observation_space = env.observation_space 43 | action_space = env.action_space 44 | 45 | pool = replay_pools.SimpleReplayPool( 46 | observation_space, action_space, max_size=path_length) 47 | sampler = simple_sampler.SimpleSampler( 48 | max_path_length=path_length, 49 | min_pool_size=None, 50 | batch_size=None) 51 | 52 | sampler.initialize(env, policy, pool) 53 | 54 | images = [] 55 | infos = [] 56 | 57 | t = 0 58 | for t in range(path_length): 59 | observation, reward, terminal, info = sampler.sample() 60 | infos.append(info) 61 | 62 | if callback is not None: 63 | callback(observation) 64 | 65 | if render_mode is not None: 66 | if render_mode == 'rgb_array': 67 | image = env.render(mode=render_mode) 68 | images.append(image) 69 | else: 70 | env.render() 71 | 72 | if terminal: 73 | policy.reset() 74 | if break_on_terminal: break 75 | 76 | assert pool._size == t + 1 77 | 78 | path = pool.batch_by_indices( 79 | np.arange(pool._size), 80 | observation_keys=getattr(env, 'observation_keys', None)) 81 | path['infos'] = infos 82 | 83 | if render_mode == 'rgb_array': 84 | path['images'] = np.stack(images, axis=0) 85 | 86 | return path 87 | 88 | 89 | def rollouts(n_paths, *args, **kwargs): 90 | paths = [rollout(*args, **kwargs) for i in range(n_paths)] 91 | return paths 92 | -------------------------------------------------------------------------------- /examples/development/simulate_policy.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from distutils.util import strtobool 3 | import json 4 | import os 5 | import pickle 6 | 7 | import tensorflow as tf 8 | 9 | from softlearning.environments.utils import get_environment_from_params 10 | from softlearning.policies.utils import get_policy_from_variant 11 | from softlearning.samplers import rollouts 12 | 13 | 14 | def parse_args(): 15 | parser = argparse.ArgumentParser() 16 | parser.add_argument('checkpoint_path', 17 | type=str, 18 | help='Path to the checkpoint.') 19 | parser.add_argument('--max-path-length', '-l', type=int, default=1000) 20 | parser.add_argument('--num-rollouts', '-n', type=int, default=10) 21 | parser.add_argument('--render-mode', '-r', 22 | type=str, 23 | default='human', 24 | choices=('human', 'rgb_array', None), 25 | help="Mode to render the rollouts in.") 26 | parser.add_argument('--deterministic', '-d', 27 | type=lambda x: bool(strtobool(x)), 28 | nargs='?', 29 | const=True, 30 | default=True, 31 | help="Evaluate policy deterministically.") 32 | 33 | args = parser.parse_args() 34 | 35 | return args 36 | 37 | 38 | def simulate_policy(args): 39 | session = tf.keras.backend.get_session() 40 | checkpoint_path = args.checkpoint_path.rstrip('/') 41 | experiment_path = os.path.dirname(checkpoint_path) 42 | 43 | variant_path = os.path.join(experiment_path, 'params.json') 44 | with open(variant_path, 'r') as f: 45 | variant = json.load(f) 46 | 47 | with session.as_default(): 48 | pickle_path = os.path.join(checkpoint_path, 'checkpoint.pkl') 49 | with open(pickle_path, 'rb') as f: 50 | picklable = pickle.load(f) 51 | 52 | environment_params = ( 53 | variant['environment_params']['evaluation'] 54 | if 'evaluation' in variant['environment_params'] 55 | else variant['environment_params']['training']) 56 | evaluation_environment = get_environment_from_params(environment_params) 57 | 58 | policy = ( 59 | get_policy_from_variant(variant, evaluation_environment, Qs=[None])) 60 | policy.set_weights(picklable['policy_weights']) 61 | 62 | with policy.set_deterministic(args.deterministic): 63 | paths = rollouts(args.num_rollouts, 64 | evaluation_environment, 65 | policy, 66 | path_length=args.max_path_length, 67 | render_mode=args.render_mode) 68 | 69 | if args.render_mode != 'human': 70 | from pprint import pprint; import pdb; pdb.set_trace() 71 | pass 72 | 73 | return paths 74 | 75 | 76 | if __name__ == '__main__': 77 | args = parse_args() 78 | simulate_policy(args) 79 | -------------------------------------------------------------------------------- /softlearning/misc/kernel.py: -------------------------------------------------------------------------------- 1 | from distutils.version import LooseVersion 2 | 3 | import numpy as np 4 | import tensorflow as tf 5 | 6 | 7 | def adaptive_isotropic_gaussian_kernel(xs, ys, h_min=1e-3): 8 | """Gaussian kernel with dynamic bandwidth. 9 | 10 | The bandwidth is adjusted dynamically to match median_distance / log(Kx). 11 | See [2] for more information. 12 | 13 | Args: 14 | xs(`tf.Tensor`): A tensor of shape (N x Kx x D) containing N sets of Kx 15 | particles of dimension D. This is the first kernel argument. 16 | ys(`tf.Tensor`): A tensor of shape (N x Ky x D) containing N sets of Kx 17 | particles of dimension D. This is the second kernel argument. 18 | h_min(`float`): Minimum bandwidth. 19 | 20 | Returns: 21 | `dict`: Returned dictionary has two fields: 22 | 'output': A `tf.Tensor` object of shape (N x Kx x Ky) representing 23 | the kernel matrix for inputs `xs` and `ys`. 24 | 'gradient': A 'tf.Tensor` object of shape (N x Kx x Ky x D) 25 | representing the gradient of the kernel with respect to `xs`. 26 | 27 | Reference: 28 | [2] Qiang Liu,Dilin Wang, "Stein Variational Gradient Descent: A General 29 | Purpose Bayesian Inference Algorithm," Neural Information Processing 30 | Systems (NIPS), 2016. 31 | """ 32 | Kx, D = xs.get_shape().as_list()[-2:] 33 | Ky, D2 = ys.get_shape().as_list()[-2:] 34 | assert D == D2 35 | 36 | leading_shape = tf.shape(xs)[:-2] 37 | 38 | # Compute the pairwise distances of left and right particles. 39 | diff = tf.expand_dims(xs, -2) - tf.expand_dims(ys, -3) 40 | # ... x Kx x Ky x D 41 | 42 | if LooseVersion(tf.__version__) <= LooseVersion('1.5.0'): 43 | dist_sq = tf.reduce_sum(diff**2, axis=-1, keep_dims=False) 44 | else: 45 | dist_sq = tf.reduce_sum(diff**2, axis=-1, keepdims=False) 46 | # ... x Kx x Ky 47 | 48 | # Get median. 49 | input_shape = tf.concat((leading_shape, [Kx * Ky]), axis=0) 50 | values, _ = tf.nn.top_k( 51 | input=tf.reshape(dist_sq, input_shape), 52 | k=(Kx * Ky // 2 + 1), # This is exactly true only if Kx*Ky is odd. 53 | sorted=True) # ... x floor(Ks*Kd/2) 54 | 55 | medians_sq = values[..., -1] # ... (shape) (last element is the median) 56 | 57 | h = medians_sq / np.log(Kx) # ... (shape) 58 | h = tf.maximum(h, h_min) 59 | h = tf.stop_gradient(h) # Just in case. 60 | h_expanded_twice = tf.expand_dims(tf.expand_dims(h, -1), -1) 61 | # ... x 1 x 1 62 | 63 | kappa = tf.exp(-dist_sq / h_expanded_twice) # ... x Kx x Ky 64 | 65 | # Construct the gradient 66 | h_expanded_thrice = tf.expand_dims(h_expanded_twice, -1) 67 | # ... x 1 x 1 x 1 68 | kappa_expanded = tf.expand_dims(kappa, -1) # ... x Kx x Ky x 1 69 | 70 | kappa_grad = -2 * diff / h_expanded_thrice * kappa_expanded 71 | # ... x Kx x Ky x D 72 | 73 | return {"output": kappa, "gradient": kappa_grad} 74 | -------------------------------------------------------------------------------- /mbpo/models/utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | from __future__ import print_function 3 | from __future__ import absolute_import 4 | 5 | import tensorflow as tf 6 | import numpy as np 7 | 8 | def get_required_argument(dotmap, key, message, default=None): 9 | val = dotmap.get(key, default) 10 | if val is default: 11 | raise ValueError(message) 12 | return val 13 | 14 | class TensorStandardScaler: 15 | """Helper class for automatically normalizing inputs into the network. 16 | """ 17 | def __init__(self, x_dim): 18 | """Initializes a scaler. 19 | 20 | Arguments: 21 | x_dim (int): The dimensionality of the inputs into the scaler. 22 | 23 | Returns: None. 24 | """ 25 | self.fitted = False 26 | with tf.variable_scope("Scaler"): 27 | self.mu = tf.get_variable( 28 | name="scaler_mu", shape=[1, x_dim], initializer=tf.constant_initializer(0.0), 29 | trainable=False 30 | ) 31 | self.sigma = tf.get_variable( 32 | name="scaler_std", shape=[1, x_dim], initializer=tf.constant_initializer(1.0), 33 | trainable=False 34 | ) 35 | 36 | self.cached_mu, self.cached_sigma = np.zeros([0, x_dim]), np.ones([1, x_dim]) 37 | 38 | def fit(self, data): 39 | """Runs two ops, one for assigning the mean of the data to the internal mean, and 40 | another for assigning the standard deviation of the data to the internal standard deviation. 41 | This function must be called within a 'with .as_default()' block. 42 | 43 | Arguments: 44 | data (np.ndarray): A numpy array containing the input 45 | 46 | Returns: None. 47 | """ 48 | mu = np.mean(data, axis=0, keepdims=True) 49 | sigma = np.std(data, axis=0, keepdims=True) 50 | sigma[sigma < 1e-12] = 1.0 51 | 52 | self.mu.load(mu) 53 | self.sigma.load(sigma) 54 | self.fitted = True 55 | self.cache() 56 | 57 | def transform(self, data): 58 | """Transforms the input matrix data using the parameters of this scaler. 59 | 60 | Arguments: 61 | data (np.array): A numpy array containing the points to be transformed. 62 | 63 | Returns: (np.array) The transformed dataset. 64 | """ 65 | return (data - self.mu) / self.sigma 66 | 67 | def inverse_transform(self, data): 68 | """Undoes the transformation performed by this scaler. 69 | 70 | Arguments: 71 | data (np.array): A numpy array containing the points to be transformed. 72 | 73 | Returns: (np.array) The transformed dataset. 74 | """ 75 | return self.sigma * data + self.mu 76 | 77 | def get_vars(self): 78 | """Returns a list of variables managed by this object. 79 | 80 | Returns: (list) The list of variables. 81 | """ 82 | return [self.mu, self.sigma] 83 | 84 | def cache(self): 85 | """Caches current values of this scaler. 86 | 87 | Returns: None. 88 | """ 89 | self.cached_mu = self.mu.eval() 90 | self.cached_sigma = self.sigma.eval() 91 | 92 | def load_cache(self): 93 | """Loads values from the cache 94 | 95 | Returns: None. 96 | """ 97 | self.mu.load(self.cached_mu) 98 | self.sigma.load(self.cached_sigma) 99 | 100 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Model-Based Policy Optimization 2 | 3 | Code to reproduce the experiments in [When to Trust Your Model: Model-Based Policy Optimization](https://arxiv.org/abs/1906.08253). 4 | 5 |

6 | 7 | 8 |

9 | 10 | ## Installation 11 | 1. Install [MuJoCo 1.50](https://www.roboti.us/index.html) at `~/.mujoco/mjpro150` and copy your license key to `~/.mujoco/mjkey.txt` 12 | 2. Clone `mbpo` 13 | ``` 14 | git clone --recursive https://github.com/jannerm/mbpo.git 15 | ``` 16 | 3. Create a conda environment and install mbpo 17 | ``` 18 | cd mbpo 19 | conda env create -f environment/gpu-env.yml 20 | conda activate mbpo 21 | pip install -e viskit 22 | pip install -e . 23 | ``` 24 | 25 | ## Usage 26 | Configuration files can be found in [`examples/config/`](examples/config). 27 | 28 | ``` 29 | mbpo run_local examples.development --config=examples.config.halfcheetah.0 --gpus=1 --trial-gpus=1 30 | ``` 31 | 32 | Currently only running locally is supported. 33 | 34 | #### New environments 35 | To run on a different environment, you can modify the provided [template](examples/config/custom/0.py). You will also need to provide the termination function for the environment in [`mbpo/static`](mbpo/static). If you name the file the lowercase version of the environment name, it will be found automatically. See [`hopper.py`](mbpo/static/hopper.py) for an example. 36 | 37 | #### Logging 38 | 39 | This codebase contains [viskit](https://github.com/vitchyr/viskit) as a submodule. You can view saved runs with: 40 | ``` 41 | viskit ~/ray_mbpo --port 6008 42 | ``` 43 | assuming you used the default [`log_dir`](examples/config/halfcheetah/0.py#L7). 44 | 45 | #### Hyperparameters 46 | 47 | The rollout length schedule is defined by a length-4 list in a [config file](examples/config/halfcheetah/0.py#L31). The format is `[start_epoch, end_epoch, start_length, end_length]`, so the following: 48 | ``` 49 | 'rollout_schedule': [20, 100, 1, 5] 50 | ``` 51 | corresponds to a model rollout length linearly increasing from 1 to 5 over epochs 20 to 100. 52 | 53 | If you want to speed up training in terms of wall clock time (but possibly make the runs less sample-efficient), you can set a timeout for model training ([`max_model_t`](examples/config/halfcheetah/0.py#L30), in seconds) or train the model less frequently (every [`model_train_freq`](examples/config/halfcheetah/0.py#L22) steps). 54 | 55 | ## Comparing to MBPO 56 | If you would like to compare to MBPO but do not have the resources to re-run all experiments, the learning curves found in Figure 2 of the paper (plus on the Humanoid environment) are available in this [shared folder](https://drive.google.com/drive/folders/1matvC7hPi5al9-5S2uL4GuXfT5rzO9qU?usp=sharing). See `plot.py` for an example of how to read the pickle files with the results. 57 | 58 | ## Reference 59 | 60 | ``` 61 | @inproceedings{janner2019mbpo, 62 | author = {Michael Janner and Justin Fu and Marvin Zhang and Sergey Levine}, 63 | title = {When to Trust Your Model: Model-Based Policy Optimization}, 64 | booktitle = {Advances in Neural Information Processing Systems}, 65 | year = {2019} 66 | } 67 | ``` 68 | 69 | ## Acknowledgments 70 | The underlying soft actor-critic implementation in MBPO comes from [Tuomas Haarnoja](https://scholar.google.com/citations?user=VT7peyEAAAAJ&hl=en) and [Kristian Hartikainen's](https://hartikainen.github.io/) [softlearning](https://github.com/rail-berkeley/softlearning) codebase. The modeling code is a slightly modified version of [Kurtland Chua's](https://kchua.github.io/) [PETS](https://github.com/kchua/handful-of-trials) implementation. 71 | 72 | 73 | -------------------------------------------------------------------------------- /softlearning/policies/base_policy.py: -------------------------------------------------------------------------------- 1 | from contextlib import contextmanager 2 | from collections import OrderedDict 3 | 4 | import numpy as np 5 | from serializable import Serializable 6 | 7 | 8 | class BasePolicy(Serializable): 9 | def __init__(self): 10 | self._deterministic = False 11 | 12 | def reset(self): 13 | """Reset and clean the policy.""" 14 | raise NotImplementedError 15 | 16 | def actions(self, conditions): 17 | """Compute (symbolic) actions given conditions (observations)""" 18 | raise NotImplementedError 19 | 20 | def log_pis(self, conditions, actions): 21 | """Compute (symbolic) log probs for given observations and actions.""" 22 | raise NotImplementedError 23 | 24 | def actions_np(self, conditions): 25 | """Compute (numeric) actions given conditions (observations)""" 26 | raise NotImplementedError 27 | 28 | def log_pis_np(self, conditions, actions): 29 | """Compute (numeric) log probs for given observations and actions.""" 30 | raise NotImplementedError 31 | 32 | @contextmanager 33 | def set_deterministic(self, deterministic=True): 34 | """Context manager for changing the determinism of the policy. 35 | Args: 36 | set_deterministic (`bool`): Value to set the self._is_deterministic 37 | to during the context. The value will be reset back to the 38 | previous value when the context exits. 39 | """ 40 | was_deterministic = self._deterministic 41 | self._deterministic = deterministic 42 | yield 43 | self._deterministic = was_deterministic 44 | 45 | def get_diagnostics(self, conditions): 46 | """Return diagnostic information of the policy. 47 | 48 | Arguments: 49 | conditions: Observations to run the diagnostics for. 50 | Returns: 51 | diagnostics: OrderedDict of diagnostic information. 52 | """ 53 | diagnostics = OrderedDict({}) 54 | return diagnostics 55 | 56 | def __getstate__(self): 57 | state = Serializable.__getstate__(self) 58 | state['pickled_weights'] = self.get_weights() 59 | 60 | return state 61 | 62 | def __setstate__(self, state): 63 | Serializable.__setstate__(self, state) 64 | self.set_weights(state['pickled_weights']) 65 | 66 | 67 | class LatentSpacePolicy(BasePolicy): 68 | def __init__(self, *args, smoothing_coefficient=None, **kwargs): 69 | super(LatentSpacePolicy, self).__init__(*args, **kwargs) 70 | 71 | assert smoothing_coefficient is None or 0 <= smoothing_coefficient <= 1 72 | self._smoothing_alpha = smoothing_coefficient or 0 73 | self._smoothing_beta = ( 74 | np.sqrt(1.0 - np.power(self._smoothing_alpha, 2.0)) 75 | / (1.0 - self._smoothing_alpha)) 76 | self._reset_smoothing_x() 77 | self._smooth_latents = False 78 | 79 | def _reset_smoothing_x(self): 80 | self._smoothing_x = np.zeros((1, *self._output_shape)) 81 | 82 | def actions_np(self, conditions): 83 | if self._deterministic: 84 | return self.deterministic_actions_model.predict(conditions) 85 | elif self._smoothing_alpha == 0: 86 | return self.actions_model.predict(conditions) 87 | else: 88 | alpha, beta = self._smoothing_alpha, self._smoothing_beta 89 | raw_latents = self.latents_model.predict(conditions) 90 | self._smoothing_x = ( 91 | alpha * self._smoothing_x + (1.0 - alpha) * raw_latents) 92 | latents = beta * self._smoothing_x 93 | 94 | return self.actions_model_for_fixed_latents.predict( 95 | [*conditions, latents]) 96 | 97 | def reset(self): 98 | self._reset_smoothing_x() 99 | -------------------------------------------------------------------------------- /softlearning/samplers/simple_sampler.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | 3 | import numpy as np 4 | 5 | from .base_sampler import BaseSampler 6 | 7 | 8 | class SimpleSampler(BaseSampler): 9 | def __init__(self, **kwargs): 10 | super(SimpleSampler, self).__init__(**kwargs) 11 | 12 | self._path_length = 0 13 | self._path_return = 0 14 | self._current_path = defaultdict(list) 15 | self._last_path_return = 0 16 | self._max_path_return = -np.inf 17 | self._n_episodes = 0 18 | self._current_observation = None 19 | self._total_samples = 0 20 | 21 | def _process_observations(self, 22 | observation, 23 | action, 24 | reward, 25 | terminal, 26 | next_observation, 27 | info): 28 | processed_observation = { 29 | 'observations': observation, 30 | 'actions': action, 31 | 'rewards': [reward], 32 | 'terminals': [terminal], 33 | 'next_observations': next_observation, 34 | 'infos': info, 35 | } 36 | 37 | return processed_observation 38 | 39 | def sample(self): 40 | if self._current_observation is None: 41 | self._current_observation = self.env.reset() 42 | 43 | action = self.policy.actions_np([ 44 | self.env.convert_to_active_observation( 45 | self._current_observation)[None] 46 | ])[0] 47 | 48 | next_observation, reward, terminal, info = self.env.step(action) 49 | self._path_length += 1 50 | self._path_return += reward 51 | self._total_samples += 1 52 | 53 | processed_sample = self._process_observations( 54 | observation=self._current_observation, 55 | action=action, 56 | reward=reward, 57 | terminal=terminal, 58 | next_observation=next_observation, 59 | info=info, 60 | ) 61 | 62 | for key, value in processed_sample.items(): 63 | self._current_path[key].append(value) 64 | 65 | if terminal or self._path_length >= self._max_path_length: 66 | last_path = { 67 | field_name: np.array(values) 68 | for field_name, values in self._current_path.items() 69 | } 70 | self.pool.add_path(last_path) 71 | self._last_n_paths.appendleft(last_path) 72 | 73 | self._max_path_return = max(self._max_path_return, 74 | self._path_return) 75 | self._last_path_return = self._path_return 76 | 77 | self.policy.reset() 78 | self._current_observation = None 79 | self._path_length = 0 80 | self._path_return = 0 81 | self._current_path = defaultdict(list) 82 | 83 | self._n_episodes += 1 84 | else: 85 | self._current_observation = next_observation 86 | 87 | return next_observation, reward, terminal, info 88 | 89 | def random_batch(self, batch_size=None, **kwargs): 90 | batch_size = batch_size or self._batch_size 91 | observation_keys = getattr(self.env, 'observation_keys', None) 92 | 93 | return self.pool.random_batch( 94 | batch_size, observation_keys=observation_keys, **kwargs) 95 | 96 | def get_diagnostics(self): 97 | diagnostics = super(SimpleSampler, self).get_diagnostics() 98 | diagnostics.update({ 99 | 'max-path-return': self._max_path_return, 100 | 'last-path-return': self._last_path_return, 101 | 'episodes': self._n_episodes, 102 | 'total-samples': self._total_samples, 103 | }) 104 | 105 | return diagnostics 106 | -------------------------------------------------------------------------------- /softlearning/samplers/explore_sampler.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | 3 | import numpy as np 4 | 5 | from .base_sampler import BaseSampler 6 | 7 | 8 | class ExploreSampler(BaseSampler): 9 | def __init__(self, **kwargs): 10 | super(ExploreSampler, self).__init__(**kwargs) 11 | 12 | self._path_length = 0 13 | self._path_return = 0 14 | self._current_path = defaultdict(list) 15 | self._last_path_return = 0 16 | self._max_path_return = -np.inf 17 | self._n_episodes = 0 18 | self._current_observation = None 19 | self._total_samples = 0 20 | 21 | def _process_observations(self, 22 | observation, 23 | action, 24 | reward, 25 | terminal, 26 | next_observation, 27 | info): 28 | processed_observation = { 29 | 'observations': observation, 30 | 'actions': action, 31 | 'rewards': [reward], 32 | 'terminals': [terminal], 33 | 'next_observations': next_observation, 34 | 'infos': info, 35 | } 36 | 37 | return processed_observation 38 | 39 | def sample(self): 40 | if self._current_observation is None: 41 | self._current_observation = self.env.reset() 42 | self._s0 = self.env.unwrapped.state_vector() 43 | 44 | action = self.policy.actions_np([ 45 | self.env.convert_to_active_observation( 46 | self._current_observation)[None] 47 | ])[0] 48 | 49 | next_observation, reward, terminal, info = self.env.step(action) 50 | self._path_length += 1 51 | self._path_return += reward 52 | self._total_samples += 1 53 | 54 | processed_sample = self._process_observations( 55 | observation=self._current_observation, 56 | action=action, 57 | reward=reward, 58 | terminal=terminal, 59 | next_observation=next_observation, 60 | info=info, 61 | ) 62 | 63 | for key, value in processed_sample.items(): 64 | self._current_path[key].append(value) 65 | 66 | if terminal or self._path_length >= self._max_path_length: 67 | last_path = { 68 | field_name: np.array(values) 69 | for field_name, values in self._current_path.items() 70 | } 71 | self.pool.add_path(last_path) 72 | self._last_n_paths.appendleft(last_path) 73 | 74 | self._max_path_return = max(self._max_path_return, 75 | self._path_return) 76 | self._last_path_return = self._path_return 77 | 78 | self.policy.reset() 79 | self._current_observation = None 80 | self._path_length = 0 81 | self._path_return = 0 82 | self._current_path = defaultdict(list) 83 | 84 | self._n_episodes += 1 85 | else: 86 | self._current_observation = next_observation 87 | 88 | return next_observation, reward, terminal, info 89 | 90 | def random_batch(self, batch_size=None, **kwargs): 91 | batch_size = batch_size or self._batch_size 92 | observation_keys = getattr(self.env, 'observation_keys', None) 93 | 94 | return self.pool.random_batch( 95 | batch_size, observation_keys=observation_keys, **kwargs) 96 | 97 | def get_diagnostics(self): 98 | diagnostics = super(SimpleSampler, self).get_diagnostics() 99 | diagnostics.update({ 100 | 'max-path-return': self._max_path_return, 101 | 'last-path-return': self._last_path_return, 102 | 'episodes': self._n_episodes, 103 | 'total-samples': self._total_samples, 104 | }) 105 | 106 | return diagnostics 107 | -------------------------------------------------------------------------------- /softlearning/samplers/remote_sampler.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | from collections import OrderedDict 3 | 4 | import ray 5 | import tensorflow as tf 6 | import numpy as np 7 | 8 | 9 | from .base_sampler import BaseSampler 10 | from .utils import rollout 11 | 12 | 13 | class RemoteSampler(BaseSampler): 14 | def __init__(self, **kwargs): 15 | super(RemoteSampler, self).__init__(**kwargs) 16 | 17 | self._remote_environment = None 18 | self._remote_path = None 19 | self._n_episodes = 0 20 | self._total_samples = 0 21 | self._last_path_return = 0 22 | self._max_path_return = -np.inf 23 | 24 | def _create_remote_environment(self, env, policy): 25 | env_pkl = pickle.dumps(env) 26 | policy_pkl = pickle.dumps(policy) 27 | 28 | if not ray.is_initialized(): 29 | ray.init() 30 | 31 | self._remote_environment = _RemoteEnv.remote(env_pkl, policy_pkl) 32 | 33 | # Block until the env and policy is ready 34 | initialized = ray.get(self._remote_environment.initialized.remote()) 35 | assert initialized, initialized 36 | 37 | def initialize(self, env, policy, pool): 38 | super(RemoteSampler, self).initialize(env, policy, pool) 39 | self._create_remote_environment(env, policy) 40 | 41 | def wait_for_path(self, timeout=1): 42 | if self._remote_path is None: 43 | return [True] 44 | 45 | path_ready, _ = ray.wait([self._remote_path], timeout=timeout) 46 | return path_ready 47 | 48 | def sample(self, timeout=0): 49 | if self._remote_path is None: 50 | policy_params = self.policy.get_weights() 51 | self._remote_path = self._remote_environment.rollout.remote( 52 | policy_params, self._max_path_length) 53 | 54 | path_ready = self.wait_for_path(timeout=timeout) 55 | 56 | if len(path_ready) or not self.batch_ready(): 57 | path = ray.get(self._remote_path) 58 | self._last_n_paths.appendleft(path) 59 | 60 | self.pool.add_path(path) 61 | 62 | self._remote_path = None 63 | self._total_samples += len(path['observations']) 64 | self._last_path_return = np.sum(path['rewards']) 65 | self._max_path_return = max(self._max_path_return, 66 | self._last_path_return) 67 | self._n_episodes += 1 68 | 69 | def get_diagnostics(self): 70 | diagnostics = OrderedDict({ 71 | 'max-path-return': self._max_path_return, 72 | 'last-path-return': self._last_path_return, 73 | 'pool-size': self.pool.size, 74 | 'episodes': self._n_episodes, 75 | 'total-samples': self._total_samples, 76 | }) 77 | 78 | return diagnostics 79 | 80 | def __getstate__(self): 81 | super_state = super(RemoteSampler, self).__getstate__() 82 | state = { 83 | key: value for key, value in super_state.items() 84 | if key not in ('_remote_environment', '_remote_path') 85 | } 86 | 87 | return state 88 | 89 | def __setstate__(self, state): 90 | super(RemoteSampler, self).__setstate__(state) 91 | self._create_remote_environment(self.env, self.policy) 92 | self._remote_path = None 93 | 94 | 95 | @ray.remote 96 | class _RemoteEnv(object): 97 | def __init__(self, env_pkl, policy_pkl): 98 | self._session = tf.keras.backend.get_session() 99 | self._session.run(tf.global_variables_initializer()) 100 | 101 | self._env = pickle.loads(env_pkl) 102 | self._policy = pickle.loads(policy_pkl) 103 | 104 | if hasattr(self._env, 'initialize'): 105 | self._env.initialize() 106 | 107 | self._initialized = True 108 | 109 | def initialized(self): 110 | return self._initialized 111 | 112 | def rollout(self, policy_weights, path_length): 113 | self._policy.set_weights(policy_weights) 114 | path = rollout(self._env, self._policy, path_length) 115 | 116 | return path 117 | -------------------------------------------------------------------------------- /mbpo/models/fake_env.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | import pdb 4 | 5 | class FakeEnv: 6 | 7 | def __init__(self, model, config): 8 | self.model = model 9 | self.config = config 10 | 11 | ''' 12 | x : [ batch_size, obs_dim + 1 ] 13 | means : [ num_models, batch_size, obs_dim + 1 ] 14 | vars : [ num_models, batch_size, obs_dim + 1 ] 15 | ''' 16 | def _get_logprob(self, x, means, variances): 17 | 18 | k = x.shape[-1] 19 | 20 | ## [ num_networks, batch_size ] 21 | log_prob = -1/2 * (k * np.log(2*np.pi) + np.log(variances).sum(-1) + (np.power(x-means, 2)/variances).sum(-1)) 22 | 23 | ## [ batch_size ] 24 | prob = np.exp(log_prob).sum(0) 25 | 26 | ## [ batch_size ] 27 | log_prob = np.log(prob) 28 | 29 | stds = np.std(means,0).mean(-1) 30 | 31 | return log_prob, stds 32 | 33 | def step(self, obs, act, deterministic=False): 34 | assert len(obs.shape) == len(act.shape) 35 | if len(obs.shape) == 1: 36 | obs = obs[None] 37 | act = act[None] 38 | return_single = True 39 | else: 40 | return_single = False 41 | 42 | inputs = np.concatenate((obs, act), axis=-1) 43 | ensemble_model_means, ensemble_model_vars = self.model.predict(inputs, factored=True) 44 | ensemble_model_means[:,:,1:] += obs 45 | ensemble_model_stds = np.sqrt(ensemble_model_vars) 46 | 47 | if deterministic: 48 | ensemble_samples = ensemble_model_means 49 | else: 50 | ensemble_samples = ensemble_model_means + np.random.normal(size=ensemble_model_means.shape) * ensemble_model_stds 51 | 52 | #### choose one model from ensemble 53 | num_models, batch_size, _ = ensemble_model_means.shape 54 | model_inds = self.model.random_inds(batch_size) 55 | batch_inds = np.arange(0, batch_size) 56 | samples = ensemble_samples[model_inds, batch_inds] 57 | model_means = ensemble_model_means[model_inds, batch_inds] 58 | model_stds = ensemble_model_stds[model_inds, batch_inds] 59 | #### 60 | 61 | log_prob, dev = self._get_logprob(samples, ensemble_model_means, ensemble_model_vars) 62 | 63 | rewards, next_obs = samples[:,:1], samples[:,1:] 64 | terminals = self.config.termination_fn(obs, act, next_obs) 65 | 66 | batch_size = model_means.shape[0] 67 | return_means = np.concatenate((model_means[:,:1], terminals, model_means[:,1:]), axis=-1) 68 | return_stds = np.concatenate((model_stds[:,:1], np.zeros((batch_size,1)), model_stds[:,1:]), axis=-1) 69 | 70 | if return_single: 71 | next_obs = next_obs[0] 72 | return_means = return_means[0] 73 | return_stds = return_stds[0] 74 | rewards = rewards[0] 75 | terminals = terminals[0] 76 | 77 | info = {'mean': return_means, 'std': return_stds, 'log_prob': log_prob, 'dev': dev} 78 | return next_obs, rewards, terminals, info 79 | 80 | ## for debugging computation graph 81 | def step_ph(self, obs_ph, act_ph, deterministic=False): 82 | assert len(obs_ph.shape) == len(act_ph.shape) 83 | 84 | inputs = tf.concat([obs_ph, act_ph], axis=1) 85 | # inputs = np.concatenate((obs, act), axis=-1) 86 | ensemble_model_means, ensemble_model_vars = self.model.create_prediction_tensors(inputs, factored=True) 87 | # ensemble_model_means, ensemble_model_vars = self.model.predict(inputs, factored=True) 88 | ensemble_model_means = tf.concat([ensemble_model_means[:,:,0:1], ensemble_model_means[:,:,1:] + obs_ph[None]], axis=-1) 89 | # ensemble_model_means[:,:,1:] += obs_ph 90 | ensemble_model_stds = tf.sqrt(ensemble_model_vars) 91 | # ensemble_model_stds = np.sqrt(ensemble_model_vars) 92 | 93 | if deterministic: 94 | ensemble_samples = ensemble_model_means 95 | else: 96 | # ensemble_samples = ensemble_model_means + np.random.normal(size=ensemble_model_means.shape) * ensemble_model_stds 97 | ensemble_samples = ensemble_model_means + tf.random.normal(tf.shape(ensemble_model_means)) * ensemble_model_stds 98 | 99 | samples = ensemble_samples[0] 100 | 101 | rewards, next_obs = samples[:,:1], samples[:,1:] 102 | terminals = self.config.termination_ph_fn(obs_ph, act_ph, next_obs) 103 | info = {} 104 | 105 | return next_obs, rewards, terminals, info 106 | 107 | def close(self): 108 | pass 109 | 110 | 111 | 112 | -------------------------------------------------------------------------------- /softlearning/misc/utils.py: -------------------------------------------------------------------------------- 1 | import collections 2 | import datetime 3 | import os 4 | import random 5 | 6 | import tensorflow as tf 7 | import numpy as np 8 | 9 | 10 | PROJECT_PATH = os.path.dirname( 11 | os.path.realpath(os.path.join(__file__, '..', '..'))) 12 | 13 | 14 | DEFAULT_SNAPSHOT_MODE = 'none' 15 | DEFAULT_SNAPSHOT_GAP = 1000 16 | 17 | 18 | def initialize_tf_variables(session, only_uninitialized=True): 19 | variables = tf.global_variables() + tf.local_variables() 20 | 21 | def is_initialized(variable): 22 | try: 23 | session.run(variable) 24 | return True 25 | except tf.errors.FailedPreconditionError: 26 | return False 27 | 28 | return False 29 | 30 | if only_uninitialized: 31 | variables = [ 32 | variable for variable in variables 33 | if not is_initialized(variable) 34 | ] 35 | 36 | session.run(tf.variables_initializer(variables)) 37 | 38 | 39 | def set_seed(seed): 40 | seed %= 4294967294 41 | random.seed(seed) 42 | np.random.seed(seed) 43 | tf.set_random_seed(seed) 44 | print("Using seed {}".format(seed)) 45 | 46 | 47 | def datetimestamp(divider='-', datetime_divider='T'): 48 | now = datetime.datetime.now() 49 | return now.strftime( 50 | '%Y{d}%m{d}%dT%H{d}%M{d}%S' 51 | ''.format(d=divider, dtd=datetime_divider)) 52 | 53 | 54 | def datestamp(divider='-'): 55 | return datetime.date.today().isoformat().replace('-', divider) 56 | 57 | 58 | def timestamp(divider='-'): 59 | now = datetime.datetime.now() 60 | time_now = datetime.datetime.time(now) 61 | return time_now.strftime( 62 | '%H{d}%M{d}%S'.format(d=divider)) 63 | 64 | 65 | def concat_obs_z(obs, z, num_skills): 66 | """Concatenates the observation to a one-hot encoding of Z.""" 67 | assert np.isscalar(z) 68 | z_one_hot = np.zeros(num_skills) 69 | z_one_hot[z] = 1 70 | return np.hstack([obs, z_one_hot]) 71 | 72 | 73 | def split_aug_obs(aug_obs, num_skills): 74 | """Splits an augmented observation into the observation and Z.""" 75 | (obs, z_one_hot) = (aug_obs[:-num_skills], aug_obs[-num_skills:]) 76 | z = np.where(z_one_hot == 1)[0][0] 77 | return (obs, z) 78 | 79 | 80 | def _make_dir(filename): 81 | folder = os.path.dirname(filename) 82 | if not os.path.exists(folder): 83 | os.makedirs(folder) 84 | 85 | 86 | def save_video(video_frames, filename): 87 | import cv2 88 | _make_dir(filename) 89 | 90 | video_frames = np.flip(video_frames, axis=-1) 91 | 92 | # Define the codec and create VideoWriter object 93 | fourcc = cv2.VideoWriter_fourcc(*'MJPG') 94 | fps = 30.0 95 | (height, width, _) = video_frames[0].shape 96 | writer = cv2.VideoWriter(filename, fourcc, fps, (width, height)) 97 | for video_frame in video_frames: 98 | writer.write(video_frame) 99 | writer.release() 100 | 101 | 102 | def deep_update(d, *us): 103 | d = d.copy() 104 | 105 | for u in us: 106 | u = u.copy() 107 | for k, v in u.items(): 108 | d[k] = ( 109 | deep_update(d.get(k, {}), v) 110 | if isinstance(v, collections.Mapping) 111 | else v) 112 | 113 | return d 114 | 115 | 116 | def get_git_rev(): 117 | try: 118 | import git 119 | except ImportError: 120 | print( 121 | "Warning: gitpython not installed." 122 | " Unable to log git rev." 123 | " Run `pip install gitpython` if you want git revs to be logged.") 124 | return None 125 | 126 | try: 127 | repo = git.Repo(os.getcwd()) 128 | git_rev = repo.active_branch.commit.name_rev 129 | except TypeError: 130 | git_rev = repo.head.object.name_rev 131 | 132 | return git_rev 133 | 134 | 135 | def flatten(unflattened, parent_key='', separator='.'): 136 | items = [] 137 | for k, v in unflattened.items(): 138 | if separator in k: 139 | raise ValueError( 140 | "Found separator ({}) from key ({})".format(separator, k)) 141 | new_key = parent_key + separator + k if parent_key else k 142 | if isinstance(v, collections.MutableMapping) and v: 143 | items.extend(flatten(v, new_key, separator=separator).items()) 144 | else: 145 | items.append((new_key, v)) 146 | 147 | return dict(items) 148 | 149 | 150 | def unflatten(flattened, separator='.'): 151 | result = {} 152 | for key, value in flattened.items(): 153 | parts = key.split(separator) 154 | d = result 155 | for part in parts[:-1]: 156 | if part not in d: 157 | d[part] = {} 158 | d = d[part] 159 | d[parts[-1]] = value 160 | 161 | return result 162 | -------------------------------------------------------------------------------- /mbpo/utils/visualization.py: -------------------------------------------------------------------------------- 1 | import io 2 | import math 3 | import numpy as np 4 | import cv2 5 | import matplotlib 6 | matplotlib.use('Agg') 7 | import matplotlib.pyplot as plt 8 | import pdb 9 | 10 | 11 | def plot_trajectories(writer, label, epoch, env_traj, model_traj, means, stds): 12 | state_dim = env_traj[0].size 13 | model_states = [[obs[s] for obs in model_traj] for s in range(state_dim)] 14 | env_states = [[obs[s] for obs in env_traj ] for s in range(state_dim)] 15 | 16 | means = [np.array([mean[s] for mean in means]) for s in range(state_dim)] 17 | stds = [np.array([std[s] for std in stds]) for s in range(state_dim)] 18 | 19 | cols = 1 20 | rows = math.ceil(state_dim / cols) 21 | 22 | plt.clf() 23 | fig, axes = plt.subplots(rows, cols, figsize = (9*cols, 3*rows)) 24 | axes = axes.ravel() 25 | 26 | for i in range(state_dim): 27 | ax = axes[i] 28 | X = range(len(model_states[i])) 29 | 30 | ax.fill_between(X, means[i]+stds[i], means[i]-stds[i], color='r', alpha=0.5) 31 | ax.plot(env_states[i], color='k') 32 | ax.plot(model_states[i], color='b') 33 | ax.plot(means[i], color='r') 34 | 35 | if i == 0: 36 | ax.set_title('reward') 37 | elif i == 1: 38 | ax.set_title('terminal') 39 | else: 40 | ax.set_title('state dim {}'.format(i-2)) 41 | plt.tight_layout() 42 | 43 | buf = io.BytesIO() 44 | plt.savefig(buf, format='png', layout = 'tight') 45 | buf.seek(0) 46 | 47 | img = cv2.imdecode(np.fromstring(buf.getvalue(), dtype=np.uint8), -1) 48 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 49 | img = img.transpose(2,0,1) / 255. 50 | 51 | writer.add_image(label, img, epoch) 52 | 53 | plt.close() 54 | 55 | 56 | ''' 57 | writer video : [ batch x channels x timesteps x height x width ] 58 | ''' 59 | def record_trajectories(writer, label, epoch, env_images, model_images=None): 60 | traj_length = len(env_images) 61 | if model_images is not None: 62 | assert len(env_images) == len(model_images) 63 | images = [np.concatenate((env_img, model_img)) for (env_img, model_img) in zip(env_images, model_images)] 64 | else: 65 | images = env_images 66 | 67 | ## [ traj_length, 2 * H, W, C ] 68 | images = np.array(images) 69 | images = torch.Tensor(images) 70 | 71 | ## [ traj_length, C, 2 * H, W ] 72 | images = images.permute(0,3,1,2) 73 | ## [ B, traj_length, C, 2 * H, W ] 74 | images = images.unsqueeze(0) 75 | 76 | images = images / 255. 77 | images = images[:,:,0].unsqueeze(2) 78 | 79 | print('[ Visualization ] Saving to {}'.format(label)) 80 | fps = min(max(traj_length / 5, 2), 30) 81 | writer.add_video('video_' + label, images, epoch, fps = fps) 82 | 83 | 84 | def visualize_policy(real_env, fake_env, policy, writer, timestep, max_steps=100, focus=None, label='model_vis', img_dim=128): 85 | init_obs = real_env.reset() 86 | obs = init_obs.copy() 87 | 88 | observations_r = [obs] 89 | observations_f = [obs] 90 | rewards_r = [0] 91 | rewards_f = [0] 92 | terminals_r = [False] 93 | terminals_f = [False] 94 | means_f = [np.concatenate((np.zeros(2), obs))] 95 | stds_f = [np.concatenate((np.zeros(2), obs*0))] 96 | actions = [] 97 | 98 | i = 0 99 | term_r, term_f = False, False 100 | while not (term_r and term_f) and i <= max_steps: 101 | 102 | act = policy.actions_np(obs[None])[0] 103 | if not term_r: 104 | next_obs_r, rew_r, term_r, info_r = real_env.step(act) 105 | observations_r.append(next_obs_r) 106 | rewards_r.append(rew_r) 107 | terminals_r.append(term_r) 108 | 109 | if not term_f: 110 | next_obs_f, rew_f, term_f, info_f = fake_env.step(obs, act) 111 | observations_f.append(next_obs_f) 112 | rewards_f.append(rew_f) 113 | terminals_f.append(term_f) 114 | means_f.append(info_f['mean']) 115 | stds_f.append(info_f['std']) 116 | 117 | actions.append(act) 118 | 119 | if not term_f: 120 | obs = next_obs_f 121 | else: 122 | obs = next_obs_r 123 | 124 | i += 1 125 | 126 | terminals_r = np.array([terminals_r]).astype(np.uint8).T 127 | terminals_f = np.array([terminals_f]).astype(np.uint8).T 128 | rewards_r = np.array([rewards_r]).T 129 | rewards_f = np.array([rewards_f]).T 130 | 131 | rewards_observations_r = np.concatenate((rewards_r, terminals_r, np.array(observations_r)), -1) 132 | rewards_observations_f = np.concatenate((rewards_f, terminals_f, np.array(observations_f)), -1) 133 | plot_trajectories(writer, label, timestep, rewards_observations_r, rewards_observations_f, means_f, stds_f) 134 | record_trajectories(writer, label, epoch, images_r) 135 | 136 | -------------------------------------------------------------------------------- /mbpo/utils/logging.py: -------------------------------------------------------------------------------- 1 | import time 2 | import math 3 | import pdb 4 | 5 | class Progress: 6 | 7 | def __init__(self, total, name = 'Progress', ncol=3, max_length=20, indent=0, line_width=100, speed_update_freq=100): 8 | self.total = total 9 | self.name = name 10 | self.ncol = ncol 11 | self.max_length = max_length 12 | self.indent = indent 13 | self.line_width = line_width 14 | self._speed_update_freq = speed_update_freq 15 | 16 | self._step = 0 17 | self._prev_line = '\033[F' 18 | self._clear_line = ' ' * self.line_width 19 | 20 | self._pbar_size = self.ncol * self.max_length 21 | self._complete_pbar = '#' * self._pbar_size 22 | self._incomplete_pbar = ' ' * self._pbar_size 23 | 24 | self.lines = [''] 25 | self.fraction = '{} / {}'.format(0, self.total) 26 | 27 | self.resume() 28 | 29 | 30 | def update(self, n=1): 31 | self._step += n 32 | if self._step % self._speed_update_freq == 0: 33 | self._time0 = time.time() 34 | self._step0 = self._step 35 | 36 | def resume(self): 37 | self._skip_lines = 1 38 | print('\n', end='') 39 | self._time0 = time.time() 40 | self._step0 = self._step 41 | 42 | def pause(self): 43 | self._clear() 44 | self._skip_lines = 1 45 | 46 | def set_description(self, params=[]): 47 | 48 | ############ 49 | # Position # 50 | ############ 51 | self._clear() 52 | 53 | ########### 54 | # Percent # 55 | ########### 56 | percent, fraction = self._format_percent(self._step, self.total) 57 | self.fraction = fraction 58 | 59 | ######### 60 | # Speed # 61 | ######### 62 | speed = self._format_speed(self._step) 63 | 64 | ########## 65 | # Params # 66 | ########## 67 | num_params = len(params) 68 | nrow = math.ceil(num_params / self.ncol) 69 | params_split = self._chunk(params, self.ncol) 70 | params_string, lines = self._format(params_split) 71 | self.lines = lines 72 | 73 | 74 | description = '{} | {}{}'.format(percent, speed, params_string) 75 | print(description) 76 | self._skip_lines = nrow + 1 77 | 78 | def append_description(self, descr): 79 | self.lines.append(descr) 80 | 81 | def _clear(self): 82 | position = self._prev_line * self._skip_lines 83 | empty = '\n'.join([self._clear_line for _ in range(self._skip_lines)]) 84 | print(position, end='') 85 | print(empty) 86 | print(position, end='') 87 | 88 | def _format_percent(self, n, total): 89 | if total: 90 | percent = n / float(total) 91 | 92 | complete_entries = int(percent * self._pbar_size) 93 | incomplete_entries = self._pbar_size - complete_entries 94 | 95 | pbar = self._complete_pbar[:complete_entries] + self._incomplete_pbar[:incomplete_entries] 96 | fraction = '{} / {}'.format(n, total) 97 | string = '{} [{}] {:3d}%'.format(fraction, pbar, int(percent*100)) 98 | else: 99 | fraction = '{}'.format(n) 100 | string = '{} iterations'.format(n) 101 | return string, fraction 102 | 103 | def _format_speed(self, n): 104 | num_steps = n - self._step0 105 | t = time.time() - self._time0 106 | speed = num_steps / t 107 | string = '{:.1f} Hz'.format(speed) 108 | if num_steps > 0: 109 | self._speed = string 110 | return string 111 | 112 | def _chunk(self, l, n): 113 | return [l[i:i+n] for i in range(0, len(l), n)] 114 | 115 | def _format(self, chunks): 116 | lines = [self._format_chunk(chunk) for chunk in chunks] 117 | lines.insert(0,'') 118 | padding = '\n' + ' '*self.indent 119 | string = padding.join(lines) 120 | return string, lines 121 | 122 | def _format_chunk(self, chunk): 123 | line = ' | '.join([self._format_param(param) for param in chunk]) 124 | return line 125 | 126 | def _format_param(self, param): 127 | k, v = param 128 | return '{} : {}'.format(k, v)[:self.max_length] 129 | 130 | def stamp(self): 131 | if self.lines != ['']: 132 | params = ' | '.join(self.lines) 133 | string = '[ {} ] {}{} | {}'.format(self.name, self.fraction, params, self._speed) 134 | self._clear() 135 | print(string, end='\n') 136 | self._skip_lines = 1 137 | else: 138 | self._clear() 139 | self._skip_lines = 0 140 | 141 | def close(self): 142 | self.pause() 143 | 144 | class Silent: 145 | 146 | def __init__(self, *args, **kwargs): 147 | pass 148 | 149 | def __getattr__(self, attr): 150 | return lambda *args: None 151 | 152 | 153 | if __name__ == '__main__': 154 | silent = Silent() 155 | silent.update() 156 | silent.stamp() 157 | 158 | num_steps = 1000 159 | progress = Progress(num_steps) 160 | for i in range(num_steps): 161 | progress.update() 162 | params = [ 163 | ['A', '{:06d}'.format(i)], 164 | ['B', '{:06d}'.format(i)], 165 | ['C', '{:06d}'.format(i)], 166 | ['D', '{:06d}'.format(i)], 167 | ['E', '{:06d}'.format(i)], 168 | ['F', '{:06d}'.format(i)], 169 | ['G', '{:06d}'.format(i)], 170 | ['H', '{:06d}'.format(i)], 171 | ] 172 | progress.set_description(params) 173 | time.sleep(0.01) 174 | progress.close() 175 | -------------------------------------------------------------------------------- /softlearning/environments/gym/__init__.py: -------------------------------------------------------------------------------- 1 | """Custom Gym environments. 2 | 3 | Every class inside this module should extend a gym.Env class. The file 4 | structure should be similar to gym.envs file structure, e.g. if you're 5 | implementing a mujoco env, you would implement it under gym.mujoco submodule. 6 | """ 7 | 8 | import gym 9 | 10 | from mbpo.env import register_mbpo_environments 11 | 12 | CUSTOM_GYM_ENVIRONMENTS_PATH = __package__ 13 | MUJOCO_ENVIRONMENTS_PATH = f'{CUSTOM_GYM_ENVIRONMENTS_PATH}.mujoco' 14 | 15 | MUJOCO_ENVIRONMENT_SPECS = ( 16 | { 17 | 'id': 'Swimmer-Parameterizable-v3', 18 | 'entry_point': (f'gym.envs.mujoco.swimmer_v3:SwimmerEnv'), 19 | }, 20 | { 21 | 'id': 'Hopper-Parameterizable-v3', 22 | 'entry_point': (f'gym.envs.mujoco.hopper_v3:HopperEnv'), 23 | }, 24 | { 25 | 'id': 'Walker2d-Parameterizable-v3', 26 | 'entry_point': (f'gym.envs.mujoco.walker2d_v3:Walker2dEnv'), 27 | }, 28 | { 29 | 'id': 'HalfCheetah-Parameterizable-v3', 30 | 'entry_point': (f'gym.envs.mujoco.half_cheetah_v3:HalfCheetahEnv'), 31 | }, 32 | { 33 | 'id': 'Ant-Parameterizable-v3', 34 | 'entry_point': (f'gym.envs.mujoco.ant_v3:AntEnv'), 35 | }, 36 | { 37 | 'id': 'Humanoid-Parameterizable-v3', 38 | 'entry_point': (f'gym.envs.mujoco.humanoid_v3:HumanoidEnv'), 39 | }, 40 | { 41 | 'id': 'Pusher2d-Default-v0', 42 | 'entry_point': (f'{MUJOCO_ENVIRONMENTS_PATH}' 43 | '.pusher_2d:Pusher2dEnv'), 44 | }, 45 | { 46 | 'id': 'Pusher2d-DefaultReach-v0', 47 | 'entry_point': (f'{MUJOCO_ENVIRONMENTS_PATH}' 48 | '.pusher_2d:ForkReacherEnv'), 49 | }, 50 | { 51 | 'id': 'Pusher2d-ImageDefault-v0', 52 | 'entry_point': (f'{MUJOCO_ENVIRONMENTS_PATH}' 53 | '.image_pusher_2d:ImagePusher2dEnv'), 54 | }, 55 | { 56 | 'id': 'Pusher2d-ImageReach-v0', 57 | 'entry_point': (f'{MUJOCO_ENVIRONMENTS_PATH}' 58 | '.image_pusher_2d:ImageForkReacher2dEnv'), 59 | }, 60 | { 61 | 'id': 'Pusher2d-BlindReach-v0', 62 | 'entry_point': (f'{MUJOCO_ENVIRONMENTS_PATH}' 63 | '.image_pusher_2d:BlindForkReacher2dEnv'), 64 | }, 65 | ) 66 | 67 | GENERAL_ENVIRONMENT_SPECS = ( 68 | { 69 | 'id': 'MultiGoal-Default-v0', 70 | 'entry_point': (f'{CUSTOM_GYM_ENVIRONMENTS_PATH}' 71 | '.multi_goal:MultiGoalEnv') 72 | }, 73 | ) 74 | 75 | MULTIWORLD_ENVIRONMENT_SPECS = ( 76 | { 77 | 'id': 'Point2DEnv-Default-v0', 78 | 'entry_point': 'multiworld.envs.pygame.point2d:Point2DEnv' 79 | }, 80 | { 81 | 'id': 'Point2DEnv-Wall-v0', 82 | 'entry_point': 'multiworld.envs.pygame.point2d:Point2DWallEnv' 83 | }, 84 | ) 85 | 86 | MUJOCO_ENVIRONMENTS = tuple( 87 | environment_spec['id'] 88 | for environment_spec in MUJOCO_ENVIRONMENT_SPECS) 89 | 90 | 91 | GENERAL_ENVIRONMENTS = tuple( 92 | environment_spec['id'] 93 | for environment_spec in GENERAL_ENVIRONMENT_SPECS) 94 | 95 | 96 | MULTIWORLD_ENVIRONMENTS = tuple( 97 | environment_spec['id'] 98 | for environment_spec in MULTIWORLD_ENVIRONMENT_SPECS) 99 | 100 | GYM_ENVIRONMENTS = ( 101 | *MUJOCO_ENVIRONMENTS, 102 | *GENERAL_ENVIRONMENTS, 103 | *MULTIWORLD_ENVIRONMENTS, 104 | ) 105 | 106 | 107 | def register_mujoco_environments(): 108 | """Register softlearning mujoco environments.""" 109 | for mujoco_environment in MUJOCO_ENVIRONMENT_SPECS: 110 | gym.register(**mujoco_environment) 111 | 112 | gym_ids = tuple( 113 | environment_spec['id'] 114 | for environment_spec in MUJOCO_ENVIRONMENT_SPECS) 115 | 116 | return gym_ids 117 | 118 | 119 | def register_general_environments(): 120 | """Register gym environments that don't fall under a specific category.""" 121 | for general_environment in GENERAL_ENVIRONMENT_SPECS: 122 | gym.register(**general_environment) 123 | 124 | gym_ids = tuple( 125 | environment_spec['id'] 126 | for environment_spec in GENERAL_ENVIRONMENT_SPECS) 127 | 128 | return gym_ids 129 | 130 | 131 | def register_multiworld_environments(): 132 | """Register custom environments from multiworld package.""" 133 | for multiworld_environment in MULTIWORLD_ENVIRONMENT_SPECS: 134 | gym.register(**multiworld_environment) 135 | 136 | gym_ids = tuple( 137 | environment_spec['id'] 138 | for environment_spec in MULTIWORLD_ENVIRONMENT_SPECS) 139 | 140 | return gym_ids 141 | 142 | 143 | def register_environments(): 144 | registered_mujoco_environments = register_mujoco_environments() 145 | registered_general_environments = register_general_environments() 146 | registered_multiworld_environments = register_multiworld_environments() 147 | registered_mbpo_environments = register_mbpo_environments() 148 | 149 | return ( 150 | *registered_mujoco_environments, 151 | *registered_general_environments, 152 | *registered_multiworld_environments, 153 | *registered_mbpo_environments, 154 | ) 155 | -------------------------------------------------------------------------------- /softlearning/environments/gym/mujoco/image_pusher_2d.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from softlearning.environments.helpers import random_point_in_circle 4 | from .pusher_2d import Pusher2dEnv 5 | 6 | 7 | class ImagePusher2dEnv(Pusher2dEnv): 8 | def __init__(self, image_shape, *args, **kwargs): 9 | self._Serializable__initialize(locals()) 10 | self.image_shape = image_shape 11 | Pusher2dEnv.__init__(self, *args, **kwargs) 12 | 13 | def _get_obs(self): 14 | width, height = self.image_shape[:2] 15 | image = self.render(mode='rgb_array', width=width, height=height) 16 | image = ((2.0 / 255.0) * image - 1.0) 17 | 18 | return np.concatenate([ 19 | image.reshape(-1), 20 | self.sim.data.qpos.flat[self.JOINT_INDS], 21 | self.sim.data.qvel.flat[self.JOINT_INDS], 22 | ]).reshape(-1) 23 | 24 | def step(self, action): 25 | """Step, computing reward from 'true' observations and not images.""" 26 | 27 | reward_observations = super(ImagePusher2dEnv, self)._get_obs() 28 | reward, info = self.compute_reward(reward_observations, action) 29 | 30 | self.do_simulation(action, self.frame_skip) 31 | 32 | observation = self._get_obs() 33 | done = False 34 | 35 | return observation, reward, done, info 36 | 37 | def viewer_setup(self): 38 | self.viewer.cam.trackbodyid = 0 39 | self.viewer.cam.lookat[:3] = [0, 0, 0] 40 | self.viewer.cam.distance = 3.5 41 | self.viewer.cam.elevation = -90 42 | self.viewer.cam.azimuth = 0 43 | self.viewer.cam.trackbodyid = -1 44 | 45 | 46 | class ImageForkReacher2dEnv(ImagePusher2dEnv): 47 | def __init__(self, 48 | arm_goal_distance_cost_coeff, 49 | arm_object_distance_cost_coeff, 50 | *args, 51 | **kwargs): 52 | self._Serializable__initialize(locals()) 53 | 54 | self._arm_goal_distance_cost_coeff = arm_goal_distance_cost_coeff 55 | self._arm_object_distance_cost_coeff = arm_object_distance_cost_coeff 56 | 57 | super(ImageForkReacher2dEnv, self).__init__(*args, **kwargs) 58 | 59 | def compute_reward(self, observations, actions): 60 | is_batch = True 61 | if observations.ndim == 1: 62 | observations = observations[None] 63 | actions = actions[None] 64 | is_batch = False 65 | else: 66 | raise NotImplementedError('Might be broken.') 67 | 68 | arm_pos = observations[:, -6:-4] 69 | goal_pos = self.get_body_com('goal')[:2][None] 70 | object_pos = observations[:, -3:-1] 71 | 72 | arm_goal_dists = np.linalg.norm(arm_pos - goal_pos, axis=1) 73 | arm_object_dists = np.linalg.norm(arm_pos - object_pos, axis=1) 74 | ctrl_costs = np.sum(actions**2, axis=1) 75 | 76 | costs = ( 77 | + self._arm_goal_distance_cost_coeff * arm_goal_dists 78 | + self._arm_object_distance_cost_coeff * arm_object_dists 79 | + self._ctrl_cost_coeff * ctrl_costs) 80 | 81 | rewards = -costs 82 | 83 | if not is_batch: 84 | rewards = rewards.squeeze() 85 | arm_goal_dists = arm_goal_dists.squeeze() 86 | arm_object_dists = arm_object_dists.squeeze() 87 | 88 | return rewards, { 89 | 'arm_goal_distance': arm_goal_dists, 90 | 'arm_object_distance': arm_object_dists, 91 | } 92 | 93 | def reset_model(self): 94 | qpos = np.random.uniform( 95 | low=-0.1, high=0.1, size=self.model.nq) + self.init_qpos.squeeze() 96 | 97 | # qpos[self.JOINT_INDS[0]] = np.random.uniform(-np.pi, np.pi) 98 | # qpos[self.JOINT_INDS[1]] = np.random.uniform( 99 | # -np.pi/2, np.pi/2) + np.pi/4 100 | # qpos[self.JOINT_INDS[2]] = np.random.uniform( 101 | # -np.pi/2, np.pi/2) + np.pi/2 102 | 103 | target_position = np.array(random_point_in_circle( 104 | angle_range=(0, 2*np.pi), radius=(0.6, 1.2))) 105 | target_position[1] += 1.0 106 | 107 | qpos[self.TARGET_INDS] = target_position 108 | # qpos[self.TARGET_INDS] = [1.0, 2.0] 109 | # qpos[self.TARGET_INDS] = self.init_qpos.squeeze()[self.TARGET_INDS] 110 | 111 | puck_position = np.random.uniform([-1.0], [1.0], size=[2]) 112 | puck_position = ( 113 | np.sign(puck_position) 114 | * np.maximum(np.abs(puck_position), 1/2)) 115 | puck_position[np.where(puck_position == 0)] = 1.0 116 | # puck_position[1] += 1.0 117 | # puck_position = np.random.uniform( 118 | # low=[0.3, -1.0], high=[1.0, -0.4]), 119 | 120 | qpos[self.PUCK_INDS] = puck_position 121 | 122 | qvel = self.init_qvel.copy().squeeze() 123 | qvel[self.PUCK_INDS] = 0 124 | qvel[self.TARGET_INDS] = 0 125 | 126 | # TODO: remnants from rllab -> gym conversion 127 | # qacc = np.zeros(self.sim.data.qacc.shape[0]) 128 | # ctrl = np.zeros(self.sim.data.ctrl.shape[0]) 129 | # full_state = np.concatenate((qpos, qvel, qacc, ctrl)) 130 | 131 | # super(Pusher2dEnv, self).reset(full_state) 132 | 133 | self.set_state(qpos, qvel) 134 | 135 | return self._get_obs() 136 | 137 | 138 | class BlindForkReacher2dEnv(ImageForkReacher2dEnv): 139 | def _get_obs(self): 140 | return np.concatenate([ 141 | self.sim.data.qpos.flat[self.JOINT_INDS], 142 | self.sim.data.qvel.flat[self.JOINT_INDS], 143 | ]).reshape(-1) 144 | -------------------------------------------------------------------------------- /softlearning/environments/adapters/gym_adapter.py: -------------------------------------------------------------------------------- 1 | """Implements a GymAdapter that converts Gym envs into SoftlearningEnv.""" 2 | 3 | import numpy as np 4 | import gym 5 | from gym import spaces, wrappers 6 | 7 | from .softlearning_env import SoftlearningEnv 8 | from softlearning.environments.gym import register_environments 9 | from softlearning.environments.gym.wrappers import NormalizeActionWrapper 10 | from collections import defaultdict 11 | 12 | 13 | def parse_domain_task(gym_id): 14 | domain_task_parts = gym_id.split('-') 15 | domain = '-'.join(domain_task_parts[:1]) 16 | task = '-'.join(domain_task_parts[1:]) 17 | 18 | return domain, task 19 | 20 | 21 | CUSTOM_GYM_ENVIRONMENT_IDS = register_environments() 22 | CUSTOM_GYM_ENVIRONMENTS = defaultdict(list) 23 | 24 | for gym_id in CUSTOM_GYM_ENVIRONMENT_IDS: 25 | domain, task = parse_domain_task(gym_id) 26 | CUSTOM_GYM_ENVIRONMENTS[domain].append(task) 27 | 28 | CUSTOM_GYM_ENVIRONMENTS = dict(CUSTOM_GYM_ENVIRONMENTS) 29 | 30 | GYM_ENVIRONMENT_IDS = tuple(gym.envs.registry.env_specs.keys()) 31 | GYM_ENVIRONMENTS = defaultdict(list) 32 | 33 | 34 | for gym_id in GYM_ENVIRONMENT_IDS: 35 | domain, task = parse_domain_task(gym_id) 36 | GYM_ENVIRONMENTS[domain].append(task) 37 | 38 | GYM_ENVIRONMENTS = dict(GYM_ENVIRONMENTS) 39 | 40 | 41 | class GymAdapter(SoftlearningEnv): 42 | """Adapter that implements the SoftlearningEnv for Gym envs.""" 43 | 44 | def __init__(self, 45 | domain, 46 | task, 47 | *args, 48 | env=None, 49 | normalize=True, 50 | observation_keys=None, 51 | unwrap_time_limit=True, 52 | **kwargs): 53 | assert not args, ( 54 | "Gym environments don't support args. Use kwargs instead.") 55 | 56 | self.normalize = normalize 57 | self.observation_keys = observation_keys 58 | self.unwrap_time_limit = unwrap_time_limit 59 | 60 | self._Serializable__initialize(locals()) 61 | super(GymAdapter, self).__init__(domain, task, *args, **kwargs) 62 | 63 | if env is None: 64 | assert (domain is not None and task is not None), (domain, task) 65 | env_id = f"{domain}-{task}" 66 | env = gym.envs.make(env_id, **kwargs) 67 | else: 68 | assert domain is None and task is None, (domain, task) 69 | 70 | if isinstance(env, wrappers.TimeLimit) and unwrap_time_limit: 71 | # Remove the TimeLimit wrapper that sets 'done = True' when 72 | # the time limit specified for each environment has been passed and 73 | # therefore the environment is not Markovian (terminal condition 74 | # depends on time rather than state). 75 | env = env.env 76 | 77 | if isinstance(env.observation_space, spaces.Dict): 78 | observation_keys = ( 79 | observation_keys or list(env.observation_space.spaces.keys())) 80 | if normalize: 81 | env = NormalizeActionWrapper(env) 82 | 83 | self._env = env 84 | 85 | @property 86 | def observation_space(self): 87 | observation_space = self._env.observation_space 88 | return observation_space 89 | 90 | @property 91 | def active_observation_shape(self): 92 | """Shape for the active observation based on observation_keys.""" 93 | if not isinstance(self._env.observation_space, spaces.Dict): 94 | return super(GymAdapter, self).active_observation_shape 95 | 96 | observation_keys = ( 97 | self.observation_keys 98 | or list(self._env.observation_space.spaces.keys())) 99 | 100 | active_size = sum( 101 | np.prod(self._env.observation_space.spaces[key].shape) 102 | for key in observation_keys) 103 | 104 | active_observation_shape = (active_size, ) 105 | 106 | return active_observation_shape 107 | 108 | def convert_to_active_observation(self, observation): 109 | if not isinstance(self._env.observation_space, spaces.Dict): 110 | return observation 111 | 112 | observation_keys = ( 113 | self.observation_keys 114 | or list(self._env.observation_space.spaces.keys())) 115 | 116 | observation = np.concatenate([ 117 | observation[key] for key in observation_keys 118 | ], axis=-1) 119 | 120 | return observation 121 | 122 | @property 123 | def action_space(self, *args, **kwargs): 124 | action_space = self._env.action_space 125 | if len(action_space.shape) > 1: 126 | raise NotImplementedError( 127 | "Action space ({}) is not flat, make sure to check the" 128 | " implemenation.".format(action_space)) 129 | return action_space 130 | 131 | def step(self, action, *args, **kwargs): 132 | # TODO(hartikainen): refactor this to always return an OrderedDict, 133 | # such that the observations for all the envs is consistent. Right now 134 | # some of the gym envs return np.array whereas others return dict. 135 | # 136 | # Something like: 137 | # observation = OrderedDict() 138 | # observation['observation'] = env.step(action, *args, **kwargs) 139 | # return observation 140 | 141 | return self._env.step(action, *args, **kwargs) 142 | 143 | def reset(self, *args, **kwargs): 144 | return self._env.reset(*args, **kwargs) 145 | 146 | def render(self, *args, **kwargs): 147 | return self._env.render(*args, **kwargs) 148 | 149 | def close(self, *args, **kwargs): 150 | return self._env.close(*args, **kwargs) 151 | 152 | def seed(self, *args, **kwargs): 153 | return self._env.seed(*args, **kwargs) 154 | 155 | @property 156 | def unwrapped(self): 157 | return self._env.unwrapped 158 | 159 | def get_param_values(self, *args, **kwargs): 160 | raise NotImplementedError 161 | 162 | def set_param_values(self, *args, **kwargs): 163 | raise NotImplementedError 164 | -------------------------------------------------------------------------------- /softlearning/replay_pools/flexible_replay_pool.py: -------------------------------------------------------------------------------- 1 | import gzip 2 | import pickle 3 | 4 | import numpy as np 5 | 6 | from .replay_pool import ReplayPool 7 | 8 | 9 | class FlexibleReplayPool(ReplayPool): 10 | def __init__(self, max_size, fields_attrs): 11 | super(FlexibleReplayPool, self).__init__() 12 | 13 | max_size = int(max_size) 14 | self._max_size = max_size 15 | 16 | self.fields = {} 17 | self.fields_attrs = {} 18 | 19 | self.add_fields(fields_attrs) 20 | 21 | self._pointer = 0 22 | self._size = 0 23 | self._samples_since_save = 0 24 | 25 | @property 26 | def size(self): 27 | return self._size 28 | 29 | @property 30 | def field_names(self): 31 | return list(self.fields.keys()) 32 | 33 | def add_fields(self, fields_attrs): 34 | self.fields_attrs.update(fields_attrs) 35 | 36 | for field_name, field_attrs in fields_attrs.items(): 37 | field_shape = (self._max_size, *field_attrs['shape']) 38 | initializer = field_attrs.get('initializer', np.zeros) 39 | self.fields[field_name] = initializer( 40 | field_shape, dtype=field_attrs['dtype']) 41 | 42 | def _advance(self, count=1): 43 | self._pointer = (self._pointer + count) % self._max_size 44 | self._size = min(self._size + count, self._max_size) 45 | self._samples_since_save += count 46 | 47 | def add_sample(self, sample): 48 | samples = { 49 | key: value[None, ...] 50 | for key, value in sample.items() 51 | } 52 | self.add_samples(samples) 53 | 54 | def add_samples(self, samples): 55 | field_names = list(samples.keys()) 56 | num_samples = samples[field_names[0]].shape[0] 57 | 58 | index = np.arange( 59 | self._pointer, self._pointer + num_samples) % self._max_size 60 | 61 | for field_name in self.field_names: 62 | default_value = ( 63 | self.fields_attrs[field_name].get('default_value', 0.0)) 64 | values = samples.get(field_name, default_value) 65 | assert values.shape[0] == num_samples 66 | self.fields[field_name][index] = values 67 | 68 | self._advance(num_samples) 69 | 70 | def random_indices(self, batch_size): 71 | if self._size == 0: return np.arange(0, 0) 72 | return np.random.randint(0, self._size, batch_size) 73 | 74 | def random_batch(self, batch_size, field_name_filter=None, **kwargs): 75 | random_indices = self.random_indices(batch_size) 76 | return self.batch_by_indices( 77 | random_indices, field_name_filter=field_name_filter, **kwargs) 78 | 79 | def last_n_batch(self, last_n, field_name_filter=None, **kwargs): 80 | last_n_indices = np.arange( 81 | self._pointer - min(self.size, last_n), self._pointer 82 | ) % self._max_size 83 | return self.batch_by_indices( 84 | last_n_indices, field_name_filter=field_name_filter, **kwargs) 85 | 86 | def filter_fields(self, field_names, field_name_filter): 87 | if isinstance(field_name_filter, str): 88 | field_name_filter = [field_name_filter] 89 | 90 | if isinstance(field_name_filter, (list, tuple)): 91 | field_name_list = field_name_filter 92 | 93 | def filter_fn(field_name): 94 | return field_name in field_name_list 95 | 96 | else: 97 | filter_fn = field_name_filter 98 | 99 | filtered_field_names = [ 100 | field_name for field_name in field_names 101 | if filter_fn(field_name) 102 | ] 103 | 104 | return filtered_field_names 105 | 106 | def batch_by_indices(self, indices, field_name_filter=None): 107 | if np.any(indices % self._max_size > self.size): 108 | raise ValueError( 109 | "Tried to retrieve batch with indices greater than current" 110 | " size") 111 | 112 | field_names = self.field_names 113 | if field_name_filter is not None: 114 | field_names = self.filter_fields( 115 | field_names, field_name_filter) 116 | 117 | return { 118 | field_name: self.fields[field_name][indices] 119 | for field_name in field_names 120 | } 121 | 122 | def save_latest_experience(self, pickle_path): 123 | latest_samples = self.last_n_batch(self._samples_since_save) 124 | 125 | with gzip.open(pickle_path, 'wb') as f: 126 | pickle.dump(latest_samples, f) 127 | 128 | self._samples_since_save = 0 129 | 130 | def load_experience(self, experience_path): 131 | with gzip.open(experience_path, 'rb') as f: 132 | latest_samples = pickle.load(f) 133 | 134 | key = list(latest_samples.keys())[0] 135 | num_samples = latest_samples[key].shape[0] 136 | for field_name, data in latest_samples.items(): 137 | assert data.shape[0] == num_samples, data.shape 138 | 139 | self.add_samples(latest_samples) 140 | self._samples_since_save = 0 141 | 142 | def return_all_samples(self): 143 | return { 144 | field_name: self.fields[field_name][:self.size] 145 | for field_name in self.field_names 146 | } 147 | 148 | def __getstate__(self): 149 | state = self.__dict__.copy() 150 | state['fields'] = { 151 | field_name: self.fields[field_name][:self.size] 152 | for field_name in self.field_names 153 | } 154 | 155 | return state 156 | 157 | def __setstate__(self, state): 158 | if state['_size'] < state['_max_size']: 159 | pad_size = state['_max_size'] - state['_size'] 160 | for field_name in state['fields'].keys(): 161 | field_shape = state['fields_attrs'][field_name]['shape'] 162 | state['fields'][field_name] = np.concatenate(( 163 | state['fields'][field_name], 164 | np.zeros((pad_size, *field_shape)) 165 | ), axis=0) 166 | 167 | self.__dict__ = state 168 | -------------------------------------------------------------------------------- /softlearning/replay_pools/trajectory_replay_pool.py: -------------------------------------------------------------------------------- 1 | from collections import deque 2 | import gzip 3 | import pickle 4 | from itertools import islice 5 | 6 | import numpy as np 7 | 8 | from softlearning.utils.numpy import softmax 9 | from .replay_pool import ReplayPool 10 | 11 | 12 | def random_int_with_variable_range(mins, maxs): 13 | result = np.floor(np.random.uniform(mins, maxs)).astype(int) 14 | return result 15 | 16 | 17 | class TrajectoryReplayPool(ReplayPool): 18 | def __init__(self, 19 | observation_space, 20 | action_space, 21 | max_size): 22 | super(TrajectoryReplayPool, self).__init__() 23 | 24 | max_size = int(max_size) 25 | self._max_size = max_size 26 | 27 | self._trajectories = deque(maxlen=max_size) 28 | self._trajectory_lengths = deque(maxlen=max_size) 29 | self._num_samples = 0 30 | self._trajectories_since_save = 0 31 | 32 | @property 33 | def num_trajectories(self): 34 | return len(self._trajectories) 35 | 36 | @property 37 | def size(self): 38 | return sum(self._trajectory_lengths) 39 | 40 | @property 41 | def num_samples(self): 42 | return self._num_samples 43 | 44 | def add_paths(self, trajectories): 45 | self._trajectories += trajectories 46 | self._trajectory_lengths += [ 47 | trajectory[next(iter(trajectory.keys()))].shape[0] 48 | for trajectory in trajectories 49 | ] 50 | self._trajectories_since_save += len(trajectories) 51 | 52 | def add_path(self, trajectory): 53 | self.add_paths([trajectory]) 54 | 55 | def add_sample(self, sample): 56 | raise NotImplementedError( 57 | f"{self.__class__.__name__} only supports adding full paths at" 58 | " once.") 59 | 60 | def add_samples(self, samples): 61 | raise NotImplementedError( 62 | f"{self.__class__.__name__} only supports adding full paths at" 63 | " once.") 64 | 65 | def batch_by_indices(self, 66 | episode_indices, 67 | step_indices, 68 | field_name_filter=None): 69 | assert len(episode_indices) == len(step_indices) 70 | 71 | batch_size = len(episode_indices) 72 | trajectories = [self._trajectories[i] for i in episode_indices] 73 | 74 | batch = { 75 | field_name: np.empty( 76 | (batch_size, *values.shape[1:]), dtype=values.dtype) 77 | for field_name, values in trajectories[0].items() 78 | } 79 | 80 | for i, episode in enumerate(trajectories): 81 | for field_name, episode_values in episode.items(): 82 | batch[field_name][i] = episode_values[step_indices[i]] 83 | 84 | return batch 85 | 86 | def random_batch(self, batch_size, *args, **kwargs): 87 | num_trajectories = len(self._trajectories) 88 | if num_trajectories < 1: 89 | return {} 90 | 91 | trajectory_lengths = np.array(self._trajectory_lengths) 92 | trajectory_weights = trajectory_lengths / np.sum(trajectory_lengths) 93 | trajectory_probabilities = softmax(trajectory_weights) 94 | 95 | trajectory_indices = np.random.choice( 96 | np.arange(num_trajectories), 97 | size=batch_size, 98 | replace=True, 99 | p=trajectory_probabilities) 100 | first_key = next(iter( 101 | self._trajectories[trajectory_indices[0]].keys())) 102 | trajectory_lengths = np.array([ 103 | self._trajectories[trajectory_index][first_key].shape[0] 104 | for trajectory_index in trajectory_indices 105 | ]) 106 | 107 | step_indices = random_int_with_variable_range( 108 | np.zeros_like(trajectory_lengths, dtype=np.int64), 109 | trajectory_lengths) 110 | 111 | batch = self.batch_by_indices(trajectory_indices, step_indices) 112 | 113 | return batch 114 | 115 | def last_n_batch(self, last_n, field_name_filter=None, **kwargs): 116 | num_trajectories = len(self._trajectories) 117 | if num_trajectories < 1: 118 | return {} 119 | 120 | trajectory_indices = [] 121 | step_indices = [] 122 | 123 | trajectory_lengths = 0 124 | for trajectory_index in range(num_trajectories-1, -1, -1): 125 | trajectory = self._trajectories[trajectory_index] 126 | trajectory_length = trajectory[list(trajectory.keys())[0]].shape[0] 127 | 128 | steps_from_this_episode = min(trajectory_length, last_n - trajectory_lengths) 129 | step_indices += list(range( 130 | trajectory_length-1, 131 | trajectory_length - steps_from_this_episode - 1, 132 | -1)) 133 | trajectory_indices += [trajectory_index] * steps_from_this_episode 134 | 135 | trajectory_lengths += trajectory_length 136 | 137 | if trajectory_lengths >= last_n: 138 | break 139 | 140 | trajectory_indices = trajectory_indices[::-1] 141 | step_indices = step_indices[::-1] 142 | 143 | batch = self.batch_by_indices(trajectory_indices, step_indices) 144 | 145 | return batch 146 | 147 | def save_latest_experience(self, pickle_path): 148 | # deque doesn't support direct slicing, thus need to use islice 149 | num_trajectories = self.num_trajectories 150 | start_index = max(num_trajectories - self._trajectories_since_save, 0) 151 | end_index = num_trajectories 152 | 153 | latest_trajectories = tuple(islice( 154 | self._trajectories, start_index, end_index)) 155 | 156 | with gzip.open(pickle_path, 'wb') as f: 157 | pickle.dump(latest_trajectories, f) 158 | 159 | self._trajectories_since_save = 0 160 | 161 | def load_experience(self, experience_path): 162 | with gzip.open(experience_path, 'rb') as f: 163 | latest_trajectories = pickle.load(f) 164 | 165 | self.add_paths(latest_trajectories) 166 | self._trajectories_since_save = 0 167 | -------------------------------------------------------------------------------- /softlearning/replay_pools/simple_replay_pool.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | 3 | import numpy as np 4 | from gym.spaces import Box, Dict, Discrete 5 | import pdb 6 | 7 | from .flexible_replay_pool import FlexibleReplayPool 8 | 9 | 10 | def normalize_observation_fields(observation_space, name='observations'): 11 | if isinstance(observation_space, Dict): 12 | fields = [ 13 | normalize_observation_fields(child_observation_space, name) 14 | for name, child_observation_space 15 | in observation_space.spaces.items() 16 | ] 17 | fields = { 18 | 'observations.{}'.format(name): value 19 | for field in fields 20 | for name, value in field.items() 21 | } 22 | elif isinstance(observation_space, (Box, Discrete)): 23 | fields = { 24 | name: { 25 | 'shape': observation_space.shape, 26 | 'dtype': observation_space.dtype, 27 | } 28 | } 29 | else: 30 | raise NotImplementedError( 31 | "Observation space of type '{}' not supported." 32 | "".format(type(observation_space))) 33 | 34 | return fields 35 | 36 | 37 | class SimpleReplayPool(FlexibleReplayPool): 38 | def __init__(self, observation_space, action_space, *args, **kwargs): 39 | self._observation_space = observation_space 40 | self._action_space = action_space 41 | 42 | observation_fields = normalize_observation_fields(observation_space) 43 | # It's a bit memory inefficient to save the observations twice, 44 | # but it makes the code *much* easier since you no longer have 45 | # to worry about termination conditions. 46 | observation_fields.update({ 47 | 'next_' + key: value 48 | for key, value in observation_fields.items() 49 | }) 50 | 51 | fields = { 52 | **observation_fields, 53 | **{ 54 | 'actions': { 55 | 'shape': self._action_space.shape, 56 | 'dtype': 'float32' 57 | }, 58 | 'rewards': { 59 | 'shape': (1, ), 60 | 'dtype': 'float32' 61 | }, 62 | # self.terminals[i] = a terminal was received at time i 63 | 'terminals': { 64 | 'shape': (1, ), 65 | 'dtype': 'bool' 66 | }, 67 | } 68 | } 69 | 70 | super(SimpleReplayPool, self).__init__( 71 | *args, fields_attrs=fields, **kwargs) 72 | 73 | def add_samples(self, samples): 74 | if not isinstance(self._observation_space, Dict): 75 | return super(SimpleReplayPool, self).add_samples(samples) 76 | 77 | dict_observations = defaultdict(list) 78 | for observation in samples['observations']: 79 | for key, value in observation.items(): 80 | dict_observations[key].append(value) 81 | 82 | dict_next_observations = defaultdict(list) 83 | for next_observation in samples['next_observations']: 84 | for key, value in next_observation.items(): 85 | dict_next_observations[key].append(value) 86 | 87 | samples.update( 88 | **{ 89 | f'observations.{observation_key}': np.array(values) 90 | for observation_key, values in dict_observations.items() 91 | }, 92 | **{ 93 | f'next_observations.{observation_key}': np.array(values) 94 | for observation_key, values in dict_next_observations.items() 95 | }, 96 | ) 97 | 98 | del samples['observations'] 99 | del samples['next_observations'] 100 | 101 | return super(SimpleReplayPool, self).add_samples(samples) 102 | 103 | # def add_model_samples(self, samples): 104 | # field_names = list(samples.keys()) 105 | # num_samples = samples[field_names[0]].shape[0] 106 | 107 | # index = np.arange( 108 | # self._pointer, self._pointer + num_samples) % self._max_size 109 | 110 | # for field_name in self.field_names: 111 | # values = samples[field_name] 112 | # assert values.shape[0] == num_samples 113 | # self.fields[field_name][index] = values 114 | 115 | # self._advance(num_samples) 116 | # pdb.set_trace() 117 | # field_names = samples.keys() 118 | # num_samples = samples['observations'].shape[0] 119 | # for i in range(num_samples): 120 | # sample = {field: samples[field][i] for field in field_names} 121 | # self.add_model_sample(sample) 122 | # pdb.set_trace() 123 | # # self.fields 124 | 125 | 126 | # def add_model_sample(self, sample): 127 | # # self._size 128 | # pass 129 | # self._advance() 130 | 131 | def batch_by_indices(self, 132 | indices, 133 | field_name_filter=None, 134 | observation_keys=None): 135 | if not isinstance(self._observation_space, Dict): 136 | return super(SimpleReplayPool, self).batch_by_indices( 137 | indices, field_name_filter=field_name_filter) 138 | 139 | batch = { 140 | field_name: self.fields[field_name][indices] 141 | for field_name in self.field_names 142 | } 143 | 144 | if observation_keys is None: 145 | observation_keys = tuple(self._observation_space.spaces.keys()) 146 | 147 | observations = np.concatenate([ 148 | batch['observations.{}'.format(key)] 149 | for key in observation_keys 150 | ], axis=-1) 151 | 152 | next_observations = np.concatenate([ 153 | batch['next_observations.{}'.format(key)] 154 | for key in observation_keys 155 | ], axis=-1) 156 | 157 | batch['observations'] = observations 158 | batch['next_observations'] = next_observations 159 | 160 | if field_name_filter is not None: 161 | filtered_fields = self.filter_fields( 162 | batch.keys(), field_name_filter) 163 | batch = { 164 | field_name: batch[field_name] 165 | for field_name in filtered_fields 166 | } 167 | 168 | return batch 169 | 170 | def terminate_episode(self): 171 | pass 172 | -------------------------------------------------------------------------------- /softlearning/environments/gym/multi_goal.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | 4 | from gym.utils import EzPickle 5 | from gym import spaces 6 | from gym.envs.mujoco.mujoco_env import MujocoEnv 7 | 8 | 9 | class MultiGoalEnv(MujocoEnv, EzPickle): 10 | """ 11 | Move a 2D point mass to one of the goal positions. Cost is the distance to 12 | the closest goal. 13 | 14 | State: position. 15 | Action: velocity. 16 | """ 17 | def __init__(self, 18 | goal_reward=10, 19 | actuation_cost_coeff=30.0, 20 | distance_cost_coeff=1.0, 21 | init_sigma=0.1): 22 | EzPickle.__init__(**locals()) 23 | 24 | self.dynamics = PointDynamics(dim=2, sigma=0) 25 | self.init_mu = np.zeros(2, dtype=np.float32) 26 | self.init_sigma = init_sigma 27 | self.goal_positions = np.array( 28 | ( 29 | (5, 0), 30 | (-5, 0), 31 | (0, 5), 32 | (0, -5) 33 | ), 34 | dtype=np.float32) 35 | self.goal_threshold = 1.0 36 | self.goal_reward = goal_reward 37 | self.action_cost_coeff = actuation_cost_coeff 38 | self.distance_cost_coeff = distance_cost_coeff 39 | self.xlim = (-7, 7) 40 | self.ylim = (-7, 7) 41 | self.vel_bound = 1. 42 | self.reset() 43 | self.observation = None 44 | 45 | self._ax = None 46 | self._env_lines = [] 47 | self.fixed_plots = None 48 | self.dynamic_plots = [] 49 | 50 | def reset(self): 51 | unclipped_observation = ( 52 | self.init_mu 53 | + self.init_sigma 54 | * np.random.normal(size=self.dynamics.s_dim)) 55 | self.observation = np.clip( 56 | unclipped_observation, 57 | self.observation_space.low, 58 | self.observation_space.high) 59 | return self.observation 60 | 61 | @property 62 | def observation_space(self): 63 | return spaces.Box( 64 | low=np.array((self.xlim[0], self.ylim[0])), 65 | high=np.array((self.xlim[1], self.ylim[1])), 66 | dtype=np.float32, 67 | shape=None) 68 | 69 | @property 70 | def action_space(self): 71 | return spaces.Box( 72 | low=-self.vel_bound, 73 | high=self.vel_bound, 74 | shape=(self.dynamics.a_dim, ), 75 | dtype=np.float32) 76 | 77 | def get_current_obs(self): 78 | return np.copy(self.observation) 79 | 80 | def step(self, action): 81 | action = action.ravel() 82 | 83 | action = np.clip( 84 | action, 85 | self.action_space.low, 86 | self.action_space.high).ravel() 87 | 88 | observation = self.dynamics.forward(self.observation, action) 89 | observation = np.clip( 90 | observation, 91 | self.observation_space.low, 92 | self.observation_space.high) 93 | 94 | reward = self.compute_reward(observation, action) 95 | dist_to_goal = np.amin([ 96 | np.linalg.norm(observation - goal_position) 97 | for goal_position in self.goal_positions 98 | ]) 99 | done = dist_to_goal < self.goal_threshold 100 | if done: 101 | reward += self.goal_reward 102 | 103 | self.observation = np.copy(observation) 104 | 105 | return observation, reward, done, {'pos': observation} 106 | 107 | def _init_plot(self): 108 | fig_env = plt.figure(figsize=(7, 7)) 109 | self._ax = fig_env.add_subplot(111) 110 | self._ax.axis('equal') 111 | 112 | self._env_lines = [] 113 | self._ax.set_xlim((-7, 7)) 114 | self._ax.set_ylim((-7, 7)) 115 | 116 | self._ax.set_title('Multigoal Environment') 117 | self._ax.set_xlabel('x') 118 | self._ax.set_ylabel('y') 119 | 120 | self._plot_position_cost(self._ax) 121 | 122 | def render_rollouts(self, paths=()): 123 | """Render for rendering the past rollouts of the environment.""" 124 | if self._ax is None: 125 | self._init_plot() 126 | 127 | # noinspection PyArgumentList 128 | [line.remove() for line in self._env_lines] 129 | self._env_lines = [] 130 | 131 | for path in paths: 132 | positions = np.stack([info['pos'] for info in path['infos']]) 133 | xx = positions[:, 0] 134 | yy = positions[:, 1] 135 | self._env_lines += self._ax.plot(xx, yy, 'b') 136 | 137 | plt.draw() 138 | plt.pause(0.01) 139 | 140 | def render(self, mode='human'): 141 | """Render for rendering the current state of the environment.""" 142 | pass 143 | 144 | def compute_reward(self, observation, action): 145 | # penalize the L2 norm of acceleration 146 | # noinspection PyTypeChecker 147 | action_cost = np.sum(action ** 2) * self.action_cost_coeff 148 | 149 | # penalize squared dist to goal 150 | cur_position = observation 151 | # noinspection PyTypeChecker 152 | goal_cost = self.distance_cost_coeff * np.amin([ 153 | np.sum((cur_position - goal_position) ** 2) 154 | for goal_position in self.goal_positions 155 | ]) 156 | 157 | # penalize staying with the log barriers 158 | costs = [action_cost, goal_cost] 159 | reward = -np.sum(costs) 160 | return reward 161 | 162 | def _plot_position_cost(self, ax): 163 | delta = 0.01 164 | x_min, x_max = tuple(1.1 * np.array(self.xlim)) 165 | y_min, y_max = tuple(1.1 * np.array(self.ylim)) 166 | X, Y = np.meshgrid( 167 | np.arange(x_min, x_max, delta), 168 | np.arange(y_min, y_max, delta) 169 | ) 170 | goal_costs = np.amin([ 171 | (X - goal_x) ** 2 + (Y - goal_y) ** 2 172 | for goal_x, goal_y in self.goal_positions 173 | ], axis=0) 174 | costs = goal_costs 175 | 176 | contours = ax.contour(X, Y, costs, 20) 177 | ax.clabel(contours, inline=1, fontsize=10, fmt='%.0f') 178 | ax.set_xlim([x_min, x_max]) 179 | ax.set_ylim([y_min, y_max]) 180 | goal = ax.plot(self.goal_positions[:, 0], 181 | self.goal_positions[:, 1], 'ro') 182 | return [contours, goal] 183 | 184 | 185 | class PointDynamics(object): 186 | """ 187 | State: position. 188 | Action: velocity. 189 | """ 190 | def __init__(self, dim, sigma): 191 | self.dim = dim 192 | self.sigma = sigma 193 | self.s_dim = dim 194 | self.a_dim = dim 195 | 196 | def forward(self, state, action): 197 | mu_next = state + action 198 | state_next = mu_next + self.sigma * \ 199 | np.random.normal(size=self.s_dim) 200 | return state_next 201 | -------------------------------------------------------------------------------- /examples/development/base.py: -------------------------------------------------------------------------------- 1 | from ray import tune 2 | import numpy as np 3 | import pdb 4 | 5 | from softlearning.misc.utils import get_git_rev, deep_update 6 | 7 | M = 256 8 | REPARAMETERIZE = True 9 | 10 | NUM_COUPLING_LAYERS = 2 11 | 12 | GAUSSIAN_POLICY_PARAMS_BASE = { 13 | 'type': 'GaussianPolicy', 14 | 'kwargs': { 15 | 'hidden_layer_sizes': (M, M), 16 | 'squash': True, 17 | } 18 | } 19 | 20 | GAUSSIAN_POLICY_PARAMS_FOR_DOMAIN = {} 21 | 22 | POLICY_PARAMS_BASE = { 23 | 'GaussianPolicy': GAUSSIAN_POLICY_PARAMS_BASE, 24 | } 25 | 26 | POLICY_PARAMS_BASE.update({ 27 | 'gaussian': POLICY_PARAMS_BASE['GaussianPolicy'], 28 | }) 29 | 30 | POLICY_PARAMS_FOR_DOMAIN = { 31 | 'GaussianPolicy': GAUSSIAN_POLICY_PARAMS_FOR_DOMAIN, 32 | } 33 | 34 | POLICY_PARAMS_FOR_DOMAIN.update({ 35 | 'gaussian': POLICY_PARAMS_FOR_DOMAIN['GaussianPolicy'], 36 | }) 37 | 38 | DEFAULT_MAX_PATH_LENGTH = 1000 39 | MAX_PATH_LENGTH_PER_DOMAIN = { 40 | 'Point2DEnv': 50, 41 | 'Pendulum': 200, 42 | } 43 | 44 | ALGORITHM_PARAMS_ADDITIONAL = { 45 | 'MBPO': { 46 | 'type': 'MBPO', 47 | 'kwargs': { 48 | 'reparameterize': REPARAMETERIZE, 49 | 'lr': 3e-4, 50 | 'target_update_interval': 1, 51 | 'tau': 5e-3, 52 | 'store_extra_policy_info': False, 53 | 'action_prior': 'uniform', 54 | 'n_initial_exploration_steps': int(5000), 55 | } 56 | }, 57 | 'SQL': { 58 | 'type': 'SQL', 59 | 'kwargs': { 60 | 'policy_lr': 3e-4, 61 | 'target_update_interval': 1, 62 | 'n_initial_exploration_steps': int(1e3), 63 | 'reward_scale': tune.sample_from(lambda spec: ( 64 | { 65 | 'Swimmer': 30, 66 | 'Hopper': 30, 67 | 'HalfCheetah': 30, 68 | 'Walker2d': 10, 69 | 'Ant': 300, 70 | 'Humanoid': 100, 71 | 'Pendulum': 1, 72 | }.get( 73 | spec.get('config', spec) 74 | ['environment_params'] 75 | ['training'] 76 | ['domain'], 77 | 1.0 78 | ), 79 | )), 80 | } 81 | }, 82 | 'MVE': { 83 | 'type': 'MVE', 84 | 'kwargs': { 85 | 'reparameterize': REPARAMETERIZE, 86 | 'lr': 3e-4, 87 | 'target_update_interval': 1, 88 | 'tau': 5e-3, 89 | 'target_entropy': 'auto', 90 | 'store_extra_policy_info': False, 91 | 'action_prior': 'uniform', 92 | 'n_initial_exploration_steps': int(5000), 93 | } 94 | }, 95 | } 96 | 97 | DEFAULT_NUM_EPOCHS = 200 98 | 99 | NUM_EPOCHS_PER_DOMAIN = { 100 | 'Hopper': int(1e3), 101 | 'HalfCheetah': int(3e3), 102 | 'Walker2d': int(3e3), 103 | 'Ant': int(3e3), 104 | 'Humanoid': int(1e4), 105 | 'Pendulum': 10, 106 | } 107 | 108 | ALGORITHM_PARAMS_PER_DOMAIN = { 109 | **{ 110 | domain: { 111 | 'kwargs': { 112 | 'n_epochs': NUM_EPOCHS_PER_DOMAIN.get( 113 | domain, DEFAULT_NUM_EPOCHS), 114 | 'n_initial_exploration_steps': ( 115 | MAX_PATH_LENGTH_PER_DOMAIN.get( 116 | domain, DEFAULT_MAX_PATH_LENGTH 117 | ) * 10), 118 | } 119 | } for domain in NUM_EPOCHS_PER_DOMAIN 120 | } 121 | } 122 | 123 | ENVIRONMENT_PARAMS = { 124 | } 125 | 126 | NUM_CHECKPOINTS = 10 127 | 128 | 129 | def get_variant_spec_base(universe, domain, task, policy, algorithm, env_params): 130 | algorithm_params = deep_update( 131 | ALGORITHM_PARAMS_PER_DOMAIN.get(domain, {}), 132 | ALGORITHM_PARAMS_ADDITIONAL.get(algorithm, {}) 133 | ) 134 | algorithm_params = deep_update( 135 | algorithm_params, 136 | env_params 137 | ) 138 | 139 | variant_spec = { 140 | 'git_sha': get_git_rev(), 141 | 142 | 'environment_params': { 143 | 'training': { 144 | 'domain': domain, 145 | 'task': task, 146 | 'universe': universe, 147 | 'kwargs': ( 148 | ENVIRONMENT_PARAMS.get(domain, {}).get(task, {})), 149 | }, 150 | 'evaluation': tune.sample_from(lambda spec: ( 151 | spec.get('config', spec) 152 | ['environment_params'] 153 | ['training'] 154 | )), 155 | }, 156 | 'policy_params': deep_update( 157 | POLICY_PARAMS_BASE[policy], 158 | POLICY_PARAMS_FOR_DOMAIN[policy].get(domain, {}) 159 | ), 160 | 'Q_params': { 161 | 'type': 'double_feedforward_Q_function', 162 | 'kwargs': { 163 | 'hidden_layer_sizes': (M, M), 164 | } 165 | }, 166 | 'algorithm_params': algorithm_params, 167 | 'replay_pool_params': { 168 | 'type': 'SimpleReplayPool', 169 | 'kwargs': { 170 | 'max_size': tune.sample_from(lambda spec: ( 171 | { 172 | 'SimpleReplayPool': int(1e6), 173 | 'TrajectoryReplayPool': int(1e4), 174 | }.get( 175 | spec.get('config', spec) 176 | ['replay_pool_params'] 177 | ['type'], 178 | int(1e6)) 179 | )), 180 | } 181 | }, 182 | 'sampler_params': { 183 | 'type': 'SimpleSampler', 184 | 'kwargs': { 185 | 'max_path_length': MAX_PATH_LENGTH_PER_DOMAIN.get( 186 | domain, DEFAULT_MAX_PATH_LENGTH), 187 | 'min_pool_size': MAX_PATH_LENGTH_PER_DOMAIN.get( 188 | domain, DEFAULT_MAX_PATH_LENGTH), 189 | 'batch_size': 256, 190 | } 191 | }, 192 | 'run_params': { 193 | 'seed': tune.sample_from( 194 | lambda spec: np.random.randint(0, 10000)), 195 | 'checkpoint_at_end': True, 196 | 'checkpoint_frequency': NUM_EPOCHS_PER_DOMAIN.get( 197 | domain, DEFAULT_NUM_EPOCHS) // NUM_CHECKPOINTS, 198 | 'checkpoint_replay_pool': False, 199 | }, 200 | } 201 | 202 | return variant_spec 203 | 204 | def get_variant_spec(args, env_params): 205 | universe, domain, task = env_params.universe, env_params.domain, env_params.task 206 | 207 | variant_spec = get_variant_spec_base( 208 | universe, domain, task, args.policy, env_params.type, env_params) 209 | 210 | if args.checkpoint_replay_pool is not None: 211 | variant_spec['run_params']['checkpoint_replay_pool'] = ( 212 | args.checkpoint_replay_pool) 213 | 214 | return variant_spec 215 | -------------------------------------------------------------------------------- /softlearning/scripts/console_scripts.py: -------------------------------------------------------------------------------- 1 | """A command line interface that exposes softlearning examples to user. 2 | 3 | This package exposes the functions in examples.instrument module to the user 4 | through a cli, which allows seamless runs of examples in different modes (e.g. 5 | locally, in google compute engine, or ec2). 6 | 7 | 8 | There are two types of cli commands in this file (each have their corresponding 9 | function in examples.instrument): 10 | 1. run_example_* methods, which run the experiments by invoking 11 | `tune.run_experiments` function. 12 | 2. launch_example_* methods, which are helpers function to submit an 13 | example to be run in the cloud. In practice, these launch a cluster, 14 | and then run the `run_example_cluster` method with the provided 15 | arguments and options. 16 | """ 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import logging 23 | 24 | import click 25 | 26 | from examples.instrument import ( 27 | run_example_dry, 28 | run_example_local, 29 | run_example_debug, 30 | run_example_cluster, 31 | launch_example_cluster, 32 | launch_example_gce, 33 | launch_example_ec2) 34 | 35 | 36 | logging.basicConfig(level=logging.INFO) 37 | logger = logging.getLogger(__name__) 38 | logger.setLevel(logging.INFO) 39 | 40 | 41 | def add_options(options): 42 | def decorator(f): 43 | for option in options[::-1]: 44 | click.decorators._param_memo(f, option) 45 | return f 46 | return decorator 47 | 48 | 49 | @click.group() 50 | def cli(): 51 | pass 52 | 53 | 54 | @cli.command( 55 | name='run_example_dry', 56 | context_settings={'ignore_unknown_options': True}) 57 | @click.argument("example_module_name", required=True, type=str) 58 | @click.argument('example_argv', nargs=-1, type=click.UNPROCESSED) 59 | def run_example_dry_cmd(example_module_name, example_argv): 60 | """Print the variant spec and related information of an example.""" 61 | return run_example_dry(example_module_name, example_argv) 62 | 63 | 64 | @cli.command( 65 | name='run_local', 66 | context_settings={'ignore_unknown_options': True}) 67 | @click.argument("example_module_name", required=True, type=str) 68 | @click.argument('example_argv', nargs=-1, type=click.UNPROCESSED) 69 | def run_example_local_cmd(example_module_name, example_argv): 70 | """Run example locally, potentially parallelizing across cpus/gpus.""" 71 | return run_example_local(example_module_name, example_argv) 72 | 73 | 74 | @cli.command( 75 | name='run_example_debug', 76 | context_settings={'ignore_unknown_options': True}) 77 | @click.argument("example_module_name", required=True, type=str) 78 | @click.argument('example_argv', nargs=-1, type=click.UNPROCESSED) 79 | def run_example_debug_cmd(example_module_name, example_argv): 80 | """The debug mode limits tune trial runs to enable use of debugger.""" 81 | return run_example_debug(example_module_name, example_argv) 82 | 83 | 84 | @cli.command( 85 | name='run_example_cluster', 86 | context_settings={'ignore_unknown_options': True}) 87 | @click.argument("example_module_name", required=True, type=str) 88 | @click.argument('example_argv', nargs=-1, type=click.UNPROCESSED) 89 | def run_example_cluster_cmd(example_module_name, example_argv): 90 | """Run example on cluster mode. 91 | 92 | This functions is very similar to the local mode, except that it 93 | correctly sets the redis address to make ray/tune work on a cluster. 94 | """ 95 | run_example_cluster(example_module_name, example_argv) 96 | 97 | 98 | @cli.command( 99 | name='launch_example_cluster', 100 | context_settings={ 101 | 'allow_extra_args': True, 102 | 'ignore_unknown_options': True 103 | }) 104 | @click.argument("example_module_name", required=True, type=str) 105 | @click.argument('example_argv', nargs=-1, type=click.UNPROCESSED) 106 | @click.option( 107 | "--config_file", 108 | required=False, 109 | type=str) 110 | @click.option( 111 | "--stop/--no-stop", 112 | is_flag=True, 113 | default=True, 114 | help="Stop the cluster after the command finishes running.") 115 | @click.option( 116 | "--start/--no-start", 117 | is_flag=True, 118 | default=True, 119 | help="Start the cluster if needed.") 120 | @click.option( 121 | "--screen/--no-screen", 122 | is_flag=True, 123 | default=False, 124 | help="Run the command in a screen.") 125 | @click.option( 126 | "--tmux/--no-tmux", 127 | is_flag=True, 128 | default=True, 129 | help="Run the command in tmux.") 130 | @click.option( 131 | "--override-cluster-name", 132 | required=False, 133 | type=str, 134 | help="Override the configured cluster name.") 135 | @click.option( 136 | "--port-forward", required=False, type=int, help="Port to forward.") 137 | def launch_example_cluster_cmd(*args, **kwargs): 138 | """Launches the example on autoscaled ray cluster through ray exec_cmd. 139 | 140 | This handles basic validation and sanity checks for the experiment, and 141 | then executes the command on autoscaled ray cluster. If necessary, it will 142 | also fill in more useful defaults for our workflow (i.e. for tmux and 143 | override_cluster_name). 144 | """ 145 | return launch_example_cluster(*args, **kwargs) 146 | 147 | 148 | @cli.command( 149 | name='launch_example_gce', 150 | context_settings={ 151 | 'allow_extra_args': True, 152 | 'ignore_unknown_options': True 153 | }) 154 | @add_options(launch_example_cluster_cmd.params) 155 | def launch_example_gce_cmd(*args, **kwargs): 156 | """Forwards call to `launch_example_cluster` after adding gce defaults. 157 | 158 | This optionally sets the ray autoscaler configuration file to the default 159 | gce configuration file, and then calls `launch_example_cluster` to 160 | execute the original command on autoscaled gce cluster by parsing the args. 161 | 162 | See `launch_example_cluster` for further details. 163 | """ 164 | return launch_example_gce(*args, **kwargs) 165 | 166 | 167 | @cli.command( 168 | name='launch_example_ec2', 169 | context_settings={ 170 | 'allow_extra_args': True, 171 | 'ignore_unknown_options': True 172 | }) 173 | @add_options(launch_example_cluster_cmd.params) 174 | def launch_example_ec2_cmd(*args, **kwargs): 175 | """Forwards call to `launch_example_cluster` after adding ec2 defaults. 176 | 177 | This optionally sets the ray autoscaler configuration file to the default 178 | ec2 configuration file, and then calls `launch_example_cluster` to 179 | execute the original command on autoscaled ec2 cluster by parsing the args. 180 | 181 | See `launch_example_cluster` for further details. 182 | """ 183 | return launch_example_ec2(*args, **kwargs) 184 | 185 | 186 | cli.add_command(run_example_local_cmd) 187 | cli.add_command(run_example_dry_cmd) 188 | cli.add_command(run_example_cluster_cmd) 189 | 190 | # Alias for run_example_local 191 | cli.add_command(run_example_local_cmd, name='launch_example_local') 192 | # Alias for run_example_dry 193 | cli.add_command(run_example_dry_cmd, name='launch_example_dry') 194 | # Alias for run_example_debug 195 | cli.add_command(run_example_debug_cmd, name='launch_example_debug') 196 | cli.add_command(launch_example_cluster_cmd) 197 | cli.add_command(launch_example_gce_cmd) 198 | cli.add_command(launch_example_ec2_cmd) 199 | 200 | 201 | def main(): 202 | return cli() 203 | 204 | 205 | if __name__ == "__main__": 206 | main() 207 | -------------------------------------------------------------------------------- /softlearning/environments/adapters/softlearning_env.py: -------------------------------------------------------------------------------- 1 | """Implements the SoftlearningEnv that is usable in softlearning algorithms.""" 2 | 3 | from abc import ABCMeta, abstractmethod 4 | from collections import defaultdict 5 | 6 | import numpy as np 7 | from serializable import Serializable 8 | 9 | 10 | class SoftlearningEnv(Serializable, metaclass=ABCMeta): 11 | """The abstract Softlearning environment class. 12 | 13 | It's an abstract class defining the interface an adapter needs to implement 14 | in order to function with softlearning algorithms. It closely follows the 15 | gym.Env, yet that may not be the case in the future. 16 | 17 | The main API methods that users of this class need to know are: 18 | 19 | step 20 | reset 21 | render 22 | close 23 | seed 24 | 25 | And set the following attributes: 26 | 27 | action_space: The Space object corresponding to valid actions 28 | observation_space: The Space object corresponding to valid observations 29 | reward_range: A tuple corresponding to the min and max possible rewards 30 | 31 | The methods are accessed publicly as "step", "reset", etc.. The 32 | non-underscored versions are wrapper methods to which we may add 33 | functionality over time. 34 | """ 35 | 36 | # Set this in SOME subclasses 37 | metadata = {'render.modes': []} 38 | reward_range = (-float('inf'), float('inf')) 39 | spec = None 40 | 41 | # Set these in ALL subclasses 42 | action_space = None 43 | observation_space = None 44 | 45 | def __init__(self, domain, task, *args, **kwargs): 46 | """Initialize an environment based on domain and task. 47 | Keyword Arguments: 48 | domain -- 49 | task -- 50 | *args -- 51 | **kwargs -- 52 | """ 53 | self._Serializable__initialize(locals()) 54 | self._domain = domain 55 | self._task = task 56 | 57 | @property 58 | @abstractmethod 59 | def observation_space(self): 60 | raise NotImplementedError 61 | 62 | @property 63 | def active_observation_shape(self): 64 | return self.observation_space.shape 65 | 66 | def convert_to_active_observation(self, observation): 67 | return observation 68 | 69 | @property 70 | @abstractmethod 71 | def action_space(self): 72 | raise NotImplementedError 73 | 74 | @abstractmethod 75 | def step(self, action): 76 | """Run one timestep of the environment's dynamics. When end of 77 | episode is reached, you are responsible for calling `reset()` 78 | to reset this environment's state. 79 | 80 | Accepts an action and returns a tuple (observation, reward, done, info). 81 | 82 | Args: 83 | action (object): an action provided by the environment 84 | 85 | Returns: 86 | observation (object): agent's observation of the current environment 87 | reward (float) : amount of reward returned after previous action 88 | done (boolean): whether the episode has ended, in which case further step() calls will return undefined results 89 | info (dict): contains auxiliary diagnostic information (helpful for debugging, and sometimes learning) 90 | """ 91 | raise NotImplementedError 92 | 93 | @abstractmethod 94 | def reset(self): 95 | """Resets the state of the environment and returns an initial observation. 96 | 97 | Returns: observation (object): the initial observation of the 98 | space. 99 | """ 100 | raise NotImplementedError 101 | 102 | @abstractmethod 103 | def render(self, mode='human'): 104 | """Renders the environment. 105 | 106 | The set of supported modes varies per environment. (And some 107 | environments do not support rendering at all.) By convention, 108 | if mode is: 109 | 110 | - human: render to the current display or terminal and 111 | return nothing. Usually for human consumption. 112 | - rgb_array: Return an numpy.ndarray with shape (x, y, 3), 113 | representing RGB values for an x-by-y pixel image, suitable 114 | for turning into a video. 115 | - ansi: Return a string (str) or StringIO.StringIO containing a 116 | terminal-style text representation. The text can include newlines 117 | and ANSI escape sequences (e.g. for colors). 118 | 119 | Note: 120 | Make sure that your class's metadata 'render.modes' key includes 121 | the list of supported modes. It's recommended to call super() 122 | in implementations to use the functionality of this method. 123 | 124 | Args: 125 | mode (str): the mode to render with 126 | close (bool): close all open renderings 127 | 128 | Example: 129 | 130 | class MyEnv(Env): 131 | metadata = {'render.modes': ['human', 'rgb_array']} 132 | 133 | def render(self, mode='human'): 134 | if mode == 'rgb_array': 135 | return np.array(...) # return RGB frame suitable for video 136 | elif mode is 'human': 137 | ... # pop up a window and render 138 | else: 139 | super(MyEnv, self).render(mode=mode) # just raise an exception 140 | """ 141 | raise NotImplementedError 142 | 143 | def render_rollouts(self, paths): 144 | """Renders past rollouts of the environment.""" 145 | if hasattr(self._env, 'render_rollouts'): 146 | return self._env.render_rollouts(paths) 147 | 148 | unwrapped_env = self.unwrapped 149 | if hasattr(unwrapped_env, 'render_rollouts'): 150 | return unwrapped_env.render_rollouts(paths) 151 | 152 | @abstractmethod 153 | def close(self): 154 | """Override _close in your subclass to perform any necessary cleanup. 155 | 156 | Environments will automatically close() themselves when 157 | garbage collected or when the program exits. 158 | """ 159 | return 160 | 161 | @abstractmethod 162 | def seed(self, seed=None): 163 | """Sets the seed for this env's random number generator(s). 164 | 165 | Note: 166 | Some environments use multiple pseudorandom number generators. 167 | We want to capture all such seeds used in order to ensure that 168 | there aren't accidental correlations between multiple generators. 169 | 170 | Returns: 171 | list: Returns the list of seeds used in this env's random 172 | number generators. The first value in the list should be the 173 | "main" seed, or the value which a reproducer should pass to 174 | 'seed'. Often, the main seed equals the provided 'seed', but 175 | this won't be true if seed=None, for example. 176 | """ 177 | pass 178 | 179 | def copy(self): 180 | """Create a deep copy the environment. 181 | 182 | TODO: Investigate if this can be done somehow else, especially for gym 183 | envs. 184 | """ 185 | return Serializable.clone(self) 186 | 187 | @property 188 | @abstractmethod 189 | def unwrapped(self): 190 | """Completely unwrap this env. 191 | 192 | Returns: 193 | gym.Env: The base non-wrapped gym.Env instance 194 | """ 195 | return self._env 196 | 197 | def __str__(self): 198 | return '<{type_name}(domain={domain}, task={task}) <{env}>>'.format( 199 | type_name=type(self).__name__, 200 | domain=self._domain, 201 | task=self._task, 202 | env=self._env) 203 | 204 | @abstractmethod 205 | def get_param_values(self): 206 | raise NotImplementedError 207 | 208 | @abstractmethod 209 | def set_param_values(self, params): 210 | raise NotImplementedError 211 | 212 | def get_path_infos(self, paths, *args, **kwargs): 213 | """Log some general diagnostics from the env infos. 214 | 215 | TODO(hartikainen): These logs don't make much sense right now. Need to 216 | figure out better format for logging general env infos. 217 | """ 218 | keys = list(paths[0].get('infos', [{}])[0].keys()) 219 | 220 | results = defaultdict(list) 221 | 222 | for path in paths: 223 | path_results = { 224 | k: [ 225 | info[k] 226 | for info in path['infos'] 227 | ] for k in keys 228 | } 229 | for info_key, info_values in path_results.items(): 230 | info_values = np.array(info_values) 231 | results[info_key + '-first'].append(info_values[0]) 232 | results[info_key + '-last'].append(info_values[-1]) 233 | results[info_key + '-mean'].append(np.mean(info_values)) 234 | results[info_key + '-median'].append(np.median(info_values)) 235 | if np.array(info_values).dtype != np.dtype('bool'): 236 | results[info_key + '-range'].append(np.ptp(info_values)) 237 | 238 | aggregated_results = {} 239 | for key, value in results.items(): 240 | aggregated_results[key + '-mean'] = np.mean(value) 241 | 242 | return aggregated_results 243 | -------------------------------------------------------------------------------- /examples/development/main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import copy 3 | import glob 4 | import pickle 5 | import sys 6 | import pdb 7 | 8 | import tensorflow as tf 9 | from ray import tune 10 | 11 | from softlearning.environments.utils import get_environment_from_params 12 | from softlearning.algorithms.utils import get_algorithm_from_variant 13 | from softlearning.policies.utils import get_policy_from_variant, get_policy 14 | from softlearning.replay_pools.utils import get_replay_pool_from_variant 15 | from softlearning.samplers.utils import get_sampler_from_variant 16 | from softlearning.value_functions.utils import get_Q_function_from_variant 17 | 18 | from softlearning.misc.utils import set_seed, initialize_tf_variables 19 | from examples.instrument import run_example_local 20 | 21 | import mbpo.static 22 | 23 | class ExperimentRunner(tune.Trainable): 24 | def _setup(self, variant): 25 | set_seed(variant['run_params']['seed']) 26 | 27 | self._variant = variant 28 | 29 | gpu_options = tf.GPUOptions(allow_growth=True) 30 | session = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options)) 31 | tf.keras.backend.set_session(session) 32 | self._session = tf.keras.backend.get_session() 33 | 34 | self.train_generator = None 35 | self._built = False 36 | 37 | def _stop(self): 38 | tf.reset_default_graph() 39 | tf.keras.backend.clear_session() 40 | 41 | def _build(self): 42 | variant = copy.deepcopy(self._variant) 43 | 44 | environment_params = variant['environment_params'] 45 | training_environment = self.training_environment = ( 46 | get_environment_from_params(environment_params['training'])) 47 | evaluation_environment = self.evaluation_environment = ( 48 | get_environment_from_params(environment_params['evaluation']) 49 | if 'evaluation' in environment_params 50 | else training_environment) 51 | 52 | replay_pool = self.replay_pool = ( 53 | get_replay_pool_from_variant(variant, training_environment)) 54 | sampler = self.sampler = get_sampler_from_variant(variant) 55 | Qs = self.Qs = get_Q_function_from_variant( 56 | variant, training_environment) 57 | policy = self.policy = get_policy_from_variant( 58 | variant, training_environment, Qs) 59 | initial_exploration_policy = self.initial_exploration_policy = ( 60 | get_policy('UniformPolicy', training_environment)) 61 | 62 | #### get termination function 63 | domain = environment_params['training']['domain'] 64 | static_fns = mbpo.static[domain.lower()] 65 | #### 66 | 67 | self.algorithm = get_algorithm_from_variant( 68 | variant=self._variant, 69 | training_environment=training_environment, 70 | evaluation_environment=evaluation_environment, 71 | policy=policy, 72 | initial_exploration_policy=initial_exploration_policy, 73 | Qs=Qs, 74 | pool=replay_pool, 75 | static_fns=static_fns, 76 | sampler=sampler, 77 | session=self._session) 78 | 79 | initialize_tf_variables(self._session, only_uninitialized=True) 80 | 81 | self._built = True 82 | 83 | def _train(self): 84 | if not self._built: 85 | self._build() 86 | 87 | if self.train_generator is None: 88 | self.train_generator = self.algorithm.train() 89 | 90 | diagnostics = next(self.train_generator) 91 | 92 | return diagnostics 93 | 94 | def _pickle_path(self, checkpoint_dir): 95 | return os.path.join(checkpoint_dir, 'checkpoint.pkl') 96 | 97 | def _replay_pool_pickle_path(self, checkpoint_dir): 98 | return os.path.join(checkpoint_dir, 'replay_pool.pkl') 99 | 100 | def _tf_checkpoint_prefix(self, checkpoint_dir): 101 | return os.path.join(checkpoint_dir, 'checkpoint') 102 | 103 | def _get_tf_checkpoint(self): 104 | tf_checkpoint = tf.train.Checkpoint(**self.algorithm.tf_saveables) 105 | 106 | return tf_checkpoint 107 | 108 | @property 109 | def picklables(self): 110 | return { 111 | 'variant': self._variant, 112 | 'training_environment': self.training_environment, 113 | 'evaluation_environment': self.evaluation_environment, 114 | 'sampler': self.sampler, 115 | 'algorithm': self.algorithm, 116 | 'Qs': self.Qs, 117 | 'policy_weights': self.policy.get_weights(), 118 | } 119 | 120 | def _save(self, checkpoint_dir): 121 | """Implements the checkpoint logic. 122 | 123 | TODO(hartikainen): This implementation is currently very hacky. Things 124 | that need to be fixed: 125 | - Figure out how serialize/save tf.keras.Model subclassing. The 126 | current implementation just dumps the weights in a pickle, which 127 | is not optimal. 128 | - Try to unify all the saving and loading into easily 129 | extendable/maintainable interfaces. Currently we use 130 | `tf.train.Checkpoint` and `pickle.dump` in very unorganized way 131 | which makes things not so usable. 132 | """ 133 | pickle_path = self._pickle_path(checkpoint_dir) 134 | with open(pickle_path, 'wb') as f: 135 | pickle.dump(self.picklables, f) 136 | 137 | if self._variant['run_params'].get('checkpoint_replay_pool', False): 138 | self._save_replay_pool(checkpoint_dir) 139 | 140 | tf_checkpoint = self._get_tf_checkpoint() 141 | 142 | tf_checkpoint.save( 143 | file_prefix=self._tf_checkpoint_prefix(checkpoint_dir), 144 | session=self._session) 145 | 146 | return os.path.join(checkpoint_dir, '') 147 | 148 | def _save_replay_pool(self, checkpoint_dir): 149 | replay_pool_pickle_path = self._replay_pool_pickle_path( 150 | checkpoint_dir) 151 | self.replay_pool.save_latest_experience(replay_pool_pickle_path) 152 | 153 | def _restore_replay_pool(self, current_checkpoint_dir): 154 | experiment_root = os.path.dirname(current_checkpoint_dir) 155 | 156 | experience_paths = [ 157 | self._replay_pool_pickle_path(checkpoint_dir) 158 | for checkpoint_dir in sorted(glob.iglob( 159 | os.path.join(experiment_root, 'checkpoint_*'))) 160 | ] 161 | 162 | for experience_path in experience_paths: 163 | self.replay_pool.load_experience(experience_path) 164 | 165 | def _restore(self, checkpoint_dir): 166 | assert isinstance(checkpoint_dir, str), checkpoint_dir 167 | 168 | checkpoint_dir = checkpoint_dir.rstrip('/') 169 | 170 | with self._session.as_default(): 171 | pickle_path = self._pickle_path(checkpoint_dir) 172 | with open(pickle_path, 'rb') as f: 173 | picklable = pickle.load(f) 174 | 175 | training_environment = self.training_environment = picklable[ 176 | 'training_environment'] 177 | evaluation_environment = self.evaluation_environment = picklable[ 178 | 'evaluation_environment'] 179 | 180 | replay_pool = self.replay_pool = ( 181 | get_replay_pool_from_variant(self._variant, training_environment)) 182 | 183 | if self._variant['run_params'].get('checkpoint_replay_pool', False): 184 | self._restore_replay_pool(checkpoint_dir) 185 | 186 | sampler = self.sampler = picklable['sampler'] 187 | Qs = self.Qs = picklable['Qs'] 188 | # policy = self.policy = picklable['policy'] 189 | policy = self.policy = ( 190 | get_policy_from_variant(self._variant, training_environment, Qs)) 191 | self.policy.set_weights(picklable['policy_weights']) 192 | initial_exploration_policy = self.initial_exploration_policy = ( 193 | get_policy('UniformPolicy', training_environment)) 194 | 195 | self.algorithm = get_algorithm_from_variant( 196 | variant=self._variant, 197 | training_environment=training_environment, 198 | evaluation_environment=evaluation_environment, 199 | policy=policy, 200 | initial_exploration_policy=initial_exploration_policy, 201 | Qs=Qs, 202 | pool=replay_pool, 203 | sampler=sampler, 204 | session=self._session) 205 | self.algorithm.__setstate__(picklable['algorithm'].__getstate__()) 206 | 207 | tf_checkpoint = self._get_tf_checkpoint() 208 | status = tf_checkpoint.restore(tf.train.latest_checkpoint( 209 | os.path.split(self._tf_checkpoint_prefix(checkpoint_dir))[0])) 210 | 211 | status.assert_consumed().run_restore_ops(self._session) 212 | initialize_tf_variables(self._session, only_uninitialized=True) 213 | 214 | # TODO(hartikainen): target Qs should either be checkpointed or pickled. 215 | for Q, Q_target in zip(self.algorithm._Qs, self.algorithm._Q_targets): 216 | Q_target.set_weights(Q.get_weights()) 217 | 218 | self._built = True 219 | 220 | 221 | def main(argv=None): 222 | """Run ExperimentRunner locally on ray. 223 | 224 | To run this example on cloud (e.g. gce/ec2), use the setup scripts: 225 | 'softlearning launch_example_{gce,ec2} examples.development '. 226 | 227 | Run 'softlearning launch_example_{gce,ec2} --help' for further 228 | instructions. 229 | """ 230 | # __package__ should be `development.main` 231 | run_example_local(__package__, argv) 232 | 233 | 234 | if __name__ == '__main__': 235 | main(argv=sys.argv[1:]) 236 | -------------------------------------------------------------------------------- /softlearning/softlearning.md: -------------------------------------------------------------------------------- 1 | # Softlearning 2 | 3 | Softlearning is a deep reinforcement learning toolbox for training maximum entropy policies in continuous domains. The implementation is fairly thin and primarily optimized for our own development purposes. It utilizes the tf.keras modules for most of the model classes (e.g. policies and value functions). We use Ray for the experiment orchestration. Ray Tune and Autoscaler implement several neat features that enable us to seamlessly run the same experiment scripts that we use for local prototyping to launch large-scale experiments on any chosen cloud service (e.g. GCP or AWS), and intelligently parallelize and distribute training for effective resource allocation. 4 | 5 | This implementation uses Tensorflow. For a PyTorch implementation of soft actor-critic, take a look at [rlkit](https://github.com/vitchyr/rlkit). 6 | 7 | # Getting Started 8 | 9 | ## Prerequisites 10 | 11 | The environment can be run either locally using conda or inside a docker container. For conda installation, you need to have [Conda](https://conda.io/docs/user-guide/install/index.html) installed. For docker installation you will need to have [Docker](https://docs.docker.com/engine/installation/) and [Docker Compose](https://docs.docker.com/compose/install/) installed. Also, most of our environments currently require a [MuJoCo](https://www.roboti.us/license.html) license. 12 | 13 | ## Conda Installation 14 | 15 | 1. [Download](https://www.roboti.us/index.html) and install MuJoCo 1.50 from the MuJoCo website. We assume that the MuJoCo files are extracted to the default location (`~/.mujoco/mjpro150`). 16 | 17 | 2. Copy your MuJoCo license key (mjkey.txt) to ~/.mujoco/mjkey.txt: 18 | 19 | 3. Clone `softlearning` 20 | ``` 21 | git clone https://github.com/rail-berkeley/softlearning.git ${SOFTLEARNING_PATH} 22 | ``` 23 | 24 | 4. Create and activate conda environment, install softlearning to enable command line interface. 25 | ``` 26 | cd ${SOFTLEARNING_PATH} 27 | conda env create -f environment.yml 28 | conda activate softlearning 29 | pip install -e ${SOFTLEARNING_PATH} 30 | ``` 31 | 32 | The environment should be ready to run. See examples section for examples of how to train and simulate the agents. 33 | 34 | Finally, to deactivate and remove the conda environment: 35 | ``` 36 | conda deactivate 37 | conda remove --name softlearning --all 38 | ``` 39 | 40 | ## Docker Installation 41 | 42 | ### docker-compose 43 | To build the image and run the container: 44 | ``` 45 | export MJKEY="$(cat ~/.mujoco/mjkey.txt)" \ 46 | && docker-compose \ 47 | -f ./docker/docker-compose.dev.cpu.yml \ 48 | up \ 49 | -d \ 50 | --force-recreate 51 | ``` 52 | 53 | You can access the container with the typical Docker [exec](https://docs.docker.com/engine/reference/commandline/exec/)-command, i.e. 54 | 55 | ``` 56 | docker exec -it softlearning bash 57 | ``` 58 | 59 | See examples section for examples of how to train and simulate the agents. 60 | 61 | Finally, to clean up the docker setup: 62 | ``` 63 | docker-compose \ 64 | -f ./docker/docker-compose.dev.cpu.yml \ 65 | down \ 66 | --rmi all \ 67 | --volumes 68 | ``` 69 | 70 | ## Examples 71 | ### Training and simulating an agent 72 | 1. To train the agent 73 | ``` 74 | softlearning run_example_local examples.development \ 75 | --universe=gym \ 76 | --domain=HalfCheetah \ 77 | --task=v3 \ 78 | --exp-name=my-sac-experiment-1 \ 79 | --checkpoint-frequency=1000 # Save the checkpoint to resume training later 80 | ``` 81 | 82 | 2. To simulate the resulting policy: 83 | First, find the path that the checkpoint is saved to. By default (i.e. without specifying the `log-dir` argument to the previous script), the data is saved under `~/ray_results////-//`. For example: `~/ray_results/gym/HalfCheetah/v3/2018-12-12T16-48-37-my-sac-experiment-1-0/mujoco-runner_0_seed=7585_2018-12-12_16-48-37xuadh9vd/checkpoint_1000/`. The next command assumes that this path is found from `${SAC_CHECKPOINT_DIR}` environment variable. 84 | 85 | ``` 86 | python -m examples.development.simulate_policy \ 87 | ${SAC_CHECKPOINT_DIR} \ 88 | --max-path-length=1000 \ 89 | --num-rollouts=1 \ 90 | --render-mode=human 91 | ``` 92 | 93 | `examples.development.main` contains several different environments and there are more example scripts available in the `/examples` folder. For more information about the agents and configurations, run the scripts with `--help` flag: `python ./examples/development/main.py --help` 94 | ``` 95 | optional arguments: 96 | -h, --help show this help message and exit 97 | --universe {gym} 98 | --domain {...} 99 | --task {...} 100 | --num-samples NUM_SAMPLES 101 | --resources RESOURCES 102 | Resources to allocate to ray process. Passed to 103 | `ray.init`. 104 | --cpus CPUS Cpus to allocate to ray process. Passed to `ray.init`. 105 | --gpus GPUS Gpus to allocate to ray process. Passed to `ray.init`. 106 | --trial-resources TRIAL_RESOURCES 107 | Resources to allocate for each trial. Passed to 108 | `tune.run_experiments`. 109 | --trial-cpus TRIAL_CPUS 110 | Resources to allocate for each trial. Passed to 111 | `tune.run_experiments`. 112 | --trial-gpus TRIAL_GPUS 113 | Resources to allocate for each trial. Passed to 114 | `tune.run_experiments`. 115 | --trial-extra-cpus TRIAL_EXTRA_CPUS 116 | Extra CPUs to reserve in case the trials need to 117 | launch additional Ray actors that use CPUs. 118 | --trial-extra-gpus TRIAL_EXTRA_GPUS 119 | Extra GPUs to reserve in case the trials need to 120 | launch additional Ray actors that use GPUs. 121 | --checkpoint-frequency CHECKPOINT_FREQUENCY 122 | Save the training checkpoint every this many epochs. 123 | If set, takes precedence over 124 | variant['run_params']['checkpoint_frequency']. 125 | --checkpoint-at-end CHECKPOINT_AT_END 126 | Whether a checkpoint should be saved at the end of 127 | training. If set, takes precedence over 128 | variant['run_params']['checkpoint_at_end']. 129 | --restore RESTORE Path to checkpoint. Only makes sense to set if running 130 | 1 trial. Defaults to None. 131 | --policy {gaussian} 132 | --env ENV 133 | --exp-name EXP_NAME 134 | --log-dir LOG_DIR 135 | --upload-dir UPLOAD_DIR 136 | Optional URI to sync training results to (e.g. 137 | s3:// or gs://). 138 | --confirm-remote [CONFIRM_REMOTE] 139 | Whether or not to query yes/no on remote run. 140 | ``` 141 | 142 | ### Resume training from a saved checkpoint 143 | In order to resume training from previous checkpoint, run the original example main-script, with an additional `--restore` flag. For example, the previous example can be resumed as follows: 144 | 145 | ``` 146 | softlearning run_example_local examples.development \ 147 | --universe=gym \ 148 | --domain=HalfCheetah \ 149 | --task=v3 \ 150 | --exp-name=my-sac-experiment-1 \ 151 | --checkpoint-frequency=1000 \ 152 | --restore=${SAC_CHECKPOINT_PATH} 153 | ``` 154 | 155 | # References 156 | The algorithms are based on the following papers: 157 | 158 | *Soft Actor-Critic Algorithms and Applications*.
159 | Tuomas Haarnoja*, Aurick Zhou*, Kristian Hartikainen*, George Tucker, Sehoon Ha, Jie Tan, Vikash Kumar, Henry Zhu, Abhishek Gupta, Pieter Abbeel, and Sergey Levine. 160 | arXiv preprint, 2018.
161 | [paper](https://arxiv.org/abs/1812.05905) | [videos](https://sites.google.com/view/sac-and-applications) 162 | 163 | *Latent Space Policies for Hierarchical Reinforcement Learning*.
164 | Tuomas Haarnoja*, Kristian Hartikainen*, Pieter Abbeel, and Sergey Levine. 165 | International Conference on Machine Learning (ICML), 2018.
166 | [paper](https://arxiv.org/abs/1804.02808) | [videos](https://sites.google.com/view/latent-space-deep-rl) 167 | 168 | *Soft Actor-Critic: Off-Policy Maximum Entropy Deep Reinforcement Learning with a Stochastic Actor*.
169 | Tuomas Haarnoja, Aurick Zhou, Pieter Abbeel, and Sergey Levine. 170 | International Conference on Machine Learning (ICML), 2018.
171 | [paper](https://arxiv.org/abs/1801.01290) | [videos](https://sites.google.com/view/soft-actor-critic) 172 | 173 | *Composable Deep Reinforcement Learning for Robotic Manipulation*.
174 | Tuomas Haarnoja, Vitchyr Pong, Aurick Zhou, Murtaza Dalal, Pieter Abbeel, Sergey Levine. 175 | International Conference on Robotics and Automation (ICRA), 2018.
176 | [paper](https://arxiv.org/abs/1803.06773) | [videos](https://sites.google.com/view/composing-real-world-policies) 177 | 178 | *Reinforcement Learning with Deep Energy-Based Policies*.
179 | Tuomas Haarnoja*, Haoran Tang*, Pieter Abbeel, Sergey Levine. 180 | International Conference on Machine Learning (ICML), 2017.
181 | [paper](https://arxiv.org/abs/1702.08165) | [videos](https://sites.google.com/view/softqlearning/home) 182 | 183 | If Softlearning helps you in your academic research, you are encouraged to cite our paper. Here is an example bibtex: 184 | ``` 185 | @techreport{haarnoja2018sacapps, 186 | title={Soft Actor-Critic Algorithms and Applications}, 187 | author={Tuomas Haarnoja, Aurick Zhou, Kristian Hartikainen, George Tucker, Sehoon Ha, Jie Tan, Vikash Kumar, Henry Zhu, Abhishek Gupta, Pieter Abbeel, and Sergey Levine}, 188 | journal={arXiv preprint arXiv:1812.05905}, 189 | year={2018} 190 | } 191 | ``` 192 | -------------------------------------------------------------------------------- /softlearning/policies/gaussian_policy.py: -------------------------------------------------------------------------------- 1 | """GaussianPolicy.""" 2 | 3 | from collections import OrderedDict 4 | 5 | import numpy as np 6 | import tensorflow as tf 7 | import tensorflow_probability as tfp 8 | from softlearning.distributions.squash_bijector import SquashBijector 9 | from softlearning.models.feedforward import feedforward_model 10 | 11 | from .base_policy import LatentSpacePolicy 12 | 13 | 14 | SCALE_DIAG_MIN_MAX = (-20, 2) 15 | 16 | 17 | class GaussianPolicy(LatentSpacePolicy): 18 | def __init__(self, 19 | input_shapes, 20 | output_shape, 21 | squash=True, 22 | preprocessor=None, 23 | name=None, 24 | *args, 25 | **kwargs): 26 | self._Serializable__initialize(locals()) 27 | 28 | self._input_shapes = input_shapes 29 | self._output_shape = output_shape 30 | self._squash = squash 31 | self._name = name 32 | self._preprocessor = preprocessor 33 | 34 | super(GaussianPolicy, self).__init__(*args, **kwargs) 35 | 36 | self.condition_inputs = [ 37 | tf.keras.layers.Input(shape=input_shape) 38 | for input_shape in input_shapes 39 | ] 40 | 41 | conditions = tf.keras.layers.Lambda( 42 | lambda x: tf.concat(x, axis=-1) 43 | )(self.condition_inputs) 44 | 45 | if preprocessor is not None: 46 | conditions = preprocessor(conditions) 47 | 48 | shift_and_log_scale_diag = self._shift_and_log_scale_diag_net( 49 | input_shapes=(conditions.shape[1:], ), 50 | output_size=output_shape[0] * 2, 51 | )(conditions) 52 | 53 | shift, log_scale_diag = tf.keras.layers.Lambda( 54 | lambda shift_and_log_scale_diag: tf.split( 55 | shift_and_log_scale_diag, 56 | num_or_size_splits=2, 57 | axis=-1) 58 | )(shift_and_log_scale_diag) 59 | 60 | log_scale_diag = tf.keras.layers.Lambda( 61 | lambda log_scale_diag: tf.clip_by_value( 62 | log_scale_diag, *SCALE_DIAG_MIN_MAX) 63 | )(log_scale_diag) 64 | 65 | batch_size = tf.keras.layers.Lambda( 66 | lambda x: tf.shape(x)[0])(conditions) 67 | 68 | base_distribution = tfp.distributions.MultivariateNormalDiag( 69 | loc=tf.zeros(output_shape), 70 | scale_diag=tf.ones(output_shape)) 71 | 72 | latents = tf.keras.layers.Lambda( 73 | lambda batch_size: base_distribution.sample(batch_size) 74 | )(batch_size) 75 | 76 | self.latents_model = tf.keras.Model(self.condition_inputs, latents) 77 | self.latents_input = tf.keras.layers.Input(shape=output_shape) 78 | 79 | def raw_actions_fn(inputs): 80 | shift, log_scale_diag, latents = inputs 81 | bijector = tfp.bijectors.Affine( 82 | shift=shift, 83 | scale_diag=tf.exp(log_scale_diag)) 84 | actions = bijector.forward(latents) 85 | return actions 86 | 87 | raw_actions = tf.keras.layers.Lambda( 88 | raw_actions_fn 89 | )((shift, log_scale_diag, latents)) 90 | 91 | raw_actions_for_fixed_latents = tf.keras.layers.Lambda( 92 | raw_actions_fn 93 | )((shift, log_scale_diag, self.latents_input)) 94 | 95 | squash_bijector = ( 96 | SquashBijector() 97 | if self._squash 98 | else tfp.bijectors.Identity()) 99 | 100 | actions = tf.keras.layers.Lambda( 101 | lambda raw_actions: squash_bijector.forward(raw_actions) 102 | )(raw_actions) 103 | self.actions_model = tf.keras.Model(self.condition_inputs, actions) 104 | 105 | actions_for_fixed_latents = tf.keras.layers.Lambda( 106 | lambda raw_actions: squash_bijector.forward(raw_actions) 107 | )(raw_actions_for_fixed_latents) 108 | self.actions_model_for_fixed_latents = tf.keras.Model( 109 | (*self.condition_inputs, self.latents_input), 110 | actions_for_fixed_latents) 111 | 112 | deterministic_actions = tf.keras.layers.Lambda( 113 | lambda shift: squash_bijector.forward(shift) 114 | )(shift) 115 | 116 | self.deterministic_actions_model = tf.keras.Model( 117 | self.condition_inputs, deterministic_actions) 118 | 119 | def log_pis_fn(inputs): 120 | shift, log_scale_diag, actions = inputs 121 | base_distribution = tfp.distributions.MultivariateNormalDiag( 122 | loc=tf.zeros(output_shape), 123 | scale_diag=tf.ones(output_shape)) 124 | bijector = tfp.bijectors.Chain(( 125 | squash_bijector, 126 | tfp.bijectors.Affine( 127 | shift=shift, 128 | scale_diag=tf.exp(log_scale_diag)), 129 | )) 130 | distribution = ( 131 | tfp.distributions.ConditionalTransformedDistribution( 132 | distribution=base_distribution, 133 | bijector=bijector)) 134 | 135 | log_pis = distribution.log_prob(actions)[:, None] 136 | return log_pis 137 | 138 | self.actions_input = tf.keras.layers.Input(shape=output_shape) 139 | 140 | log_pis = tf.keras.layers.Lambda( 141 | log_pis_fn)([shift, log_scale_diag, actions]) 142 | 143 | log_pis_for_action_input = tf.keras.layers.Lambda( 144 | log_pis_fn)([shift, log_scale_diag, self.actions_input]) 145 | 146 | self.log_pis_model = tf.keras.Model( 147 | (*self.condition_inputs, self.actions_input), 148 | log_pis_for_action_input) 149 | 150 | self.diagnostics_model = tf.keras.Model( 151 | self.condition_inputs, 152 | (shift, log_scale_diag, log_pis, raw_actions, actions)) 153 | 154 | def _shift_and_log_scale_diag_net(self, input_shapes, output_size): 155 | raise NotImplementedError 156 | 157 | def get_weights(self): 158 | return self.actions_model.get_weights() 159 | 160 | def set_weights(self, *args, **kwargs): 161 | return self.actions_model.set_weights(*args, **kwargs) 162 | 163 | @property 164 | def trainable_variables(self): 165 | return self.actions_model.trainable_variables 166 | 167 | @property 168 | def non_trainable_weights(self): 169 | """Due to our nested model structure, we need to filter duplicates.""" 170 | return list(set(super(GaussianPolicy, self).non_trainable_weights)) 171 | 172 | def actions(self, conditions): 173 | if self._deterministic: 174 | return self.deterministic_actions_model(conditions) 175 | 176 | return self.actions_model(conditions) 177 | 178 | def log_pis(self, conditions, actions): 179 | assert not self._deterministic, self._deterministic 180 | return self.log_pis_model([*conditions, actions]) 181 | 182 | def actions_np(self, conditions): 183 | return super(GaussianPolicy, self).actions_np(conditions) 184 | 185 | def log_pis_np(self, conditions, actions): 186 | assert not self._deterministic, self._deterministic 187 | return self.log_pis_model.predict([*conditions, actions]) 188 | 189 | def get_diagnostics(self, conditions): 190 | """Return diagnostic information of the policy. 191 | 192 | Returns the mean, min, max, and standard deviation of means and 193 | covariances. 194 | """ 195 | (shifts_np, 196 | log_scale_diags_np, 197 | log_pis_np, 198 | raw_actions_np, 199 | actions_np) = self.diagnostics_model.predict(conditions) 200 | 201 | return OrderedDict({ 202 | 'shifts-mean': np.mean(shifts_np), 203 | 'shifts-std': np.std(shifts_np), 204 | 205 | 'log_scale_diags-mean': np.mean(log_scale_diags_np), 206 | 'log_scale_diags-std': np.std(log_scale_diags_np), 207 | 208 | '-log-pis-mean': np.mean(-log_pis_np), 209 | '-log-pis-std': np.std(-log_pis_np), 210 | 211 | 'raw-actions-mean': np.mean(raw_actions_np), 212 | 'raw-actions-std': np.std(raw_actions_np), 213 | 214 | 'actions-mean': np.mean(actions_np), 215 | 'actions-std': np.std(actions_np), 216 | }) 217 | 218 | 219 | class FeedforwardGaussianPolicy(GaussianPolicy): 220 | def __init__(self, 221 | hidden_layer_sizes, 222 | activation='relu', 223 | output_activation='linear', 224 | *args, **kwargs): 225 | self._hidden_layer_sizes = hidden_layer_sizes 226 | self._activation = activation 227 | self._output_activation = output_activation 228 | 229 | self._Serializable__initialize(locals()) 230 | super(FeedforwardGaussianPolicy, self).__init__(*args, **kwargs) 231 | 232 | def _shift_and_log_scale_diag_net(self, input_shapes, output_size): 233 | shift_and_log_scale_diag_net = feedforward_model( 234 | input_shapes=input_shapes, 235 | hidden_layer_sizes=self._hidden_layer_sizes, 236 | output_size=output_size, 237 | activation=self._activation, 238 | output_activation=self._output_activation) 239 | 240 | return shift_and_log_scale_diag_net 241 | 242 | def get_distribution(self, conditions): 243 | """Return diagnostic information of the policy. 244 | 245 | Returns the mean, min, max, and standard deviation of means and 246 | covariances. 247 | """ 248 | (shifts_np, 249 | log_scale_diags_np, 250 | log_pis_np, 251 | raw_actions_np, 252 | actions_np) = self.diagnostics_model.predict(conditions) 253 | 254 | return OrderedDict({ 255 | 'shifts': shifts_np, 256 | 'log_scale_diags': log_scale_diags_np, 257 | }) 258 | --------------------------------------------------------------------------------