├── env ├── __init__.py ├── mujoco_env.py ├── ant.py └── humanoid.py ├── requirements.txt ├── config ├── ant.py ├── hopper.py ├── swimmer.py ├── walker2d.py ├── invertedpendulum.py ├── halfcheetah.py ├── humanoid.py └── __init__.py ├── components ├── __init__.py ├── static_fns │ ├── swimmer.py │ ├── humanoid_truncated_obs.py │ ├── inverted_pendulum.py │ ├── ant_truncated_obs.py │ ├── walker2d.py │ ├── hopper.py │ ├── halfcheetah.py │ └── __init__.py ├── critic.py ├── network.py ├── actor.py ├── dynamics_model.py └── dynamics.py ├── README.md ├── utils ├── scaler.py ├── plotter.py └── logger.py ├── main.py ├── trainer ├── base_trainer.py └── mppve_trainer.py ├── buffer.py └── mppve.py /env/__init__.py: -------------------------------------------------------------------------------- 1 | from .mujoco_env import make_mujoco_env 2 | 3 | ENV = { 4 | "mujoco": make_mujoco_env 5 | } -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | gym==0.23.1 2 | matplotlib==3.5.2 3 | numpy==1.22.3 4 | setproctitle==1.2.3 5 | tqdm==4.64.0 6 | -------------------------------------------------------------------------------- /config/ant.py: -------------------------------------------------------------------------------- 1 | ant_config = { 2 | "target_entropy": -4, 3 | "plan_length": 2, 4 | "rollout_schedule": [20000, 150000, 1, 20], 5 | "n_steps": 300000 6 | } -------------------------------------------------------------------------------- /config/hopper.py: -------------------------------------------------------------------------------- 1 | hopper_config = { 2 | "target_entropy": -1, 3 | "plan_length": 3, 4 | "rollout_schedule": [20000, 50000, 1, 4], 5 | "n_steps": 100000 6 | } -------------------------------------------------------------------------------- /config/swimmer.py: -------------------------------------------------------------------------------- 1 | swimmer_config = { 2 | "target_entropy": -1, 3 | "plan_length": 3, 4 | "rollout_schedule": [20000, 100000, 1, 1], 5 | "n_steps": 200000 6 | } -------------------------------------------------------------------------------- /config/walker2d.py: -------------------------------------------------------------------------------- 1 | walker2d_config={ 2 | "target_entropy": -3, 3 | "plan_length": 2, 4 | "rollout_schedule": [20000, 100000, 1, 1], 5 | "n_steps": 200000 6 | } -------------------------------------------------------------------------------- /config/invertedpendulum.py: -------------------------------------------------------------------------------- 1 | inverted_pendulum_config = { 2 | "target_entropy": -0.05, 3 | "plan_length": 3, 4 | "rollout_schedule": [0, 1000, 1, 5], 5 | "n_steps": 100000 6 | } -------------------------------------------------------------------------------- /config/halfcheetah.py: -------------------------------------------------------------------------------- 1 | halfcheetah_config = { 2 | "ac_hidden_dims": [512, 512], 3 | "target_entropy": -3, 4 | "plan_length": 2, 5 | "rollout_schedule": [20000, 80000, 1, 4], 6 | "n_steps": 200000 7 | } -------------------------------------------------------------------------------- /components/__init__.py: -------------------------------------------------------------------------------- 1 | from .actor import ProbActor, DeterActor 2 | from .critic import Critic 3 | 4 | ACTOR = { 5 | "prob": ProbActor, 6 | "deter": DeterActor 7 | } 8 | 9 | CRITIC = { 10 | "q": Critic, 11 | "v": None 12 | } -------------------------------------------------------------------------------- /config/humanoid.py: -------------------------------------------------------------------------------- 1 | humanoid_config = { 2 | "target_entropy": -8, 3 | "plan_length": 2, 4 | "dynamics_hidden_dims": [400, 400, 400, 400], 5 | "model_update_interval": 1000, 6 | "model_retain_steps": 5000, 7 | "rollout_schedule": [20000, 300000, 1, 15], 8 | "n_steps": 300000 9 | } -------------------------------------------------------------------------------- /components/static_fns/swimmer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | class StaticFns: 4 | 5 | @staticmethod 6 | def termination_fn(obs, act, next_obs): 7 | assert len(obs.shape) == len(next_obs.shape) == len(act.shape) == 2 8 | 9 | done = np.array([False]).repeat(len(obs)) 10 | done = done[:,None] 11 | return done 12 | -------------------------------------------------------------------------------- /components/static_fns/humanoid_truncated_obs.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import numpy as np 3 | import pdb 4 | 5 | class StaticFns: 6 | 7 | @staticmethod 8 | def termination_fn(obs, act, next_obs): 9 | assert len(obs.shape) == len(next_obs.shape) == len(act.shape) == 2 10 | 11 | z = next_obs[:,0] 12 | done = (z < 1.0) + (z > 2.0) 13 | 14 | done = done[:,None] 15 | return done -------------------------------------------------------------------------------- /components/static_fns/inverted_pendulum.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import numpy as np 3 | import pdb 4 | 5 | class StaticFns: 6 | 7 | @staticmethod 8 | def termination_fn(obs, act, next_obs): 9 | assert len(obs.shape) == len(next_obs.shape) == len(act.shape) == 2 10 | 11 | notdone = np.isfinite(next_obs).all(axis=-1) \ 12 | * (np.abs(next_obs[:,1]) <= .2) 13 | done = ~notdone 14 | 15 | done = done[:,None] 16 | 17 | return done -------------------------------------------------------------------------------- /components/static_fns/ant_truncated_obs.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | class StaticFns: 4 | 5 | @staticmethod 6 | def termination_fn(obs, act, next_obs): 7 | assert len(obs.shape) == len(next_obs.shape) == len(act.shape) == 2 8 | 9 | x = next_obs[:, 0] 10 | not_done = np.isfinite(next_obs).all(axis=-1) \ 11 | * (x >= 0.2) \ 12 | * (x <= 1.0) 13 | 14 | done = ~not_done 15 | done = done[:,None] 16 | return done -------------------------------------------------------------------------------- /components/static_fns/walker2d.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | class StaticFns: 4 | 5 | @staticmethod 6 | def termination_fn(obs, act, next_obs): 7 | assert len(obs.shape) == len(next_obs.shape) == len(act.shape) == 2 8 | 9 | height = next_obs[:, 0] 10 | angle = next_obs[:, 1] 11 | not_done = (height > 0.8) \ 12 | * (height < 2.0) \ 13 | * (angle > -1.0) \ 14 | * (angle < 1.0) 15 | done = ~not_done 16 | done = done[:,None] 17 | return done 18 | -------------------------------------------------------------------------------- /env/mujoco_env.py: -------------------------------------------------------------------------------- 1 | import gym 2 | 3 | MBPO_ENVIRONMENT_SPECS = ( 4 | { 5 | "id": "AntTruncatedObs-v3", 6 | "entry_point": (f"env.ant:AntTruncatedObsEnv"), 7 | "max_episode_steps": 1000 8 | }, 9 | 10 | { 11 | "id": "HumanoidTruncatedObs-v3", 12 | "entry_point": (f"env.humanoid:HumanoidTruncatedObsEnv"), 13 | "max_episode_steps": 1000 14 | }, 15 | ) 16 | 17 | # register XxxTruncatedObs 18 | for mbpo_environment in MBPO_ENVIRONMENT_SPECS: 19 | gym.register(**mbpo_environment) 20 | 21 | make_mujoco_env = lambda env_name: gym.make(env_name) 22 | -------------------------------------------------------------------------------- /components/static_fns/hopper.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | class StaticFns: 4 | 5 | @staticmethod 6 | def termination_fn(obs, act, next_obs): 7 | assert len(obs.shape) == len(next_obs.shape) == len(act.shape) == 2 8 | 9 | height = next_obs[:, 0] 10 | angle = next_obs[:, 1] 11 | not_done = np.isfinite(next_obs).all(axis=-1) \ 12 | * np.abs(next_obs[:,1:] < 100).all(axis=-1) \ 13 | * (height > .7) \ 14 | * (np.abs(angle) < .2) 15 | 16 | done = ~not_done 17 | done = done[:,None] 18 | return done 19 | -------------------------------------------------------------------------------- /components/static_fns/halfcheetah.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | class StaticFns: 4 | 5 | @staticmethod 6 | def termination_fn(obs, act, next_obs): 7 | assert len(obs.shape) == len(next_obs.shape) == len(act.shape) == 2 8 | 9 | done = np.array([False]).repeat(len(obs)) 10 | done = done[:,None] 11 | return done 12 | 13 | @staticmethod 14 | def recompute_reward_fn(obs, act, next_obs, rew): 15 | assert len(obs.shape) == len(next_obs.shape) == len(act.shape) == 2 16 | 17 | new_rew = -(rew + 0.1 * np.sum(np.square(act))) - 0.1 * np.sum(np.square(act)) 18 | return new_rew 19 | -------------------------------------------------------------------------------- /config/__init__.py: -------------------------------------------------------------------------------- 1 | from config.invertedpendulum import inverted_pendulum_config 2 | from config.hopper import hopper_config 3 | from config.swimmer import swimmer_config 4 | from config.walker2d import walker2d_config 5 | from config.halfcheetah import halfcheetah_config 6 | from config.ant import ant_config 7 | from config.humanoid import humanoid_config 8 | 9 | 10 | CONFIG = { 11 | "InvertedPendulum": inverted_pendulum_config, 12 | "Hopper": hopper_config, 13 | "Swimmer": swimmer_config, 14 | "Walker2d": walker2d_config, 15 | "HalfCheetah": halfcheetah_config, 16 | "AntTruncatedObs": ant_config, 17 | "HumanoidTruncatedObs": humanoid_config 18 | } -------------------------------------------------------------------------------- /components/static_fns/__init__.py: -------------------------------------------------------------------------------- 1 | from .hopper import StaticFns as HopperStaticFns 2 | from .swimmer import StaticFns as SwimmerStaticFns 3 | from .walker2d import StaticFns as Walker2dStaticFns 4 | from .halfcheetah import StaticFns as HalfcheetahStaticFns 5 | from .inverted_pendulum import StaticFns as InvertedPendulumFns 6 | from .ant_truncated_obs import StaticFns as AntTruncatedObsStaticFns 7 | from .humanoid_truncated_obs import StaticFns as HumanoidTruncatedObsStaticFns 8 | 9 | STATICFUNC = { 10 | "Hopper": HopperStaticFns, 11 | "Swimmer": SwimmerStaticFns, 12 | "Walker2d": Walker2dStaticFns, 13 | "HalfCheetah": HalfcheetahStaticFns, 14 | "InvertedPendulum": InvertedPendulumFns, 15 | "AntTruncatedObs": AntTruncatedObsStaticFns, 16 | "HumanoidTruncatedObs": HumanoidTruncatedObsStaticFns 17 | } -------------------------------------------------------------------------------- /components/critic.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import torch.nn as nn 4 | from components.network import MLP 5 | 6 | 7 | class Critic(nn.Module): 8 | """ Q(s,a) """ 9 | def __init__(self, obs_shape, hidden_dims, action_dim, device="cpu"): 10 | super(Critic, self).__init__() 11 | self.device = torch.device(device) 12 | self.backbone = MLP(input_dim=np.prod(obs_shape)+action_dim, hidden_dims=hidden_dims).to(self.device) 13 | latent_dim = getattr(self.backbone, "output_dim") 14 | self.last = nn.Linear(latent_dim, 1).to(self.device) 15 | 16 | def forward(self, obs, actions): 17 | """ return Q(s,a) """ 18 | obs = torch.as_tensor(obs, dtype=torch.float32, device=self.device) 19 | actions = torch.as_tensor(actions, dtype=torch.float32, device=self.device) 20 | net_in = torch.cat([obs, actions], dim=1) 21 | logits = self.backbone(net_in) 22 | values = self.last(logits) 23 | return values 24 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # MPPVE: Model-based Planning Policy Learning with Multi-step Plan Value Estimation 2 | 3 | This is the code for the paper "Model-based Reinforcement Learning with Multi-step Plan Value Estimation". 4 | 5 | ## Installation instructions 6 | 7 | Install Python environment with: 8 | 9 | ```bash 10 | conda create -n mppve python=3.9 -y 11 | conda activate mppve 12 | conda install pytorch cudatoolkit=11.3 -c pytorch -y 13 | pip install -r ./requirements.txt 14 | ``` 15 | 16 | ## Run an experiment 17 | 18 | ```shell 19 | python3 main.py --env-name=[Env name] 20 | ``` 21 | 22 | The config files located in `config` act as defaults for a task. `env-name` refers to the config files in `config/` including Hopper-v3, Walker2d-v3, Swimmer-v3, HalfCheetah-v3, AntTruncatedObs-v3 and HumanoidTruncatedObs-v3. 23 | 24 | All results will be stored in the `result` folder. 25 | 26 | For example, run MPPVE on Hopper: 27 | 28 | ```bash 29 | python main.py --env-name=Hopper-v3 30 | ``` 31 | 32 | 33 | ## Citation 34 | If you find this repository useful for your research, please cite: 35 | ``` 36 | @inproceedings{ 37 | mppve, 38 | title={Model-based Reinforcement Learning with Multi-step Plan Value Estimation}, 39 | author={Haoxin Lin and Yihao Sun and Jiaji Zhang and Yang Yu}, 40 | booktitle={Proceedings of the 26th European Conference on Artificial Intelligence (ECAI'23)}, 41 | address={Kraków, Poland}, 42 | year=2023 43 | } 44 | ``` -------------------------------------------------------------------------------- /env/ant.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from gym.envs.mujoco.ant_v3 import AntEnv 3 | 4 | DEFAULT_CAMERA_CONFIG = { 5 | "distance": 4.0, 6 | } 7 | 8 | class AntTruncatedObsEnv(AntEnv): 9 | """ External forces (sim.data.cfrc_ext) are removed from the observation """ 10 | def __init__( 11 | self, 12 | xml_file="ant.xml", 13 | ctrl_cost_weight=0.5, 14 | contact_cost_weight=5e-4, 15 | healthy_reward=1.0, 16 | terminate_when_unhealthy=True, 17 | healthy_z_range=(0.2, 1.0), 18 | contact_force_range=(-1.0, 1.0), 19 | reset_noise_scale=0.1, 20 | exclude_current_positions_from_observation=True, 21 | ): 22 | super(AntTruncatedObsEnv, self).__init__( 23 | xml_file, 24 | ctrl_cost_weight, 25 | contact_cost_weight, 26 | healthy_reward, 27 | terminate_when_unhealthy, 28 | healthy_z_range, 29 | contact_force_range, 30 | reset_noise_scale, 31 | exclude_current_positions_from_observation 32 | ) 33 | 34 | def _get_obs(self): 35 | position = self.sim.data.qpos.flat.copy() 36 | velocity = self.sim.data.qvel.flat.copy() 37 | # contact_force = self.contact_forces.flat.copy() 38 | 39 | if self._exclude_current_positions_from_observation: 40 | position = position[2:] 41 | 42 | # observations = np.concatenate((position, velocity, contact_force)) 43 | observations = np.concatenate((position, velocity)) 44 | 45 | return observations 46 | -------------------------------------------------------------------------------- /env/humanoid.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from gym.envs.mujoco.humanoid_v3 import HumanoidEnv 3 | 4 | class HumanoidTruncatedObsEnv(HumanoidEnv): 5 | """ 6 | COM inertia (cinert), COM velocity (cvel), actuator forces (qfrc_actuator), 7 | and external forces (cfrc_ext) are removed from the observation. 8 | """ 9 | def __init__( 10 | self, 11 | xml_file="humanoid.xml", 12 | forward_reward_weight=1.25, 13 | ctrl_cost_weight=0.1, 14 | contact_cost_weight=5e-7, 15 | contact_cost_range=(-np.inf, 10.0), 16 | healthy_reward=5.0, 17 | terminate_when_unhealthy=True, 18 | healthy_z_range=(1.0, 2.0), 19 | reset_noise_scale=1e-2, 20 | exclude_current_positions_from_observation=True, 21 | ): 22 | super(HumanoidTruncatedObsEnv, self).__init__( 23 | xml_file, 24 | forward_reward_weight, 25 | ctrl_cost_weight, 26 | contact_cost_weight, 27 | contact_cost_range, 28 | healthy_reward, 29 | terminate_when_unhealthy, 30 | healthy_z_range, 31 | reset_noise_scale, 32 | exclude_current_positions_from_observation 33 | ) 34 | 35 | def _get_obs(self): 36 | position = self.sim.data.qpos.flat.copy() 37 | velocity = self.sim.data.qvel.flat.copy() 38 | 39 | # com_inertia = self.sim.data.cinert.flat.copy() 40 | # com_velocity = self.sim.data.cvel.flat.copy() 41 | 42 | # actuator_forces = self.sim.data.qfrc_actuator.flat.copy() 43 | # external_contact_forces = self.sim.data.cfrc_ext.flat.copy() 44 | 45 | if self._exclude_current_positions_from_observation: 46 | position = position[2:] 47 | 48 | return np.concatenate( 49 | ( 50 | position, 51 | velocity, 52 | # com_inertia, 53 | # com_velocity, 54 | # actuator_forces, 55 | # external_contact_forces, 56 | ) 57 | ) 58 | -------------------------------------------------------------------------------- /utils/scaler.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os.path as path 3 | import torch 4 | 5 | 6 | class StandardScaler(object): 7 | def __init__(self): 8 | pass 9 | 10 | def fit(self, data): 11 | """Runs two ops, one for assigning the mean of the data to the internal mean, and 12 | another for assigning the standard deviation of the data to the internal standard deviation. 13 | This function must be called within a 'with .as_default()' block. 14 | 15 | Arguments: 16 | data (np.ndarray): A numpy array containing the input 17 | 18 | Returns: None. 19 | """ 20 | self.mu = np.mean(data, axis=0, keepdims=True) 21 | self.std = np.std(data, axis=0, keepdims=True) 22 | self.std[self.std < 1e-12] = 1.0 23 | 24 | def transform(self, data): 25 | """Transforms the input matrix data using the parameters of this scaler. 26 | 27 | Arguments: 28 | data (np.array): A numpy array containing the points to be transformed. 29 | 30 | Returns: (np.array) The transformed dataset. 31 | """ 32 | return (data - self.mu) / self.std 33 | 34 | def inverse_transform(self, data): 35 | """Undoes the transformation performed by this scaler. 36 | 37 | Arguments: 38 | data (np.array): A numpy array containing the points to be transformed. 39 | 40 | Returns: (np.array) The transformed dataset. 41 | """ 42 | return self.std * data + self.mu 43 | 44 | def save_scaler(self, save_path): 45 | mu_path = path.join(save_path, "mu.npy") 46 | std_path = path.join(save_path, "std.npy") 47 | np.save(mu_path, self.mu) 48 | np.save(std_path, self.std) 49 | 50 | def load_scaler(self, load_path): 51 | mu_path = path.join(load_path, "mu.npy") 52 | std_path = path.join(load_path, "std.npy") 53 | self.mu = np.load(mu_path) 54 | self.std = np.load(std_path) 55 | 56 | def transform_tensor(self, obs_action: torch.Tensor, device): 57 | obs_action = obs_action.cpu().numpy() 58 | obs_action = self.transform(obs_action) 59 | obs_action = torch.tensor(obs_action, device=device) 60 | return obs_action -------------------------------------------------------------------------------- /components/network.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class MLP(nn.Module): 6 | """ multi-layer perceptron """ 7 | def __init__(self, input_dim, hidden_dims, output_dim=None, activation=nn.ReLU): 8 | super(MLP, self).__init__() 9 | 10 | # hidden layers 11 | dims = [input_dim] + list(hidden_dims) 12 | layers = [] 13 | for in_dim, out_dim in zip(dims[:-1], dims[1:]): 14 | layers += [nn.Linear(in_dim, out_dim), activation()] 15 | 16 | self.output_dim = dims[-1] 17 | if output_dim is not None: 18 | layers += [nn.Linear(dims[-1], output_dim)] 19 | self.output_dim = output_dim 20 | self.net = nn.Sequential(*layers) 21 | 22 | def forward(self, x): 23 | return self.net(x) 24 | 25 | 26 | class EnsembleLinear(nn.Module): 27 | def __init__( 28 | self, 29 | input_dim, 30 | output_dim, 31 | num_ensemble, 32 | num_elites, 33 | weight_decay=0.0, 34 | load_model=False 35 | ): 36 | super().__init__() 37 | 38 | self.num_ensemble = num_ensemble 39 | self.num_elites = num_elites 40 | 41 | self.register_parameter("weight", nn.Parameter(torch.zeros(num_ensemble, input_dim, output_dim))) 42 | self.register_parameter("bias", nn.Parameter(torch.zeros(num_ensemble, 1, output_dim))) 43 | 44 | nn.init.trunc_normal_(self.weight, std=1/(2*input_dim**0.5)) 45 | 46 | self.register_parameter("saved_weight", nn.Parameter(self.weight.detach().clone())) 47 | self.register_parameter("saved_bias", nn.Parameter(self.bias.detach().clone())) 48 | 49 | if not load_model: 50 | self.register_parameter("elites", nn.Parameter(torch.tensor(list(range(0, self.num_ensemble))), requires_grad=False)) 51 | else: 52 | self.register_parameter("elites", nn.Parameter(torch.tensor(list(range(0, self.num_elites))), requires_grad=False)) 53 | 54 | self.weight_decay = weight_decay 55 | 56 | def forward(self, x): 57 | weight = self.weight[self.elites] 58 | bias = self.bias[self.elites] 59 | 60 | if len(x.shape) == 2: 61 | x = torch.einsum('ij,bjk->bik', x, weight) 62 | else: 63 | x = torch.einsum('bij,bjk->bik', x, weight) 64 | 65 | x = x + bias 66 | 67 | return x 68 | 69 | def set_elites(self, indexes): 70 | assert len(indexes) <= self.num_ensemble and max(indexes) < self.num_ensemble 71 | self.register_parameter('elites', nn.Parameter(torch.tensor(indexes), requires_grad=False)) 72 | self.weight.data.copy_(self.saved_weight.data) 73 | self.bias.data.copy_(self.saved_bias.data) 74 | 75 | def update_save(self, indexes): 76 | self.saved_weight.data[indexes] = self.weight.data[indexes] 77 | self.saved_bias.data[indexes] = self.bias.data[indexes] 78 | 79 | def reset_elites(self): 80 | self.register_parameter('elites', nn.Parameter(torch.tensor(list(range(0, self.num_ensemble))), requires_grad=False)) 81 | 82 | def get_decay_loss(self): 83 | decay_loss = self.weight_decay * (0.5*((self.weight**2).sum())) 84 | return decay_loss 85 | -------------------------------------------------------------------------------- /components/actor.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import torch.nn as nn 4 | from components.network import MLP 5 | 6 | 7 | class NormalWrapper(torch.distributions.Normal): 8 | """ wrapper of normal distribution """ 9 | def log_prob(self, actions): 10 | return super().log_prob(actions).sum(-1, keepdim=True) 11 | 12 | def entropy(self): 13 | return super().entropy().sum(-1) 14 | 15 | def mode(self): 16 | return self.mean 17 | 18 | 19 | class DiagGaussian(nn.Module): 20 | """ independent Gaussian """ 21 | def __init__( 22 | self, 23 | latent_dim, 24 | output_dim, 25 | unbounded=False, 26 | conditioned_sigma=False, 27 | max_mu=1.0, 28 | sigma_min=-20, 29 | sigma_max=2 30 | ): 31 | super(DiagGaussian, self).__init__() 32 | self.mu = nn.Linear(latent_dim, output_dim) 33 | self._c_sigma = conditioned_sigma 34 | if conditioned_sigma: 35 | self.sigma = nn.Linear(latent_dim, output_dim) 36 | else: 37 | self.sigma_param = nn.Parameter(torch.zeros(output_dim, 1)) 38 | self._unbounded = unbounded 39 | self._max = max_mu 40 | self._sigma_min = sigma_min 41 | self._sigma_max = sigma_max 42 | 43 | def forward(self, logits): 44 | mu = self.mu(logits) 45 | if not self._unbounded: 46 | mu = self._max * torch.tanh(mu) 47 | if self._c_sigma: 48 | sigma = torch.clamp(self.sigma(logits), min=self._sigma_min, max=self._sigma_max).exp() 49 | else: 50 | shape = [1] * len(mu.shape) 51 | shape[1] = -1 52 | sigma = (self.sigma_param.view(shape) + torch.zeros_like(mu)).exp() 53 | return NormalWrapper(mu, sigma) 54 | 55 | 56 | class ProbActor(nn.Module): 57 | """ stochastic actor for PPO/A2C/SAC """ 58 | def __init__(self, obs_shape, hidden_dims, action_dim, device="cpu"): 59 | super(ProbActor, self).__init__() 60 | self.device = torch.device(device) 61 | self.backbone = MLP(input_dim=np.prod(obs_shape), hidden_dims=hidden_dims).to(self.device) 62 | self.dist_net = DiagGaussian( 63 | latent_dim=getattr(self.backbone, "output_dim"), 64 | output_dim=action_dim, 65 | unbounded=True, 66 | conditioned_sigma=True 67 | ).to(self.device) 68 | 69 | def forward(self, obs): 70 | """ return prob distribution among actions """ 71 | obs = torch.as_tensor(obs, dtype=torch.float32, device=self.device) 72 | logits = self.backbone(obs) 73 | dist = self.dist_net(logits) 74 | return dist 75 | 76 | 77 | class DeterActor(nn.Module): 78 | """ deterministic actor for DDPG/TD3 """ 79 | def __init__(self, obs_shape, hidden_dims, action_dim, max_action, device="cpu"): 80 | super(DeterActor, self).__init__() 81 | self.device = torch.device(device) 82 | self.backbone = MLP(input_dim=np.prod(obs_shape), hidden_dims=hidden_dims).to(self.device) 83 | self.to_action = nn.Linear(getattr(self.backbone, "output_dim"), action_dim).to(self.device) 84 | self.max_action = max_action 85 | 86 | def forward(self, obs): 87 | """ return deterministic action """ 88 | obs = torch.as_tensor(obs, dtype=torch.float32, device=self.device) 89 | logits = self.backbone(obs) 90 | a = self.max_action*torch.tanh(self.to_action(logits)) 91 | return a 92 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import argparse 4 | 5 | import torch 6 | import numpy as np 7 | 8 | from config import CONFIG 9 | from trainer.mppve_trainer import MPPVETrainer 10 | 11 | 12 | def get_args(): 13 | parser = argparse.ArgumentParser(description="DRL") 14 | 15 | # environment settings 16 | parser.add_argument("--env", type=str, default="mujoco") 17 | parser.add_argument("--env-name", type=str, default="Hopper-v3") 18 | 19 | # algorithm parameters 20 | parser.add_argument("--algo", type=str, default="mppve") 21 | parser.add_argument("--ac-hidden-dims", type=int, nargs='*', default=[256, 256]) 22 | parser.add_argument("--actor-lr", type=float, default=3e-4) 23 | parser.add_argument("--critic-lr", type=float, default=3e-4) 24 | parser.add_argument("--gamma", type=float, default=0.99) 25 | parser.add_argument("--tau", type=float, default=0.005) 26 | parser.add_argument("--plan-length", type=int, default=3) 27 | # (for sac) 28 | parser.add_argument("--alpha", type=float, default=0.2) 29 | parser.add_argument("--auto-alpha", type=bool, default=True) 30 | parser.add_argument("--alpha-lr", type=float, default=3e-4) 31 | parser.add_argument("--target-entropy", type=int, default=-1) 32 | 33 | # replay-buffer parameters 34 | parser.add_argument("--buffer-size", type=int, default=int(1e6)) 35 | 36 | # dynamics-model parameters 37 | parser.add_argument("--dynamics-hidden-dims", type=int, nargs='*', default=[200, 200, 200, 200]) 38 | parser.add_argument("--dynamics-weight-decay", type=float, nargs='*', default=[2.5e-5, 5e-5, 7.5e-5, 7.5e-5, 1e-4]) 39 | parser.add_argument("--n-ensembles", type=int, default=7) 40 | parser.add_argument("--n-elites", type=int, default=5) 41 | parser.add_argument("--rollout-batch-size", type=int, default=int(1e5)) 42 | parser.add_argument("--rollout-schedule", type=int, nargs='*', default=[int(2e4), int(5e4), 1, 4]) 43 | parser.add_argument("--model-update-interval", type=int, default=250) 44 | parser.add_argument("--model-retain-steps", type=int, default=1000) 45 | parser.add_argument("--real-ratio", type=float, default=0.05) 46 | 47 | # running parameters 48 | parser.add_argument("--n-steps", type=int, default=int(1e5)) 49 | parser.add_argument("--start-learning", type=int, default=int(5e3)) 50 | parser.add_argument("--update-interval", type=int, default=1) 51 | parser.add_argument("--updates-per-step", type=int, default=20) 52 | parser.add_argument("--actor-freq", type=int, default=20) 53 | parser.add_argument("--batch-size", type=int, default=256) 54 | parser.add_argument("--eval-interval", type=int, default=int(1e3)) 55 | parser.add_argument("--eval-n-episodes", type=int, default=10) 56 | parser.add_argument("--render", action="store_true", default=False) 57 | parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu") 58 | parser.add_argument("--seed", type=int, default=0) 59 | 60 | args = parser.parse_args() 61 | return args 62 | 63 | def main(): 64 | args = vars(get_args()) 65 | config = CONFIG[args["env_name"].split('-')[0]] 66 | for k, v in config.items(): 67 | args[k] = v 68 | args = argparse.Namespace(**args) 69 | 70 | # set seed 71 | random.seed(args.seed) 72 | np.random.seed(args.seed) 73 | os.environ["PYTHONHASHSEED"] = str(args.seed) 74 | torch.manual_seed(args.seed) 75 | torch.cuda.manual_seed(args.seed) 76 | torch.cuda.manual_seed_all(args.seed) 77 | torch.backends.cudnn.deterministic = True 78 | torch.backends.cudnn.benchmark = False 79 | 80 | trainer = MPPVETrainer(args) 81 | trainer.train() 82 | 83 | if __name__ == "__main__": 84 | main() 85 | -------------------------------------------------------------------------------- /components/dynamics_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.functional import F 4 | 5 | from components.network import EnsembleLinear 6 | 7 | 8 | class Swish(nn.Module): 9 | def forward(self, x): 10 | return x * torch.sigmoid(x) 11 | 12 | 13 | def soft_clamp(x : torch.Tensor, _min=None, _max=None): 14 | # clamp tensor values while mataining the gradient 15 | if _max is not None: 16 | x = _max - F.softplus(_max - x) 17 | if _min is not None: 18 | x = _min + F.softplus(x - _min) 19 | return x 20 | 21 | 22 | class EnsembleDynamicsModel(nn.Module): 23 | def __init__( 24 | self, 25 | obs_dim, 26 | action_dim, 27 | hidden_dims, 28 | num_ensemble=7, 29 | num_elites=5, 30 | activation=Swish, 31 | weight_decays=None, 32 | with_reward=True, 33 | load_model=False, 34 | device="cpu" 35 | ) -> None: 36 | super().__init__() 37 | 38 | self.num_ensemble = num_ensemble 39 | self.num_elites = num_elites 40 | self._with_reward = with_reward 41 | self.load_model = load_model 42 | self.device = torch.device(device) 43 | 44 | self.activation = activation() 45 | 46 | assert len(weight_decays) == (len(hidden_dims) + 1) 47 | 48 | module_list = [] 49 | hidden_dims = [obs_dim+action_dim] + list(hidden_dims) 50 | if weight_decays is None: 51 | weight_decays = [0.0] * (len(hidden_dims) + 1) 52 | for in_dim, out_dim, weight_decay in zip(hidden_dims[:-1], hidden_dims[1:], weight_decays[:-1]): 53 | module_list.append(EnsembleLinear(in_dim, out_dim, num_ensemble, num_elites, weight_decay, load_model)) 54 | self.backbones = nn.ModuleList(module_list) 55 | 56 | self.output_layer = EnsembleLinear( 57 | hidden_dims[-1], 58 | 2 * (obs_dim + self._with_reward), 59 | num_ensemble, 60 | num_elites, 61 | weight_decays[-1], 62 | load_model 63 | ) 64 | 65 | self.register_parameter( 66 | "max_logvar", 67 | nn.Parameter(torch.ones(obs_dim + self._with_reward) * 0.5, requires_grad=True) 68 | ) 69 | self.register_parameter( 70 | "min_logvar", 71 | nn.Parameter(torch.ones(obs_dim + self._with_reward) * -10, requires_grad=True) 72 | ) 73 | 74 | self.to(self.device) 75 | 76 | def forward(self, obs_action): 77 | obs_action = torch.as_tensor(obs_action, dtype=torch.float32, device=self.device) 78 | output = obs_action 79 | for layer in self.backbones: 80 | output = self.activation(layer(output)) 81 | mean, logvar = torch.chunk(self.output_layer(output), 2, dim=-1) 82 | logvar = soft_clamp(logvar, self.min_logvar, self.max_logvar) 83 | return mean, logvar 84 | 85 | def set_elites(self, indexes): 86 | for layer in self.backbones: 87 | layer.set_elites(indexes) 88 | self.output_layer.set_elites(indexes) 89 | 90 | def reset_elites(self): 91 | for layer in self.backbones: 92 | layer.reset_elites() 93 | self.output_layer.reset_elites() 94 | 95 | def update_save(self, indexes): 96 | for layer in self.backbones: 97 | layer.update_save(indexes) 98 | self.output_layer.update_save(indexes) 99 | 100 | def get_decay_loss(self): 101 | decay_loss = 0 102 | for layer in self.backbones: 103 | decay_loss += layer.get_decay_loss() 104 | decay_loss += self.output_layer.get_decay_loss() 105 | return decay_loss -------------------------------------------------------------------------------- /trainer/base_trainer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from tqdm import tqdm 3 | 4 | from env import ENV 5 | from utils.logger import Logger, make_log_dirs 6 | 7 | 8 | class BASETrainer: 9 | """ base trainer """ 10 | def __init__(self, args): 11 | # init env 12 | self.env = ENV[args.env](args.env_name) 13 | self.env.seed(args.seed) 14 | self.env.action_space.seed(args.seed) 15 | 16 | self.eval_env = ENV[args.env](args.env_name) 17 | self.eval_env.seed(args.seed) 18 | self.eval_env.action_space.seed(args.seed) 19 | 20 | args.obs_shape = self.env.observation_space.shape 21 | args.action_dim = int(np.prod(self.env.action_space.shape)) 22 | 23 | # logger 24 | log_dirs = make_log_dirs(args.env_name, args.algo, args.seed, vars(args)) 25 | # key: output file name, value: output handler type 26 | output_config = { 27 | "consoleout_backup": "stdout", 28 | "progress": "csv", 29 | "tb": "tensorboard" 30 | } 31 | self.logger = Logger(log_dirs, output_config) 32 | self.logger.log_hyperparameters(vars(args)) 33 | 34 | # running parameters 35 | self.n_steps = args.n_steps 36 | self.start_learning = args.start_learning 37 | self.update_interval = args.update_interval 38 | self.batch_size = args.batch_size 39 | self.eval_interval = args.eval_interval 40 | self.eval_n_episodes = args.eval_n_episodes 41 | self.render = args.render 42 | self.device = args.device 43 | self.seed = args.seed 44 | self.args = args 45 | 46 | def _warm_up(self): 47 | """ randomly sample a lot of transitions into buffer before starting learning """ 48 | obs = self.env.reset() 49 | 50 | # step for {self.start_learning} time-steps 51 | pbar = tqdm(range(self.start_learning), desc="Warming up") 52 | for _ in pbar: 53 | action = self.env.action_space.sample() 54 | next_obs, reward, done, info = self.env.step(action) 55 | timeout = info.get("TimeLimit.truncated", False) 56 | self.memory.store(obs, action, reward, next_obs, done, timeout) 57 | 58 | obs = next_obs 59 | if done: obs = self.env.reset() 60 | 61 | return obs 62 | 63 | def _eval_policy(self): 64 | """ evaluate policy """ 65 | episode_rewards = [] 66 | for _ in range(self.eval_n_episodes): 67 | done = False 68 | episode_rewards.append(0) 69 | obs = self.eval_env.reset() 70 | while not done: 71 | action = self.agent.act(obs, deterministic=True) 72 | obs, reward, done, _ = self.eval_env.step(action) 73 | episode_rewards[-1] += reward 74 | return episode_rewards 75 | 76 | def _eval_value_estimation(self): 77 | """ evaluate value estimation""" 78 | value_bias_mean, value_bias_std = [], [] 79 | for _ in range(self.eval_n_episodes): 80 | rewards = [] 81 | log_probs = [] 82 | value_preds = [] 83 | obs = self.eval_env.reset() 84 | done = False 85 | while not done: 86 | action, log_prob = self.agent.act(obs, deterministic=False, return_logprob=True) 87 | value_preds.append(self.agent.value(obs, action)[0]) 88 | obs, reward, done, info = self.eval_env.step(action) 89 | rewards.append(reward) 90 | log_probs.append(log_prob.flatten()[0]) 91 | 92 | timeout = info.get("TimeLimit.truncated", False) 93 | returns = [] 94 | if timeout: 95 | action, log_prob = self.agent.act(obs, deterministic=False, return_logprob=True) 96 | next_value = self.agent.value(obs, action)[0] 97 | returns.append(next_value) 98 | log_probs.append(log_prob.flatten()[0]) 99 | else: 100 | returns.append(0) 101 | log_probs.append(0) 102 | for r in reversed(rewards): 103 | returns.append(r + self.agent._gamma * (returns[-1] - self.agent._alpha.cpu().item()*log_probs[-1])) 104 | log_probs.pop() 105 | 106 | returns = np.array(list(reversed(returns[1:]))).flatten() 107 | value_preds = np.array(value_preds).flatten() 108 | 109 | value_bias_mean.append(((value_preds - returns) / (np.abs(returns.mean())+1e-5)).mean()) 110 | value_bias_std.append(((value_preds - returns) / (np.abs(returns.mean())+1e-5)).std()) 111 | 112 | return { 113 | "value_bias_mean": np.mean(value_bias_mean), 114 | "value_bias_std": np.mean(value_bias_std) 115 | } -------------------------------------------------------------------------------- /utils/plotter.py: -------------------------------------------------------------------------------- 1 | import csv 2 | import os 3 | 4 | import numpy as np 5 | import pandas as pd 6 | 7 | import matplotlib.pyplot as plt 8 | import argparse 9 | 10 | 11 | COLORS = ( 12 | [ 13 | '#318DE9', # blue 14 | '#FF7D00', # orange 15 | '#E52B50', # red 16 | '#7B68EE', # purple 17 | '#00CD66', # green 18 | '#FFD700', # yellow 19 | ] 20 | ) 21 | 22 | 23 | def merge_csv(root_dir, query_file, query_x, query_y): 24 | """Merge result in csv_files into a single csv file.""" 25 | csv_files = [] 26 | for dirname, _, files in os.walk(root_dir): 27 | for f in files: 28 | if f == query_file: 29 | csv_files.append(os.path.join(dirname, f)) 30 | results = {} 31 | for csv_file in csv_files: 32 | content = [[query_x, query_y]] 33 | df = pd.read_csv(csv_file) 34 | values = df[[query_x, query_y]].values 35 | for line in values: 36 | if np.isnan(line[1]): continue 37 | content.append(line) 38 | results[csv_file] = content 39 | assert len(results) > 0 40 | sorted_keys = sorted(results.keys()) 41 | sorted_values = [results[k][1:] for k in sorted_keys] 42 | content = [ 43 | [query_x, query_y+'_mean', query_y+'_std'] 44 | ] 45 | for rows in zip(*sorted_values): 46 | array = np.array(rows) 47 | assert len(set(array[:, 0])) == 1, (set(array[:, 0]), array[:, 0]) 48 | line = [rows[0][0], round(array[:, 1].mean(), 4), round(array[:, 1].std(), 4)] 49 | content.append(line) 50 | output_path = os.path.join(root_dir, query_y.replace('/', '_')+".csv") 51 | print(f"Output merged csv file to {output_path} with {len(content[1:])} lines.") 52 | csv.writer(open(output_path, "w")).writerows(content) 53 | return output_path 54 | 55 | 56 | def csv2numpy(file_path): 57 | df = pd.read_csv(file_path) 58 | step = df.iloc[:,0].to_numpy() 59 | mean = df.iloc[:,1].to_numpy() 60 | std = df.iloc[:,2].to_numpy() 61 | return step, mean, std 62 | 63 | 64 | def smooth(y, radius=0): 65 | convkernel = np.ones(2 * radius + 1) 66 | out = np.convolve(y, convkernel, mode='same') / np.convolve(np.ones_like(y), convkernel, mode='same') 67 | return out 68 | 69 | 70 | def plot_figure( 71 | results, 72 | x_label, 73 | y_label, 74 | title=None, 75 | smooth_radius=10, 76 | figsize=None, 77 | dpi=None, 78 | color_list=None 79 | ): 80 | fig, ax = plt.subplots(figsize=figsize, dpi=dpi) 81 | if color_list == None: 82 | color_list = [COLORS[i] for i in range(len(results))] 83 | else: 84 | assert len(color_list) == len(results) 85 | for i, (algo_name, csv_file) in enumerate(results.items()): 86 | x, y, shaded = csv2numpy(csv_file) 87 | y = smooth(y, smooth_radius) 88 | shaded = smooth(shaded, smooth_radius) 89 | ax.plot(x, y, color=color_list[i], label=algo_name) 90 | ax.fill_between(x, y-shaded, y+shaded, color=color_list[i], alpha=0.2) 91 | ax.set_title(title) 92 | ax.set_xlabel(x_label) 93 | ax.set_ylabel(y_label) 94 | ax.legend() 95 | 96 | 97 | if __name__ == "__main__": 98 | parser = argparse.ArgumentParser(description="plotter") 99 | parser.add_argument("--root-dir", default="log") 100 | parser.add_argument("--task", default="Hopper-v3") 101 | parser.add_argument("--algos", type=str, nargs='*', default=["mppve"]) 102 | parser.add_argument("--query-file", default="progress.csv") 103 | parser.add_argument("--query-x", default="timestep") 104 | parser.add_argument("--query-y", default="eval/episode_rewards") 105 | parser.add_argument("--title", default=None) 106 | parser.add_argument("--xlabel", default="timestep") 107 | parser.add_argument("--ylabel", default="Test Reward Mean") 108 | parser.add_argument("--smooth", type=int, default=10) 109 | parser.add_argument("--colors", type=str, nargs='*', default=None) 110 | parser.add_argument("--show", action='store_true') 111 | parser.add_argument("--output-path", default="./figure.png") 112 | parser.add_argument("--figsize", type=float, nargs=2, default=(5, 5)) 113 | parser.add_argument("--dpi", type=int, default=200) 114 | args = parser.parse_args() 115 | 116 | results = {} 117 | for algo in args.algos: 118 | path = os.path.join(args.root_dir, args.task, algo) 119 | csv_file = merge_csv(path, args.query_file, args.query_x, args.query_y) 120 | results[algo] = csv_file 121 | 122 | plt.style.use('seaborn') 123 | plot_figure( 124 | results=results, 125 | x_label=args.xlabel, 126 | y_label=args.ylabel, 127 | title=args.title, 128 | smooth_radius=args.smooth, 129 | figsize=args.figsize, 130 | dpi=args.dpi, 131 | color_list=args.colors 132 | ) 133 | if args.output_path: 134 | plt.savefig(args.output_path, bbox_inches='tight') 135 | if args.show: 136 | plt.show() -------------------------------------------------------------------------------- /components/dynamics.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy as np 4 | import torch 5 | 6 | from utils.scaler import StandardScaler 7 | 8 | 9 | class Dynamics: 10 | def __init__(self, model, static_fn): 11 | self.model = model 12 | self.optim = torch.optim.Adam(self.model.parameters(), lr=1e-3) 13 | self.scaler = StandardScaler() 14 | self.static_fn = static_fn 15 | 16 | @ torch.no_grad() 17 | def step(self, obs, action): 18 | obs_act = np.concatenate([obs, action], axis=-1) 19 | obs_act = self.scaler.transform(obs_act) 20 | mean, logvar = self.model(obs_act) 21 | mean = mean.cpu().numpy() 22 | logvar = logvar.cpu().numpy() 23 | mean[..., :-1] += obs 24 | std = np.sqrt(np.exp(logvar)) 25 | samples = (mean + np.random.normal(size=mean.shape) * std).astype(np.float32) 26 | next_obses = samples[..., :-1] 27 | rewards = samples[..., -1:] 28 | 29 | select_indexes = np.random.randint(0, next_obses.shape[0], size=(obs.shape[0])) 30 | next_obs = next_obses[select_indexes, np.arange(obs.shape[0])] 31 | reward = rewards[select_indexes, np.arange(obs.shape[0])] 32 | terminal = self.static_fn.termination_fn(obs, action, next_obs) 33 | 34 | return next_obs, reward, terminal, {} 35 | 36 | def train(self, inputs, targets, batch_size=256): 37 | self.model.reset_elites() 38 | data_size = inputs.shape[0] 39 | holdout_size = min(int(data_size * 0.2), 5000) 40 | train_size = data_size - holdout_size 41 | train_splits, holdout_splits = torch.utils.data.random_split(range(data_size), (train_size, holdout_size)) 42 | train_inputs, train_targets = inputs[train_splits.indices], targets[train_splits.indices] 43 | holdout_inputs, holdout_targets = inputs[holdout_splits.indices], targets[holdout_splits.indices] 44 | 45 | self.scaler.fit(train_inputs) 46 | train_inputs = self.scaler.transform(train_inputs) 47 | holdout_inputs = self.scaler.transform(holdout_inputs) 48 | holdout_losses = [1e10 for i in range(self.model.num_ensemble)] 49 | 50 | data_idxes = np.random.randint(train_size, size=[self.model.num_ensemble, train_size]) 51 | def shuffle_rows(arr): 52 | idxes = np.argsort(np.random.uniform(size=arr.shape), axis=-1) 53 | return arr[np.arange(arr.shape[0])[:, None], idxes] 54 | 55 | epoch = 0 56 | cnt = 0 57 | num_elites = self.model.num_elites 58 | 59 | while True: 60 | epoch += 1 61 | self.learn_batch(train_inputs[data_idxes], train_targets[data_idxes], batch_size) 62 | new_holdout_losses = self.validate(holdout_inputs, holdout_targets) 63 | holdout_loss = (np.sort(new_holdout_losses)[:num_elites]).mean() 64 | 65 | # shuffle data for each base learner 66 | data_idxes = shuffle_rows(data_idxes) 67 | 68 | indexes = [] 69 | for i, new_loss, old_loss in zip(range(len(holdout_losses)), new_holdout_losses, holdout_losses): 70 | improvement = (old_loss - new_loss) / old_loss 71 | if improvement > 0.01: 72 | indexes.append(i) 73 | holdout_losses[i] = new_loss 74 | 75 | if len(indexes) > 0: 76 | self.model.update_save(indexes) 77 | cnt = 0 78 | else: 79 | cnt += 1 80 | 81 | if cnt >= 5: 82 | break 83 | 84 | indexes = self.select_elites(holdout_losses) 85 | self.model.set_elites(indexes) 86 | return { 87 | "num_epochs": epoch, 88 | "elites": indexes, 89 | "holdout_loss": (np.sort(holdout_losses)[:num_elites]).mean() 90 | } 91 | 92 | def learn_batch(self, inputs, targets, batch_size): 93 | self.model.train() 94 | train_size = inputs.shape[1] 95 | 96 | for batch_num in range(int(np.ceil(train_size / batch_size))): 97 | inputs_batch = inputs[:, batch_num * batch_size:(batch_num + 1) * batch_size] 98 | targets_batch = targets[:, batch_num * batch_size:(batch_num + 1) * batch_size] 99 | targets_batch = torch.as_tensor(targets_batch).to(self.model.device) 100 | 101 | mean, logvar = self.model(inputs_batch) 102 | inv_var = torch.exp(-logvar) 103 | # Average over batch and dim, sum over ensembles. 104 | mse_loss_inv = (torch.pow(mean - targets_batch, 2) * inv_var).mean(dim=(1, 2)) 105 | var_loss = logvar.mean(dim=(1, 2)) 106 | loss = mse_loss_inv.sum() + var_loss.sum() 107 | loss = loss + self.model.get_decay_loss() 108 | loss = loss + 0.01 * self.model.max_logvar.sum() - 0.01 * self.model.min_logvar.sum() 109 | 110 | self.optim.zero_grad() 111 | loss.backward() 112 | self.optim.step() 113 | 114 | @ torch.no_grad() 115 | def validate(self, inputs, targets): 116 | self.model.eval() 117 | targets = torch.as_tensor(targets).to(self.model.device) 118 | mean, _ = self.model(inputs) 119 | loss = ((mean - targets) ** 2).mean(dim=(1, 2)) 120 | val_loss = list(loss.cpu().numpy()) 121 | return val_loss 122 | 123 | def select_elites(self, metrics): 124 | pairs = [(metric, index) for metric, index in zip(metrics, range(len(metrics)))] 125 | pairs = sorted(pairs, key=lambda x: x[0]) 126 | elites = [pairs[i][1] for i in range(self.model.num_elites)] 127 | return elites 128 | 129 | def save(self, save_path): 130 | torch.save(self.model.state_dict(), os.path.join(save_path, "dynamics.pth")) 131 | self.scaler.save_scaler(save_path) 132 | 133 | def load(self, load_path): 134 | self.model.load_state_dict(torch.load(os.path.join(load_path, "dynamics.pth"), map_location=self.model.device)) 135 | self.scaler.load_scaler(load_path) 136 | 137 | 138 | def format_samples_for_training(samples): 139 | obs = samples["s"] 140 | act = samples["a"] 141 | next_obs = samples["s_"] 142 | rew = samples["r"] 143 | delta_obs = next_obs - obs 144 | inputs = np.concatenate((obs, act), axis=-1) 145 | targets = np.concatenate((delta_obs, rew), axis=-1) 146 | return inputs, targets 147 | -------------------------------------------------------------------------------- /buffer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | class ReplayBuffer: 5 | """ replay buffer """ 6 | def __init__(self, buffer_size, obs_shape, action_dim): 7 | self.obs_shape = obs_shape 8 | self.action_dim = action_dim 9 | self.memory = { 10 | "s": np.zeros((buffer_size, *self.obs_shape), dtype=np.float32), 11 | "a": np.zeros((buffer_size, self.action_dim), dtype=np.float32), 12 | "r": np.zeros((buffer_size, 1), dtype=np.float32), 13 | "s_": np.zeros((buffer_size, *self.obs_shape), dtype=np.float32), 14 | "done": np.zeros((buffer_size, 1), dtype=np.float32), 15 | } 16 | 17 | self.capacity = buffer_size 18 | self.size = 0 19 | self.cnt = 0 20 | 21 | def store(self, s, a, r, s_, done, timeout): 22 | """ store transition (s, a, r, s_, done) """ 23 | done *= (1-timeout) 24 | self.memory["s"][self.cnt] = s 25 | self.memory["a"][self.cnt] = a 26 | self.memory["r"][self.cnt] = r 27 | self.memory["s_"][self.cnt] = s_ 28 | self.memory["done"][self.cnt] = done 29 | 30 | self.cnt = (self.cnt+1)%self.capacity 31 | self.size = min(self.size+1, self.capacity) 32 | 33 | def store_batch(self, s, a, r, s_, done): 34 | """ store batch transitions (s, a, r, s_, done) """ 35 | batch_size = len(s) 36 | 37 | indices = np.arange(self.cnt, self.cnt+batch_size)%self.capacity 38 | self.memory["s"][indices] = s 39 | self.memory["a"][indices] = a 40 | self.memory["r"][indices] = r 41 | self.memory["s_"][indices] = s_ 42 | self.memory["done"][indices] = done 43 | 44 | self.cnt = (self.cnt+batch_size)%self.capacity 45 | self.size = min(self.size+batch_size, self.capacity) 46 | 47 | def sample(self, batch_size): 48 | """ sample a batch of transitions """ 49 | indices = np.random.randint(0, self.size, batch_size) 50 | return { 51 | "s": self.memory["s"][indices].copy(), 52 | "a": self.memory["a"][indices].copy(), 53 | "r": self.memory["r"][indices].copy(), 54 | "s_": self.memory["s_"][indices].copy(), 55 | "done": self.memory["done"][indices].copy() 56 | } 57 | 58 | def sample_all(self): 59 | """ sample all transitions """ 60 | indices = np.arange(self.size) 61 | return { 62 | "s": self.memory["s"][indices].copy(), 63 | "a": self.memory["a"][indices].copy(), 64 | "r": self.memory["r"][indices].copy(), 65 | "s_": self.memory["s_"][indices].copy(), 66 | "done": self.memory["done"][indices].copy() 67 | } 68 | 69 | 70 | class ReplayBufferForSeqSampling(ReplayBuffer): 71 | """ replay buffer for sequential actions sampling """ 72 | def __init__(self, buffer_size, obs_shape, action_dim, plan_length, gamma): 73 | super().__init__(buffer_size, obs_shape, action_dim) 74 | # used for mbpc-based policy 75 | self.endpoint = np.zeros(buffer_size, dtype=np.float32) # whether the step is an endpoint (end ≠ done) 76 | self.sample_sign = np.zeros(buffer_size, dtype=np.float32) # whether the step can be sampled 77 | self.sample_mask = np.zeros((buffer_size, plan_length), dtype=np.float32) 78 | self.sample_end = np.zeros(buffer_size, dtype=np.int64) 79 | self.plan_length = plan_length 80 | self.gammas = gamma**np.arange(plan_length).reshape((plan_length, 1)) 81 | 82 | def store(self, s, a, r, s_, done, timeout): 83 | self.endpoint[self.cnt] = done 84 | self.sample_sign[self.cnt] = 0 85 | self.sample_mask[self.cnt] = 0 86 | self.sample_end[self.cnt] = 0 87 | super().store(s, a, r, s_, done, timeout) 88 | 89 | if self.size >= self.plan_length: 90 | if self.endpoint[np.arange(self.cnt-self.plan_length, self.cnt-1)].sum() == 0: 91 | self.sample_sign[self.cnt-self.plan_length] = 1 92 | self.sample_mask[self.cnt-self.plan_length] = 1 93 | self.sample_end[self.cnt-self.plan_length] = self.plan_length - 1 94 | 95 | elif self.memory["done"][np.arange(self.cnt-self.plan_length, self.cnt-1)].sum() == 1: 96 | for i in range(self.plan_length-1): 97 | if self.memory["done"][self.cnt-self.plan_length+i]: 98 | self.sample_sign[self.cnt-self.plan_length] = 1 99 | self.sample_mask[self.cnt-self.plan_length, :i+1] = 1 100 | self.sample_end[self.cnt-self.plan_length] = i 101 | break 102 | 103 | def sample_nstep(self, batch_size): 104 | """ sample a batch of {plan_length}-step transitions """ 105 | all_start_indices = np.arange(self.size)[self.sample_sign[:self.size]==1] 106 | start_indices = np.random.choice(all_start_indices, batch_size) 107 | indices = (start_indices.reshape(-1, 1) + np.arange(self.plan_length))%self.size 108 | sample_mask = self.sample_mask[start_indices] 109 | sample_end = self.sample_end[start_indices] 110 | 111 | return { 112 | "s": self.memory["s"][start_indices].copy(), 113 | "a": (self.memory["a"][indices].reshape((batch_size, -1))*sample_mask.repeat(self.action_dim, axis=-1)).copy(), 114 | "r": (self.memory["r"][indices].reshape((batch_size, -1))*sample_mask).dot(self.gammas).copy(), 115 | "s_": self.memory["s_"][indices[np.arange(batch_size), sample_end]].copy(), 116 | "done": self.memory["done"][indices].sum(axis=1).clip(None, 1).copy() 117 | } 118 | 119 | def sample_all_nstep(self): 120 | """ sample all {plan_length}-step transitions """ 121 | start_indices = np.arange(self.size)[self.sample_sign[:self.size]==1] 122 | indices = (start_indices.reshape(-1, 1) + np.arange(self.plan_length))%self.size 123 | sample_mask = self.sample_mask[start_indices] 124 | sample_end = self.sample_end[start_indices] 125 | 126 | return { 127 | "s": self.memory["s"][start_indices].copy(), 128 | "a": (self.memory["a"][indices].reshape((indices.shape[0], -1))*sample_mask.repeat(self.action_dim, axis=-1)).copy(), 129 | "r": (self.memory["r"][indices].reshape((indices.shape[0], -1))*sample_mask).dot(self.gammas).copy(), 130 | "s_": self.memory["s_"][indices[np.arange(indices.shape[0]), sample_end]].copy(), 131 | "done": self.memory["done"][indices].sum(axis=1).clip(None, 1).copy() 132 | } 133 | 134 | def sample_nstep4rollout(self, batch_size): 135 | """ sample a batch of {plan_length-1}-step transitions for rollout """ 136 | all_start_indices = np.arange(self.size)[self.sample_end[:self.size]==self.plan_length-1] 137 | start_indices = np.random.choice(all_start_indices, batch_size) 138 | indices = (start_indices.reshape(-1, 1) + np.arange(self.plan_length-1))%self.size 139 | 140 | return { 141 | "s": self.memory["s"][indices].reshape((batch_size, -1)).copy(), 142 | "a": self.memory["a"][indices].reshape((batch_size, -1)).copy(), 143 | "r": self.memory["r"][indices].reshape((batch_size, -1)).copy(), 144 | "s_": self.memory["s_"][indices].reshape((batch_size, -1)).copy(), 145 | "done": self.memory["done"][indices].reshape((batch_size, -1)).copy() 146 | } 147 | -------------------------------------------------------------------------------- /trainer/mppve_trainer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from tqdm import tqdm 3 | 4 | from mppve import MPPVEAgent 5 | from .base_trainer import BASETrainer 6 | from buffer import ReplayBuffer, ReplayBufferForSeqSampling 7 | 8 | from components.dynamics import Dynamics 9 | from components.static_fns import STATICFUNC 10 | from components.dynamics_model import EnsembleDynamicsModel 11 | 12 | 13 | class MPPVETrainer(BASETrainer): 14 | """ model-based planning policy learning with multi-step plan-value estimation """ 15 | def __init__(self, args): 16 | super().__init__(args) 17 | 18 | # init dynamics 19 | dynamics_model = EnsembleDynamicsModel( 20 | obs_dim=int(np.prod(args.obs_shape)), 21 | action_dim=args.action_dim, 22 | hidden_dims=args.dynamics_hidden_dims, 23 | num_ensemble=args.n_ensembles, 24 | num_elites=args.n_elites, 25 | weight_decays=args.dynamics_weight_decay, 26 | load_model=False, 27 | device=args.device 28 | ) 29 | task = args.env_name.split('-')[0] 30 | static_fns = STATICFUNC[task] 31 | dynamics = Dynamics(dynamics_model, static_fns) 32 | 33 | self.model_update_interval = args.model_update_interval 34 | self.actor_freq = args.actor_freq 35 | 36 | # init mppve-agent 37 | self.agent = MPPVEAgent( 38 | obs_shape=args.obs_shape, 39 | hidden_dims=args.ac_hidden_dims, 40 | action_dim=args.action_dim, 41 | action_space=self.env.action_space, 42 | dynamics=dynamics, 43 | plan_length=args.plan_length, 44 | actor_lr=args.actor_lr, 45 | critic_lr=args.critic_lr, 46 | batch_size=args.batch_size, 47 | tau=args.tau, 48 | gamma=args.gamma, 49 | alpha=args.alpha, 50 | auto_alpha=args.auto_alpha, 51 | alpha_lr=args.alpha_lr, 52 | target_entropy=args.target_entropy, 53 | device=args.device 54 | ) 55 | self.agent.train() 56 | 57 | # planning actions queue 58 | self.plan_length = args.plan_length 59 | self.plan_actions = [] 60 | 61 | # init replay buffer 62 | self.memory = ReplayBufferForSeqSampling( 63 | args.buffer_size, args.obs_shape, args.action_dim, args.plan_length, args.gamma) 64 | 65 | # create memory to store imaginary transitions 66 | model_rollout_size = args.rollout_batch_size*args.rollout_schedule[2] 67 | model_buffer_size = int(model_rollout_size*args.model_retain_steps/args.model_update_interval) 68 | self.model_memory = ReplayBuffer( 69 | buffer_size=model_buffer_size, 70 | obs_shape=args.obs_shape, 71 | action_dim=args.action_dim*self.plan_length 72 | ) 73 | 74 | # func 4 calculate new rollout length (x->y over steps a->b) 75 | a, b, x, y = args.rollout_schedule 76 | self.make_rollout_len = lambda it: int(min(max(x+(it-a)/(b-a)*(y-x), x), y)) 77 | # func 4 calculate new model buffer size 78 | self.make_model_buffer_size = lambda it: \ 79 | int(args.rollout_batch_size*self.make_rollout_len(it) * \ 80 | args.model_retain_steps/args.model_update_interval) 81 | 82 | # other parameters 83 | self.model_update_interval = args.model_update_interval 84 | self.rollout_batch_size = args.rollout_batch_size 85 | self.real_ratio = args.real_ratio 86 | self.updates_per_step = args.updates_per_step 87 | 88 | def train(self): 89 | """ train {args.algo} on {args.env} for {args.n_steps} steps""" 90 | 91 | # init 92 | obs = self._warm_up() 93 | 94 | pbar = tqdm(range(self.n_steps), desc="Training {} on {}.{} (seed: {})".format( 95 | self.args.algo.upper(), self.args.env.title(), self.args.env_name, self.seed)) 96 | 97 | for it in pbar: 98 | # update (one-step) dynamics model 99 | if it%self.model_update_interval == 0: 100 | transitions = self.memory.sample_all() 101 | model_loss = self.agent.learn_dynamics(transitions) 102 | self.agent.valid_plan_length = self.make_rollout_len(it) 103 | 104 | # update imaginary memory 105 | new_model_buffer_size = self.make_model_buffer_size(it) 106 | if self.model_memory.capacity != new_model_buffer_size: 107 | new_buffer = ReplayBuffer( 108 | buffer_size=new_model_buffer_size, 109 | obs_shape=self.model_memory.obs_shape, 110 | action_dim=self.model_memory.action_dim 111 | ) 112 | old_transitions = self.model_memory.sample_all() 113 | new_buffer.store_batch(**old_transitions) 114 | self.model_memory = new_buffer 115 | 116 | # rollout 117 | init_transitions = self.memory.sample_nstep4rollout(self.rollout_batch_size) 118 | rollout_len = self.make_rollout_len(it) 119 | fake_transitions = self.agent.rollout_transitions(init_transitions, rollout_len) 120 | self.model_memory.store_batch(**fake_transitions) 121 | 122 | self.logger.log(f"rollout length: {rollout_len},"+ 123 | f"model buffer capacity: {new_model_buffer_size},"+ 124 | f"model buffer size: {self.model_memory.size}") 125 | 126 | # step in env 127 | action = self.agent.act(obs) 128 | next_obs, reward, done, info = self.env.step(action) 129 | timeout = info.get("TimeLimit.truncated", False) 130 | self.memory.store(obs, action, reward, next_obs, done, timeout) 131 | 132 | obs = next_obs 133 | if done: obs = self.env.reset(); self.plan_actions = [] 134 | 135 | # render 136 | if self.render: self.env.render() 137 | 138 | # update policy 139 | if it%self.update_interval == 0: 140 | real_states = [] 141 | update_num = int(self.update_interval*self.updates_per_step) 142 | update_cnt = 0 143 | for _ in range(update_num): 144 | # sample transitions 145 | real_sample_size = int(self.batch_size*self.real_ratio) 146 | fake_sample_size = self.batch_size - real_sample_size 147 | real_batch = self.memory.sample_nstep(batch_size=real_sample_size) 148 | fake_batch = self.model_memory.sample(batch_size=fake_sample_size) 149 | transitions = {key: np.concatenate( 150 | (real_batch[key], fake_batch[key]), axis=0) for key in real_batch.keys()} 151 | 152 | real_states.append(real_batch["s"]) 153 | 154 | # update 155 | critic_learning_info = self.agent.learn_critic(**transitions) 156 | critic_loss = critic_learning_info["critic_loss"] 157 | update_cnt += 1 158 | 159 | if update_cnt % self.actor_freq == 0: 160 | real_states = np.concatenate(real_states, axis=0) 161 | actor_learning_info = self.agent.learn_actor(real_states) 162 | actor_loss = actor_learning_info["actor_loss"] 163 | alpha = actor_learning_info["alpha"] 164 | real_states = [] 165 | 166 | # evaluate policy 167 | if it%self.eval_interval == 0: 168 | episode_rewards = np.mean(self._eval_policy()) 169 | self.logger.logkv("loss/model", model_loss) 170 | self.logger.logkv("loss/actor", actor_loss) 171 | self.logger.logkv("loss/critic", critic_loss) 172 | self.logger.logkv("alpha", alpha) 173 | self.logger.logkv("eval/episode_rewards", np.mean(episode_rewards)) 174 | 175 | value_bias_info = self._eval_value_estimation() 176 | self.logger.logkv("eval/value_bias_mean", value_bias_info["value_bias_mean"]) 177 | self.logger.logkv("eval/value_bias_std", value_bias_info["value_bias_std"]) 178 | 179 | self.logger.set_timestep(it) 180 | self.logger.dumpkvs() 181 | 182 | pbar.set_postfix( 183 | alpha=alpha, 184 | model_loss=model_loss, 185 | actor_loss=actor_loss, 186 | critic_loss=critic_loss, 187 | eval_reward=episode_rewards 188 | ) 189 | -------------------------------------------------------------------------------- /utils/logger.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import json 4 | import pprint 5 | import argparse 6 | import datetime 7 | import warnings 8 | import numpy as np 9 | 10 | from collections import defaultdict, deque 11 | from typing import Any, Dict, List, Optional, Sequence, TextIO, Tuple, Union 12 | from tokenize import Number 13 | from torch.utils.tensorboard import SummaryWriter 14 | 15 | 16 | DEBUG = 10 17 | INFO = 20 18 | WARN = 30 19 | ERROR = 40 20 | BACKUP = 60 21 | 22 | DEFAULT_X_NAME = "timestep" 23 | ROOT_DIR = "log" 24 | 25 | 26 | class KVWriter(object): 27 | """ 28 | Key Value writer 29 | """ 30 | def writekvs(self, kvs: Dict) -> None: 31 | """ 32 | write a dictionary to file 33 | """ 34 | raise NotImplementedError 35 | 36 | 37 | class StrWriter(object): 38 | """ 39 | string writer 40 | """ 41 | def writestr(self, s: str) -> None: 42 | """ 43 | write a string to file 44 | """ 45 | raise NotImplementedError 46 | 47 | 48 | class StandardOutputHandler(KVWriter, StrWriter): 49 | def __init__(self, filename_or_textio: Union[str, TextIO]) -> None: 50 | """ 51 | log to a file, in a human readable format 52 | 53 | :param filename_or_file: (str or File) the file to write the log to 54 | """ 55 | if isinstance(filename_or_textio, str): 56 | self.file = open(filename_or_textio+".txt", 'at') 57 | self.own_file = True 58 | self.handler_name = os.path.basename(filename_or_textio) 59 | else: 60 | assert hasattr(filename_or_textio, 'write'), 'Expected file or str, got {}'.format(filename_or_textio) 61 | self.file = filename_or_textio 62 | self.own_file = False 63 | self.handler_name = "stdio" 64 | super().__init__() 65 | 66 | def writekvs(self, kvs: Dict) -> None: 67 | # Create strings for printing 68 | key2str = {} 69 | for (key, val) in sorted(kvs.items()): 70 | if isinstance(val, float): 71 | valstr = '%-8.3g' % (val,) 72 | else: 73 | valstr = str(val) 74 | key2str[self._truncate(key)] = self._truncate(valstr) 75 | 76 | # Find max widths 77 | if len(key2str) == 0: 78 | warnings.warn('Tried to write empty key-value dict') 79 | return 80 | else: 81 | keywidth = max(map(len, key2str.keys())) 82 | valwidth = max(map(len, key2str.values())) 83 | 84 | # Write out the data 85 | dashes = '-' * (keywidth + valwidth + 40) 86 | lines = [dashes] 87 | for (key, val) in sorted(key2str.items(), key=lambda kv: kv[0].lower()): 88 | lines.append('| %s%s | %s%s |' % ( 89 | key, 90 | ' ' * (keywidth - len(key)), 91 | val, 92 | ' ' * (valwidth - len(val)), 93 | )) 94 | lines.append(dashes) 95 | self.file.write('\n'.join(lines) + '\n') 96 | 97 | # Flush the output to the file 98 | self.file.flush() 99 | 100 | def _truncate(self, s: str) -> str: 101 | return s[:40] + '...' if len(s) > 80 else s 102 | 103 | def writestr(self, s: str) -> None: 104 | self.file.write(s) 105 | self.file.write('\n') 106 | self.file.flush() 107 | 108 | def close(self) -> None: 109 | """ 110 | closes the file 111 | """ 112 | if self.own_file: 113 | self.file.close() 114 | 115 | 116 | class JSONOutputHandler(KVWriter): 117 | def __init__(self, filename: str) -> None: 118 | """ 119 | log to a file in the JSON format 120 | """ 121 | self.file = open(filename+".json", 'at') 122 | self.handler_name = os.path.basename(filename) 123 | super().__init__() 124 | 125 | def writekvs(self, kvs: Dict) -> None: 126 | for key, value in sorted(kvs.items()): 127 | if hasattr(value, 'dtype'): 128 | if value.shape == () or len(value) == 1: 129 | # if value is a dimensionless numpy array or of length 1, serialize as a float 130 | kvs[key] = float(value) 131 | else: 132 | # otherwise, a value is a numpy array, serialize as a list or nested lists 133 | kvs[key] = value.tolist() 134 | self.file.write(json.dumps(kvs) + '\n') 135 | self.file.flush() 136 | 137 | def close(self) -> None: 138 | """ 139 | closes the file 140 | """ 141 | self.file.close() 142 | 143 | 144 | class CSVOutputHandler(KVWriter): 145 | def __init__(self, filename: str) -> None: 146 | """ 147 | log to a file in the CSV format 148 | """ 149 | filename += ".csv" 150 | self.filename = filename 151 | self.file = open(filename, 'a+t') 152 | self.file.seek(0) 153 | keys = self.file.readline() 154 | if keys != '': 155 | keys = keys[:-1] # skip '\n' 156 | keys = keys.split(',') 157 | self.keys = keys 158 | else: 159 | self.keys = [] 160 | self.file = open(filename, 'a+t') 161 | self.sep = ',' 162 | self.handler_name = os.path.splitext(os.path.basename(filename))[0] 163 | super().__init__() 164 | 165 | def writekvs(self, kvs: Dict) -> None: 166 | # Add our current row to the history 167 | extra_keys = list(kvs.keys() - self.keys) 168 | extra_keys.sort() 169 | if extra_keys: 170 | self.keys.extend(extra_keys) 171 | self.file.seek(0) 172 | lines = self.file.readlines() 173 | self.file = open(self.filename, 'w+t') 174 | self.file.seek(0) 175 | for (i, key) in enumerate(self.keys): 176 | if i > 0: 177 | self.file.write(',') 178 | self.file.write(key) 179 | self.file.write('\n') 180 | for line in lines[1:]: 181 | self.file.write(line[:-1]) 182 | self.file.write(self.sep * len(extra_keys)) 183 | self.file.write('\n') 184 | self.file = open(self.filename, 'a+t') 185 | for i, key in enumerate(self.keys): 186 | if i > 0: 187 | self.file.write(',') 188 | value = kvs.get(key) 189 | if value is not None: 190 | self.file.write(str(value)) 191 | self.file.write('\n') 192 | self.file.flush() 193 | 194 | def close(self) -> None: 195 | """ 196 | closes the file 197 | """ 198 | self.file.close() 199 | 200 | 201 | class TensorBoardOutputHandler(KVWriter): 202 | """ 203 | Dumps key/value pairs into TensorBoard's numeric format. 204 | """ 205 | def __init__(self, filename: str) -> None: 206 | self.step = 1 207 | self.tb_writer = SummaryWriter(filename) 208 | self.handler_name = os.path.basename(filename) 209 | super().__init__() 210 | 211 | @property 212 | def writer(self) -> SummaryWriter: 213 | return self.tb_writer 214 | 215 | def add_hyper_params_to_tb(self, hyper_param: Dict, metric_dict=None) -> None: 216 | if metric_dict is None: 217 | pp = pprint.PrettyPrinter(indent=4) 218 | self.writer.add_text('hyperparameters', pp.pformat(hyper_param)) 219 | else: 220 | self.writer.add_hparams(hyper_param, metric_dict) 221 | 222 | def writekvs(self, kvs: Dict) -> None: 223 | def summary_val(k, v): 224 | kwargs = {'tag': k, 'scalar_value': float(v), 'global_step': self.step} 225 | self.writer.add_scalar(**kwargs) 226 | 227 | for k, v in kvs.items(): 228 | if k == DEFAULT_X_NAME: continue 229 | summary_val(k, v) 230 | 231 | def set_step(self, step: int) -> None: 232 | self.step = step 233 | 234 | def close(self) -> None: 235 | if self.writer: 236 | self.writer.close() 237 | 238 | 239 | HANDLER = { 240 | "stdout": StandardOutputHandler, 241 | "csv": CSVOutputHandler, 242 | "tensorboard": TensorBoardOutputHandler 243 | } 244 | 245 | 246 | class Logger(object): 247 | def __init__(self, dir: str, ouput_config: Dict) -> None: 248 | self._dir = dir 249 | self._init_dirs() 250 | self._init_ouput_handlers(ouput_config) 251 | self._name2val = defaultdict(float) 252 | self._name2cnt = defaultdict(int) 253 | self._level = INFO 254 | self._timestep = 0 255 | 256 | def _init_dirs(self) -> None: 257 | self._record_dir = os.path.join(self._dir, "record") 258 | self._checkpoint_dir = os.path.join(self._dir, "checkpoint") 259 | self._model_dir = os.path.join(self._dir, "model") 260 | self._result_dir = os.path.join(self._dir, "result") 261 | os.mkdir(self._record_dir) 262 | os.mkdir(self._checkpoint_dir) 263 | os.mkdir(self._model_dir) 264 | os.mkdir(self._result_dir) 265 | 266 | def _init_ouput_handlers(self, output_config: Dict) -> None: 267 | self._output_handlers = [] 268 | for file_name, fmt in output_config.items(): 269 | try: 270 | self._output_handlers.append(HANDLER[fmt](os.path.join(self._record_dir, file_name))) 271 | except KeyError: 272 | warnings.warn("Invalid output type, Valid types: stdout, csv, tensorboard", DeprecationWarning) 273 | # default output to console 274 | self._output_handlers.append(StandardOutputHandler(sys.stdout)) 275 | 276 | def log_hyperparameters(self, hyper_param: Dict) -> None: 277 | json_output_handler = JSONOutputHandler(os.path.join(self._record_dir, "hyper_param")) 278 | json_output_handler.writekvs(hyper_param) 279 | json_output_handler.close() 280 | for handler in self._output_handlers: 281 | if isinstance(handler, TensorBoardOutputHandler): 282 | handler.add_hyper_params_to_tb(hyper_param) 283 | 284 | def logkv(self, key: Any, val: Any) -> None: 285 | """ 286 | Log a value of some diagnostic 287 | Call this once for each diagnostic quantity, each iteration 288 | If called many times, last value will be used. 289 | """ 290 | self._name2val[key] = val 291 | 292 | def logkv_mean(self, key: Any, val: Number) -> None: 293 | """ 294 | The same as logkv(), but if called many times, values averaged. 295 | """ 296 | oldval, cnt = self._name2val[key], self._name2cnt[key] 297 | self._name2val[key] = oldval*cnt/(cnt+1) + val/(cnt+1) 298 | self._name2cnt[key] = cnt + 1 299 | 300 | def dumpkvs(self, exclude:Optional[Union[str, Tuple[str, ...]]]=None) -> None: 301 | # log timestep 302 | self.logkv(DEFAULT_X_NAME, self._timestep) 303 | for handler in self._output_handlers: 304 | if isinstance(handler, KVWriter): 305 | if exclude is not None and handler.handler_name in exclude: 306 | continue 307 | handler.writekvs(self._name2val) 308 | self._name2val.clear() 309 | self._name2cnt.clear() 310 | 311 | def log(self, s: str, level=INFO) -> None: 312 | for handler in self._output_handlers: 313 | if isinstance(handler, StandardOutputHandler): 314 | handler.writestr(s) 315 | 316 | def set_timestep(self, timestep: int) -> None: 317 | self._timestep = timestep 318 | for handler in self._output_handlers: 319 | if isinstance(handler, TensorBoardOutputHandler): 320 | handler.set_step(timestep) 321 | 322 | def set_level(self, level) -> None: 323 | self._level = level 324 | 325 | @property 326 | def record_dir(self) -> str: 327 | return self._record_dir 328 | 329 | @property 330 | def checkpoint_dir(self) -> str: 331 | return self._checkpoint_dir 332 | 333 | @property 334 | def model_dir(self) -> str: 335 | return self._model_dir 336 | 337 | @property 338 | def result_dir(self) -> str: 339 | return self._result_dir 340 | 341 | def close(self) -> None: 342 | for handler in self._output_handlers: 343 | handler.close() 344 | 345 | 346 | def make_log_dirs( 347 | task_name: str, 348 | algo_name: str, 349 | seed: int, 350 | args: Dict, 351 | record_params: Optional[List]=None 352 | ) -> str: 353 | if record_params is not None: 354 | for param_name in record_params: 355 | algo_name += f"&{param_name}={args[param_name]}" 356 | timestamp = datetime.datetime.now().strftime("%y-%m%d-%H%M%S") 357 | exp_name = f"seed_{seed}×tamp_{timestamp}" 358 | log_dirs = os.path.join(ROOT_DIR, task_name, algo_name, exp_name) 359 | os.makedirs(log_dirs) 360 | return log_dirs 361 | 362 | 363 | def load_args(load_path: str) -> argparse.ArgumentParser: 364 | args_dict = {} 365 | with open(load_path,'r') as f: 366 | args_dict.update(json.load(f)) 367 | return argparse.Namespace(**args_dict) 368 | -------------------------------------------------------------------------------- /mppve.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import numpy as np 4 | from copy import deepcopy 5 | 6 | from components import ACTOR, CRITIC 7 | from components.dynamics import format_samples_for_training 8 | 9 | 10 | class MPPVEAgent: 11 | """ Learning Model-based Predictive Policy with Sequence-value Estimation """ 12 | def __init__( 13 | self, 14 | obs_shape, 15 | hidden_dims, 16 | action_dim, 17 | action_space, 18 | dynamics, 19 | plan_length, 20 | actor_lr, 21 | critic_lr, 22 | batch_size, 23 | tau=0.005, 24 | gamma=0.99, 25 | alpha=0.2, 26 | auto_alpha=True, 27 | alpha_lr=3e-4, 28 | target_entropy=-1, 29 | device="cpu" 30 | ): 31 | # actor 32 | self.actor = ACTOR["prob"](obs_shape, hidden_dims, action_dim, device) 33 | 34 | # critic 35 | self.valid_plan_length = 1 36 | self.plan_length = plan_length 37 | self.critic1 = CRITIC["q"](obs_shape, hidden_dims, action_dim*self.plan_length, device) 38 | self.critic2 = CRITIC["q"](obs_shape, hidden_dims, action_dim*self.plan_length, device) 39 | # target critic 40 | self.critic1_trgt = deepcopy(self.critic1) 41 | self.critic2_trgt = deepcopy(self.critic2) 42 | self.critic1_trgt.eval() 43 | self.critic2_trgt.eval() 44 | 45 | # optimizer 46 | self.actor_optim = torch.optim.Adam(self.actor.parameters(), lr=actor_lr) 47 | self.critic1_optim = torch.optim.Adam(self.critic1.parameters(), lr=critic_lr) 48 | self.critic2_optim = torch.optim.Adam(self.critic2.parameters(), lr=critic_lr) 49 | 50 | # env space 51 | self.obs_dim = np.prod(obs_shape) 52 | self.action_dim = action_dim 53 | self.action_space = action_space 54 | 55 | # alpha: weight of entropy 56 | self._auto_alpha = auto_alpha 57 | if self._auto_alpha: 58 | if not target_entropy: 59 | target_entropy = -np.prod(self.action_space.shape)*self.plan_length 60 | self._target_entropy = target_entropy*self.plan_length 61 | self._log_alpha = torch.zeros(1, requires_grad=True, device=device) 62 | self._alpha = self._log_alpha.detach().exp() 63 | self._alpha_optim = torch.optim.Adam([self._log_alpha], lr=alpha_lr) 64 | else: 65 | self._alpha = alpha 66 | 67 | # dynamics model 68 | self.dynamics = dynamics 69 | 70 | # other parameters 71 | self._tau = tau 72 | self._gamma = gamma 73 | self._eps = np.finfo(np.float32).eps.item() 74 | self.batch_size = batch_size 75 | self.device = device 76 | 77 | def train(self): 78 | self.actor.train() 79 | self.critic1.train() 80 | self.critic2.train() 81 | 82 | def eval(self): 83 | self.actor.eval() 84 | self.critic1.eval() 85 | self.critic2.eval() 86 | 87 | def _sync_weight(self): 88 | """ synchronize weight """ 89 | for trgt, src in zip(self.critic1_trgt.parameters(), self.critic1.parameters()): 90 | trgt.data.copy_(trgt.data*(1.0-self._tau) + src.data*self._tau) 91 | for trgt, src in zip(self.critic2_trgt.parameters(), self.critic2.parameters()): 92 | trgt.data.copy_(trgt.data*(1.0-self._tau) + src.data*self._tau) 93 | 94 | def actor4ward(self, obs, deterministic=False): 95 | """ forward propagation of actor """ 96 | dist = self.actor(obs) 97 | if deterministic: 98 | action = dist.mode() 99 | else: 100 | action = dist.rsample() 101 | log_prob = dist.log_prob(action) 102 | 103 | action_scale = torch.tensor((self.action_space.high-self.action_space.low)/2, device=self.device) 104 | squashed_action = torch.tanh(action) 105 | log_prob = log_prob - torch.log(action_scale*(1-squashed_action.pow(2))+self._eps).sum(-1, keepdim=True) 106 | 107 | return action_scale*squashed_action, log_prob 108 | 109 | def actor4ward_plan(self, obs, deterministic=False): 110 | """ forward propagation of actor (planning version) """ 111 | bs = obs.size(0) 112 | mask = torch.ones((bs, 1), device=self.device) 113 | plan_actions = torch.zeros((bs, self.plan_length, self.action_dim), device=self.device) 114 | plan_log_prob = torch.zeros((bs, self.plan_length, 1), device=self.device) 115 | obs = obs.cpu().detach().numpy() 116 | 117 | for t in range(self.plan_length): 118 | # plan one step 119 | obs_torch = torch.as_tensor(obs, dtype=torch.float32, device=self.device) 120 | action, log_prob = self.actor4ward(obs_torch, deterministic) 121 | plan_actions[:, t] = action*mask.clone().expand(-1, self.action_dim) 122 | plan_log_prob[:, t] = log_prob*mask.clone() 123 | if t > self.valid_plan_length: 124 | plan_actions[:, t] = plan_actions[:, t].detach() 125 | plan_log_prob[:, t] = plan_log_prob[:, t].detach() 126 | 127 | # imaginary step 128 | obs, _, done, _ = self.dynamics.step(obs, action.cpu().detach().numpy()) 129 | mask[done.flatten()==1] = 0 130 | 131 | plan_actions = plan_actions.view((bs, -1)) 132 | plan_log_prob = plan_log_prob.sum(1) 133 | return plan_actions, plan_log_prob 134 | 135 | def act(self, obs, deterministic=False, return_logprob=False): 136 | """ sample action """ 137 | with torch.no_grad(): 138 | obs = torch.as_tensor(obs, dtype=torch.float32, device=self.device) 139 | action, log_prob = self.actor4ward(obs, deterministic) 140 | action = action.cpu().detach().numpy() 141 | log_prob = log_prob.cpu().detach().numpy() 142 | 143 | if return_logprob: 144 | return action, log_prob 145 | else: 146 | return action 147 | 148 | def value(self, obs, action): 149 | with torch.no_grad(): 150 | obs = torch.as_tensor(obs, dtype=torch.float32, device=self.device) 151 | if len(obs.shape) == 1: 152 | obs = obs.reshape(1, -1) 153 | na, log_prob = self.actor4ward_plan(obs) 154 | q1, q2 = self.critic1(obs, na), self.critic2(obs, na) 155 | value = torch.cat((q1, q2), dim=-1).mean(1).flatten().cpu().numpy() 156 | return value 157 | 158 | def plan(self, obs, deterministic=False): 159 | """ planning """ 160 | plan_actions = [] 161 | 162 | for _ in range(self.plan_length): 163 | with torch.no_grad(): 164 | action = self.act(obs, deterministic) 165 | plan_actions.append(action) 166 | obs, _, _, _ = self.dynamics.step(obs, action) 167 | return plan_actions 168 | 169 | def rollout_transitions(self, init_transitions, rollout_len): 170 | if not hasattr(self, "gammas"): 171 | self.gammas = self._gamma**np.arange(self.plan_length).reshape((self.plan_length, 1)) 172 | 173 | """ rollout to generate {plan_length}-steps transitions """ 174 | obs = init_transitions["s_"][:, -self.obs_dim:] 175 | transitions = init_transitions 176 | 177 | nstep_transitions = {"s": [], "a": [], "r": [], "s_": [], "done": []} 178 | for _ in range(rollout_len): 179 | # imaginary step 180 | actions = self.act(obs) 181 | next_obs, rewards, dones, _ = self.dynamics.step(obs, actions) 182 | 183 | # update 184 | transitions["s"] = np.concatenate((transitions["s"], obs), axis=-1) 185 | transitions["a"] = np.concatenate((transitions["a"], actions), axis=-1) 186 | transitions["r"] = np.concatenate((transitions["r"], rewards), axis=-1) 187 | transitions["s_"] = np.concatenate((transitions["s_"], next_obs), axis=-1) 188 | transitions["done"] = np.concatenate((transitions["done"], dones), axis=-1) 189 | 190 | # store 191 | nstep_transitions["s"].append(transitions["s"][:, :self.obs_dim]) 192 | nstep_transitions["a"].append(transitions["a"]) 193 | nstep_transitions["r"].append(transitions["r"].dot(self.gammas)) 194 | nstep_transitions["s_"].append(transitions["s_"][:, -self.obs_dim:]) 195 | nstep_transitions["done"].append(transitions["done"].sum(-1, keepdims=True).clip(None, 1)) 196 | 197 | # to next step 198 | nonterm_mask = (~dones).flatten() 199 | if nonterm_mask.sum() == 0: break 200 | obs = next_obs[nonterm_mask] 201 | 202 | # mask 203 | transitions["s"] = transitions["s"][nonterm_mask, self.obs_dim:] 204 | transitions["a"] = transitions["a"][nonterm_mask, self.action_dim:] 205 | transitions["r"] = transitions["r"][nonterm_mask, 1:] 206 | transitions["s_"] = transitions["s_"][nonterm_mask, self.obs_dim:] 207 | transitions["done"] = transitions["done"][nonterm_mask, 1:] 208 | 209 | nstep_transitions = {key: np.concatenate(nstep_transitions[key], axis=0) for key in nstep_transitions.keys()} 210 | return nstep_transitions 211 | 212 | def learn_dynamics(self, transitions): 213 | """ learn dynamics model """ 214 | inputs, targets = format_samples_for_training(transitions) 215 | loss = self.dynamics.train( 216 | inputs, 217 | targets, 218 | batch_size=self.batch_size 219 | ) 220 | return loss["holdout_loss"].item() 221 | 222 | def learn_actor(self, s): 223 | """ learn predictive policy from {plan_length}-steps transitions """ 224 | s = torch.as_tensor(s, device=self.device) 225 | # update actor 226 | na, log_prob = self.actor4ward_plan(s) 227 | q1, q2 = self.critic1(s, na).flatten(), self.critic2(s, na).flatten() 228 | actor_loss = (self._alpha*log_prob.flatten() - torch.min(q1, q2)).mean() 229 | self.actor_optim.zero_grad() 230 | actor_loss.backward() 231 | torch.nn.utils.clip_grad_norm_(self.actor.parameters(), 10) 232 | self.actor_optim.step() 233 | actor_loss = actor_loss.item() 234 | 235 | # update alpha 236 | if self._auto_alpha: 237 | log_prob = log_prob.detach() + self._target_entropy 238 | alpha_loss = -(self._log_alpha*log_prob).mean() 239 | self._alpha_optim.zero_grad() 240 | alpha_loss.backward() 241 | self._alpha_optim.step() 242 | alpha_loss = alpha_loss.item() 243 | self._alpha = self._log_alpha.detach().exp() 244 | 245 | 246 | info = { 247 | "actor_loss": actor_loss 248 | } 249 | 250 | if self._auto_alpha: 251 | info["alpha_loss"] = alpha_loss 252 | info["alpha"] = self._alpha.item() 253 | else: 254 | info["alpha"] = self._alpha.item() 255 | 256 | return info 257 | 258 | def learn_critic(self, s, a, r, s_, done): 259 | """ learn predictive policy from {plan_length}-steps transitions """ 260 | s = torch.as_tensor(s, device=self.device) 261 | na = torch.as_tensor(a, device=self.device) 262 | nr = torch.as_tensor(r, device=self.device) 263 | nth_s_ = torch.as_tensor(s_, device=self.device) 264 | done = torch.as_tensor(done, device=self.device) 265 | 266 | # update critic 267 | q1, q2 = self.critic1(s, na), self.critic2(s, na) 268 | with torch.no_grad(): 269 | na_, log_prob_ = self.actor4ward_plan(nth_s_) 270 | nth_q_ = torch.min( 271 | self.critic1_trgt(nth_s_, na_), 272 | self.critic2_trgt(nth_s_, na_)) - self._alpha*log_prob_ 273 | q_trgt = nr + self._gamma**self.plan_length*(1-done)*nth_q_ 274 | 275 | critic1_loss = ((q1-q_trgt).pow(2)).mean() 276 | self.critic1_optim.zero_grad() 277 | critic1_loss.backward() 278 | self.critic1_optim.step() 279 | 280 | critic2_loss = ((q2-q_trgt).pow(2)).mean() 281 | self.critic2_optim.zero_grad() 282 | critic2_loss.backward() 283 | self.critic2_optim.step() 284 | 285 | # synchronize weight 286 | self._sync_weight() 287 | 288 | info = { 289 | "critic_loss": (critic1_loss.item() + critic2_loss.item()) / 2, 290 | "value": torch.cat((q1, q2), dim=-1).mean(1).mean().item() 291 | } 292 | return info 293 | 294 | def save_model(self, filepath): 295 | """ save model """ 296 | # save policy 297 | state_dict = { 298 | "actor": self.actor.state_dict(), 299 | "critic1": self.critic1.state_dict(), 300 | "critic2": self.critic2.state_dict(), 301 | "alpha": self._alpha 302 | } 303 | torch.save(state_dict, filepath) 304 | 305 | # save dynamics 306 | dynamics_dir = filepath.split(".pth")[0] 307 | if not os.path.exists(dynamics_dir): 308 | os.makedirs(dynamics_dir) 309 | self.dynamics.save(dynamics_dir) 310 | 311 | def load_model(self, filepath): 312 | """ load model """ 313 | # load policy 314 | state_dict = torch.load(filepath, map_location=torch.device(self.device)) 315 | self.actor.load_state_dict(state_dict["actor"]) 316 | self.critic1.load_state_dict(state_dict["critic1"]) 317 | self.critic2.load_state_dict(state_dict["critic2"]) 318 | self._alpha = state_dict["alpha"] 319 | --------------------------------------------------------------------------------