├── offlinerlkit ├── __init__.py ├── utils │ ├── __init__.py │ ├── noise.py │ ├── scaler.py │ ├── plotter.py │ ├── termination_fns.py │ └── load_dataset.py ├── policy │ ├── model_based │ │ ├── __init__.py │ │ ├── mopo.py │ │ ├── mobile.py │ │ └── combo.py │ ├── model_free │ │ ├── __init__.py │ │ ├── bc.py │ │ ├── td3bc.py │ │ ├── td3.py │ │ ├── sac.py │ │ ├── iql.py │ │ ├── mcq.py │ │ ├── edac.py │ │ └── cql.py │ ├── base_policy.py │ └── __init__.py ├── buffer │ ├── __init__.py │ └── buffer.py ├── policy_trainer │ ├── __init__.py │ ├── mf_policy_trainer.py │ └── mb_policy_trainer.py ├── nets │ ├── __init__.py │ ├── mlp.py │ ├── ensemble_linear.py │ ├── vae.py │ └── rnn.py ├── dynamics │ ├── __init__.py │ ├── base_dynamics.py │ ├── mujoco_oracle_dynamics.py │ ├── rnn_dynamics.py │ └── ensemble_dynamics.py └── modules │ ├── __init__.py │ ├── critic_module.py │ ├── actor_module.py │ ├── ensemble_critic_module.py │ ├── dist_module.py │ └── dynamics_module.py ├── assets └── logo.png ├── .gitignore ├── setup.py ├── LICENSE ├── run_example ├── run_bc.py ├── run_td3bc.py ├── run_mcq.py ├── plotter.py ├── run_edac.py ├── run_cql.py ├── run_iql.py └── run_mopo.py ├── tune_example └── tune_mopo.py └── README.md /offlinerlkit/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /offlinerlkit/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /offlinerlkit/policy/model_based/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /offlinerlkit/policy/model_free/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /assets/logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yihaosun1124/OfflineRL-Kit/HEAD/assets/logo.png -------------------------------------------------------------------------------- /offlinerlkit/buffer/__init__.py: -------------------------------------------------------------------------------- 1 | from offlinerlkit.buffer.buffer import ReplayBuffer 2 | 3 | 4 | __all__ = [ 5 | "ReplayBuffer" 6 | ] -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | **/__pycache__ 2 | .idea/ 3 | **/.DS_STORE 4 | **/log 5 | **/build 6 | **/dist 7 | **/*.egg-info 8 | **/*.txt 9 | **/.vscode 10 | **/_log -------------------------------------------------------------------------------- /offlinerlkit/policy_trainer/__init__.py: -------------------------------------------------------------------------------- 1 | from offlinerlkit.policy_trainer.mf_policy_trainer import MFPolicyTrainer 2 | from offlinerlkit.policy_trainer.mb_policy_trainer import MBPolicyTrainer 3 | 4 | __all__ = [ 5 | "MFPolicyTrainer", 6 | "MBPolicyTrainer" 7 | ] -------------------------------------------------------------------------------- /offlinerlkit/nets/__init__.py: -------------------------------------------------------------------------------- 1 | from offlinerlkit.nets.mlp import MLP 2 | from offlinerlkit.nets.vae import VAE 3 | from offlinerlkit.nets.ensemble_linear import EnsembleLinear 4 | from offlinerlkit.nets.rnn import RNNModel 5 | 6 | 7 | __all__ = [ 8 | "MLP", 9 | "VAE", 10 | "EnsembleLinear", 11 | "RNNModel" 12 | ] -------------------------------------------------------------------------------- /offlinerlkit/dynamics/__init__.py: -------------------------------------------------------------------------------- 1 | from offlinerlkit.dynamics.base_dynamics import BaseDynamics 2 | from offlinerlkit.dynamics.ensemble_dynamics import EnsembleDynamics 3 | from offlinerlkit.dynamics.rnn_dynamics import RNNDynamics 4 | from offlinerlkit.dynamics.mujoco_oracle_dynamics import MujocoOracleDynamics 5 | 6 | 7 | __all__ = [ 8 | "BaseDynamics", 9 | "EnsembleDynamics", 10 | "RNNDynamics", 11 | "MujocoOracleDynamics" 12 | ] -------------------------------------------------------------------------------- /offlinerlkit/modules/__init__.py: -------------------------------------------------------------------------------- 1 | from offlinerlkit.modules.actor_module import Actor, ActorProb 2 | from offlinerlkit.modules.critic_module import Critic 3 | from offlinerlkit.modules.ensemble_critic_module import EnsembleCritic 4 | from offlinerlkit.modules.dist_module import DiagGaussian, TanhDiagGaussian 5 | from offlinerlkit.modules.dynamics_module import EnsembleDynamicsModel 6 | 7 | 8 | __all__ = [ 9 | "Actor", 10 | "ActorProb", 11 | "Critic", 12 | "EnsembleCritic", 13 | "DiagGaussian", 14 | "TanhDiagGaussian", 15 | "EnsembleDynamicsModel" 16 | ] -------------------------------------------------------------------------------- /offlinerlkit/dynamics/base_dynamics.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | 5 | from typing import Callable, List, Tuple, Dict 6 | 7 | 8 | class BaseDynamics(object): 9 | def __init__( 10 | self, 11 | model: nn.Module, 12 | optim: torch.optim.Optimizer 13 | ) -> None: 14 | super().__init__() 15 | self.model = model 16 | self.optim = optim 17 | 18 | def step( 19 | self, 20 | obs: np.ndarray, 21 | action: np.ndarray 22 | ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, Dict]: 23 | raise NotImplementedError 24 | -------------------------------------------------------------------------------- /offlinerlkit/policy/base_policy.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | 5 | from typing import Dict, Union 6 | 7 | 8 | class BasePolicy(nn.Module): 9 | def __init__(self) -> None: 10 | super().__init__() 11 | 12 | def train() -> None: 13 | raise NotImplementedError 14 | 15 | def eval() -> None: 16 | raise NotImplementedError 17 | 18 | def select_action( 19 | self, 20 | obs: np.ndarray, 21 | deterministic: bool = False 22 | ) -> np.ndarray: 23 | raise NotImplementedError 24 | 25 | def learn(self, batch: Dict) -> Dict[str, float]: 26 | raise NotImplementedError -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup( 4 | name='offlinerlkit', 5 | version="0.0.1", 6 | description=( 7 | 'OfflineRL-kit' 8 | ), 9 | author='Yihao Sun', 10 | author_email='sunyh@lamda.nju.edu.cn', 11 | maintainer='yihaosun1124', 12 | packages=find_packages(), 13 | platforms=["all"], 14 | install_requires=[ 15 | "gym>=0.15.4,<=0.24.1", 16 | "matplotlib", 17 | "numpy", 18 | "pandas", 19 | # "ray==1.13.0", 20 | "torch", 21 | "tensorboard", 22 | "tqdm", 23 | ] 24 | ) 25 | -------------------------------------------------------------------------------- /offlinerlkit/modules/critic_module.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | from torch.nn import functional as F 5 | from typing import Union, Optional 6 | 7 | 8 | class Critic(nn.Module): 9 | def __init__(self, backbone: nn.Module, device: str = "cpu") -> None: 10 | super().__init__() 11 | 12 | self.device = torch.device(device) 13 | self.backbone = backbone.to(device) 14 | latent_dim = getattr(backbone, "output_dim") 15 | self.last = nn.Linear(latent_dim, 1).to(device) 16 | 17 | def forward( 18 | self, 19 | obs: Union[np.ndarray, torch.Tensor], 20 | actions: Optional[Union[np.ndarray, torch.Tensor]] = None 21 | ) -> torch.Tensor: 22 | obs = torch.as_tensor(obs, device=self.device, dtype=torch.float32) 23 | if actions is not None: 24 | actions = torch.as_tensor(actions, device=self.device, dtype=torch.float32).flatten(1) 25 | obs = torch.cat([obs, actions], dim=1) 26 | logits = self.backbone(obs) 27 | values = self.last(logits) 28 | return values -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Yi-hao Sun 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /offlinerlkit/dynamics/mujoco_oracle_dynamics.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from gym.envs.mujoco import mujoco_env 4 | from typing import Callable, List, Tuple, Dict 5 | 6 | 7 | class MujocoOracleDynamics(object): 8 | def __init__(self, env: mujoco_env.MujocoEnv) -> None: 9 | self.env = env 10 | 11 | def _set_state_from_obs(self, obs:np.ndarray) -> None: 12 | self.env.reset() 13 | if len(obs) == (self.env.model.nq + self.env.model.nv - 1): 14 | xpos = np.zeros(1) 15 | obs = np.concatenate([xpos, obs]) 16 | qpos = obs[:self.env.model.nq] 17 | qvel = obs[self.env.model.nq:] 18 | self.env._elapsed_steps = 0 19 | self.env.set_state(qpos, qvel) 20 | 21 | def step( 22 | self, 23 | obs: np.ndarray, 24 | action: np.ndarray 25 | ) -> Tuple[np.ndarray, float, bool, Dict]: 26 | if (len(obs.shape) > 1) or (len(action.shape) > 1): 27 | raise ValueError 28 | self.env.reset() 29 | self._set_state_from_obs(obs) 30 | next_obs, reward, terminal, info = self.env.step(action) 31 | return next_obs, reward, terminal, info -------------------------------------------------------------------------------- /offlinerlkit/policy/__init__.py: -------------------------------------------------------------------------------- 1 | from offlinerlkit.policy.base_policy import BasePolicy 2 | 3 | # model free 4 | from offlinerlkit.policy.model_free.bc import BCPolicy 5 | from offlinerlkit.policy.model_free.sac import SACPolicy 6 | from offlinerlkit.policy.model_free.td3 import TD3Policy 7 | from offlinerlkit.policy.model_free.cql import CQLPolicy 8 | from offlinerlkit.policy.model_free.iql import IQLPolicy 9 | from offlinerlkit.policy.model_free.mcq import MCQPolicy 10 | from offlinerlkit.policy.model_free.td3bc import TD3BCPolicy 11 | from offlinerlkit.policy.model_free.edac import EDACPolicy 12 | 13 | # model based 14 | from offlinerlkit.policy.model_based.mopo import MOPOPolicy 15 | from offlinerlkit.policy.model_based.mobile import MOBILEPolicy 16 | from offlinerlkit.policy.model_based.rambo import RAMBOPolicy 17 | from offlinerlkit.policy.model_based.combo import COMBOPolicy 18 | 19 | 20 | __all__ = [ 21 | "BasePolicy", 22 | "BCPolicy", 23 | "SACPolicy", 24 | "TD3Policy", 25 | "CQLPolicy", 26 | "IQLPolicy", 27 | "MCQPolicy", 28 | "TD3BCPolicy", 29 | "EDACPolicy", 30 | "MOPOPolicy", 31 | "MOBILEPolicy", 32 | "RAMBOPolicy", 33 | "COMBOPolicy" 34 | ] -------------------------------------------------------------------------------- /offlinerlkit/nets/mlp.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | 5 | from torch.nn import functional as F 6 | from typing import Dict, List, Union, Tuple, Optional 7 | 8 | 9 | class MLP(nn.Module): 10 | def __init__( 11 | self, 12 | input_dim: int, 13 | hidden_dims: Union[List[int], Tuple[int]], 14 | output_dim: Optional[int] = None, 15 | activation: nn.Module = nn.ReLU, 16 | dropout_rate: Optional[float] = None 17 | ) -> None: 18 | super().__init__() 19 | hidden_dims = [input_dim] + list(hidden_dims) 20 | model = [] 21 | for in_dim, out_dim in zip(hidden_dims[:-1], hidden_dims[1:]): 22 | model += [nn.Linear(in_dim, out_dim), activation()] 23 | if dropout_rate is not None: 24 | model += [nn.Dropout(p=dropout_rate)] 25 | 26 | self.output_dim = hidden_dims[-1] 27 | if output_dim is not None: 28 | model += [nn.Linear(hidden_dims[-1], output_dim)] 29 | self.output_dim = output_dim 30 | self.model = nn.Sequential(*model) 31 | 32 | def forward(self, x: torch.Tensor) -> torch.Tensor: 33 | return self.model(x) -------------------------------------------------------------------------------- /offlinerlkit/policy/model_free/bc.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | 5 | from torch.nn import functional as F 6 | from typing import Dict, Union, Tuple, Callable 7 | from offlinerlkit.policy import BasePolicy 8 | 9 | 10 | class BCPolicy(BasePolicy): 11 | 12 | def __init__( 13 | self, 14 | actor: nn.Module, 15 | actor_optim: torch.optim.Optimizer 16 | ) -> None: 17 | 18 | super().__init__() 19 | self.actor = actor 20 | self.actor_optim = actor_optim 21 | 22 | def train(self) -> None: 23 | self.actor.train() 24 | 25 | def eval(self) -> None: 26 | self.actor.eval() 27 | 28 | def select_action(self, obs: np.ndarray, deterministic: bool = False) -> np.ndarray: 29 | with torch.no_grad(): 30 | action = self.actor(obs).cpu().numpy() 31 | return action 32 | 33 | def learn(self, batch: Dict) -> Dict[str, float]: 34 | obss, actions = batch["observations"], batch["actions"] 35 | 36 | a = self.actor(obss) 37 | actor_loss = ((a - actions).pow(2)).mean() 38 | self.actor_optim.zero_grad() 39 | actor_loss.backward() 40 | self.actor_optim.step() 41 | 42 | return { 43 | "loss/actor": actor_loss.item() 44 | } -------------------------------------------------------------------------------- /offlinerlkit/modules/actor_module.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | from torch.nn import functional as F 5 | from typing import Union, Optional 6 | 7 | 8 | # for SAC 9 | class ActorProb(nn.Module): 10 | def __init__( 11 | self, 12 | backbone: nn.Module, 13 | dist_net: nn.Module, 14 | device: str = "cpu" 15 | ) -> None: 16 | super().__init__() 17 | 18 | self.device = torch.device(device) 19 | self.backbone = backbone.to(device) 20 | self.dist_net = dist_net.to(device) 21 | 22 | def forward(self, obs: Union[np.ndarray, torch.Tensor]) -> torch.distributions.Normal: 23 | obs = torch.as_tensor(obs, device=self.device, dtype=torch.float32) 24 | logits = self.backbone(obs) 25 | dist = self.dist_net(logits) 26 | return dist 27 | 28 | 29 | # for TD3 30 | class Actor(nn.Module): 31 | def __init__( 32 | self, 33 | backbone: nn.Module, 34 | action_dim: int, 35 | max_action: float = 1.0, 36 | device: str = "cpu" 37 | ) -> None: 38 | super().__init__() 39 | 40 | self.device = torch.device(device) 41 | self.backbone = backbone.to(device) 42 | latent_dim = getattr(backbone, "output_dim") 43 | output_dim = action_dim 44 | self.last = nn.Linear(latent_dim, output_dim).to(device) 45 | self._max = max_action 46 | 47 | def forward(self, obs: Union[np.ndarray, torch.Tensor]) -> torch.Tensor: 48 | obs = torch.as_tensor(obs, device=self.device, dtype=torch.float32) 49 | logits = self.backbone(obs) 50 | actions = self._max * torch.tanh(self.last(logits)) 51 | return actions -------------------------------------------------------------------------------- /offlinerlkit/modules/ensemble_critic_module.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | from torch.nn import functional as F 5 | from typing import Union, Optional, List, Tuple 6 | 7 | from offlinerlkit.nets import EnsembleLinear 8 | 9 | 10 | class EnsembleCritic(nn.Module): 11 | def __init__( 12 | self, 13 | obs_dim: int, 14 | action_dim: int, 15 | hidden_dims: Union[List[int], Tuple[int]], 16 | activation: nn.Module = nn.ReLU, 17 | num_ensemble: int = 10, 18 | device: str = "cpu" 19 | ) -> None: 20 | super().__init__() 21 | input_dim = obs_dim + action_dim 22 | hidden_dims = [input_dim] + list(hidden_dims) 23 | model = [] 24 | for in_dim, out_dim in zip(hidden_dims[:-1], hidden_dims[1:]): 25 | model += [EnsembleLinear(in_dim, out_dim, num_ensemble), activation()] 26 | model.append(EnsembleLinear(hidden_dims[-1], 1, num_ensemble)) 27 | self.model = nn.Sequential(*model) 28 | 29 | self.device = torch.device(device) 30 | self.model = self.model.to(device) 31 | self._num_ensemble = num_ensemble 32 | 33 | def forward( 34 | self, 35 | obs: Union[np.ndarray, torch.Tensor], 36 | actions: Optional[Union[np.ndarray, torch.Tensor]] = None 37 | ) -> torch.Tensor: 38 | obs = torch.as_tensor(obs, device=self.device, dtype=torch.float32) 39 | if actions is not None: 40 | actions = torch.as_tensor(actions, device=self.device, dtype=torch.float32) 41 | obs = torch.cat([obs, actions], dim=-1) 42 | values = self.model(obs) 43 | # values: [num_ensemble, batch_size, 1] 44 | return values -------------------------------------------------------------------------------- /offlinerlkit/utils/noise.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | class GaussianNoise: 5 | """The vanilla Gaussian process, for exploration in DDPG by default.""" 6 | 7 | def __init__(self, mu=0.0, sigma=1.0): 8 | self._mu = mu 9 | assert 0 <= sigma, "Noise std should not be negative." 10 | self._sigma = sigma 11 | 12 | def __call__(self, size): 13 | return np.random.normal(self._mu, self._sigma, size) 14 | 15 | 16 | class OUNoise: 17 | """Class for Ornstein-Uhlenbeck process, as used for exploration in DDPG. 18 | 19 | Usage: 20 | :: 21 | 22 | # init 23 | self.noise = OUNoise() 24 | # generate noise 25 | noise = self.noise(logits.shape, eps) 26 | 27 | For required parameters, you can refer to the stackoverflow page. However, 28 | our experiment result shows that (similar to OpenAI SpinningUp) using 29 | vanilla Gaussian process has little difference from using the 30 | Ornstein-Uhlenbeck process. 31 | """ 32 | 33 | def __init__(self, mu=0.0, sigma=0.3, theta=0.15, dt=1e-2, x0=None): 34 | self._mu = mu 35 | self._alpha = theta * dt 36 | self._beta = sigma * np.sqrt(dt) 37 | self._x0 = x0 38 | self.reset() 39 | 40 | def reset(self) -> None: 41 | """Reset to the initial state.""" 42 | self._x = self._x0 43 | 44 | def __call__(self, size, mu=None): 45 | """Generate new noise. 46 | 47 | Return an numpy array which size is equal to ``size``. 48 | """ 49 | if self._x is None or isinstance( 50 | self._x, np.ndarray 51 | ) and self._x.shape != size: 52 | self._x = 0.0 53 | if mu is None: 54 | mu = self._mu 55 | r = self._beta * np.random.normal(size=size) 56 | self._x = self._x + self._alpha * (mu - self._x) + r 57 | return self._x # type: ignore -------------------------------------------------------------------------------- /offlinerlkit/nets/ensemble_linear.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | from torch.nn import functional as F 5 | from typing import Dict, List, Union, Tuple, Optional 6 | 7 | 8 | class EnsembleLinear(nn.Module): 9 | def __init__( 10 | self, 11 | input_dim: int, 12 | output_dim: int, 13 | num_ensemble: int, 14 | weight_decay: float = 0.0 15 | ) -> None: 16 | super().__init__() 17 | 18 | self.num_ensemble = num_ensemble 19 | 20 | self.register_parameter("weight", nn.Parameter(torch.zeros(num_ensemble, input_dim, output_dim))) 21 | self.register_parameter("bias", nn.Parameter(torch.zeros(num_ensemble, 1, output_dim))) 22 | 23 | nn.init.trunc_normal_(self.weight, std=1/(2*input_dim**0.5)) 24 | 25 | self.register_parameter("saved_weight", nn.Parameter(self.weight.detach().clone())) 26 | self.register_parameter("saved_bias", nn.Parameter(self.bias.detach().clone())) 27 | 28 | self.weight_decay = weight_decay 29 | 30 | def forward(self, x: torch.Tensor) -> torch.Tensor: 31 | weight = self.weight 32 | bias = self.bias 33 | 34 | if len(x.shape) == 2: 35 | x = torch.einsum('ij,bjk->bik', x, weight) 36 | else: 37 | x = torch.einsum('bij,bjk->bik', x, weight) 38 | 39 | x = x + bias 40 | 41 | return x 42 | 43 | def load_save(self) -> None: 44 | self.weight.data.copy_(self.saved_weight.data) 45 | self.bias.data.copy_(self.saved_bias.data) 46 | 47 | def update_save(self, indexes: List[int]) -> None: 48 | self.saved_weight.data[indexes] = self.weight.data[indexes] 49 | self.saved_bias.data[indexes] = self.bias.data[indexes] 50 | 51 | def get_decay_loss(self) -> torch.Tensor: 52 | decay_loss = self.weight_decay * (0.5*((self.weight**2).sum())) 53 | return decay_loss -------------------------------------------------------------------------------- /offlinerlkit/nets/vae.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn import functional as F 4 | from typing import Dict, List, Union, Tuple, Optional 5 | 6 | 7 | # Vanilla Variational Auto-Encoder 8 | class VAE(nn.Module): 9 | def __init__( 10 | self, 11 | input_dim: int, 12 | output_dim: int, 13 | hidden_dim: int, 14 | latent_dim: int, 15 | max_action: Union[int, float], 16 | device: str = "cpu" 17 | ) -> None: 18 | super(VAE, self).__init__() 19 | self.e1 = nn.Linear(input_dim + output_dim, hidden_dim) 20 | self.e2 = nn.Linear(hidden_dim, hidden_dim) 21 | 22 | self.mean = nn.Linear(hidden_dim, latent_dim) 23 | self.log_std = nn.Linear(hidden_dim, latent_dim) 24 | 25 | self.d1 = nn.Linear(input_dim + latent_dim, hidden_dim) 26 | self.d2 = nn.Linear(hidden_dim, hidden_dim) 27 | self.d3 = nn.Linear(hidden_dim, output_dim) 28 | 29 | self.max_action = max_action 30 | self.latent_dim = latent_dim 31 | self.device = torch.device(device) 32 | 33 | self.to(device=self.device) 34 | 35 | 36 | def forward( 37 | self, 38 | obs: torch.Tensor, 39 | action: torch.Tensor 40 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 41 | z = F.relu(self.e1(torch.cat([obs, action], 1))) 42 | z = F.relu(self.e2(z)) 43 | 44 | mean = self.mean(z) 45 | # Clamped for numerical stability 46 | log_std = self.log_std(z).clamp(-4, 15) 47 | std = torch.exp(log_std) 48 | z = mean + std * torch.randn_like(std) 49 | 50 | u = self.decode(obs, z) 51 | 52 | return u, mean, std 53 | 54 | def decode(self, obs: torch.Tensor, z: Optional[torch.Tensor] = None) -> torch.Tensor: 55 | # When sampling from the VAE, the latent vector is clipped to [-0.5, 0.5] 56 | if z is None: 57 | z = torch.randn((obs.shape[0], self.latent_dim)).to(self.device).clamp(-0.5,0.5) 58 | 59 | a = F.relu(self.d1(torch.cat([obs, z], 1))) 60 | a = F.relu(self.d2(a)) 61 | return self.max_action * torch.tanh(self.d3(a)) -------------------------------------------------------------------------------- /offlinerlkit/utils/scaler.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os.path as path 3 | import torch 4 | 5 | 6 | class StandardScaler(object): 7 | def __init__(self, mu=None, std=None): 8 | self.mu = mu 9 | self.std = std 10 | 11 | def fit(self, data): 12 | """Runs two ops, one for assigning the mean of the data to the internal mean, and 13 | another for assigning the standard deviation of the data to the internal standard deviation. 14 | This function must be called within a 'with .as_default()' block. 15 | 16 | Arguments: 17 | data (np.ndarray): A numpy array containing the input 18 | 19 | Returns: None. 20 | """ 21 | self.mu = np.mean(data, axis=0, keepdims=True) 22 | self.std = np.std(data, axis=0, keepdims=True) 23 | self.std[self.std < 1e-12] = 1.0 24 | 25 | def transform(self, data): 26 | """Transforms the input matrix data using the parameters of this scaler. 27 | 28 | Arguments: 29 | data (np.array): A numpy array containing the points to be transformed. 30 | 31 | Returns: (np.array) The transformed dataset. 32 | """ 33 | return (data - self.mu) / self.std 34 | 35 | def inverse_transform(self, data): 36 | """Undoes the transformation performed by this scaler. 37 | 38 | Arguments: 39 | data (np.array): A numpy array containing the points to be transformed. 40 | 41 | Returns: (np.array) The transformed dataset. 42 | """ 43 | return self.std * data + self.mu 44 | 45 | def save_scaler(self, save_path): 46 | mu_path = path.join(save_path, "mu.npy") 47 | std_path = path.join(save_path, "std.npy") 48 | np.save(mu_path, self.mu) 49 | np.save(std_path, self.std) 50 | 51 | def load_scaler(self, load_path): 52 | mu_path = path.join(load_path, "mu.npy") 53 | std_path = path.join(load_path, "std.npy") 54 | self.mu = np.load(mu_path) 55 | self.std = np.load(std_path) 56 | 57 | def transform_tensor(self, data: torch.Tensor): 58 | device = data.device 59 | data = self.transform(data.cpu().numpy()) 60 | data = torch.tensor(data, device=device) 61 | return data -------------------------------------------------------------------------------- /offlinerlkit/dynamics/rnn_dynamics.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import torch 4 | import torch.nn as nn 5 | 6 | from typing import Callable, List, Tuple, Dict 7 | from torch.utils.data.dataloader import DataLoader 8 | from offlinerlkit.dynamics import BaseDynamics 9 | from offlinerlkit.utils.scaler import StandardScaler 10 | from offlinerlkit.utils.logger import Logger 11 | 12 | 13 | class RNNDynamics(BaseDynamics): 14 | def __init__( 15 | self, 16 | model: nn.Module, 17 | optim: torch.optim.Optimizer, 18 | scaler: StandardScaler, 19 | terminal_fn: Callable[[np.ndarray, np.ndarray, np.ndarray], np.ndarray], 20 | ) -> None: 21 | super().__init__(model, optim) 22 | self.scaler = scaler 23 | self.terminal_fn = terminal_fn 24 | 25 | @ torch.no_grad() 26 | def step( 27 | self, 28 | obss: np.ndarray, 29 | actions: np.ndarray 30 | ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, Dict]: 31 | "imagine single forward step" 32 | inputs = np.concatenate([obss, actions], axis=-1) 33 | inputs = self.scaler.transform(inputs) 34 | preds, _ = self.model(inputs) 35 | # get last timestep pred 36 | preds = preds[:, -1] 37 | next_obss = preds[..., :-1].cpu().numpy() + obss[:, -1] 38 | rewards = preds[..., -1:].cpu().numpy() 39 | 40 | terminals = self.terminal_fn(obss[:, -1], actions[:, -1], next_obss) 41 | info = {} 42 | 43 | return next_obss, rewards, terminals, info 44 | 45 | def train(self, data: Dict, batch_size: int, max_iters: int, logger: Logger) -> None: 46 | self.model.train() 47 | loader = DataLoader(data, shuffle=True, batch_size=batch_size) 48 | for iter in range(max_iters): 49 | for batch in loader: 50 | train_loss = self.learn(batch) 51 | logger.logkv_mean("loss/model", train_loss) 52 | 53 | logger.set_timestep(iter) 54 | logger.dumpkvs(exclude=["policy_training_progress"]) 55 | self.save(logger.model_dir) 56 | self.model.eval() 57 | 58 | def learn(self, batch) -> float: 59 | inputs, targets, masks = batch 60 | preds, _ = self.model.forward(inputs) 61 | 62 | loss = (((preds - targets) ** 2).mean(-1) * masks).mean() 63 | 64 | self.optim.zero_grad() 65 | loss.backward() 66 | self.optim.step() 67 | 68 | return loss.item() 69 | 70 | def save(self, save_path: str) -> None: 71 | torch.save(self.model.state_dict(), os.path.join(save_path, "dynamics.pth")) 72 | self.scaler.save_scaler(save_path) 73 | 74 | def load(self, load_path: str) -> None: 75 | self.model.load_state_dict(torch.load(os.path.join(load_path, "dynamics.pth"), map_location=self.model.device)) 76 | self.scaler.load_scaler(load_path) -------------------------------------------------------------------------------- /offlinerlkit/policy/model_based/mopo.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import gym 5 | 6 | from torch.nn import functional as F 7 | from typing import Dict, Union, Tuple 8 | from collections import defaultdict 9 | from offlinerlkit.policy import SACPolicy 10 | from offlinerlkit.dynamics import BaseDynamics 11 | 12 | 13 | class MOPOPolicy(SACPolicy): 14 | """ 15 | Model-based Offline Policy Optimization 16 | """ 17 | 18 | def __init__( 19 | self, 20 | dynamics: BaseDynamics, 21 | actor: nn.Module, 22 | critic1: nn.Module, 23 | critic2: nn.Module, 24 | actor_optim: torch.optim.Optimizer, 25 | critic1_optim: torch.optim.Optimizer, 26 | critic2_optim: torch.optim.Optimizer, 27 | tau: float = 0.005, 28 | gamma: float = 0.99, 29 | alpha: Union[float, Tuple[float, torch.Tensor, torch.optim.Optimizer]] = 0.2 30 | ) -> None: 31 | super().__init__( 32 | actor, 33 | critic1, 34 | critic2, 35 | actor_optim, 36 | critic1_optim, 37 | critic2_optim, 38 | tau=tau, 39 | gamma=gamma, 40 | alpha=alpha 41 | ) 42 | 43 | self.dynamics = dynamics 44 | 45 | def rollout( 46 | self, 47 | init_obss: np.ndarray, 48 | rollout_length: int 49 | ) -> Tuple[Dict[str, np.ndarray], Dict]: 50 | 51 | num_transitions = 0 52 | rewards_arr = np.array([]) 53 | rollout_transitions = defaultdict(list) 54 | 55 | # rollout 56 | observations = init_obss 57 | for _ in range(rollout_length): 58 | actions = self.select_action(observations) 59 | next_observations, rewards, terminals, info = self.dynamics.step(observations, actions) 60 | rollout_transitions["obss"].append(observations) 61 | rollout_transitions["next_obss"].append(next_observations) 62 | rollout_transitions["actions"].append(actions) 63 | rollout_transitions["rewards"].append(rewards) 64 | rollout_transitions["terminals"].append(terminals) 65 | 66 | num_transitions += len(observations) 67 | rewards_arr = np.append(rewards_arr, rewards.flatten()) 68 | 69 | nonterm_mask = (~terminals).flatten() 70 | if nonterm_mask.sum() == 0: 71 | break 72 | 73 | observations = next_observations[nonterm_mask] 74 | 75 | for k, v in rollout_transitions.items(): 76 | rollout_transitions[k] = np.concatenate(v, axis=0) 77 | 78 | return rollout_transitions, \ 79 | {"num_transitions": num_transitions, "reward_mean": rewards_arr.mean()} 80 | 81 | def learn(self, batch: Dict) -> Dict[str, float]: 82 | real_batch, fake_batch = batch["real"], batch["fake"] 83 | mix_batch = {k: torch.cat([real_batch[k], fake_batch[k]], 0) for k in real_batch.keys()} 84 | return super().learn(mix_batch) 85 | -------------------------------------------------------------------------------- /run_example/run_bc.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import random 3 | 4 | import gym 5 | import d4rl 6 | 7 | import numpy as np 8 | import torch 9 | 10 | 11 | from offlinerlkit.nets import MLP 12 | from offlinerlkit.modules import Actor 13 | from offlinerlkit.buffer import ReplayBuffer 14 | from offlinerlkit.utils.logger import Logger, make_log_dirs 15 | from offlinerlkit.policy_trainer import MFPolicyTrainer 16 | from offlinerlkit.policy import BCPolicy 17 | 18 | 19 | def get_args(): 20 | parser = argparse.ArgumentParser() 21 | parser.add_argument("--algo-name", type=str, default="bc") 22 | parser.add_argument("--task", type=str, default="hopper-medium-v2") 23 | parser.add_argument("--seed", type=int, default=0) 24 | parser.add_argument("--actor-lr", type=float, default=3e-4) 25 | parser.add_argument("--epoch", type=int, default=200) 26 | parser.add_argument("--step-per-epoch", type=int, default=1000) 27 | parser.add_argument("--eval_episodes", type=int, default=20) 28 | parser.add_argument("--batch-size", type=int, default=256) 29 | parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu") 30 | 31 | return parser.parse_args() 32 | 33 | 34 | def train(args=get_args()): 35 | # create env and dataset 36 | env = gym.make(args.task) 37 | dataset = d4rl.qlearning_dataset(env) 38 | args.obs_shape = env.observation_space.shape 39 | args.action_dim = np.prod(env.action_space.shape) 40 | args.max_action = env.action_space.high[0] 41 | 42 | # create buffer 43 | buffer = ReplayBuffer( 44 | buffer_size=len(dataset["observations"]), 45 | obs_shape=args.obs_shape, 46 | obs_dtype=np.float32, 47 | action_dim=args.action_dim, 48 | action_dtype=np.float32, 49 | device=args.device 50 | ) 51 | buffer.load_dataset(dataset) 52 | 53 | # seed 54 | random.seed(args.seed) 55 | np.random.seed(args.seed) 56 | torch.manual_seed(args.seed) 57 | torch.cuda.manual_seed_all(args.seed) 58 | torch.backends.cudnn.deterministic = True 59 | env.seed(args.seed) 60 | 61 | # create policy model 62 | actor_backbone = MLP(input_dim=np.prod(args.obs_shape), hidden_dims=[256, 256]) 63 | actor = Actor(actor_backbone, args.action_dim, max_action=args.max_action, device=args.device) 64 | actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr) 65 | 66 | # create policy 67 | policy = BCPolicy(actor, actor_optim) 68 | 69 | # log 70 | log_dirs = make_log_dirs(args.task, args.algo_name, args.seed, vars(args)) 71 | # key: output file name, value: output handler type 72 | output_config = { 73 | "consoleout_backup": "stdout", 74 | "policy_training_progress": "csv", 75 | "tb": "tensorboard" 76 | } 77 | logger = Logger(log_dirs, output_config) 78 | logger.log_hyperparameters(vars(args)) 79 | 80 | # create policy trainer 81 | policy_trainer = MFPolicyTrainer( 82 | policy=policy, 83 | eval_env=env, 84 | buffer=buffer, 85 | logger=logger, 86 | epoch=args.epoch, 87 | step_per_epoch=args.step_per_epoch, 88 | batch_size=args.batch_size, 89 | eval_episodes=args.eval_episodes 90 | ) 91 | 92 | # train 93 | policy_trainer.train() 94 | 95 | 96 | if __name__ == "__main__": 97 | train() -------------------------------------------------------------------------------- /offlinerlkit/nets/rnn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn import functional as F 4 | 5 | 6 | class Swish(nn.Module): 7 | def __init__(self): 8 | super(Swish, self).__init__() 9 | 10 | def forward(self, x): 11 | x = x * torch.sigmoid(x) 12 | return x 13 | 14 | 15 | def soft_clamp(x : torch.Tensor, _min=None, _max=None): 16 | # clamp tensor values while mataining the gradient 17 | if _max is not None: 18 | x = _max - F.softplus(_max - x) 19 | if _min is not None: 20 | x = _min + F.softplus(x - _min) 21 | return x 22 | 23 | 24 | class ResBlock(nn.Module): 25 | def __init__( 26 | self, 27 | input_dim, 28 | output_dim, 29 | activation=Swish(), 30 | layer_norm=True, 31 | with_residual=True, 32 | dropout=0.1 33 | ): 34 | super().__init__() 35 | 36 | self.linear = nn.Linear(input_dim, output_dim) 37 | self.activation = activation 38 | self.layer_norm = nn.LayerNorm(output_dim) if layer_norm else None 39 | self.dropout = nn.Dropout(dropout) if dropout else None 40 | self.with_residual = with_residual 41 | 42 | def forward(self, x): 43 | y = self.activation(self.linear(x)) 44 | if self.dropout is not None: 45 | y = self.dropout(y) 46 | if self.with_residual: 47 | y = x + y 48 | if self.layer_norm is not None: 49 | y = self.layer_norm(y) 50 | return y 51 | 52 | 53 | class RNNModel(nn.Module): 54 | def __init__( 55 | self, 56 | input_dim, 57 | output_dim, 58 | hidden_dims=[200, 200, 200, 200], 59 | rnn_num_layers=3, 60 | dropout_rate=0.1, 61 | device="cpu" 62 | ): 63 | super().__init__() 64 | self.input_dim = input_dim 65 | self.hidden_dims = hidden_dims 66 | self.output_dim = output_dim 67 | self.device = torch.device(device) 68 | 69 | self.activation = Swish() 70 | self.rnn_layer = nn.GRU( 71 | input_size=input_dim, 72 | hidden_size=hidden_dims[0], 73 | num_layers=rnn_num_layers, 74 | batch_first=True 75 | ) 76 | module_list = [] 77 | self.input_layer = ResBlock(input_dim, hidden_dims[0], dropout=dropout_rate, with_residual=False) 78 | dims = list(hidden_dims) 79 | for in_dim, out_dim in zip(dims[:-1], dims[1:]): 80 | module_list.append(ResBlock(in_dim, out_dim, dropout=dropout_rate)) 81 | self.backbones = nn.ModuleList(module_list) 82 | self.merge_layer = nn.Linear(dims[0] + dims[-1], hidden_dims[0]) 83 | self.output_layer = nn.Linear(hidden_dims[-1], output_dim) 84 | 85 | self.to(self.device) 86 | 87 | def forward(self, input, h_state=None): 88 | batch_size, num_timesteps, _ = input.shape 89 | input = torch.as_tensor(input, dtype=torch.float32).to(self.device) 90 | rnn_output, h_state = self.rnn_layer(input, h_state) 91 | rnn_output = rnn_output.reshape(-1, self.hidden_dims[0]) 92 | input = input.view(-1, self.input_dim) 93 | output = self.input_layer(input) 94 | output = torch.cat([output, rnn_output], dim=-1) 95 | output = self.activation(self.merge_layer(output)) 96 | for layer in self.backbones: 97 | output = layer(output) 98 | output = self.output_layer(output) 99 | output = output.view(batch_size, num_timesteps, -1) 100 | return output, h_state 101 | 102 | 103 | if __name__ == "__main__": 104 | model = RNNModel(14, 12) 105 | x = torch.randn(64, 20, 14) 106 | y, _ = model(x) 107 | print(y.shape) -------------------------------------------------------------------------------- /offlinerlkit/modules/dist_module.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | 5 | 6 | class NormalWrapper(torch.distributions.Normal): 7 | def log_prob(self, actions): 8 | return super().log_prob(actions).sum(-1, keepdim=True) 9 | 10 | def entropy(self): 11 | return super().entropy().sum(-1) 12 | 13 | def mode(self): 14 | return self.mean 15 | 16 | 17 | class TanhNormalWrapper(torch.distributions.Normal): 18 | def __init__(self, loc, scale, max_action): 19 | super().__init__(loc, scale) 20 | self._max_action = max_action 21 | 22 | def log_prob(self, action, raw_action=None): 23 | squashed_action = action/self._max_action 24 | if raw_action is None: 25 | raw_action = self.arctanh(squashed_action) 26 | log_prob = super().log_prob(raw_action).sum(-1, keepdim=True) 27 | eps = 1e-6 28 | log_prob = log_prob - torch.log(self._max_action*(1 - squashed_action.pow(2)) + eps).sum(-1, keepdim=True) 29 | return log_prob 30 | 31 | def mode(self): 32 | raw_action = self.mean 33 | action = self._max_action * torch.tanh(self.mean) 34 | return action, raw_action 35 | 36 | def arctanh(self, x): 37 | one_plus_x = (1 + x).clamp(min=1e-6) 38 | one_minus_x = (1 - x).clamp(min=1e-6) 39 | return 0.5 * torch.log(one_plus_x / one_minus_x) 40 | 41 | def rsample(self): 42 | raw_action = super().rsample() 43 | action = self._max_action * torch.tanh(raw_action) 44 | return action, raw_action 45 | 46 | 47 | class DiagGaussian(nn.Module): 48 | def __init__( 49 | self, 50 | latent_dim, 51 | output_dim, 52 | unbounded=False, 53 | conditioned_sigma=False, 54 | max_mu=1.0, 55 | sigma_min=-5.0, 56 | sigma_max=2.0 57 | ): 58 | super().__init__() 59 | self.mu = nn.Linear(latent_dim, output_dim) 60 | self._c_sigma = conditioned_sigma 61 | if conditioned_sigma: 62 | self.sigma = nn.Linear(latent_dim, output_dim) 63 | else: 64 | self.sigma_param = nn.Parameter(torch.zeros(output_dim, 1)) 65 | self._unbounded = unbounded 66 | self._max = max_mu 67 | self._sigma_min = sigma_min 68 | self._sigma_max = sigma_max 69 | 70 | def forward(self, logits): 71 | mu = self.mu(logits) 72 | if not self._unbounded: 73 | mu = self._max * torch.tanh(mu) 74 | if self._c_sigma: 75 | sigma = torch.clamp(self.sigma(logits), min=self._sigma_min, max=self._sigma_max).exp() 76 | else: 77 | shape = [1] * len(mu.shape) 78 | shape[1] = -1 79 | sigma = (self.sigma_param.view(shape) + torch.zeros_like(mu)).exp() 80 | return NormalWrapper(mu, sigma) 81 | 82 | 83 | class TanhDiagGaussian(DiagGaussian): 84 | def __init__( 85 | self, 86 | latent_dim, 87 | output_dim, 88 | unbounded=False, 89 | conditioned_sigma=False, 90 | max_mu=1.0, 91 | sigma_min=-5.0, 92 | sigma_max=2.0 93 | ): 94 | super().__init__( 95 | latent_dim=latent_dim, 96 | output_dim=output_dim, 97 | unbounded=unbounded, 98 | conditioned_sigma=conditioned_sigma, 99 | max_mu=max_mu, 100 | sigma_min=sigma_min, 101 | sigma_max=sigma_max 102 | ) 103 | 104 | def forward(self, logits): 105 | mu = self.mu(logits) 106 | if not self._unbounded: 107 | mu = self._max * torch.tanh(mu) 108 | if self._c_sigma: 109 | sigma = torch.clamp(self.sigma(logits), min=self._sigma_min, max=self._sigma_max).exp() 110 | else: 111 | shape = [1] * len(mu.shape) 112 | shape[1] = -1 113 | sigma = (self.sigma_param.view(shape) + torch.zeros_like(mu)).exp() 114 | return TanhNormalWrapper(mu, sigma, self._max) -------------------------------------------------------------------------------- /offlinerlkit/modules/dynamics_module.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | from torch.nn import functional as F 5 | from typing import Dict, List, Union, Tuple, Optional 6 | from offlinerlkit.nets import EnsembleLinear 7 | 8 | 9 | class Swish(nn.Module): 10 | def __init__(self) -> None: 11 | super(Swish, self).__init__() 12 | 13 | def forward(self, x: torch.Tensor) -> torch.Tensor: 14 | x = x * torch.sigmoid(x) 15 | return x 16 | 17 | 18 | def soft_clamp( 19 | x : torch.Tensor, 20 | _min: Optional[torch.Tensor] = None, 21 | _max: Optional[torch.Tensor] = None 22 | ) -> torch.Tensor: 23 | # clamp tensor values while mataining the gradient 24 | if _max is not None: 25 | x = _max - F.softplus(_max - x) 26 | if _min is not None: 27 | x = _min + F.softplus(x - _min) 28 | return x 29 | 30 | 31 | class EnsembleDynamicsModel(nn.Module): 32 | def __init__( 33 | self, 34 | obs_dim: int, 35 | action_dim: int, 36 | hidden_dims: Union[List[int], Tuple[int]], 37 | num_ensemble: int = 7, 38 | num_elites: int = 5, 39 | activation: nn.Module = Swish, 40 | weight_decays: Optional[Union[List[float], Tuple[float]]] = None, 41 | with_reward: bool = True, 42 | device: str = "cpu" 43 | ) -> None: 44 | super().__init__() 45 | 46 | self.num_ensemble = num_ensemble 47 | self.num_elites = num_elites 48 | self._with_reward = with_reward 49 | self.device = torch.device(device) 50 | 51 | self.activation = activation() 52 | 53 | assert len(weight_decays) == (len(hidden_dims) + 1) 54 | 55 | module_list = [] 56 | hidden_dims = [obs_dim+action_dim] + list(hidden_dims) 57 | if weight_decays is None: 58 | weight_decays = [0.0] * (len(hidden_dims) + 1) 59 | for in_dim, out_dim, weight_decay in zip(hidden_dims[:-1], hidden_dims[1:], weight_decays[:-1]): 60 | module_list.append(EnsembleLinear(in_dim, out_dim, num_ensemble, weight_decay)) 61 | self.backbones = nn.ModuleList(module_list) 62 | 63 | self.output_layer = EnsembleLinear( 64 | hidden_dims[-1], 65 | 2 * (obs_dim + self._with_reward), 66 | num_ensemble, 67 | weight_decays[-1] 68 | ) 69 | 70 | self.register_parameter( 71 | "max_logvar", 72 | nn.Parameter(torch.ones(obs_dim + self._with_reward) * 0.5, requires_grad=True) 73 | ) 74 | self.register_parameter( 75 | "min_logvar", 76 | nn.Parameter(torch.ones(obs_dim + self._with_reward) * -10, requires_grad=True) 77 | ) 78 | 79 | self.register_parameter( 80 | "elites", 81 | nn.Parameter(torch.tensor(list(range(0, self.num_elites))), requires_grad=False) 82 | ) 83 | 84 | self.to(self.device) 85 | 86 | def forward(self, obs_action: np.ndarray) -> Tuple[torch.Tensor, torch.Tensor]: 87 | obs_action = torch.as_tensor(obs_action, dtype=torch.float32).to(self.device) 88 | output = obs_action 89 | for layer in self.backbones: 90 | output = self.activation(layer(output)) 91 | mean, logvar = torch.chunk(self.output_layer(output), 2, dim=-1) 92 | logvar = soft_clamp(logvar, self.min_logvar, self.max_logvar) 93 | return mean, logvar 94 | 95 | def load_save(self) -> None: 96 | for layer in self.backbones: 97 | layer.load_save() 98 | self.output_layer.load_save() 99 | 100 | def update_save(self, indexes: List[int]) -> None: 101 | for layer in self.backbones: 102 | layer.update_save(indexes) 103 | self.output_layer.update_save(indexes) 104 | 105 | def get_decay_loss(self) -> torch.Tensor: 106 | decay_loss = 0 107 | for layer in self.backbones: 108 | decay_loss += layer.get_decay_loss() 109 | decay_loss += self.output_layer.get_decay_loss() 110 | return decay_loss 111 | 112 | def set_elites(self, indexes: List[int]) -> None: 113 | assert len(indexes) <= self.num_ensemble and max(indexes) < self.num_ensemble 114 | self.register_parameter('elites', nn.Parameter(torch.tensor(indexes), requires_grad=False)) 115 | 116 | def random_elite_idxs(self, batch_size: int) -> np.ndarray: 117 | idxs = np.random.choice(self.elites.data.cpu().numpy(), size=batch_size) 118 | return idxs -------------------------------------------------------------------------------- /offlinerlkit/policy_trainer/mf_policy_trainer.py: -------------------------------------------------------------------------------- 1 | import time 2 | import os 3 | 4 | import numpy as np 5 | import torch 6 | import gym 7 | 8 | from typing import Optional, Dict, List 9 | from tqdm import tqdm 10 | from collections import deque 11 | from offlinerlkit.buffer import ReplayBuffer 12 | from offlinerlkit.utils.logger import Logger 13 | from offlinerlkit.policy import BasePolicy 14 | 15 | 16 | # model-free policy trainer 17 | class MFPolicyTrainer: 18 | def __init__( 19 | self, 20 | policy: BasePolicy, 21 | eval_env: gym.Env, 22 | buffer: ReplayBuffer, 23 | logger: Logger, 24 | epoch: int = 1000, 25 | step_per_epoch: int = 1000, 26 | batch_size: int = 256, 27 | eval_episodes: int = 10, 28 | lr_scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None 29 | ) -> None: 30 | self.policy = policy 31 | self.eval_env = eval_env 32 | self.buffer = buffer 33 | self.logger = logger 34 | 35 | self._epoch = epoch 36 | self._step_per_epoch = step_per_epoch 37 | self._batch_size = batch_size 38 | self._eval_episodes = eval_episodes 39 | self.lr_scheduler = lr_scheduler 40 | 41 | def train(self) -> Dict[str, float]: 42 | start_time = time.time() 43 | 44 | num_timesteps = 0 45 | last_10_performance = deque(maxlen=10) 46 | # train loop 47 | for e in range(1, self._epoch + 1): 48 | 49 | self.policy.train() 50 | 51 | pbar = tqdm(range(self._step_per_epoch), desc=f"Epoch #{e}/{self._epoch}") 52 | for it in pbar: 53 | batch = self.buffer.sample(self._batch_size) 54 | loss = self.policy.learn(batch) 55 | pbar.set_postfix(**loss) 56 | 57 | for k, v in loss.items(): 58 | self.logger.logkv_mean(k, v) 59 | 60 | num_timesteps += 1 61 | 62 | if self.lr_scheduler is not None: 63 | self.lr_scheduler.step() 64 | 65 | # evaluate current policy 66 | eval_info = self._evaluate() 67 | ep_reward_mean, ep_reward_std = np.mean(eval_info["eval/episode_reward"]), np.std(eval_info["eval/episode_reward"]) 68 | ep_length_mean, ep_length_std = np.mean(eval_info["eval/episode_length"]), np.std(eval_info["eval/episode_length"]) 69 | norm_ep_rew_mean = self.eval_env.get_normalized_score(ep_reward_mean) * 100 70 | norm_ep_rew_std = self.eval_env.get_normalized_score(ep_reward_std) * 100 71 | last_10_performance.append(norm_ep_rew_mean) 72 | self.logger.logkv("eval/normalized_episode_reward", norm_ep_rew_mean) 73 | self.logger.logkv("eval/normalized_episode_reward_std", norm_ep_rew_std) 74 | self.logger.logkv("eval/episode_length", ep_length_mean) 75 | self.logger.logkv("eval/episode_length_std", ep_length_std) 76 | self.logger.set_timestep(num_timesteps) 77 | self.logger.dumpkvs() 78 | 79 | # save checkpoint 80 | torch.save(self.policy.state_dict(), os.path.join(self.logger.checkpoint_dir, "policy.pth")) 81 | 82 | self.logger.log("total time: {:.2f}s".format(time.time() - start_time)) 83 | torch.save(self.policy.state_dict(), os.path.join(self.logger.model_dir, "policy.pth")) 84 | self.logger.close() 85 | 86 | return {"last_10_performance": np.mean(last_10_performance)} 87 | 88 | def _evaluate(self) -> Dict[str, List[float]]: 89 | self.policy.eval() 90 | obs = self.eval_env.reset() 91 | eval_ep_info_buffer = [] 92 | num_episodes = 0 93 | episode_reward, episode_length = 0, 0 94 | 95 | while num_episodes < self._eval_episodes: 96 | action = self.policy.select_action(obs.reshape(1,-1), deterministic=True) 97 | next_obs, reward, terminal, _ = self.eval_env.step(action.flatten()) 98 | episode_reward += reward 99 | episode_length += 1 100 | 101 | obs = next_obs 102 | 103 | if terminal: 104 | eval_ep_info_buffer.append( 105 | {"episode_reward": episode_reward, "episode_length": episode_length} 106 | ) 107 | num_episodes +=1 108 | episode_reward, episode_length = 0, 0 109 | obs = self.eval_env.reset() 110 | 111 | return { 112 | "eval/episode_reward": [ep_info["episode_reward"] for ep_info in eval_ep_info_buffer], 113 | "eval/episode_length": [ep_info["episode_length"] for ep_info in eval_ep_info_buffer] 114 | } 115 | -------------------------------------------------------------------------------- /offlinerlkit/policy/model_free/td3bc.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | 5 | from torch.nn import functional as F 6 | from typing import Dict, Union, Tuple, Callable 7 | from offlinerlkit.policy import TD3Policy 8 | from offlinerlkit.utils.noise import GaussianNoise 9 | from offlinerlkit.utils.scaler import StandardScaler 10 | 11 | 12 | class TD3BCPolicy(TD3Policy): 13 | """ 14 | TD3+BC 15 | """ 16 | 17 | def __init__( 18 | self, 19 | actor: nn.Module, 20 | critic1: nn.Module, 21 | critic2: nn.Module, 22 | actor_optim: torch.optim.Optimizer, 23 | critic1_optim: torch.optim.Optimizer, 24 | critic2_optim: torch.optim.Optimizer, 25 | tau: float = 0.005, 26 | gamma: float = 0.99, 27 | max_action: float = 1.0, 28 | exploration_noise: Callable = GaussianNoise, 29 | policy_noise: float = 0.2, 30 | noise_clip: float = 0.5, 31 | update_actor_freq: int = 2, 32 | alpha: float = 2.5, 33 | scaler: StandardScaler = None 34 | ) -> None: 35 | 36 | super().__init__( 37 | actor, 38 | critic1, 39 | critic2, 40 | actor_optim, 41 | critic1_optim, 42 | critic2_optim, 43 | tau=tau, 44 | gamma=gamma, 45 | max_action=max_action, 46 | exploration_noise=exploration_noise, 47 | policy_noise=policy_noise, 48 | noise_clip=noise_clip, 49 | update_actor_freq=update_actor_freq 50 | ) 51 | 52 | self._alpha = alpha 53 | self.scaler = scaler 54 | 55 | def train(self) -> None: 56 | self.actor.train() 57 | self.critic1.train() 58 | self.critic2.train() 59 | 60 | def eval(self) -> None: 61 | self.actor.eval() 62 | self.critic1.eval() 63 | self.critic2.eval() 64 | 65 | def _sync_weight(self) -> None: 66 | for o, n in zip(self.actor_old.parameters(), self.actor.parameters()): 67 | o.data.copy_(o.data * (1.0 - self._tau) + n.data * self._tau) 68 | for o, n in zip(self.critic1_old.parameters(), self.critic1.parameters()): 69 | o.data.copy_(o.data * (1.0 - self._tau) + n.data * self._tau) 70 | for o, n in zip(self.critic2_old.parameters(), self.critic2.parameters()): 71 | o.data.copy_(o.data * (1.0 - self._tau) + n.data * self._tau) 72 | 73 | def select_action(self, obs: np.ndarray, deterministic: bool = False) -> np.ndarray: 74 | if self.scaler is not None: 75 | obs = self.scaler.transform(obs) 76 | with torch.no_grad(): 77 | action = self.actor(obs).cpu().numpy() 78 | if not deterministic: 79 | action = action + self.exploration_noise(action.shape) 80 | action = np.clip(action, -self._max_action, self._max_action) 81 | return action 82 | 83 | def learn(self, batch: Dict) -> Dict[str, float]: 84 | obss, actions, next_obss, rewards, terminals = batch["observations"], batch["actions"], \ 85 | batch["next_observations"], batch["rewards"], batch["terminals"] 86 | 87 | # update critic 88 | q1, q2 = self.critic1(obss, actions), self.critic2(obss, actions) 89 | with torch.no_grad(): 90 | noise = (torch.randn_like(actions) * self._policy_noise).clamp(-self._noise_clip, self._noise_clip) 91 | next_actions = (self.actor_old(next_obss) + noise).clamp(-self._max_action, self._max_action) 92 | next_q = torch.min(self.critic1_old(next_obss, next_actions), self.critic2_old(next_obss, next_actions)) 93 | target_q = rewards + self._gamma * (1 - terminals) * next_q 94 | 95 | critic1_loss = ((q1 - target_q).pow(2)).mean() 96 | critic2_loss = ((q2 - target_q).pow(2)).mean() 97 | 98 | self.critic1_optim.zero_grad() 99 | critic1_loss.backward() 100 | self.critic1_optim.step() 101 | 102 | self.critic2_optim.zero_grad() 103 | critic2_loss.backward() 104 | self.critic2_optim.step() 105 | 106 | # update actor 107 | if self._cnt % self._freq == 0: 108 | a = self.actor(obss) 109 | q = self.critic1(obss, a) 110 | lmbda = self._alpha / q.abs().mean().detach() 111 | actor_loss = -lmbda * q.mean() + ((a - actions).pow(2)).mean() 112 | self.actor_optim.zero_grad() 113 | actor_loss.backward() 114 | self.actor_optim.step() 115 | self._last_actor_loss = actor_loss.item() 116 | self._sync_weight() 117 | 118 | self._cnt += 1 119 | 120 | return { 121 | "loss/actor": self._last_actor_loss, 122 | "loss/critic1": critic1_loss.item(), 123 | "loss/critic2": critic2_loss.item() 124 | } -------------------------------------------------------------------------------- /offlinerlkit/policy/model_free/td3.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | 5 | from copy import deepcopy 6 | from typing import Callable, Dict, Union, Tuple 7 | from offlinerlkit.policy import BasePolicy 8 | from offlinerlkit.utils.noise import GaussianNoise 9 | 10 | 11 | class TD3Policy(BasePolicy): 12 | """ 13 | Twin Delayed Deep Deterministic policy gradient 14 | """ 15 | 16 | def __init__( 17 | self, 18 | actor: nn.Module, 19 | critic1: nn.Module, 20 | critic2: nn.Module, 21 | actor_optim: torch.optim.Optimizer, 22 | critic1_optim: torch.optim.Optimizer, 23 | critic2_optim: torch.optim.Optimizer, 24 | tau: float = 0.005, 25 | gamma: float = 0.99, 26 | max_action: float = 1.0, 27 | exploration_noise: Callable = GaussianNoise, 28 | policy_noise: float = 0.2, 29 | noise_clip: float = 0.5, 30 | update_actor_freq: int = 2, 31 | ) -> None: 32 | super().__init__() 33 | 34 | self.actor = actor 35 | self.actor_old = deepcopy(actor) 36 | self.actor_old.eval() 37 | self.actor_optim = actor_optim 38 | 39 | self.critic1 = critic1 40 | self.critic1_old = deepcopy(critic1) 41 | self.critic1_old.eval() 42 | self.critic1_optim = critic1_optim 43 | 44 | self.critic2 = critic2 45 | self.critic2_old = deepcopy(critic2) 46 | self.critic2_old.eval() 47 | self.critic2_optim = critic2_optim 48 | 49 | self._tau = tau 50 | self._gamma = gamma 51 | 52 | self._max_action = max_action 53 | self.exploration_noise = exploration_noise 54 | self._policy_noise = policy_noise 55 | self._noise_clip = noise_clip 56 | self._freq = update_actor_freq 57 | 58 | self._cnt = 0 59 | self._last_actor_loss = 0 60 | 61 | def train(self) -> None: 62 | self.actor.train() 63 | self.critic1.train() 64 | self.critic2.train() 65 | 66 | def eval(self) -> None: 67 | self.actor.eval() 68 | self.critic1.eval() 69 | self.critic2.eval() 70 | 71 | def _sync_weight(self) -> None: 72 | for o, n in zip(self.actor_old.parameters(), self.actor.parameters()): 73 | o.data.copy_(o.data * (1.0 - self._tau) + n.data * self._tau) 74 | for o, n in zip(self.critic1_old.parameters(), self.critic1.parameters()): 75 | o.data.copy_(o.data * (1.0 - self._tau) + n.data * self._tau) 76 | for o, n in zip(self.critic2_old.parameters(), self.critic2.parameters()): 77 | o.data.copy_(o.data * (1.0 - self._tau) + n.data * self._tau) 78 | 79 | def select_action(self, obs: np.ndarray, deterministic: bool = False) -> np.ndarray: 80 | with torch.no_grad(): 81 | action = self.actor(obs).cpu().numpy() 82 | if not deterministic: 83 | action = action + self.exploration_noise(action.shape) 84 | action = np.clip(action, -self._max_action, self._max_action) 85 | return action 86 | 87 | def learn(self, batch: Dict) -> Dict[str, float]: 88 | obss, actions, next_obss, rewards, terminals = batch["observations"], batch["actions"], \ 89 | batch["next_observations"], batch["rewards"], batch["terminals"] 90 | 91 | # update critic 92 | q1, q2 = self.critic1(obss, actions), self.critic2(obss, actions) 93 | with torch.no_grad(): 94 | noise = (torch.randn_like(actions) * self._policy_noise).clamp(-self._noise_clip, self._noise_clip) 95 | next_actions = (self.actor_old(next_obss) + noise).clamp(-self._max_action, self._max_action) 96 | next_q = torch.min(self.critic1_old(next_obss, next_actions), self.critic2_old(next_obss, next_actions)) 97 | target_q = rewards + self._gamma * (1 - terminals) * next_q 98 | 99 | critic1_loss = ((q1 - target_q).pow(2)).mean() 100 | critic2_loss = ((q2 - target_q).pow(2)).mean() 101 | 102 | self.critic1_optim.zero_grad() 103 | critic1_loss.backward() 104 | self.critic1_optim.step() 105 | 106 | self.critic2_optim.zero_grad() 107 | critic2_loss.backward() 108 | self.critic2_optim.step() 109 | 110 | # update actor 111 | if self._cnt % self._freq == 0: 112 | a = self.actor(obss) 113 | q = self.critic1(obss, a) 114 | actor_loss = -q.mean() 115 | self.actor_optim.zero_grad() 116 | actor_loss.backward() 117 | self.actor_optim.step() 118 | self._last_actor_loss = actor_loss.item() 119 | self._sync_weight() 120 | 121 | self._cnt += 1 122 | 123 | return { 124 | "loss/actor": self._last_actor_loss, 125 | "loss/critic1": critic1_loss.item(), 126 | "loss/critic2": critic2_loss.item() 127 | } -------------------------------------------------------------------------------- /offlinerlkit/buffer/buffer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | from typing import Optional, Union, Tuple, Dict 5 | 6 | 7 | class ReplayBuffer: 8 | def __init__( 9 | self, 10 | buffer_size: int, 11 | obs_shape: Tuple, 12 | obs_dtype: np.dtype, 13 | action_dim: int, 14 | action_dtype: np.dtype, 15 | device: str = "cpu" 16 | ) -> None: 17 | self._max_size = buffer_size 18 | self.obs_shape = obs_shape 19 | self.obs_dtype = obs_dtype 20 | self.action_dim = action_dim 21 | self.action_dtype = action_dtype 22 | 23 | self._ptr = 0 24 | self._size = 0 25 | 26 | self.observations = np.zeros((self._max_size,) + self.obs_shape, dtype=obs_dtype) 27 | self.next_observations = np.zeros((self._max_size,) + self.obs_shape, dtype=obs_dtype) 28 | self.actions = np.zeros((self._max_size, self.action_dim), dtype=action_dtype) 29 | self.rewards = np.zeros((self._max_size, 1), dtype=np.float32) 30 | self.terminals = np.zeros((self._max_size, 1), dtype=np.float32) 31 | 32 | self.device = torch.device(device) 33 | 34 | def add( 35 | self, 36 | obs: np.ndarray, 37 | next_obs: np.ndarray, 38 | action: np.ndarray, 39 | reward: np.ndarray, 40 | terminal: np.ndarray 41 | ) -> None: 42 | # Copy to avoid modification by reference 43 | self.observations[self._ptr] = np.array(obs).copy() 44 | self.next_observations[self._ptr] = np.array(next_obs).copy() 45 | self.actions[self._ptr] = np.array(action).copy() 46 | self.rewards[self._ptr] = np.array(reward).copy() 47 | self.terminals[self._ptr] = np.array(terminal).copy() 48 | 49 | self._ptr = (self._ptr + 1) % self._max_size 50 | self._size = min(self._size + 1, self._max_size) 51 | 52 | def add_batch( 53 | self, 54 | obss: np.ndarray, 55 | next_obss: np.ndarray, 56 | actions: np.ndarray, 57 | rewards: np.ndarray, 58 | terminals: np.ndarray 59 | ) -> None: 60 | batch_size = len(obss) 61 | indexes = np.arange(self._ptr, self._ptr + batch_size) % self._max_size 62 | 63 | self.observations[indexes] = np.array(obss).copy() 64 | self.next_observations[indexes] = np.array(next_obss).copy() 65 | self.actions[indexes] = np.array(actions).copy() 66 | self.rewards[indexes] = np.array(rewards).copy() 67 | self.terminals[indexes] = np.array(terminals).copy() 68 | 69 | self._ptr = (self._ptr + batch_size) % self._max_size 70 | self._size = min(self._size + batch_size, self._max_size) 71 | 72 | def load_dataset(self, dataset: Dict[str, np.ndarray]) -> None: 73 | observations = np.array(dataset["observations"], dtype=self.obs_dtype) 74 | next_observations = np.array(dataset["next_observations"], dtype=self.obs_dtype) 75 | actions = np.array(dataset["actions"], dtype=self.action_dtype) 76 | rewards = np.array(dataset["rewards"], dtype=np.float32).reshape(-1, 1) 77 | terminals = np.array(dataset["terminals"], dtype=np.float32).reshape(-1, 1) 78 | 79 | self.observations = observations 80 | self.next_observations = next_observations 81 | self.actions = actions 82 | self.rewards = rewards 83 | self.terminals = terminals 84 | 85 | self._ptr = len(observations) 86 | self._size = len(observations) 87 | 88 | def normalize_obs(self, eps: float = 1e-3) -> Tuple[np.ndarray, np.ndarray]: 89 | mean = self.observations.mean(0, keepdims=True) 90 | std = self.observations.std(0, keepdims=True) + eps 91 | self.observations = (self.observations - mean) / std 92 | self.next_observations = (self.next_observations - mean) / std 93 | obs_mean, obs_std = mean, std 94 | return obs_mean, obs_std 95 | 96 | def sample(self, batch_size: int) -> Dict[str, torch.Tensor]: 97 | 98 | batch_indexes = np.random.randint(0, self._size, size=batch_size) 99 | 100 | return { 101 | "observations": torch.tensor(self.observations[batch_indexes]).to(self.device), 102 | "actions": torch.tensor(self.actions[batch_indexes]).to(self.device), 103 | "next_observations": torch.tensor(self.next_observations[batch_indexes]).to(self.device), 104 | "terminals": torch.tensor(self.terminals[batch_indexes]).to(self.device), 105 | "rewards": torch.tensor(self.rewards[batch_indexes]).to(self.device) 106 | } 107 | 108 | def sample_all(self) -> Dict[str, np.ndarray]: 109 | return { 110 | "observations": self.observations[:self._size].copy(), 111 | "actions": self.actions[:self._size].copy(), 112 | "next_observations": self.next_observations[:self._size].copy(), 113 | "terminals": self.terminals[:self._size].copy(), 114 | "rewards": self.rewards[:self._size].copy() 115 | } -------------------------------------------------------------------------------- /offlinerlkit/utils/plotter.py: -------------------------------------------------------------------------------- 1 | import csv 2 | import os 3 | 4 | import numpy as np 5 | import pandas as pd 6 | 7 | import matplotlib.pyplot as plt 8 | import argparse 9 | 10 | 11 | COLORS = ( 12 | [ 13 | '#318DE9', # blue 14 | '#FF7D00', # orange 15 | '#E52B50', # red 16 | '#7B68EE', # purple 17 | '#00CD66', # green 18 | '#FFD700', # yellow 19 | ] 20 | ) 21 | 22 | 23 | def merge_csv(root_dir, query_file, query_x, query_y): 24 | """Merge result in csv_files into a single csv file.""" 25 | csv_files = [] 26 | for dirname, _, files in os.walk(root_dir): 27 | for f in files: 28 | if f == query_file: 29 | csv_files.append(os.path.join(dirname, f)) 30 | results = {} 31 | for csv_file in csv_files: 32 | content = [[query_x, query_y]] 33 | df = pd.read_csv(csv_file) 34 | values = df[[query_x, query_y]].values 35 | for line in values: 36 | if np.isnan(line[1]): continue 37 | content.append(line) 38 | results[csv_file] = content 39 | assert len(results) > 0 40 | sorted_keys = sorted(results.keys()) 41 | sorted_values = [results[k][1:] for k in sorted_keys] 42 | content = [ 43 | [query_x, query_y+'_mean', query_y+'_std'] 44 | ] 45 | for rows in zip(*sorted_values): 46 | array = np.array(rows) 47 | assert len(set(array[:, 0])) == 1, (set(array[:, 0]), array[:, 0]) 48 | line = [rows[0][0], round(array[:, 1].mean(), 4), round(array[:, 1].std(), 4)] 49 | content.append(line) 50 | output_path = os.path.join(root_dir, query_y.replace('/', '_')+".csv") 51 | print(f"Output merged csv file to {output_path} with {len(content[1:])} lines.") 52 | csv.writer(open(output_path, "w")).writerows(content) 53 | return output_path 54 | 55 | 56 | def csv2numpy(file_path): 57 | df = pd.read_csv(file_path) 58 | step = df.iloc[:,0].to_numpy() 59 | mean = df.iloc[:,1].to_numpy() 60 | std = df.iloc[:,2].to_numpy() 61 | return step, mean, std 62 | 63 | 64 | def smooth(y, radius=0): 65 | convkernel = np.ones(2 * radius + 1) 66 | out = np.convolve(y, convkernel, mode='same') / np.convolve(np.ones_like(y), convkernel, mode='same') 67 | return out 68 | 69 | 70 | def plot_figure( 71 | results, 72 | x_label, 73 | y_label, 74 | title=None, 75 | smooth_radius=10, 76 | figsize=None, 77 | dpi=None, 78 | color_list=None 79 | ): 80 | fig, ax = plt.subplots(figsize=figsize, dpi=dpi) 81 | if color_list == None: 82 | color_list = [COLORS[i] for i in range(len(results))] 83 | else: 84 | assert len(color_list) == len(results) 85 | for i, (algo_name, csv_file) in enumerate(results.items()): 86 | x, y, shaded = csv2numpy(csv_file) 87 | y = smooth(y, smooth_radius) 88 | shaded = smooth(shaded, smooth_radius) 89 | ax.plot(x, y, color=color_list[i], label=algo_name) 90 | ax.fill_between(x, y-shaded, y+shaded, color=color_list[i], alpha=0.2) 91 | ax.set_title(title) 92 | ax.set_xlabel(x_label) 93 | ax.set_ylabel(y_label) 94 | ax.legend() 95 | 96 | 97 | if __name__ == "__main__": 98 | parser = argparse.ArgumentParser(description="plotter") 99 | parser.add_argument("--root-dir", default="log") 100 | parser.add_argument("--task", default="walker2d-medium-v2") 101 | parser.add_argument("--algos", type=str, nargs='*', default=["mopo&penalty_coef=0.5&rollout_length=5"]) 102 | parser.add_argument("--query-file", default="policy_training_progress.csv") 103 | parser.add_argument("--query-x", default="timestep") 104 | parser.add_argument("--query-y", default="eval/normalized_episode_reward") 105 | parser.add_argument("--title", default=None) 106 | parser.add_argument("--xlabel", default="timestep") 107 | parser.add_argument("--ylabel", default="normalized_episode_reward") 108 | parser.add_argument("--smooth", type=int, default=10) 109 | parser.add_argument("--colors", type=str, nargs='*', default=None) 110 | parser.add_argument("--show", action='store_true') 111 | parser.add_argument("--output-path", default="./figure.png") 112 | parser.add_argument("--figsize", type=float, nargs=2, default=(5, 5)) 113 | parser.add_argument("--dpi", type=int, default=200) 114 | args = parser.parse_args() 115 | 116 | results = {} 117 | for algo in args.algos: 118 | path = os.path.join(args.root_dir, args.task, algo) 119 | csv_file = merge_csv(path, args.query_file, args.query_x, args.query_y) 120 | results[algo] = csv_file 121 | 122 | plt.style.use('seaborn') 123 | plot_figure( 124 | results=results, 125 | x_label=args.xlabel, 126 | y_label=args.ylabel, 127 | title=args.title, 128 | smooth_radius=args.smooth, 129 | figsize=args.figsize, 130 | dpi=args.dpi, 131 | color_list=args.colors 132 | ) 133 | if args.output_path: 134 | plt.savefig(args.output_path, bbox_inches='tight') 135 | if args.show: 136 | plt.show() -------------------------------------------------------------------------------- /offlinerlkit/utils/termination_fns.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | def obs_unnormalization(termination_fn, obs_mean, obs_std): 4 | def thunk(obs, act, next_obs): 5 | obs = obs*obs_std + obs_mean 6 | next_obs = next_obs*obs_std + obs_mean 7 | return termination_fn(obs, act, next_obs) 8 | return thunk 9 | 10 | def termination_fn_halfcheetah(obs, act, next_obs): 11 | assert len(obs.shape) == len(next_obs.shape) == len(act.shape) == 2 12 | 13 | not_done = np.logical_and(np.all(next_obs > -100, axis=-1), np.all(next_obs < 100, axis=-1)) 14 | done = ~not_done 15 | done = done[:, None] 16 | return done 17 | 18 | def termination_fn_hopper(obs, act, next_obs): 19 | assert len(obs.shape) == len(next_obs.shape) == len(act.shape) == 2 20 | 21 | height = next_obs[:, 0] 22 | angle = next_obs[:, 1] 23 | not_done = np.isfinite(next_obs).all(axis=-1) \ 24 | * np.abs(next_obs[:,1:] < 100).all(axis=-1) \ 25 | * (height > .7) \ 26 | * (np.abs(angle) < .2) 27 | 28 | done = ~not_done 29 | done = done[:,None] 30 | return done 31 | 32 | def termination_fn_halfcheetahveljump(obs, act, next_obs): 33 | assert len(obs.shape) == len(next_obs.shape) == len(act.shape) == 2 34 | 35 | done = np.array([False]).repeat(len(obs)) 36 | done = done[:,None] 37 | return done 38 | 39 | def termination_fn_antangle(obs, act, next_obs): 40 | assert len(obs.shape) == len(next_obs.shape) == len(act.shape) == 2 41 | 42 | x = next_obs[:, 0] 43 | not_done = np.isfinite(next_obs).all(axis=-1) \ 44 | * (x >= 0.2) \ 45 | * (x <= 1.0) 46 | 47 | done = ~not_done 48 | done = done[:,None] 49 | return done 50 | 51 | def termination_fn_ant(obs, act, next_obs): 52 | assert len(obs.shape) == len(next_obs.shape) == len(act.shape) == 2 53 | 54 | x = next_obs[:, 0] 55 | not_done = np.isfinite(next_obs).all(axis=-1) \ 56 | * (x >= 0.2) \ 57 | * (x <= 1.0) 58 | 59 | done = ~not_done 60 | done = done[:,None] 61 | return done 62 | 63 | def termination_fn_walker2d(obs, act, next_obs): 64 | assert len(obs.shape) == len(next_obs.shape) == len(act.shape) == 2 65 | 66 | height = next_obs[:, 0] 67 | angle = next_obs[:, 1] 68 | not_done = np.logical_and(np.all(next_obs > -100, axis=-1), np.all(next_obs < 100, axis=-1)) \ 69 | * (height > 0.8) \ 70 | * (height < 2.0) \ 71 | * (angle > -1.0) \ 72 | * (angle < 1.0) 73 | done = ~not_done 74 | done = done[:,None] 75 | return done 76 | 77 | def termination_fn_point2denv(obs, act, next_obs): 78 | assert len(obs.shape) == len(next_obs.shape) == len(act.shape) == 2 79 | 80 | done = np.array([False]).repeat(len(obs)) 81 | done = done[:,None] 82 | return done 83 | 84 | def termination_fn_point2dwallenv(obs, act, next_obs): 85 | assert len(obs.shape) == len(next_obs.shape) == len(act.shape) == 2 86 | 87 | done = np.array([False]).repeat(len(obs)) 88 | done = done[:,None] 89 | return done 90 | 91 | def termination_fn_pendulum(obs, act, next_obs): 92 | assert len(obs.shape) == len(next_obs.shape) == len(act.shape) == 2 93 | 94 | done = np.zeros((len(obs), 1)) 95 | return done 96 | 97 | def termination_fn_humanoid(obs, act, next_obs): 98 | assert len(obs.shape) == len(next_obs.shape) == len(act.shape) == 2 99 | 100 | z = next_obs[:,0] 101 | done = (z < 1.0) + (z > 2.0) 102 | 103 | done = done[:,None] 104 | return done 105 | 106 | def termination_fn_pen(obs, act, next_obs): 107 | assert len(obs.shape) == len(next_obs.shape) == len(act.shape) == 2 108 | 109 | obj_pos = next_obs[:, 24:27] 110 | done = obj_pos[:, 2] < 0.075 111 | 112 | done = done[:,None] 113 | return done 114 | 115 | def terminaltion_fn_door(obs, act, next_obs): 116 | assert len(obs.shape) == len(next_obs.shape) == len(act.shape) == 2 117 | 118 | done = np.array([False] * obs.shape[0]) 119 | 120 | done = done[:, None] 121 | return done 122 | 123 | def get_termination_fn(task): 124 | if 'halfcheetahvel' in task: 125 | return termination_fn_halfcheetahveljump 126 | elif 'halfcheetah' in task: 127 | return termination_fn_halfcheetah 128 | elif 'hopper' in task: 129 | return termination_fn_hopper 130 | elif 'antangle' in task: 131 | return termination_fn_antangle 132 | elif 'ant' in task: 133 | return termination_fn_ant 134 | elif 'walker2d' in task: 135 | return termination_fn_walker2d 136 | elif 'point2denv' in task: 137 | return termination_fn_point2denv 138 | elif 'point2dwallenv' in task: 139 | return termination_fn_point2dwallenv 140 | elif 'pendulum' in task: 141 | return termination_fn_pendulum 142 | elif 'humanoid' in task: 143 | return termination_fn_humanoid 144 | elif 'pen' in task: 145 | return termination_fn_pen 146 | elif 'door' in task: 147 | return terminaltion_fn_door 148 | else: 149 | raise np.zeros 150 | -------------------------------------------------------------------------------- /offlinerlkit/policy/model_free/sac.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | 5 | from copy import deepcopy 6 | from typing import Dict, Union, Tuple 7 | from offlinerlkit.policy import BasePolicy 8 | 9 | 10 | class SACPolicy(BasePolicy): 11 | """ 12 | Soft Actor Critic 13 | """ 14 | 15 | def __init__( 16 | self, 17 | actor: nn.Module, 18 | critic1: nn.Module, 19 | critic2: nn.Module, 20 | actor_optim: torch.optim.Optimizer, 21 | critic1_optim: torch.optim.Optimizer, 22 | critic2_optim: torch.optim.Optimizer, 23 | tau: float = 0.005, 24 | gamma: float = 0.99, 25 | alpha: Union[float, Tuple[float, torch.Tensor, torch.optim.Optimizer]] = 0.2 26 | ) -> None: 27 | super().__init__() 28 | 29 | self.actor = actor 30 | self.critic1, self.critic1_old = critic1, deepcopy(critic1) 31 | self.critic1_old.eval() 32 | self.critic2, self.critic2_old = critic2, deepcopy(critic2) 33 | self.critic2_old.eval() 34 | 35 | self.actor_optim = actor_optim 36 | self.critic1_optim = critic1_optim 37 | self.critic2_optim = critic2_optim 38 | 39 | self._tau = tau 40 | self._gamma = gamma 41 | 42 | self._is_auto_alpha = False 43 | if isinstance(alpha, tuple): 44 | self._is_auto_alpha = True 45 | self._target_entropy, self._log_alpha, self.alpha_optim = alpha 46 | self._alpha = self._log_alpha.detach().exp() 47 | else: 48 | self._alpha = alpha 49 | 50 | def train(self) -> None: 51 | self.actor.train() 52 | self.critic1.train() 53 | self.critic2.train() 54 | 55 | def eval(self) -> None: 56 | self.actor.eval() 57 | self.critic1.eval() 58 | self.critic2.eval() 59 | 60 | def _sync_weight(self) -> None: 61 | for o, n in zip(self.critic1_old.parameters(), self.critic1.parameters()): 62 | o.data.copy_(o.data * (1.0 - self._tau) + n.data * self._tau) 63 | for o, n in zip(self.critic2_old.parameters(), self.critic2.parameters()): 64 | o.data.copy_(o.data * (1.0 - self._tau) + n.data * self._tau) 65 | 66 | def actforward( 67 | self, 68 | obs: torch.Tensor, 69 | deterministic: bool = False 70 | ) -> Tuple[torch.Tensor, torch.Tensor]: 71 | dist = self.actor(obs) 72 | if deterministic: 73 | squashed_action, raw_action = dist.mode() 74 | else: 75 | squashed_action, raw_action = dist.rsample() 76 | log_prob = dist.log_prob(squashed_action, raw_action) 77 | return squashed_action, log_prob 78 | 79 | def select_action( 80 | self, 81 | obs: np.ndarray, 82 | deterministic: bool = False 83 | ) -> np.ndarray: 84 | with torch.no_grad(): 85 | action, _ = self.actforward(obs, deterministic) 86 | return action.cpu().numpy() 87 | 88 | def learn(self, batch: Dict) -> Dict[str, float]: 89 | obss, actions, next_obss, rewards, terminals = batch["observations"], batch["actions"], \ 90 | batch["next_observations"], batch["rewards"], batch["terminals"] 91 | 92 | # update critic 93 | q1, q2 = self.critic1(obss, actions), self.critic2(obss, actions) 94 | with torch.no_grad(): 95 | next_actions, next_log_probs = self.actforward(next_obss) 96 | next_q = torch.min( 97 | self.critic1_old(next_obss, next_actions), self.critic2_old(next_obss, next_actions) 98 | ) - self._alpha * next_log_probs 99 | target_q = rewards + self._gamma * (1 - terminals) * next_q 100 | 101 | critic1_loss = ((q1 - target_q).pow(2)).mean() 102 | self.critic1_optim.zero_grad() 103 | critic1_loss.backward() 104 | self.critic1_optim.step() 105 | 106 | critic2_loss = ((q2 - target_q).pow(2)).mean() 107 | self.critic2_optim.zero_grad() 108 | critic2_loss.backward() 109 | self.critic2_optim.step() 110 | 111 | # update actor 112 | a, log_probs = self.actforward(obss) 113 | q1a, q2a = self.critic1(obss, a), self.critic2(obss, a) 114 | 115 | actor_loss = - torch.min(q1a, q2a).mean() + self._alpha * log_probs.mean() 116 | self.actor_optim.zero_grad() 117 | actor_loss.backward() 118 | self.actor_optim.step() 119 | 120 | if self._is_auto_alpha: 121 | log_probs = log_probs.detach() + self._target_entropy 122 | alpha_loss = -(self._log_alpha * log_probs).mean() 123 | self.alpha_optim.zero_grad() 124 | alpha_loss.backward() 125 | self.alpha_optim.step() 126 | self._alpha = torch.clamp(self._log_alpha.detach().exp(), 0.0, 1.0) 127 | 128 | self._sync_weight() 129 | 130 | result = { 131 | "loss/actor": actor_loss.item(), 132 | "loss/critic1": critic1_loss.item(), 133 | "loss/critic2": critic2_loss.item(), 134 | } 135 | 136 | if self._is_auto_alpha: 137 | result["loss/alpha"] = alpha_loss.item() 138 | result["alpha"] = self._alpha.item() 139 | 140 | return result 141 | 142 | -------------------------------------------------------------------------------- /run_example/run_td3bc.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import random 3 | 4 | import gym 5 | import d4rl 6 | 7 | import numpy as np 8 | import torch 9 | 10 | 11 | from offlinerlkit.nets import MLP 12 | from offlinerlkit.modules import Actor, Critic 13 | from offlinerlkit.utils.noise import GaussianNoise 14 | from offlinerlkit.utils.scaler import StandardScaler 15 | from offlinerlkit.buffer import ReplayBuffer 16 | from offlinerlkit.utils.logger import Logger, make_log_dirs 17 | from offlinerlkit.policy_trainer import MFPolicyTrainer 18 | from offlinerlkit.policy import TD3BCPolicy 19 | 20 | 21 | """ 22 | suggested hypers 23 | alpha=2.5 for all D4RL-Gym tasks 24 | """ 25 | 26 | 27 | def get_args(): 28 | parser = argparse.ArgumentParser() 29 | parser.add_argument("--algo-name", type=str, default="td3bc") 30 | parser.add_argument("--task", type=str, default="hopper-medium-v2") 31 | parser.add_argument("--seed", type=int, default=0) 32 | parser.add_argument("--actor-lr", type=float, default=3e-4) 33 | parser.add_argument("--critic-lr", type=float, default=3e-4) 34 | parser.add_argument("--gamma", type=float, default=0.99) 35 | parser.add_argument("--tau", type=float, default=0.005) 36 | parser.add_argument("--exploration-noise", type=float, default=0.1) 37 | parser.add_argument("--policy-noise", type=float, default=0.2) 38 | parser.add_argument("--noise-clip", type=float, default=0.5) 39 | parser.add_argument("--update-actor-freq", type=int, default=2) 40 | parser.add_argument("--alpha", type=float, default=2.5) 41 | parser.add_argument("--epoch", type=int, default=1000) 42 | parser.add_argument("--step-per-epoch", type=int, default=1000) 43 | parser.add_argument("--eval_episodes", type=int, default=10) 44 | parser.add_argument("--batch-size", type=int, default=256) 45 | parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu") 46 | 47 | return parser.parse_args() 48 | 49 | 50 | def train(args=get_args()): 51 | # create env and dataset 52 | env = gym.make(args.task) 53 | dataset = d4rl.qlearning_dataset(env) 54 | if 'antmaze' in args.task: 55 | dataset["rewards"] -= 1.0 56 | args.obs_shape = env.observation_space.shape 57 | args.action_dim = np.prod(env.action_space.shape) 58 | args.max_action = env.action_space.high[0] 59 | 60 | # create buffer 61 | buffer = ReplayBuffer( 62 | buffer_size=len(dataset["observations"]), 63 | obs_shape=args.obs_shape, 64 | obs_dtype=np.float32, 65 | action_dim=args.action_dim, 66 | action_dtype=np.float32, 67 | device=args.device 68 | ) 69 | buffer.load_dataset(dataset) 70 | obs_mean, obs_std = buffer.normalize_obs() 71 | 72 | # seed 73 | random.seed(args.seed) 74 | np.random.seed(args.seed) 75 | torch.manual_seed(args.seed) 76 | torch.cuda.manual_seed_all(args.seed) 77 | torch.backends.cudnn.deterministic = True 78 | env.seed(args.seed) 79 | 80 | # create policy model 81 | actor_backbone = MLP(input_dim=np.prod(args.obs_shape), hidden_dims=[256, 256]) 82 | critic1_backbone = MLP(input_dim=np.prod(args.obs_shape)+args.action_dim, hidden_dims=[256, 256]) 83 | critic2_backbone = MLP(input_dim=np.prod(args.obs_shape)+args.action_dim, hidden_dims=[256, 256]) 84 | actor = Actor(actor_backbone, args.action_dim, max_action=args.max_action, device=args.device) 85 | 86 | critic1 = Critic(critic1_backbone, args.device) 87 | critic2 = Critic(critic2_backbone, args.device) 88 | actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr) 89 | critic1_optim = torch.optim.Adam(critic1.parameters(), lr=args.critic_lr) 90 | critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr) 91 | 92 | # scaler for normalizing observations 93 | scaler = StandardScaler(mu=obs_mean, std=obs_std) 94 | 95 | # create policy 96 | policy = TD3BCPolicy( 97 | actor, 98 | critic1, 99 | critic2, 100 | actor_optim, 101 | critic1_optim, 102 | critic2_optim, 103 | tau=args.tau, 104 | gamma=args.gamma, 105 | max_action=args.max_action, 106 | exploration_noise=GaussianNoise(sigma=args.exploration_noise), 107 | policy_noise=args.policy_noise, 108 | noise_clip=args.noise_clip, 109 | update_actor_freq=args.update_actor_freq, 110 | alpha=args.alpha, 111 | scaler=scaler 112 | ) 113 | 114 | # log 115 | log_dirs = make_log_dirs(args.task, args.algo_name, args.seed, vars(args)) 116 | # key: output file name, value: output handler type 117 | output_config = { 118 | "consoleout_backup": "stdout", 119 | "policy_training_progress": "csv", 120 | "tb": "tensorboard" 121 | } 122 | logger = Logger(log_dirs, output_config) 123 | logger.log_hyperparameters(vars(args)) 124 | 125 | # create policy trainer 126 | policy_trainer = MFPolicyTrainer( 127 | policy=policy, 128 | eval_env=env, 129 | buffer=buffer, 130 | logger=logger, 131 | epoch=args.epoch, 132 | step_per_epoch=args.step_per_epoch, 133 | batch_size=args.batch_size, 134 | eval_episodes=args.eval_episodes 135 | ) 136 | 137 | # train 138 | policy_trainer.train() 139 | 140 | 141 | if __name__ == "__main__": 142 | train() -------------------------------------------------------------------------------- /offlinerlkit/policy/model_free/iql.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import gym 5 | 6 | from copy import deepcopy 7 | from typing import Dict, Union, Tuple 8 | from offlinerlkit.policy import BasePolicy 9 | 10 | 11 | class IQLPolicy(BasePolicy): 12 | """ 13 | Implicit Q-Learning 14 | """ 15 | 16 | def __init__( 17 | self, 18 | actor: nn.Module, 19 | critic_q1: nn.Module, 20 | critic_q2: nn.Module, 21 | critic_v: nn.Module, 22 | actor_optim: torch.optim.Optimizer, 23 | critic_q1_optim: torch.optim.Optimizer, 24 | critic_q2_optim: torch.optim.Optimizer, 25 | critic_v_optim: torch.optim.Optimizer, 26 | action_space: gym.spaces.Space, 27 | tau: float = 0.005, 28 | gamma: float = 0.99, 29 | expectile: float = 0.8, 30 | temperature: float = 0.1 31 | ) -> None: 32 | super().__init__() 33 | 34 | self.actor = actor 35 | self.critic_q1, self.critic_q1_old = critic_q1, deepcopy(critic_q1) 36 | self.critic_q1_old.eval() 37 | self.critic_q2, self.critic_q2_old = critic_q2, deepcopy(critic_q2) 38 | self.critic_q2_old.eval() 39 | self.critic_v = critic_v 40 | 41 | self.actor_optim = actor_optim 42 | self.critic_q1_optim = critic_q1_optim 43 | self.critic_q2_optim = critic_q2_optim 44 | self.critic_v_optim = critic_v_optim 45 | 46 | self.action_space = action_space 47 | self._tau = tau 48 | self._gamma = gamma 49 | self._expectile = expectile 50 | self._temperature = temperature 51 | 52 | def train(self) -> None: 53 | self.actor.train() 54 | self.critic_q1.train() 55 | self.critic_q2.train() 56 | self.critic_v.train() 57 | 58 | def eval(self) -> None: 59 | self.actor.eval() 60 | self.critic_q1.eval() 61 | self.critic_q2.eval() 62 | self.critic_v.eval() 63 | 64 | def _sync_weight(self) -> None: 65 | for o, n in zip(self.critic_q1_old.parameters(), self.critic_q1.parameters()): 66 | o.data.copy_(o.data * (1.0 - self._tau) + n.data * self._tau) 67 | for o, n in zip(self.critic_q2_old.parameters(), self.critic_q2.parameters()): 68 | o.data.copy_(o.data * (1.0 - self._tau) + n.data * self._tau) 69 | 70 | def select_action(self, obs: np.ndarray, deterministic: bool = False) -> np.ndarray: 71 | if len(obs.shape) == 1: 72 | obs = obs.reshape(1, -1) 73 | with torch.no_grad(): 74 | dist = self.actor(obs) 75 | if deterministic: 76 | action = dist.mode().cpu().numpy() 77 | else: 78 | action = dist.sample().cpu().numpy() 79 | action = np.clip(action, self.action_space.low[0], self.action_space.high[0]) 80 | return action 81 | 82 | def _expectile_regression(self, diff: torch.Tensor) -> torch.Tensor: 83 | weight = torch.where(diff > 0, self._expectile, (1 - self._expectile)) 84 | return weight * (diff**2) 85 | 86 | def learn(self, batch: Dict) -> Dict[str, float]: 87 | obss, actions, next_obss, rewards, terminals = batch["observations"], batch["actions"], \ 88 | batch["next_observations"], batch["rewards"], batch["terminals"] 89 | 90 | # update value net 91 | with torch.no_grad(): 92 | q1, q2 = self.critic_q1_old(obss, actions), self.critic_q2_old(obss, actions) 93 | q = torch.min(q1, q2) 94 | v = self.critic_v(obss) 95 | critic_v_loss = self._expectile_regression(q-v).mean() 96 | self.critic_v_optim.zero_grad() 97 | critic_v_loss.backward() 98 | self.critic_v_optim.step() 99 | 100 | # update critic 101 | q1, q2 = self.critic_q1(obss, actions), self.critic_q2(obss, actions) 102 | with torch.no_grad(): 103 | next_v = self.critic_v(next_obss) 104 | target_q = rewards + self._gamma * (1 - terminals) * next_v 105 | 106 | critic_q1_loss = ((q1 - target_q).pow(2)).mean() 107 | critic_q2_loss = ((q2 - target_q).pow(2)).mean() 108 | 109 | self.critic_q1_optim.zero_grad() 110 | critic_q1_loss.backward() 111 | self.critic_q1_optim.step() 112 | 113 | self.critic_q2_optim.zero_grad() 114 | critic_q2_loss.backward() 115 | self.critic_q2_optim.step() 116 | 117 | # update actor 118 | with torch.no_grad(): 119 | q1, q2 = self.critic_q1_old(obss, actions), self.critic_q2_old(obss, actions) 120 | q = torch.min(q1, q2) 121 | v = self.critic_v(obss) 122 | exp_a = torch.exp((q - v) * self._temperature) 123 | exp_a = torch.clip(exp_a, None, 100.0) 124 | dist = self.actor(obss) 125 | log_probs = dist.log_prob(actions) 126 | actor_loss = -(exp_a * log_probs).mean() 127 | 128 | self.actor_optim.zero_grad() 129 | actor_loss.backward() 130 | self.actor_optim.step() 131 | 132 | self._sync_weight() 133 | 134 | return { 135 | "loss/actor": actor_loss.item(), 136 | "loss/q1": critic_q1_loss.item(), 137 | "loss/q2": critic_q2_loss.item(), 138 | "loss/v": critic_v_loss.item() 139 | } -------------------------------------------------------------------------------- /offlinerlkit/policy/model_free/mcq.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | 5 | from torch.nn import functional as F 6 | from typing import Dict, Union, Tuple 7 | from offlinerlkit.policy import SACPolicy 8 | 9 | 10 | class MCQPolicy(SACPolicy): 11 | """ 12 | Mildly Conservative Q-Learning 13 | """ 14 | 15 | def __init__( 16 | self, 17 | actor: nn.Module, 18 | critic1: nn.Module, 19 | critic2: nn.Module, 20 | behavior_policy: nn.Module, 21 | actor_optim: torch.optim.Optimizer, 22 | critic1_optim: torch.optim.Optimizer, 23 | critic2_optim: torch.optim.Optimizer, 24 | behavior_policy_optim: torch.optim.Optimizer, 25 | tau: float = 0.005, 26 | gamma: float = 0.99, 27 | alpha: Union[float, Tuple[float, torch.Tensor, torch.optim.Optimizer]] = 0.2, 28 | lmbda: float = 0.7, 29 | num_sampled_actions: int = 10 30 | ) -> None: 31 | super().__init__( 32 | actor, 33 | critic1, 34 | critic2, 35 | actor_optim, 36 | critic1_optim, 37 | critic2_optim, 38 | tau=tau, 39 | gamma=gamma, 40 | alpha=alpha 41 | ) 42 | 43 | self.behavior_policy = behavior_policy 44 | self.behavior_policy_optim = behavior_policy_optim 45 | self._lmbda = lmbda 46 | self._num_sampled_actions = num_sampled_actions 47 | 48 | def learn(self, batch: Dict) -> Dict[str, float]: 49 | obss, actions, next_obss, rewards, terminals = batch["observations"], batch["actions"], \ 50 | batch["next_observations"], batch["rewards"], batch["terminals"] 51 | 52 | # update behavior policy 53 | recon, mean, std = self.behavior_policy(obss, actions) 54 | recon_loss = F.mse_loss(recon, actions) 55 | KL_loss = -0.5 * (1 + torch.log(std.pow(2)) - mean.pow(2) - std.pow(2)).mean() 56 | vae_loss = recon_loss + KL_loss 57 | 58 | self.behavior_policy_optim.zero_grad() 59 | vae_loss.backward() 60 | self.behavior_policy_optim.step() 61 | 62 | # update critic 63 | with torch.no_grad(): 64 | next_actions, next_log_probs = self.actforward(next_obss) 65 | next_q = torch.min( 66 | self.critic1_old(next_obss, next_actions), self.critic2_old(next_obss, next_actions) 67 | ) - self._alpha * next_log_probs 68 | target_q_for_in_actions = rewards + self._gamma * (1 - terminals) * next_q 69 | q1_in, q2_in = self.critic1(obss, actions), self.critic2(obss, actions) 70 | critic1_loss_for_in_actions = ((q1_in - target_q_for_in_actions).pow(2)).mean() 71 | critic2_loss_for_in_actions = ((q2_in - target_q_for_in_actions).pow(2)).mean() 72 | 73 | s_in = torch.cat([obss, next_obss], dim=0) 74 | with torch.no_grad(): 75 | s_in_repeat = torch.repeat_interleave(s_in, self._num_sampled_actions, 0) 76 | sampled_actions = self.behavior_policy.decode(s_in_repeat) 77 | target_q1_for_ood_actions = self.critic1_old(s_in_repeat, sampled_actions).reshape(s_in.shape[0], -1).max(1)[0].reshape(-1, 1) 78 | target_q2_for_ood_actions = self.critic2_old(s_in_repeat, sampled_actions).reshape(s_in.shape[0], -1).max(1)[0].reshape(-1, 1) 79 | target_q_for_ood_actions = torch.min(target_q1_for_ood_actions, target_q2_for_ood_actions) 80 | ood_actions, _ = self.actforward(s_in) 81 | 82 | q1_ood, q2_ood = self.critic1(s_in, ood_actions), self.critic2(s_in, ood_actions) 83 | critic1_loss_for_ood_actions = ((q1_ood - target_q_for_ood_actions).pow(2)).mean() 84 | critic2_loss_for_ood_actions = ((q2_ood - target_q_for_ood_actions).pow(2)).mean() 85 | 86 | critic1_loss = self._lmbda * critic1_loss_for_in_actions + (1 - self._lmbda) * critic1_loss_for_ood_actions 87 | self.critic1_optim.zero_grad() 88 | critic1_loss.backward() 89 | self.critic1_optim.step() 90 | 91 | critic2_loss = self._lmbda * critic2_loss_for_in_actions + (1 - self._lmbda) * critic2_loss_for_ood_actions 92 | self.critic2_optim.zero_grad() 93 | critic2_loss.backward() 94 | self.critic2_optim.step() 95 | 96 | # update actor 97 | a, log_probs = self.actforward(obss) 98 | q1a, q2a = self.critic1(obss, a), self.critic2(obss, a) 99 | 100 | actor_loss = - torch.min(q1a, q2a).mean() + self._alpha * log_probs.mean() 101 | self.actor_optim.zero_grad() 102 | actor_loss.backward() 103 | self.actor_optim.step() 104 | 105 | if self._is_auto_alpha: 106 | log_probs = log_probs.detach() + self._target_entropy 107 | alpha_loss = -(self._log_alpha * log_probs).mean() 108 | self.alpha_optim.zero_grad() 109 | alpha_loss.backward() 110 | self.alpha_optim.step() 111 | self._alpha = torch.clamp(self._log_alpha.detach().exp(), 0.0, 1.0) 112 | 113 | self._sync_weight() 114 | 115 | result = { 116 | "loss/actor": actor_loss.item(), 117 | "loss/critic1": critic1_loss.item(), 118 | "loss/critic2": critic2_loss.item(), 119 | "loss/behavior_policy": vae_loss.item() 120 | } 121 | 122 | if self._is_auto_alpha: 123 | result["loss/alpha"] = alpha_loss.item() 124 | result["alpha"] = self._alpha.item() 125 | 126 | return result 127 | 128 | -------------------------------------------------------------------------------- /run_example/run_mcq.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import random 3 | 4 | import gym 5 | import d4rl 6 | 7 | import numpy as np 8 | import torch 9 | 10 | 11 | from offlinerlkit.nets import MLP, VAE 12 | from offlinerlkit.modules import ActorProb, Critic, TanhDiagGaussian 13 | from offlinerlkit.buffer import ReplayBuffer 14 | from offlinerlkit.utils.logger import Logger, make_log_dirs 15 | from offlinerlkit.policy_trainer import MFPolicyTrainer 16 | from offlinerlkit.policy import MCQPolicy 17 | 18 | 19 | def get_args(): 20 | parser = argparse.ArgumentParser() 21 | parser.add_argument("--algo-name", type=str, default="mcq") 22 | parser.add_argument("--task", type=str, default="hopper-medium-replay-v2") 23 | parser.add_argument("--seed", type=int, default=0) 24 | parser.add_argument("--actor-lr", type=float, default=3e-4) 25 | parser.add_argument("--critic-lr", type=float, default=3e-4) 26 | parser.add_argument("--hidden-dims", type=int, nargs='*', default=[400, 400]) 27 | parser.add_argument("--gamma", type=float, default=0.99) 28 | parser.add_argument("--tau", type=float, default=0.005) 29 | parser.add_argument("--alpha", type=float, default=0.2) 30 | parser.add_argument("--auto-alpha", default=True) 31 | parser.add_argument("--target-entropy", type=int, default=None) 32 | parser.add_argument("--alpha-lr", type=float, default=3e-4) 33 | parser.add_argument("--lmbda", type=float, default=0.9) 34 | parser.add_argument("--num-sampled-actions", type=int, default=10) 35 | parser.add_argument("--behavior-policy-lr", type=float, default=1e-3) 36 | parser.add_argument("--epoch", type=int, default=1000) 37 | parser.add_argument("--step-per-epoch", type=int, default=1000) 38 | parser.add_argument("--eval_episodes", type=int, default=10) 39 | parser.add_argument("--batch-size", type=int, default=256) 40 | parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu") 41 | 42 | return parser.parse_args() 43 | 44 | 45 | def train(args=get_args()): 46 | # create env and dataset 47 | env = gym.make(args.task) 48 | dataset = d4rl.qlearning_dataset(env) 49 | if 'antmaze' in args.task: 50 | dataset["rewards"] -= 1.0 51 | args.obs_shape = env.observation_space.shape 52 | args.action_dim = np.prod(env.action_space.shape) 53 | args.max_action = env.action_space.high[0] 54 | 55 | # seed 56 | random.seed(args.seed) 57 | np.random.seed(args.seed) 58 | torch.manual_seed(args.seed) 59 | torch.cuda.manual_seed_all(args.seed) 60 | torch.backends.cudnn.deterministic = True 61 | env.seed(args.seed) 62 | 63 | # create policy model 64 | actor_backbone = MLP(input_dim=np.prod(args.obs_shape), hidden_dims=args.hidden_dims, dropout_rate=0.1) 65 | critic1_backbone = MLP(input_dim=np.prod(args.obs_shape) + args.action_dim, hidden_dims=args.hidden_dims) 66 | critic2_backbone = MLP(input_dim=np.prod(args.obs_shape) + args.action_dim, hidden_dims=args.hidden_dims) 67 | dist = TanhDiagGaussian( 68 | latent_dim=getattr(actor_backbone, "output_dim"), 69 | output_dim=args.action_dim, 70 | unbounded=True, 71 | conditioned_sigma=True, 72 | max_mu=args.max_action 73 | ) 74 | actor = ActorProb(actor_backbone, dist, args.device) 75 | critic1 = Critic(critic1_backbone, args.device) 76 | critic2 = Critic(critic2_backbone, args.device) 77 | actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr) 78 | critic1_optim = torch.optim.Adam(critic1.parameters(), lr=args.critic_lr) 79 | critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr) 80 | 81 | if args.auto_alpha: 82 | target_entropy = args.target_entropy if args.target_entropy \ 83 | else -np.prod(env.action_space.shape) 84 | 85 | args.target_entropy = target_entropy 86 | 87 | log_alpha = torch.zeros(1, requires_grad=True, device=args.device) 88 | alpha_optim = torch.optim.Adam([log_alpha], lr=args.alpha_lr) 89 | alpha = (target_entropy, log_alpha, alpha_optim) 90 | else: 91 | alpha = args.alpha 92 | 93 | behavior_policy = VAE( 94 | input_dim=np.prod(args.obs_shape), 95 | output_dim=args.action_dim, 96 | hidden_dim=750, 97 | latent_dim=args.action_dim*2, 98 | max_action=args.max_action, 99 | device=args.device 100 | ) 101 | behavior_policy_optim = torch.optim.Adam(behavior_policy.parameters(), lr=args.behavior_policy_lr) 102 | 103 | # create policy 104 | policy = MCQPolicy( 105 | actor, 106 | critic1, 107 | critic2, 108 | behavior_policy, 109 | actor_optim, 110 | critic1_optim, 111 | critic2_optim, 112 | behavior_policy_optim, 113 | tau=args.tau, 114 | gamma=args.gamma, 115 | alpha=alpha, 116 | lmbda=args.lmbda, 117 | num_sampled_actions=args.num_sampled_actions 118 | ) 119 | 120 | # create buffer 121 | buffer = ReplayBuffer( 122 | buffer_size=len(dataset["observations"]), 123 | obs_shape=args.obs_shape, 124 | obs_dtype=np.float32, 125 | action_dim=args.action_dim, 126 | action_dtype=np.float32, 127 | device=args.device 128 | ) 129 | buffer.load_dataset(dataset) 130 | 131 | # log 132 | log_dirs = make_log_dirs(args.task, args.algo_name, args.seed, vars(args)) 133 | # key: output file name, value: output handler type 134 | output_config = { 135 | "consoleout_backup": "stdout", 136 | "policy_training_progress": "csv", 137 | "tb": "tensorboard" 138 | } 139 | logger = Logger(log_dirs, output_config) 140 | logger.log_hyperparameters(vars(args)) 141 | 142 | # create policy trainer 143 | policy_trainer = MFPolicyTrainer( 144 | policy=policy, 145 | eval_env=env, 146 | buffer=buffer, 147 | logger=logger, 148 | epoch=args.epoch, 149 | step_per_epoch=args.step_per_epoch, 150 | batch_size=args.batch_size, 151 | eval_episodes=args.eval_episodes 152 | ) 153 | 154 | # train 155 | policy_trainer.train() 156 | 157 | 158 | if __name__ == "__main__": 159 | train() -------------------------------------------------------------------------------- /run_example/plotter.py: -------------------------------------------------------------------------------- 1 | import csv 2 | import os 3 | 4 | import numpy as np 5 | import pandas as pd 6 | 7 | import matplotlib.pyplot as plt 8 | import argparse 9 | 10 | import sys 11 | sys.path.append('.') 12 | 13 | 14 | COLORS = ( 15 | [ 16 | '#318DE9', # blue 17 | '#FF7D00', # orange 18 | '#E52B50', # red 19 | '#8D6AB8', # purple 20 | '#00CD66', # green 21 | '#FFD700', # yellow 22 | ] 23 | ) 24 | 25 | 26 | def merge_csv(root_dir, query_file, query_x, query_y): 27 | """Merge result in csv_files into a single csv file.""" 28 | csv_files = [] 29 | for dirname, _, files in os.walk(root_dir): 30 | for f in files: 31 | if f == query_file: 32 | csv_files.append(os.path.join(dirname, f)) 33 | results = {} 34 | for csv_file in csv_files: 35 | content = [[query_x, query_y]] 36 | df = pd.read_csv(csv_file) 37 | values = df[[query_x, query_y]].values 38 | for line in values: 39 | if np.isnan(line[1]): continue 40 | content.append(line) 41 | results[csv_file] = content 42 | assert len(results) > 0 43 | sorted_keys = sorted(results.keys()) 44 | sorted_values = [results[k][1:] for k in sorted_keys] 45 | content = [ 46 | [query_x, query_y+'_mean', query_y+'_std'] 47 | ] 48 | for rows in zip(*sorted_values): 49 | array = np.array(rows) 50 | assert len(set(array[:, 0])) == 1, (set(array[:, 0]), array[:, 0]) 51 | line = [rows[0][0], round(array[:, 1].mean(), 4), round(array[:, 1].std(), 4)] 52 | content.append(line) 53 | output_path = os.path.join(root_dir, query_y.replace('/', '_')+".csv") 54 | print(f"Output merged csv file to {output_path} with {len(content[1:])} lines.") 55 | csv.writer(open(output_path, "w")).writerows(content) 56 | return output_path 57 | 58 | 59 | def csv2numpy(file_path): 60 | df = pd.read_csv(file_path) 61 | step = df.iloc[:,0].to_numpy() 62 | mean = df.iloc[:,1].to_numpy() 63 | std = df.iloc[:,2].to_numpy() 64 | return step, mean, std 65 | 66 | 67 | def smooth(y, radius=0): 68 | convkernel = np.ones(2 * radius + 1) 69 | out = np.convolve(y, convkernel, mode='same') / np.convolve(np.ones_like(y), convkernel, mode='same') 70 | return out 71 | 72 | 73 | def plot_figure( 74 | results, 75 | x_label, 76 | y_label, 77 | xlim=None, 78 | ylim=None, 79 | title=None, 80 | smooth_radius=10, 81 | figsize=None, 82 | dpi=None, 83 | color_list=None, 84 | legend_outside=False 85 | ): 86 | fig, ax = plt.subplots(figsize=figsize, dpi=dpi) 87 | if color_list == None: 88 | color_list = [COLORS[i] for i in range(len(results))] 89 | else: 90 | assert len(color_list) == len(results) 91 | for i, (algo_name, csv_file) in enumerate(results.items()): 92 | x, y, shaded = csv2numpy(csv_file) 93 | y = smooth(y, smooth_radius) 94 | shaded = smooth(shaded, smooth_radius) 95 | ax.plot(x, y, color=color_list[i], label=algo_name) 96 | ax.fill_between(x, y-shaded, y+shaded, color=color_list[i], alpha=0.2) 97 | ax.set_title(title, fontdict={'size': 10}) 98 | ax.set_xlabel(x_label, fontdict={'size': 10}) 99 | ax.set_ylabel(y_label, fontdict={'size': 10}) 100 | if xlim is not None: 101 | ax.set_xlim(*xlim) 102 | if ylim is not None: 103 | ax.set_ylim(*ylim) 104 | if legend_outside: 105 | ax.legend(loc=2, bbox_to_anchor=(1,1), prop={'size': 10}) 106 | else: 107 | ax.legend(prop={'size': 10}) 108 | 109 | 110 | def plot_func( 111 | root_dir, 112 | task, 113 | algos, 114 | query_file, 115 | query_x, 116 | query_y, 117 | xlabel, 118 | ylabel, 119 | xlim=None, 120 | ylim=None, 121 | title=None, 122 | smooth_radius=10, 123 | figsize=None, 124 | dpi=None, 125 | colors=None, 126 | legend_outside=False 127 | ): 128 | results = {} 129 | for algo in algos: 130 | path = os.path.join(root_dir, task, algo) 131 | csv_file = merge_csv(path, query_file, query_x, query_y) 132 | results[algo] = csv_file 133 | 134 | plt.style.use('seaborn') 135 | plot_figure( 136 | results=results, 137 | x_label=xlabel, 138 | y_label=ylabel, 139 | xlim=xlim, 140 | ylim=ylim, 141 | title=title, 142 | smooth_radius=smooth_radius, 143 | figsize=figsize, 144 | dpi=dpi, 145 | color_list=colors, 146 | legend_outside=legend_outside 147 | ) 148 | plt.show() 149 | 150 | 151 | if __name__ == "__main__": 152 | parser = argparse.ArgumentParser(description="plotter") 153 | parser.add_argument("--root-dir", default="log") 154 | parser.add_argument("--task", default="walker2d-medium-v2") 155 | parser.add_argument("--algos", type=str, nargs='*', default=["mopo&penalty_coef=0.5&rollout_length=5"]) 156 | parser.add_argument("--query-file", default="policy_training_progress.csv") 157 | parser.add_argument("--query-x", default="timestep") 158 | parser.add_argument("--query-y", default="eval/normalized_episode_reward") 159 | parser.add_argument("--title", default=None) 160 | parser.add_argument("--xlabel", default="Timesteps") 161 | parser.add_argument("--ylabel", default=None) 162 | parser.add_argument("--smooth", type=int, default=10) 163 | parser.add_argument("--colors", type=str, nargs='*', default=None) 164 | parser.add_argument("--show", action='store_true') 165 | parser.add_argument("--output-path", default="./1.png") 166 | parser.add_argument("--figsize", type=float, nargs=2, default=(8, 6)) 167 | parser.add_argument("--dpi", type=int, default=500) 168 | args = parser.parse_args() 169 | 170 | results = {} 171 | for algo in args.algos: 172 | path = os.path.join(args.root_dir, args.task, algo) 173 | csv_file = merge_csv(path, args.query_file, args.query_x, args.query_y) 174 | results[algo] = csv_file 175 | 176 | plt.style.use('seaborn') 177 | plot_figure( 178 | results=results, 179 | x_label=args.xlabel, 180 | y_label=args.ylabel, 181 | title=args.title, 182 | smooth_radius=args.smooth, 183 | figsize=args.figsize, 184 | dpi=args.dpi, 185 | color_list=args.colors 186 | ) 187 | if args.output_path: 188 | plt.savefig(args.output_path) 189 | if args.show: 190 | plt.show() -------------------------------------------------------------------------------- /run_example/run_edac.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import random 3 | 4 | import gym 5 | import d4rl 6 | 7 | import numpy as np 8 | import torch 9 | 10 | 11 | from offlinerlkit.nets import MLP 12 | from offlinerlkit.modules import ActorProb, EnsembleCritic, TanhDiagGaussian 13 | from offlinerlkit.buffer import ReplayBuffer 14 | from offlinerlkit.utils.logger import Logger, make_log_dirs 15 | from offlinerlkit.policy_trainer import MFPolicyTrainer 16 | from offlinerlkit.policy import EDACPolicy 17 | 18 | 19 | """ 20 | suggested hypers 21 | 22 | halfcheetah-medium-v2: num-critics=10, eta=1.0 23 | hopper-medium-v2: num-critics=50, eta=1.0 24 | walker2d-medium-v2: num-critics=10, eta=1.0 25 | halfcheetah-medium-replay-v2: num-critics=10, eta=1.0 26 | hopper-medium-replay-v2: num-critics=50, eta=1.0 27 | walker2d-medium-replay-v2: num-critics=10, eta=1.0 28 | halfcheetah-medium-expert-v2: num-critics=10, eta=5.0 29 | hopper-medium-expert-v2: num-critics=50, eta=1.0 30 | walker2d-medium-expert-v2: num-critics=10, eta=5.0 31 | """ 32 | 33 | 34 | def get_args(): 35 | parser = argparse.ArgumentParser() 36 | parser.add_argument("--algo-name", type=str, default="edac") 37 | parser.add_argument("--task", type=str, default="hopper-medium-v2") 38 | parser.add_argument("--seed", type=int, default=1) 39 | parser.add_argument("--actor-lr", type=float, default=1e-4) 40 | parser.add_argument("--critic-lr", type=float, default=3e-4) 41 | parser.add_argument("--hidden-dims", type=int, nargs='*', default=[256, 256, 256]) 42 | parser.add_argument("--gamma", type=float, default=0.99) 43 | parser.add_argument("--tau", type=float, default=0.005) 44 | parser.add_argument("--alpha", type=float, default=0.2) 45 | parser.add_argument("--auto-alpha", type=bool, default=True) 46 | parser.add_argument("--target-entropy", type=int, default=None) 47 | parser.add_argument("--alpha-lr", type=float, default=1e-4) 48 | parser.add_argument("--num-critics", type=int, default=50) 49 | parser.add_argument("--max-q-backup", type=bool, default=False) 50 | parser.add_argument("--deterministic-backup", type=bool, default=False) 51 | parser.add_argument("--eta", type=float, default=1.0) 52 | parser.add_argument("--normalize-reward", type=bool, default=False) 53 | 54 | parser.add_argument("--epoch", type=int, default=3000) 55 | parser.add_argument("--step-per-epoch", type=int, default=1000) 56 | parser.add_argument("--eval_episodes", type=int, default=10) 57 | parser.add_argument("--batch-size", type=int, default=256) 58 | parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu") 59 | 60 | return parser.parse_args() 61 | 62 | 63 | def train(args=get_args()): 64 | # create env and dataset 65 | env = gym.make(args.task) 66 | dataset = d4rl.qlearning_dataset(env) 67 | if args.normalize_reward: 68 | mu, std = dataset["rewards"].mean(), dataset["rewards"].std() 69 | dataset["rewards"] = (dataset["rewards"] - mu) / (std + 1e-3) 70 | 71 | args.obs_shape = env.observation_space.shape 72 | args.action_dim = np.prod(env.action_space.shape) 73 | args.max_action = env.action_space.high[0] 74 | 75 | # seed 76 | random.seed(args.seed) 77 | np.random.seed(args.seed) 78 | torch.manual_seed(args.seed) 79 | torch.cuda.manual_seed_all(args.seed) 80 | torch.backends.cudnn.deterministic = True 81 | env.seed(args.seed) 82 | 83 | # create policy model 84 | actor_backbone = MLP(input_dim=np.prod(args.obs_shape), hidden_dims=args.hidden_dims) 85 | dist = TanhDiagGaussian( 86 | latent_dim=getattr(actor_backbone, "output_dim"), 87 | output_dim=args.action_dim, 88 | unbounded=True, 89 | conditioned_sigma=True, 90 | max_mu=args.max_action 91 | ) 92 | actor = ActorProb(actor_backbone, dist, args.device) 93 | actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr) 94 | critics = EnsembleCritic( 95 | np.prod(args.obs_shape), args.action_dim, \ 96 | args.hidden_dims, num_ensemble=args.num_critics, \ 97 | device=args.device 98 | ) 99 | # init as in the EDAC paper 100 | for layer in critics.model[::2]: 101 | torch.nn.init.constant_(layer.bias, 0.1) 102 | torch.nn.init.uniform_(critics.model[-1].weight, -3e-3, 3e-3) 103 | torch.nn.init.uniform_(critics.model[-1].bias, -3e-3, 3e-3) 104 | critics_optim = torch.optim.Adam(critics.parameters(), lr=args.critic_lr) 105 | 106 | if args.auto_alpha: 107 | target_entropy = args.target_entropy if args.target_entropy \ 108 | else -np.prod(env.action_space.shape) 109 | args.target_entropy = target_entropy 110 | log_alpha = torch.zeros(1, requires_grad=True, device=args.device) 111 | alpha_optim = torch.optim.Adam([log_alpha], lr=args.alpha_lr) 112 | alpha = (target_entropy, log_alpha, alpha_optim) 113 | else: 114 | alpha = args.alpha 115 | 116 | # create policy 117 | policy = EDACPolicy( 118 | actor, 119 | critics, 120 | actor_optim, 121 | critics_optim, 122 | tau=args.tau, 123 | gamma=args.gamma, 124 | alpha=alpha, 125 | max_q_backup=args.max_q_backup, 126 | deterministic_backup=args.deterministic_backup, 127 | eta=args.eta 128 | ) 129 | 130 | # create buffer 131 | buffer = ReplayBuffer( 132 | buffer_size=len(dataset["observations"]), 133 | obs_shape=args.obs_shape, 134 | obs_dtype=np.float32, 135 | action_dim=args.action_dim, 136 | action_dtype=np.float32, 137 | device=args.device 138 | ) 139 | buffer.load_dataset(dataset) 140 | 141 | # log 142 | log_dirs = make_log_dirs(args.task, args.algo_name, args.seed, vars(args), record_params=["num_critics", "eta"]) 143 | # key: output file name, value: output handler type 144 | output_config = { 145 | "consoleout_backup": "stdout", 146 | "policy_training_progress": "csv", 147 | "dynamics_training_progress": "csv", 148 | "tb": "tensorboard" 149 | } 150 | logger = Logger(log_dirs, output_config) 151 | logger.log_hyperparameters(vars(args)) 152 | 153 | # create policy trainer 154 | policy_trainer = MFPolicyTrainer( 155 | policy=policy, 156 | eval_env=env, 157 | buffer=buffer, 158 | logger=logger, 159 | epoch=args.epoch, 160 | step_per_epoch=args.step_per_epoch, 161 | batch_size=args.batch_size, 162 | eval_episodes=args.eval_episodes 163 | ) 164 | 165 | policy_trainer.train() 166 | 167 | 168 | if __name__ == "__main__": 169 | train() -------------------------------------------------------------------------------- /run_example/run_cql.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import random 3 | 4 | import gym 5 | import d4rl 6 | 7 | import numpy as np 8 | import torch 9 | 10 | 11 | from offlinerlkit.nets import MLP 12 | from offlinerlkit.modules import ActorProb, Critic, TanhDiagGaussian 13 | from offlinerlkit.buffer import ReplayBuffer 14 | from offlinerlkit.utils.logger import Logger, make_log_dirs 15 | from offlinerlkit.policy_trainer import MFPolicyTrainer 16 | from offlinerlkit.policy import CQLPolicy 17 | 18 | 19 | """ 20 | suggested hypers 21 | cql-weight=5.0, temperature=1.0 for all D4RL-Gym tasks 22 | """ 23 | 24 | 25 | def get_args(): 26 | parser = argparse.ArgumentParser() 27 | parser.add_argument("--algo-name", type=str, default="cql") 28 | parser.add_argument("--task", type=str, default="hopper-medium-v2") 29 | parser.add_argument("--seed", type=int, default=0) 30 | parser.add_argument("--hidden-dims", type=int, nargs='*', default=[256, 256, 256]) 31 | parser.add_argument("--actor-lr", type=float, default=1e-4) 32 | parser.add_argument("--critic-lr", type=float, default=3e-4) 33 | parser.add_argument("--gamma", type=float, default=0.99) 34 | parser.add_argument("--tau", type=float, default=0.005) 35 | parser.add_argument("--alpha", type=float, default=0.2) 36 | parser.add_argument("--target-entropy", type=int, default=None) 37 | parser.add_argument("--auto-alpha", default=True) 38 | parser.add_argument("--alpha-lr", type=float, default=1e-4) 39 | 40 | parser.add_argument("--cql-weight", type=float, default=5.0) 41 | parser.add_argument("--temperature", type=float, default=1.0) 42 | parser.add_argument("--max-q-backup", type=bool, default=False) 43 | parser.add_argument("--deterministic-backup", type=bool, default=True) 44 | parser.add_argument("--with-lagrange", type=bool, default=False) 45 | parser.add_argument("--lagrange-threshold", type=float, default=10.0) 46 | parser.add_argument("--cql-alpha-lr", type=float, default=3e-4) 47 | parser.add_argument("--num-repeat-actions", type=int, default=10) 48 | 49 | parser.add_argument("--epoch", type=int, default=1000) 50 | parser.add_argument("--step-per-epoch", type=int, default=1000) 51 | parser.add_argument("--eval_episodes", type=int, default=10) 52 | parser.add_argument("--batch-size", type=int, default=256) 53 | parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu") 54 | 55 | return parser.parse_args() 56 | 57 | 58 | def train(args=get_args()): 59 | # create env and dataset 60 | env = gym.make(args.task) 61 | dataset = d4rl.qlearning_dataset(env) 62 | # See https://github.com/aviralkumar2907/CQL/blob/master/d4rl/examples/cql_antmaze_new.py#L22 63 | if 'antmaze' in args.task: 64 | dataset["rewards"] = (dataset["rewards"] - 0.5) * 4.0 65 | args.obs_shape = env.observation_space.shape 66 | args.action_dim = np.prod(env.action_space.shape) 67 | args.max_action = env.action_space.high[0] 68 | 69 | # seed 70 | random.seed(args.seed) 71 | np.random.seed(args.seed) 72 | torch.manual_seed(args.seed) 73 | torch.cuda.manual_seed_all(args.seed) 74 | torch.backends.cudnn.deterministic = True 75 | env.seed(args.seed) 76 | 77 | # create policy model 78 | actor_backbone = MLP(input_dim=np.prod(args.obs_shape), hidden_dims=args.hidden_dims) 79 | critic1_backbone = MLP(input_dim=np.prod(args.obs_shape) + args.action_dim, hidden_dims=args.hidden_dims) 80 | critic2_backbone = MLP(input_dim=np.prod(args.obs_shape) + args.action_dim, hidden_dims=args.hidden_dims) 81 | dist = TanhDiagGaussian( 82 | latent_dim=getattr(actor_backbone, "output_dim"), 83 | output_dim=args.action_dim, 84 | unbounded=True, 85 | conditioned_sigma=True, 86 | max_mu=args.max_action 87 | ) 88 | actor = ActorProb(actor_backbone, dist, args.device) 89 | critic1 = Critic(critic1_backbone, args.device) 90 | critic2 = Critic(critic2_backbone, args.device) 91 | actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr) 92 | critic1_optim = torch.optim.Adam(critic1.parameters(), lr=args.critic_lr) 93 | critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr) 94 | 95 | if args.auto_alpha: 96 | target_entropy = args.target_entropy if args.target_entropy \ 97 | else -np.prod(env.action_space.shape) 98 | 99 | args.target_entropy = target_entropy 100 | 101 | log_alpha = torch.zeros(1, requires_grad=True, device=args.device) 102 | alpha_optim = torch.optim.Adam([log_alpha], lr=args.alpha_lr) 103 | alpha = (target_entropy, log_alpha, alpha_optim) 104 | else: 105 | alpha = args.alpha 106 | 107 | # create policy 108 | policy = CQLPolicy( 109 | actor, 110 | critic1, 111 | critic2, 112 | actor_optim, 113 | critic1_optim, 114 | critic2_optim, 115 | action_space=env.action_space, 116 | tau=args.tau, 117 | gamma=args.gamma, 118 | alpha=alpha, 119 | cql_weight=args.cql_weight, 120 | temperature=args.temperature, 121 | max_q_backup=args.max_q_backup, 122 | deterministic_backup=args.deterministic_backup, 123 | with_lagrange=args.with_lagrange, 124 | lagrange_threshold=args.lagrange_threshold, 125 | cql_alpha_lr=args.cql_alpha_lr, 126 | num_repeart_actions=args.num_repeat_actions 127 | ) 128 | 129 | # create buffer 130 | buffer = ReplayBuffer( 131 | buffer_size=len(dataset["observations"]), 132 | obs_shape=args.obs_shape, 133 | obs_dtype=np.float32, 134 | action_dim=args.action_dim, 135 | action_dtype=np.float32, 136 | device=args.device 137 | ) 138 | buffer.load_dataset(dataset) 139 | 140 | # log 141 | log_dirs = make_log_dirs(args.task, args.algo_name, args.seed, vars(args)) 142 | # key: output file name, value: output handler type 143 | output_config = { 144 | "consoleout_backup": "stdout", 145 | "policy_training_progress": "csv", 146 | "tb": "tensorboard" 147 | } 148 | logger = Logger(log_dirs, output_config) 149 | logger.log_hyperparameters(vars(args)) 150 | 151 | # create policy trainer 152 | policy_trainer = MFPolicyTrainer( 153 | policy=policy, 154 | eval_env=env, 155 | buffer=buffer, 156 | logger=logger, 157 | epoch=args.epoch, 158 | step_per_epoch=args.step_per_epoch, 159 | batch_size=args.batch_size, 160 | eval_episodes=args.eval_episodes 161 | ) 162 | 163 | # train 164 | policy_trainer.train() 165 | 166 | 167 | if __name__ == "__main__": 168 | train() -------------------------------------------------------------------------------- /offlinerlkit/policy_trainer/mb_policy_trainer.py: -------------------------------------------------------------------------------- 1 | import time 2 | import os 3 | 4 | import numpy as np 5 | import torch 6 | import gym 7 | 8 | from typing import Optional, Dict, List, Tuple 9 | from tqdm import tqdm 10 | from collections import deque 11 | from offlinerlkit.buffer import ReplayBuffer 12 | from offlinerlkit.utils.logger import Logger 13 | from offlinerlkit.policy import BasePolicy 14 | 15 | 16 | # model-based policy trainer 17 | class MBPolicyTrainer: 18 | def __init__( 19 | self, 20 | policy: BasePolicy, 21 | eval_env: gym.Env, 22 | real_buffer: ReplayBuffer, 23 | fake_buffer: ReplayBuffer, 24 | logger: Logger, 25 | rollout_setting: Tuple[int, int, int], 26 | epoch: int = 1000, 27 | step_per_epoch: int = 1000, 28 | batch_size: int = 256, 29 | real_ratio: float = 0.05, 30 | eval_episodes: int = 10, 31 | lr_scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None, 32 | dynamics_update_freq: int = 0 33 | ) -> None: 34 | self.policy = policy 35 | self.eval_env = eval_env 36 | self.real_buffer = real_buffer 37 | self.fake_buffer = fake_buffer 38 | self.logger = logger 39 | 40 | self._rollout_freq, self._rollout_batch_size, \ 41 | self._rollout_length = rollout_setting 42 | self._dynamics_update_freq = dynamics_update_freq 43 | 44 | self._epoch = epoch 45 | self._step_per_epoch = step_per_epoch 46 | self._batch_size = batch_size 47 | self._real_ratio = real_ratio 48 | self._eval_episodes = eval_episodes 49 | self.lr_scheduler = lr_scheduler 50 | 51 | def train(self) -> Dict[str, float]: 52 | start_time = time.time() 53 | 54 | num_timesteps = 0 55 | last_10_performance = deque(maxlen=10) 56 | # train loop 57 | for e in range(1, self._epoch + 1): 58 | 59 | self.policy.train() 60 | 61 | pbar = tqdm(range(self._step_per_epoch), desc=f"Epoch #{e}/{self._epoch}") 62 | for it in pbar: 63 | if num_timesteps % self._rollout_freq == 0: 64 | init_obss = self.real_buffer.sample(self._rollout_batch_size)["observations"].cpu().numpy() 65 | rollout_transitions, rollout_info = self.policy.rollout(init_obss, self._rollout_length) 66 | self.fake_buffer.add_batch(**rollout_transitions) 67 | self.logger.log( 68 | "num rollout transitions: {}, reward mean: {:.4f}".\ 69 | format(rollout_info["num_transitions"], rollout_info["reward_mean"]) 70 | ) 71 | for _key, _value in rollout_info.items(): 72 | self.logger.logkv_mean("rollout_info/"+_key, _value) 73 | 74 | real_sample_size = int(self._batch_size * self._real_ratio) 75 | fake_sample_size = self._batch_size - real_sample_size 76 | real_batch = self.real_buffer.sample(batch_size=real_sample_size) 77 | fake_batch = self.fake_buffer.sample(batch_size=fake_sample_size) 78 | batch = {"real": real_batch, "fake": fake_batch} 79 | loss = self.policy.learn(batch) 80 | pbar.set_postfix(**loss) 81 | 82 | for k, v in loss.items(): 83 | self.logger.logkv_mean(k, v) 84 | 85 | # update the dynamics if necessary 86 | if 0 < self._dynamics_update_freq and (num_timesteps+1)%self._dynamics_update_freq == 0: 87 | dynamics_update_info = self.policy.update_dynamics(self.real_buffer) 88 | for k, v in dynamics_update_info.items(): 89 | self.logger.logkv_mean(k, v) 90 | 91 | num_timesteps += 1 92 | 93 | if self.lr_scheduler is not None: 94 | self.lr_scheduler.step() 95 | 96 | # evaluate current policy 97 | eval_info = self._evaluate() 98 | ep_reward_mean, ep_reward_std = np.mean(eval_info["eval/episode_reward"]), np.std(eval_info["eval/episode_reward"]) 99 | ep_length_mean, ep_length_std = np.mean(eval_info["eval/episode_length"]), np.std(eval_info["eval/episode_length"]) 100 | norm_ep_rew_mean = self.eval_env.get_normalized_score(ep_reward_mean) * 100 101 | norm_ep_rew_std = self.eval_env.get_normalized_score(ep_reward_std) * 100 102 | last_10_performance.append(norm_ep_rew_mean) 103 | self.logger.logkv("eval/normalized_episode_reward", norm_ep_rew_mean) 104 | self.logger.logkv("eval/normalized_episode_reward_std", norm_ep_rew_std) 105 | self.logger.logkv("eval/episode_length", ep_length_mean) 106 | self.logger.logkv("eval/episode_length_std", ep_length_std) 107 | self.logger.set_timestep(num_timesteps) 108 | self.logger.dumpkvs(exclude=["dynamics_training_progress"]) 109 | 110 | # save checkpoint 111 | torch.save(self.policy.state_dict(), os.path.join(self.logger.checkpoint_dir, "policy.pth")) 112 | 113 | self.logger.log("total time: {:.2f}s".format(time.time() - start_time)) 114 | torch.save(self.policy.state_dict(), os.path.join(self.logger.model_dir, "policy.pth")) 115 | self.policy.dynamics.save(self.logger.model_dir) 116 | self.logger.close() 117 | 118 | return {"last_10_performance": np.mean(last_10_performance)} 119 | 120 | def _evaluate(self) -> Dict[str, List[float]]: 121 | self.policy.eval() 122 | obs = self.eval_env.reset() 123 | eval_ep_info_buffer = [] 124 | num_episodes = 0 125 | episode_reward, episode_length = 0, 0 126 | 127 | while num_episodes < self._eval_episodes: 128 | action = self.policy.select_action(obs.reshape(1, -1), deterministic=True) 129 | next_obs, reward, terminal, _ = self.eval_env.step(action.flatten()) 130 | episode_reward += reward 131 | episode_length += 1 132 | 133 | obs = next_obs 134 | 135 | if terminal: 136 | eval_ep_info_buffer.append( 137 | {"episode_reward": episode_reward, "episode_length": episode_length} 138 | ) 139 | num_episodes +=1 140 | episode_reward, episode_length = 0, 0 141 | obs = self.eval_env.reset() 142 | 143 | return { 144 | "eval/episode_reward": [ep_info["episode_reward"] for ep_info in eval_ep_info_buffer], 145 | "eval/episode_length": [ep_info["episode_length"] for ep_info in eval_ep_info_buffer] 146 | } -------------------------------------------------------------------------------- /offlinerlkit/policy/model_free/edac.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | 5 | from typing import Dict, Union, Tuple 6 | from copy import deepcopy 7 | from offlinerlkit.policy import BasePolicy 8 | 9 | 10 | class EDACPolicy(BasePolicy): 11 | """ 12 | Ensemble-Diversified Actor Critic 13 | """ 14 | 15 | def __init__( 16 | self, 17 | actor: nn.Module, 18 | critics: nn.ModuleList, 19 | actor_optim: torch.optim.Optimizer, 20 | critics_optim: torch.optim.Optimizer, 21 | tau: float = 0.005, 22 | gamma: float = 0.99, 23 | alpha: Union[float, Tuple[float, torch.Tensor, torch.optim.Optimizer]] = 0.2, 24 | max_q_backup: bool = False, 25 | deterministic_backup: bool = True, 26 | eta: float = 1.0 27 | ) -> None: 28 | 29 | super().__init__() 30 | self.actor = actor 31 | self.critics = critics 32 | self.critics_old = deepcopy(critics) 33 | self.critics_old.eval() 34 | 35 | self.actor_optim = actor_optim 36 | self.critics_optim = critics_optim 37 | 38 | self._tau = tau 39 | self._gamma = gamma 40 | 41 | self._is_auto_alpha = False 42 | if isinstance(alpha, tuple): 43 | self._is_auto_alpha = True 44 | self._target_entropy, self._log_alpha, self.alpha_optim = alpha 45 | self._alpha = self._log_alpha.detach().exp() 46 | else: 47 | self._alpha = alpha 48 | 49 | self._max_q_backup = max_q_backup 50 | self._deterministic_backup = deterministic_backup 51 | self._eta = eta 52 | self._num_critics = self.critics._num_ensemble 53 | 54 | def train(self) -> None: 55 | self.actor.train() 56 | self.critics.train() 57 | 58 | def eval(self) -> None: 59 | self.actor.eval() 60 | self.critics.eval() 61 | 62 | def _sync_weight(self) -> None: 63 | for o, n in zip(self.critics_old.parameters(), self.critics.parameters()): 64 | o.data.copy_(o.data * (1.0 - self._tau) + n.data * self._tau) 65 | 66 | def actforward( 67 | self, 68 | obs: torch.Tensor, 69 | deterministic: bool = False 70 | ) -> Tuple[torch.Tensor, torch.Tensor]: 71 | dist = self.actor(obs) 72 | if deterministic: 73 | squashed_action, raw_action = dist.mode() 74 | else: 75 | squashed_action, raw_action = dist.rsample() 76 | log_prob = dist.log_prob(squashed_action, raw_action) 77 | return squashed_action, log_prob 78 | 79 | def select_action( 80 | self, 81 | obs: np.ndarray, 82 | deterministic: bool = False 83 | ) -> np.ndarray: 84 | with torch.no_grad(): 85 | action, _ = self.actforward(obs, deterministic) 86 | return action.cpu().numpy() 87 | 88 | def learn(self, batch: Dict) -> Dict: 89 | obss, actions, next_obss, rewards, terminals = \ 90 | batch["observations"], batch["actions"], batch["next_observations"], batch["rewards"], batch["terminals"] 91 | 92 | if self._eta > 0: 93 | actions.requires_grad_(True) 94 | 95 | # update actor 96 | a, log_probs = self.actforward(obss) 97 | # qas: [num_critics, batch_size, 1] 98 | qas = self.critics(obss, a) 99 | actor_loss = -torch.min(qas, 0)[0].mean() + self._alpha * log_probs.mean() 100 | self.actor_optim.zero_grad() 101 | actor_loss.backward() 102 | self.actor_optim.step() 103 | 104 | if self._is_auto_alpha: 105 | log_probs = log_probs.detach() + self._target_entropy 106 | alpha_loss = -(self._log_alpha * log_probs).mean() 107 | self.alpha_optim.zero_grad() 108 | alpha_loss.backward() 109 | self.alpha_optim.step() 110 | self._alpha = torch.clamp(self._log_alpha.detach().exp(), 0.0, 1.0) 111 | 112 | # update critic 113 | if self._max_q_backup: 114 | with torch.no_grad(): 115 | batch_size = obss.shape[0] 116 | tmp_next_obss = next_obss.unsqueeze(1).repeat(1, 10, 1) \ 117 | .view(batch_size * 10, next_obss.shape[-1]) 118 | tmp_next_actions, _ = self.actforward(tmp_next_obss) 119 | tmp_next_qs = self.critics_old(tmp_next_obss, tmp_next_actions) \ 120 | .view(self._num_critics, batch_size, 10, 1).max(2)[0] \ 121 | .view(self._num_critics, batch_size, 1) 122 | next_q = tmp_next_qs.min(0)[0] 123 | else: 124 | with torch.no_grad(): 125 | next_actions, next_log_probs = self.actforward(next_obss) 126 | next_q = self.critics_old(next_obss, next_actions).min(0)[0] 127 | if not self._deterministic_backup: 128 | next_q -= self._alpha * next_log_probs 129 | 130 | # target_q: [batch_size, 1] 131 | target_q = rewards + self._gamma * (1 - terminals) * next_q 132 | # qs: [num_critics, batch_size, 1] 133 | qs = self.critics(obss, actions) 134 | critics_loss = ((qs - target_q.unsqueeze(0)).pow(2)).mean(dim=(1, 2)).sum() 135 | 136 | if self._eta > 0: 137 | obss_tile = obss.unsqueeze(0).repeat(self._num_critics, 1, 1) 138 | actions_tile = actions.unsqueeze(0).repeat(self._num_critics, 1, 1).requires_grad_(True) 139 | qs_preds_tile = self.critics(obss_tile, actions_tile) 140 | qs_pred_grads, = torch.autograd.grad(qs_preds_tile.sum(), actions_tile, retain_graph=True, create_graph=True) 141 | qs_pred_grads = qs_pred_grads / (torch.norm(qs_pred_grads, p=2, dim=2).unsqueeze(-1) + 1e-10) 142 | qs_pred_grads = qs_pred_grads.transpose(0, 1) 143 | 144 | qs_pred_grads = torch.einsum('bik,bjk->bij', qs_pred_grads, qs_pred_grads) 145 | masks = torch.eye(self._num_critics, device=obss.device).unsqueeze(dim=0).repeat(qs_pred_grads.size(0), 1, 1) 146 | qs_pred_grads = (1 - masks) * qs_pred_grads 147 | grad_loss = torch.mean(torch.sum(qs_pred_grads, dim=(1, 2))) / (self._num_critics - 1) 148 | 149 | critics_loss += self._eta * grad_loss 150 | 151 | self.critics_optim.zero_grad() 152 | critics_loss.backward() 153 | self.critics_optim.step() 154 | 155 | self._sync_weight() 156 | 157 | result = { 158 | "loss/actor": actor_loss.item(), 159 | "loss/critics": critics_loss.item() 160 | } 161 | 162 | if self._is_auto_alpha: 163 | result["loss/alpha"] = alpha_loss.item() 164 | result["alpha"] = self._alpha.item() 165 | 166 | return result 167 | 168 | 169 | 170 | -------------------------------------------------------------------------------- /offlinerlkit/utils/load_dataset.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import collections 4 | 5 | 6 | def qlearning_dataset(env, dataset=None, terminate_on_end=False, **kwargs): 7 | """ 8 | Returns datasets formatted for use by standard Q-learning algorithms, 9 | with observations, actions, next_observations, rewards, and a terminal 10 | flag. 11 | 12 | Args: 13 | env: An OfflineEnv object. 14 | dataset: An optional dataset to pass in for processing. If None, 15 | the dataset will default to env.get_dataset() 16 | terminate_on_end (bool): Set done=True on the last timestep 17 | in a trajectory. Default is False, and will discard the 18 | last timestep in each trajectory. 19 | **kwargs: Arguments to pass to env.get_dataset(). 20 | 21 | Returns: 22 | A dictionary containing keys: 23 | observations: An N x dim_obs array of observations. 24 | actions: An N x dim_action array of actions. 25 | next_observations: An N x dim_obs array of next observations. 26 | rewards: An N-dim float array of rewards. 27 | terminals: An N-dim boolean array of "done" or episode termination flags. 28 | """ 29 | if dataset is None: 30 | dataset = env.get_dataset(**kwargs) 31 | 32 | has_next_obs = True if 'next_observations' in dataset.keys() else False 33 | 34 | N = dataset['rewards'].shape[0] 35 | obs_ = [] 36 | next_obs_ = [] 37 | action_ = [] 38 | reward_ = [] 39 | done_ = [] 40 | 41 | # The newer version of the dataset adds an explicit 42 | # timeouts field. Keep old method for backwards compatability. 43 | use_timeouts = False 44 | if 'timeouts' in dataset: 45 | use_timeouts = True 46 | 47 | episode_step = 0 48 | for i in range(N-1): 49 | obs = dataset['observations'][i].astype(np.float32) 50 | if has_next_obs: 51 | new_obs = dataset['next_observations'][i].astype(np.float32) 52 | else: 53 | new_obs = dataset['observations'][i+1].astype(np.float32) 54 | action = dataset['actions'][i].astype(np.float32) 55 | reward = dataset['rewards'][i].astype(np.float32) 56 | done_bool = bool(dataset['terminals'][i]) 57 | 58 | if use_timeouts: 59 | final_timestep = dataset['timeouts'][i] 60 | else: 61 | final_timestep = (episode_step == env._max_episode_steps - 1) 62 | if (not terminate_on_end) and final_timestep: 63 | # Skip this transition and don't apply terminals on the last step of an episode 64 | episode_step = 0 65 | continue 66 | if done_bool or final_timestep: 67 | episode_step = 0 68 | if not has_next_obs: 69 | continue 70 | 71 | obs_.append(obs) 72 | next_obs_.append(new_obs) 73 | action_.append(action) 74 | reward_.append(reward) 75 | done_.append(done_bool) 76 | episode_step += 1 77 | 78 | return { 79 | 'observations': np.array(obs_), 80 | 'actions': np.array(action_), 81 | 'next_observations': np.array(next_obs_), 82 | 'rewards': np.array(reward_), 83 | 'terminals': np.array(done_), 84 | } 85 | 86 | 87 | class SequenceDataset(torch.utils.data.Dataset): 88 | def __init__(self, dataset, max_len, max_ep_len=1000, device="cpu"): 89 | super().__init__() 90 | 91 | self.obs_dim = dataset["observations"].shape[-1] 92 | self.action_dim = dataset["actions"].shape[-1] 93 | self.max_len = max_len 94 | self.max_ep_len = max_ep_len 95 | self.device = torch.device(device) 96 | self.input_mean = np.concatenate([dataset["observations"], dataset["actions"]], axis=1).mean(0) 97 | self.input_std = np.concatenate([dataset["observations"], dataset["actions"]], axis=1).std(0) + 1e-6 98 | 99 | data_ = collections.defaultdict(list) 100 | 101 | use_timeouts = False 102 | if 'timeouts' in dataset: 103 | use_timeouts = True 104 | 105 | episode_step = 0 106 | self.trajs = [] 107 | for i in range(dataset["rewards"].shape[0]): 108 | done_bool = bool(dataset['terminals'][i]) 109 | if use_timeouts: 110 | final_timestep = dataset['timeouts'][i] 111 | else: 112 | final_timestep = (episode_step == 1000-1) 113 | for k in ['observations', 'next_observations', 'actions', 'rewards', 'terminals']: 114 | data_[k].append(dataset[k][i]) 115 | if done_bool or final_timestep: 116 | episode_step = 0 117 | episode_data = {} 118 | for k in data_: 119 | episode_data[k] = np.array(data_[k]) 120 | self.trajs.append(episode_data) 121 | data_ = collections.defaultdict(list) 122 | episode_step += 1 123 | 124 | indices = [] 125 | for traj_ind, traj in enumerate(self.trajs): 126 | end = len(traj["rewards"]) 127 | for i in range(end): 128 | indices.append((traj_ind, i, i+self.max_len)) 129 | 130 | self.indices = np.array(indices) 131 | 132 | 133 | returns = np.array([np.sum(t['rewards']) for t in self.trajs]) 134 | num_samples = np.sum([t['rewards'].shape[0] for t in self.trajs]) 135 | print(f'Number of samples collected: {num_samples}') 136 | print(f'Num trajectories: {len(self.trajs)}') 137 | print(f'Trajectory returns: mean = {np.mean(returns)}, std = {np.std(returns)}, max = {np.max(returns)}, min = {np.min(returns)}') 138 | 139 | def __len__(self): 140 | return len(self.indices) 141 | 142 | def __getitem__(self, idx): 143 | traj_ind, start_ind, end_ind = self.indices[idx] 144 | traj = self.trajs[traj_ind].copy() 145 | obss = traj['observations'][start_ind:end_ind] 146 | actions = traj['actions'][start_ind:end_ind] 147 | next_obss = traj['next_observations'][start_ind:end_ind] 148 | rewards = traj['rewards'][start_ind:end_ind].reshape(-1, 1) 149 | delta_obss = next_obss - obss 150 | 151 | # padding 152 | tlen = obss.shape[0] 153 | inputs = np.concatenate([obss, actions], axis=1) 154 | inputs = (inputs - self.input_mean) / self.input_std 155 | inputs = np.concatenate([inputs, np.zeros((self.max_len - tlen, self.obs_dim+self.action_dim))], axis=0) 156 | targets = np.concatenate([delta_obss, rewards], axis=1) 157 | targets = np.concatenate([targets, np.zeros((self.max_len - tlen, self.obs_dim+1))], axis=0) 158 | masks = np.concatenate([np.ones(tlen), np.zeros(self.max_len - tlen)], axis=0) 159 | 160 | inputs = torch.from_numpy(inputs).to(dtype=torch.float32, device=self.device) 161 | targets = torch.from_numpy(targets).to(dtype=torch.float32, device=self.device) 162 | masks = torch.from_numpy(masks).to(dtype=torch.float32, device=self.device) 163 | 164 | return inputs, targets, masks -------------------------------------------------------------------------------- /run_example/run_iql.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import random 3 | 4 | import gym 5 | import d4rl 6 | 7 | import numpy as np 8 | import torch 9 | 10 | 11 | from offlinerlkit.nets import MLP 12 | from offlinerlkit.modules import ActorProb, Critic, DiagGaussian 13 | from offlinerlkit.buffer import ReplayBuffer 14 | from offlinerlkit.utils.logger import Logger, make_log_dirs 15 | from offlinerlkit.policy_trainer import MFPolicyTrainer 16 | from offlinerlkit.policy import IQLPolicy 17 | 18 | """ 19 | suggested hypers 20 | expectile=0.7, temperature=3.0 for all D4RL-Gym tasks 21 | """ 22 | 23 | 24 | def get_args(): 25 | parser = argparse.ArgumentParser() 26 | parser.add_argument("--algo-name", type=str, default="iql") 27 | parser.add_argument("--task", type=str, default="hopper-medium-replay-v2") 28 | parser.add_argument("--seed", type=int, default=0) 29 | parser.add_argument("--hidden-dims", type=int, nargs='*', default=[256, 256]) 30 | parser.add_argument("--actor-lr", type=float, default=3e-4) 31 | parser.add_argument("--critic-q-lr", type=float, default=3e-4) 32 | parser.add_argument("--critic-v-lr", type=float, default=3e-4) 33 | parser.add_argument("--dropout_rate", type=float, default=None) 34 | parser.add_argument("--lr-decay", type=bool, default=True) 35 | parser.add_argument("--gamma", type=float, default=0.99) 36 | parser.add_argument("--tau", type=float, default=0.005) 37 | parser.add_argument("--expectile", type=float, default=0.7) 38 | parser.add_argument("--temperature", type=float, default=3.0) 39 | parser.add_argument("--epoch", type=int, default=1000) 40 | parser.add_argument("--step-per-epoch", type=int, default=1000) 41 | parser.add_argument("--eval_episodes", type=int, default=10) 42 | parser.add_argument("--batch-size", type=int, default=256) 43 | parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu") 44 | 45 | return parser.parse_args() 46 | 47 | 48 | def normalize_rewards(dataset): 49 | terminals_float = np.zeros_like(dataset["rewards"]) 50 | for i in range(len(terminals_float) - 1): 51 | if np.linalg.norm(dataset["observations"][i + 1] - 52 | dataset["next_observations"][i] 53 | ) > 1e-6 or dataset["terminals"][i] == 1.0: 54 | terminals_float[i] = 1 55 | else: 56 | terminals_float[i] = 0 57 | 58 | terminals_float[-1] = 1 59 | 60 | # split_into_trajectories 61 | trajs = [[]] 62 | for i in range(len(dataset["observations"])): 63 | trajs[-1].append((dataset["observations"][i], dataset["actions"][i], dataset["rewards"][i], 1.0-dataset["terminals"][i], 64 | terminals_float[i], dataset["next_observations"][i])) 65 | if terminals_float[i] == 1.0 and i + 1 < len(dataset["observations"]): 66 | trajs.append([]) 67 | 68 | def compute_returns(traj): 69 | episode_return = 0 70 | for _, _, rew, _, _, _ in traj: 71 | episode_return += rew 72 | 73 | return episode_return 74 | 75 | trajs.sort(key=compute_returns) 76 | 77 | # normalize rewards 78 | dataset["rewards"] /= compute_returns(trajs[-1]) - compute_returns(trajs[0]) 79 | dataset["rewards"] *= 1000.0 80 | 81 | return dataset 82 | 83 | 84 | def train(args=get_args()): 85 | # create env and dataset 86 | env = gym.make(args.task) 87 | dataset = d4rl.qlearning_dataset(env) 88 | if 'antmaze' in args.task: 89 | dataset["rewards"] -= 1.0 90 | if ("halfcheetah" in args.task or "walker2d" in args.task or "hopper" in args.task): 91 | dataset = normalize_rewards(dataset) 92 | args.obs_shape = env.observation_space.shape 93 | args.action_dim = np.prod(env.action_space.shape) 94 | args.max_action = env.action_space.high[0] 95 | 96 | # seed 97 | random.seed(args.seed) 98 | np.random.seed(args.seed) 99 | torch.manual_seed(args.seed) 100 | torch.cuda.manual_seed_all(args.seed) 101 | torch.backends.cudnn.deterministic = True 102 | env.seed(args.seed) 103 | 104 | # create policy model 105 | actor_backbone = MLP(input_dim=np.prod(args.obs_shape), hidden_dims=args.hidden_dims, dropout_rate=args.dropout_rate) 106 | critic_q1_backbone = MLP(input_dim=np.prod(args.obs_shape)+args.action_dim, hidden_dims=args.hidden_dims) 107 | critic_q2_backbone = MLP(input_dim=np.prod(args.obs_shape)+args.action_dim, hidden_dims=args.hidden_dims) 108 | critic_v_backbone = MLP(input_dim=np.prod(args.obs_shape), hidden_dims=args.hidden_dims) 109 | dist = DiagGaussian( 110 | latent_dim=getattr(actor_backbone, "output_dim"), 111 | output_dim=args.action_dim, 112 | unbounded=False, 113 | conditioned_sigma=False, 114 | max_mu=args.max_action 115 | ) 116 | actor = ActorProb(actor_backbone, dist, args.device) 117 | critic_q1 = Critic(critic_q1_backbone, args.device) 118 | critic_q2 = Critic(critic_q2_backbone, args.device) 119 | critic_v = Critic(critic_v_backbone, args.device) 120 | 121 | for m in list(actor.modules()) + list(critic_q1.modules()) + list(critic_q2.modules()) + list(critic_v.modules()): 122 | if isinstance(m, torch.nn.Linear): 123 | # orthogonal initialization 124 | torch.nn.init.orthogonal_(m.weight, gain=np.sqrt(2)) 125 | torch.nn.init.zeros_(m.bias) 126 | 127 | actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr) 128 | critic_q1_optim = torch.optim.Adam(critic_q1.parameters(), lr=args.critic_q_lr) 129 | critic_q2_optim = torch.optim.Adam(critic_q2.parameters(), lr=args.critic_q_lr) 130 | critic_v_optim = torch.optim.Adam(critic_v.parameters(), lr=args.critic_v_lr) 131 | 132 | if args.lr_decay: 133 | lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(actor_optim, args.epoch) 134 | else: 135 | lr_scheduler = None 136 | 137 | # create IQL policy 138 | policy = IQLPolicy( 139 | actor, 140 | critic_q1, 141 | critic_q2, 142 | critic_v, 143 | actor_optim, 144 | critic_q1_optim, 145 | critic_q2_optim, 146 | critic_v_optim, 147 | action_space=env.action_space, 148 | tau=args.tau, 149 | gamma=args.gamma, 150 | expectile=args.expectile, 151 | temperature=args.temperature 152 | ) 153 | 154 | # create buffer 155 | buffer = ReplayBuffer( 156 | buffer_size=len(dataset["observations"]), 157 | obs_shape=args.obs_shape, 158 | obs_dtype=np.float32, 159 | action_dim=args.action_dim, 160 | action_dtype=np.float32, 161 | device=args.device 162 | ) 163 | buffer.load_dataset(dataset) 164 | 165 | # log 166 | log_dirs = make_log_dirs(args.task, args.algo_name, args.seed, vars(args)) 167 | # key: output file name, value: output handler type 168 | output_config = { 169 | "consoleout_backup": "stdout", 170 | "policy_training_progress": "csv", 171 | "tb": "tensorboard" 172 | } 173 | logger = Logger(log_dirs, output_config) 174 | logger.log_hyperparameters(vars(args)) 175 | 176 | # create policy trainer 177 | policy_trainer = MFPolicyTrainer( 178 | policy=policy, 179 | eval_env=env, 180 | buffer=buffer, 181 | logger=logger, 182 | epoch=args.epoch, 183 | step_per_epoch=args.step_per_epoch, 184 | batch_size=args.batch_size, 185 | eval_episodes=args.eval_episodes, 186 | lr_scheduler=lr_scheduler 187 | ) 188 | 189 | # train 190 | policy_trainer.train() 191 | 192 | 193 | if __name__ == "__main__": 194 | train() -------------------------------------------------------------------------------- /offlinerlkit/policy/model_based/mobile.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import gym 5 | 6 | from torch.nn import functional as F 7 | from typing import Dict, Union, Tuple 8 | from copy import deepcopy 9 | from collections import defaultdict 10 | from offlinerlkit.policy import BasePolicy 11 | from offlinerlkit.dynamics import BaseDynamics 12 | 13 | 14 | class MOBILEPolicy(BasePolicy): 15 | """ 16 | Model-Bellman Inconsistancy Penalized Offline Reinforcement Learning 17 | """ 18 | 19 | def __init__( 20 | self, 21 | dynamics: BaseDynamics, 22 | actor: nn.Module, 23 | critics: nn.ModuleList, 24 | actor_optim: torch.optim.Optimizer, 25 | critics_optim: torch.optim.Optimizer, 26 | tau: float = 0.005, 27 | gamma: float = 0.99, 28 | alpha: Union[float, Tuple[float, torch.Tensor, torch.optim.Optimizer]] = 0.2, 29 | penalty_coef: float = 1.0, 30 | num_samples: int = 10, 31 | deterministic_backup: bool = False, 32 | max_q_backup: bool = False 33 | ) -> None: 34 | 35 | super().__init__() 36 | self.dynamics = dynamics 37 | self.actor = actor 38 | self.critics = critics 39 | self.critics_old = deepcopy(critics) 40 | self.critics_old.eval() 41 | 42 | self.actor_optim = actor_optim 43 | self.critics_optim = critics_optim 44 | 45 | self._tau = tau 46 | self._gamma = gamma 47 | 48 | self._is_auto_alpha = False 49 | if isinstance(alpha, tuple): 50 | self._is_auto_alpha = True 51 | self._target_entropy, self._log_alpha, self.alpha_optim = alpha 52 | self._alpha = self._log_alpha.detach().exp() 53 | else: 54 | self._alpha = alpha 55 | 56 | self._penalty_coef = penalty_coef 57 | self._num_samples = num_samples 58 | self._deteterministic_backup = deterministic_backup 59 | self._max_q_backup = max_q_backup 60 | 61 | def train(self) -> None: 62 | self.actor.train() 63 | self.critics.train() 64 | 65 | def eval(self) -> None: 66 | self.actor.eval() 67 | self.critics.eval() 68 | 69 | def _sync_weight(self) -> None: 70 | for o, n in zip(self.critics_old.parameters(), self.critics.parameters()): 71 | o.data.copy_(o.data * (1.0 - self._tau) + n.data * self._tau) 72 | 73 | def actforward( 74 | self, 75 | obs: torch.Tensor, 76 | deterministic: bool = False 77 | ) -> Tuple[torch.Tensor, torch.Tensor]: 78 | dist = self.actor(obs) 79 | if deterministic: 80 | squashed_action, raw_action = dist.mode() 81 | else: 82 | squashed_action, raw_action = dist.rsample() 83 | log_prob = dist.log_prob(squashed_action, raw_action) 84 | return squashed_action, log_prob 85 | 86 | def select_action( 87 | self, 88 | obs: np.ndarray, 89 | deterministic: bool = False 90 | ) -> np.ndarray: 91 | with torch.no_grad(): 92 | action, _ = self.actforward(obs, deterministic) 93 | return action.cpu().numpy() 94 | 95 | def rollout( 96 | self, 97 | init_obss: np.ndarray, 98 | rollout_length: int 99 | ) -> Tuple[Dict[str, np.ndarray], Dict]: 100 | 101 | num_transitions = 0 102 | rewards_arr = np.array([]) 103 | rollout_transitions = defaultdict(list) 104 | 105 | # rollout 106 | observations = init_obss 107 | for _ in range(rollout_length): 108 | actions = self.select_action(observations) 109 | next_observations, rewards, terminals, info = self.dynamics.step(observations, actions) 110 | 111 | rollout_transitions["obss"].append(observations) 112 | rollout_transitions["next_obss"].append(next_observations) 113 | rollout_transitions["actions"].append(actions) 114 | rollout_transitions["rewards"].append(rewards) 115 | rollout_transitions["terminals"].append(terminals) 116 | 117 | num_transitions += len(observations) 118 | rewards_arr = np.append(rewards_arr, rewards.flatten()) 119 | 120 | nonterm_mask = (~terminals).flatten() 121 | if nonterm_mask.sum() == 0: 122 | break 123 | 124 | observations = next_observations[nonterm_mask] 125 | 126 | for k, v in rollout_transitions.items(): 127 | rollout_transitions[k] = np.concatenate(v, axis=0) 128 | 129 | return rollout_transitions, \ 130 | {"num_transitions": num_transitions, "reward_mean": rewards_arr.mean()} 131 | 132 | @ torch.no_grad() 133 | def compute_lcb(self, obss: torch.Tensor, actions: torch.Tensor): 134 | # compute next q std 135 | pred_next_obss = self.dynamics.sample_next_obss(obss, actions, self._num_samples) 136 | num_samples, num_ensembles, batch_size, obs_dim = pred_next_obss.shape 137 | pred_next_obss = pred_next_obss.reshape(-1, obs_dim) 138 | pred_next_actions, _ = self.actforward(pred_next_obss) 139 | 140 | pred_next_qs = torch.cat([critic_old(pred_next_obss, pred_next_actions) for critic_old in self.critics_old], 1) 141 | pred_next_qs = torch.min(pred_next_qs, 1)[0].reshape(num_samples, num_ensembles, batch_size, 1) 142 | penalty = pred_next_qs.mean(0).std(0) 143 | 144 | return penalty 145 | 146 | def learn(self, batch: Dict) -> Dict[str, float]: 147 | real_batch, fake_batch = batch["real"], batch["fake"] 148 | mix_batch = {k: torch.cat([real_batch[k], fake_batch[k]], 0) for k in real_batch.keys()} 149 | 150 | obss, actions, next_obss, rewards, terminals = mix_batch["observations"], mix_batch["actions"], mix_batch["next_observations"], mix_batch["rewards"], mix_batch["terminals"] 151 | batch_size = obss.shape[0] 152 | 153 | # update critic 154 | qs = torch.stack([critic(obss, actions) for critic in self.critics], 0) 155 | with torch.no_grad(): 156 | penalty = self.compute_lcb(obss, actions) 157 | penalty[:len(real_batch["rewards"])] = 0.0 158 | 159 | if self._max_q_backup: 160 | tmp_next_obss = next_obss.unsqueeze(1) \ 161 | .repeat(1, 10, 1) \ 162 | .view(batch_size * 10, next_obss.shape[-1]) 163 | tmp_next_actions, _ = self.actforward(tmp_next_obss) 164 | tmp_next_qs = torch.cat([critic_old(tmp_next_obss, tmp_next_actions) for critic_old in self.critics_old], 1) 165 | tmp_next_qs = tmp_next_qs.view(batch_size, 10, len(self.critics_old)).max(1)[0].view(-1, len(self.critics_old)) 166 | next_q = torch.min(tmp_next_qs, 1)[0].reshape(-1, 1) 167 | else: 168 | next_actions, next_log_probs = self.actforward(next_obss) 169 | next_qs = torch.cat([critic_old(next_obss, next_actions) for critic_old in self.critics_old], 1) 170 | next_q = torch.min(next_qs, 1)[0].reshape(-1, 1) 171 | if not self._deteterministic_backup: 172 | next_q -= self._alpha * next_log_probs 173 | target_q = (rewards - self._penalty_coef * penalty) + self._gamma * (1 - terminals) * next_q 174 | target_q = torch.clamp(target_q, 0, None) 175 | 176 | critic_loss = ((qs - target_q) ** 2).mean() 177 | self.critics_optim.zero_grad() 178 | critic_loss.backward() 179 | self.critics_optim.step() 180 | 181 | # update actor 182 | a, log_probs = self.actforward(obss) 183 | qas = torch.cat([critic(obss, a) for critic in self.critics], 1) 184 | actor_loss = -torch.min(qas, 1)[0].mean() + self._alpha * log_probs.mean() 185 | self.actor_optim.zero_grad() 186 | actor_loss.backward() 187 | self.actor_optim.step() 188 | 189 | if self._is_auto_alpha: 190 | log_probs = log_probs.detach() + self._target_entropy 191 | alpha_loss = -(self._log_alpha * log_probs).mean() 192 | self.alpha_optim.zero_grad() 193 | alpha_loss.backward() 194 | self.alpha_optim.step() 195 | self._alpha = torch.clamp(self._log_alpha.detach().exp(), 0.0, 1.0) 196 | 197 | self._sync_weight() 198 | 199 | result = { 200 | "loss/actor": actor_loss.item(), 201 | "loss/critic": critic_loss.item() 202 | } 203 | 204 | if self._is_auto_alpha: 205 | result["loss/alpha"] = alpha_loss.item() 206 | result["alpha"] = self._alpha.item() 207 | 208 | return result -------------------------------------------------------------------------------- /offlinerlkit/policy/model_free/cql.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import gym 5 | 6 | from torch.nn import functional as F 7 | from typing import Dict, Union, Tuple 8 | from offlinerlkit.policy import SACPolicy 9 | 10 | 11 | class CQLPolicy(SACPolicy): 12 | """ 13 | Conservative Q-Learning 14 | """ 15 | 16 | def __init__( 17 | self, 18 | actor: nn.Module, 19 | critic1: nn.Module, 20 | critic2: nn.Module, 21 | actor_optim: torch.optim.Optimizer, 22 | critic1_optim: torch.optim.Optimizer, 23 | critic2_optim: torch.optim.Optimizer, 24 | action_space: gym.spaces.Space, 25 | tau: float = 0.005, 26 | gamma: float = 0.99, 27 | alpha: Union[float, Tuple[float, torch.Tensor, torch.optim.Optimizer]] = 0.2, 28 | cql_weight: float = 1.0, 29 | temperature: float = 1.0, 30 | max_q_backup: bool = False, 31 | deterministic_backup: bool = True, 32 | with_lagrange: bool = True, 33 | lagrange_threshold: float = 10.0, 34 | cql_alpha_lr: float = 1e-4, 35 | num_repeart_actions:int = 10, 36 | ) -> None: 37 | super().__init__( 38 | actor, 39 | critic1, 40 | critic2, 41 | actor_optim, 42 | critic1_optim, 43 | critic2_optim, 44 | tau=tau, 45 | gamma=gamma, 46 | alpha=alpha 47 | ) 48 | 49 | self.action_space = action_space 50 | self._cql_weight = cql_weight 51 | self._temperature = temperature 52 | self._max_q_backup = max_q_backup 53 | self._deterministic_backup = deterministic_backup 54 | self._with_lagrange = with_lagrange 55 | self._lagrange_threshold = lagrange_threshold 56 | 57 | self.cql_log_alpha = torch.zeros(1, requires_grad=True, device=self.actor.device) 58 | self.cql_alpha_optim = torch.optim.Adam([self.cql_log_alpha], lr=cql_alpha_lr) 59 | 60 | self._num_repeat_actions = num_repeart_actions 61 | 62 | def calc_pi_values( 63 | self, 64 | obs_pi: torch.Tensor, 65 | obs_to_pred: torch.Tensor 66 | ) -> Tuple[torch.Tensor, torch.Tensor]: 67 | act, log_prob = self.actforward(obs_pi) 68 | 69 | q1 = self.critic1(obs_to_pred, act) 70 | q2 = self.critic2(obs_to_pred, act) 71 | 72 | return q1 - log_prob.detach(), q2 - log_prob.detach() 73 | 74 | def calc_random_values( 75 | self, 76 | obs: torch.Tensor, 77 | random_act: torch.Tensor 78 | ) -> Tuple[torch.Tensor, torch.Tensor]: 79 | q1 = self.critic1(obs, random_act) 80 | q2 = self.critic2(obs, random_act) 81 | 82 | log_prob1 = np.log(0.5**random_act.shape[-1]) 83 | log_prob2 = np.log(0.5**random_act.shape[-1]) 84 | 85 | return q1 - log_prob1, q2 - log_prob2 86 | 87 | def learn(self, batch: Dict) -> Dict[str, float]: 88 | obss, actions, next_obss, rewards, terminals = batch["observations"], batch["actions"], \ 89 | batch["next_observations"], batch["rewards"], batch["terminals"] 90 | batch_size = obss.shape[0] 91 | 92 | # update actor 93 | a, log_probs = self.actforward(obss) 94 | q1a, q2a = self.critic1(obss, a), self.critic2(obss, a) 95 | actor_loss = (self._alpha * log_probs - torch.min(q1a, q2a)).mean() 96 | self.actor_optim.zero_grad() 97 | actor_loss.backward() 98 | self.actor_optim.step() 99 | 100 | if self._is_auto_alpha: 101 | log_probs = log_probs.detach() + self._target_entropy 102 | alpha_loss = -(self._log_alpha * log_probs).mean() 103 | self.alpha_optim.zero_grad() 104 | alpha_loss.backward() 105 | self.alpha_optim.step() 106 | self._alpha = self._log_alpha.detach().exp() 107 | 108 | # compute td error 109 | if self._max_q_backup: 110 | with torch.no_grad(): 111 | tmp_next_obss = next_obss.unsqueeze(1) \ 112 | .repeat(1, self._num_repeat_actions, 1) \ 113 | .view(batch_size * self._num_repeat_actions, next_obss.shape[-1]) 114 | tmp_next_actions, _ = self.actforward(tmp_next_obss) 115 | tmp_next_q1 = self.critic1_old(tmp_next_obss, tmp_next_actions) \ 116 | .view(batch_size, self._num_repeat_actions, 1) \ 117 | .max(1)[0].view(-1, 1) 118 | tmp_next_q2 = self.critic2_old(tmp_next_obss, tmp_next_actions) \ 119 | .view(batch_size, self._num_repeat_actions, 1) \ 120 | .max(1)[0].view(-1, 1) 121 | next_q = torch.min(tmp_next_q1, tmp_next_q2) 122 | else: 123 | with torch.no_grad(): 124 | next_actions, next_log_probs = self.actforward(next_obss) 125 | next_q = torch.min( 126 | self.critic1_old(next_obss, next_actions), 127 | self.critic2_old(next_obss, next_actions) 128 | ) 129 | if not self._deterministic_backup: 130 | next_q -= self._alpha * next_log_probs 131 | 132 | target_q = rewards + self._gamma * (1 - terminals) * next_q 133 | q1, q2 = self.critic1(obss, actions), self.critic2(obss, actions) 134 | critic1_loss = ((q1 - target_q).pow(2)).mean() 135 | critic2_loss = ((q2 - target_q).pow(2)).mean() 136 | 137 | # compute conservative loss 138 | random_actions = torch.FloatTensor( 139 | batch_size * self._num_repeat_actions, actions.shape[-1] 140 | ).uniform_(self.action_space.low[0], self.action_space.high[0]).to(self.actor.device) 141 | # tmp_obss & tmp_next_obss: (batch_size * num_repeat, obs_dim) 142 | tmp_obss = obss.unsqueeze(1) \ 143 | .repeat(1, self._num_repeat_actions, 1) \ 144 | .view(batch_size * self._num_repeat_actions, obss.shape[-1]) 145 | tmp_next_obss = next_obss.unsqueeze(1) \ 146 | .repeat(1, self._num_repeat_actions, 1) \ 147 | .view(batch_size * self._num_repeat_actions, obss.shape[-1]) 148 | 149 | obs_pi_value1, obs_pi_value2 = self.calc_pi_values(tmp_obss, tmp_obss) 150 | next_obs_pi_value1, next_obs_pi_value2 = self.calc_pi_values(tmp_next_obss, tmp_obss) 151 | random_value1, random_value2 = self.calc_random_values(tmp_obss, random_actions) 152 | 153 | for value in [ 154 | obs_pi_value1, obs_pi_value2, next_obs_pi_value1, next_obs_pi_value2, 155 | random_value1, random_value2 156 | ]: 157 | value.reshape(batch_size, self._num_repeat_actions, 1) 158 | 159 | # cat_q shape: (batch_size, 3 * num_repeat, 1) 160 | cat_q1 = torch.cat([obs_pi_value1, next_obs_pi_value1, random_value1], 1) 161 | cat_q2 = torch.cat([obs_pi_value2, next_obs_pi_value2, random_value2], 1) 162 | 163 | conservative_loss1 = \ 164 | torch.logsumexp(cat_q1 / self._temperature, dim=1).mean() * self._cql_weight * self._temperature - \ 165 | q1.mean() * self._cql_weight 166 | conservative_loss2 = \ 167 | torch.logsumexp(cat_q2 / self._temperature, dim=1).mean() * self._cql_weight * self._temperature - \ 168 | q2.mean() * self._cql_weight 169 | 170 | if self._with_lagrange: 171 | cql_alpha = torch.clamp(self.cql_log_alpha.exp(), 0.0, 1e6) 172 | conservative_loss1 = cql_alpha * (conservative_loss1 - self._lagrange_threshold) 173 | conservative_loss2 = cql_alpha * (conservative_loss2 - self._lagrange_threshold) 174 | 175 | self.cql_alpha_optim.zero_grad() 176 | cql_alpha_loss = -(conservative_loss1 + conservative_loss2) * 0.5 177 | cql_alpha_loss.backward(retain_graph=True) 178 | self.cql_alpha_optim.step() 179 | 180 | critic1_loss = critic1_loss + conservative_loss1 181 | critic2_loss = critic2_loss + conservative_loss2 182 | 183 | # update critic 184 | self.critic1_optim.zero_grad() 185 | critic1_loss.backward(retain_graph=True) 186 | self.critic1_optim.step() 187 | 188 | self.critic2_optim.zero_grad() 189 | critic2_loss.backward() 190 | self.critic2_optim.step() 191 | 192 | self._sync_weight() 193 | 194 | result = { 195 | "loss/actor": actor_loss.item(), 196 | "loss/critic1": critic1_loss.item(), 197 | "loss/critic2": critic2_loss.item() 198 | } 199 | 200 | if self._is_auto_alpha: 201 | result["loss/alpha"] = alpha_loss.item() 202 | result["alpha"] = self._alpha.item() 203 | if self._with_lagrange: 204 | result["loss/cql_alpha"] = cql_alpha_loss.item() 205 | result["cql_alpha"] = cql_alpha.item() 206 | 207 | return result 208 | 209 | -------------------------------------------------------------------------------- /tune_example/tune_mopo.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import sys 4 | import random 5 | 6 | import gym 7 | 8 | import numpy as np 9 | import torch 10 | import ray 11 | from ray import tune 12 | 13 | 14 | from offlinerlkit.nets import MLP 15 | from offlinerlkit.modules import ActorProb, Critic, TanhDiagGaussian, EnsembleDynamicsModel 16 | from offlinerlkit.dynamics import EnsembleDynamics 17 | from offlinerlkit.utils.scaler import StandardScaler 18 | from offlinerlkit.utils.termination_fns import get_termination_fn 19 | from offlinerlkit.utils.load_dataset import qlearning_dataset 20 | from offlinerlkit.buffer import ReplayBuffer 21 | from offlinerlkit.utils.logger import Logger, make_log_dirs 22 | from offlinerlkit.policy_trainer import MBPolicyTrainer 23 | from offlinerlkit.policy import MOPOPolicy 24 | 25 | 26 | def get_args(): 27 | parser = argparse.ArgumentParser() 28 | parser.add_argument("--algo-name", type=str, default="mopo") 29 | parser.add_argument("--task", type=str, default="hopper-medium-replay-v2") 30 | parser.add_argument("--seed", type=int, default=0) 31 | parser.add_argument("--actor-lr", type=float, default=3e-4) 32 | parser.add_argument("--critic-lr", type=float, default=3e-4) 33 | parser.add_argument("--hidden-dims", type=int, nargs='*', default=[256, 256]) 34 | parser.add_argument("--gamma", type=float, default=0.99) 35 | parser.add_argument("--tau", type=float, default=0.005) 36 | parser.add_argument("--alpha", type=float, default=0.2) 37 | parser.add_argument("--auto-alpha", default=True) 38 | parser.add_argument("--target-entropy", type=int, default=-3) 39 | parser.add_argument("--alpha-lr", type=float, default=3e-4) 40 | 41 | parser.add_argument("--dynamics-lr", type=float, default=1e-3) 42 | parser.add_argument("--dynamics-hidden-dims", type=int, nargs='*', default=[200, 200, 200, 200]) 43 | parser.add_argument("--dynamics-weight-decay", type=float, nargs='*', default=[2.5e-5, 5e-5, 7.5e-5, 7.5e-5, 1e-4]) 44 | parser.add_argument("--n-ensemble", type=int, default=7) 45 | parser.add_argument("--n-elites", type=int, default=5) 46 | parser.add_argument("--rollout-freq", type=int, default=1000) 47 | parser.add_argument("--rollout-batch-size", type=int, default=50000) 48 | parser.add_argument("--rollout-length", type=int, default=5) 49 | parser.add_argument("--penalty-coef", type=float, default=0.0) 50 | parser.add_argument("--model-retain-epochs", type=int, default=5) 51 | parser.add_argument("--real-ratio", type=float, default=0.05) 52 | parser.add_argument("--load-dynamics-path", type=str, default=None) 53 | 54 | parser.add_argument("--epoch", type=int, default=1000) 55 | parser.add_argument("--step-per-epoch", type=int, default=1000) 56 | parser.add_argument("--eval_episodes", type=int, default=10) 57 | parser.add_argument("--batch-size", type=int, default=256) 58 | parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu") 59 | 60 | return parser.parse_args() 61 | 62 | 63 | def run_exp(config): 64 | import d4rl 65 | # set config 66 | global args 67 | args_for_exp = vars(args) 68 | for k, v in config.items(): 69 | args_for_exp[k] = v 70 | args_for_exp = argparse.Namespace(**args_for_exp) 71 | print(args_for_exp.task) 72 | 73 | # create env and dataset 74 | env = gym.make(args_for_exp.task) 75 | dataset = qlearning_dataset(env) 76 | args_for_exp.obs_shape = env.observation_space.shape 77 | args_for_exp.action_dim = np.prod(env.action_space.shape) 78 | args_for_exp.max_action = env.action_space.high[0] 79 | 80 | # seed 81 | random.seed(args_for_exp.seed) 82 | np.random.seed(args_for_exp.seed) 83 | torch.manual_seed(args_for_exp.seed) 84 | torch.cuda.manual_seed_all(args_for_exp.seed) 85 | env.seed(args_for_exp.seed) 86 | 87 | # create policy model 88 | actor_backbone = MLP(input_dim=np.prod(args_for_exp.obs_shape), hidden_dims=args_for_exp.hidden_dims) 89 | critic1_backbone = MLP(input_dim=np.prod(args_for_exp.obs_shape) + args_for_exp.action_dim, hidden_dims=args_for_exp.hidden_dims) 90 | critic2_backbone = MLP(input_dim=np.prod(args_for_exp.obs_shape) + args_for_exp.action_dim, hidden_dims=args_for_exp.hidden_dims) 91 | dist = TanhDiagGaussian( 92 | latent_dim=getattr(actor_backbone, "output_dim"), 93 | output_dim=args_for_exp.action_dim, 94 | unbounded=True, 95 | conditioned_sigma=True 96 | ) 97 | actor = ActorProb(actor_backbone, dist, args_for_exp.device) 98 | critic1 = Critic(critic1_backbone, args_for_exp.device) 99 | critic2 = Critic(critic2_backbone, args_for_exp.device) 100 | actor_optim = torch.optim.Adam(actor.parameters(), lr=args_for_exp.actor_lr) 101 | critic1_optim = torch.optim.Adam(critic1.parameters(), lr=args_for_exp.critic_lr) 102 | critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args_for_exp.critic_lr) 103 | 104 | if args_for_exp.auto_alpha: 105 | target_entropy = args_for_exp.target_entropy if args_for_exp.target_entropy \ 106 | else -np.prod(env.action_space.shape) 107 | 108 | args_for_exp.target_entropy = target_entropy 109 | 110 | log_alpha = torch.zeros(1, requires_grad=True, device=args_for_exp.device) 111 | alpha_optim = torch.optim.Adam([log_alpha], lr=args_for_exp.alpha_lr) 112 | alpha = (target_entropy, log_alpha, alpha_optim) 113 | else: 114 | alpha = args_for_exp.alpha 115 | 116 | # create dynamics 117 | load_dynamics_model = True if args_for_exp.load_dynamics_path else False 118 | dynamics_model = EnsembleDynamicsModel( 119 | obs_dim=np.prod(args_for_exp.obs_shape), 120 | action_dim=args_for_exp.action_dim, 121 | hidden_dims=args_for_exp.dynamics_hidden_dims, 122 | num_ensemble=args_for_exp.n_ensemble, 123 | num_elites=args_for_exp.n_elites, 124 | weight_decays=args_for_exp.dynamics_weight_decay, 125 | device=args_for_exp.device 126 | ) 127 | dynamics_optim = torch.optim.Adam( 128 | dynamics_model.parameters(), 129 | lr=args_for_exp.dynamics_lr 130 | ) 131 | scaler = StandardScaler() 132 | termination_fn = get_termination_fn(task=args_for_exp.task) 133 | dynamics = EnsembleDynamics( 134 | dynamics_model, 135 | dynamics_optim, 136 | scaler, 137 | termination_fn, 138 | penalty_coef=args_for_exp.penalty_coef 139 | ) 140 | 141 | if args_for_exp.load_dynamics_path: 142 | dynamics.load(args_for_exp.load_dynamics_path) 143 | 144 | # create policy 145 | policy = MOPOPolicy( 146 | dynamics, 147 | actor, 148 | critic1, 149 | critic2, 150 | actor_optim, 151 | critic1_optim, 152 | critic2_optim, 153 | tau=args_for_exp.tau, 154 | gamma=args_for_exp.gamma, 155 | alpha=alpha 156 | ) 157 | 158 | # create buffer 159 | real_buffer = ReplayBuffer( 160 | buffer_size=len(dataset["observations"]), 161 | obs_shape=args_for_exp.obs_shape, 162 | obs_dtype=np.float32, 163 | action_dim=args_for_exp.action_dim, 164 | action_dtype=np.float32, 165 | device=args_for_exp.device 166 | ) 167 | real_buffer.load_dataset(dataset) 168 | fake_buffer = ReplayBuffer( 169 | buffer_size=args_for_exp.rollout_batch_size*args_for_exp.rollout_length*args_for_exp.model_retain_epochs, 170 | obs_shape=args_for_exp.obs_shape, 171 | obs_dtype=np.float32, 172 | action_dim=args_for_exp.action_dim, 173 | action_dtype=np.float32, 174 | device=args_for_exp.device 175 | ) 176 | 177 | # log 178 | record_params = list(config.keys()) 179 | if "seed" in record_params: 180 | record_params.remove("seed") 181 | log_dirs = make_log_dirs( 182 | args_for_exp.task, 183 | args_for_exp.algo_name, 184 | args_for_exp.seed, 185 | vars(args_for_exp), 186 | record_params=record_params 187 | ) 188 | # key: output file name, value: output handler type 189 | output_config = { 190 | "consoleout_backup": "stdout", 191 | "policy_training_progress": "csv", 192 | "dynamics_training_progress": "csv", 193 | "tb": "tensorboard" 194 | } 195 | logger = Logger(log_dirs, output_config) 196 | logger.log_hyperparameters(vars(args_for_exp)) 197 | 198 | # create policy trainer 199 | policy_trainer = MBPolicyTrainer( 200 | policy=policy, 201 | eval_env=env, 202 | real_buffer=real_buffer, 203 | fake_buffer=fake_buffer, 204 | logger=logger, 205 | rollout_setting=(args_for_exp.rollout_freq, args_for_exp.rollout_batch_size, args_for_exp.rollout_length), 206 | epoch=args_for_exp.epoch, 207 | step_per_epoch=args_for_exp.step_per_epoch, 208 | batch_size=args_for_exp.batch_size, 209 | real_ratio=args_for_exp.real_ratio, 210 | eval_episodes=args_for_exp.eval_episodes 211 | ) 212 | 213 | # train 214 | if not load_dynamics_model: 215 | dynamics.train(real_buffer.sample_all(), logger) 216 | 217 | result = policy_trainer.train() 218 | tune.report(**result) 219 | 220 | 221 | if __name__ == "__main__": 222 | ray.init() 223 | # load default args 224 | args = get_args() 225 | 226 | config = {} 227 | real_ratios = [0.05, 0.5] 228 | seeds = list(range(2)) 229 | config["real_ratio"] = tune.grid_search(real_ratios) 230 | config["seed"] = tune.grid_search(seeds) 231 | 232 | analysis = tune.run( 233 | run_exp, 234 | name="tune_mopo", 235 | config=config, 236 | resources_per_trial={ 237 | "gpu": 0.5 238 | } 239 | ) -------------------------------------------------------------------------------- /run_example/run_mopo.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import sys 4 | import random 5 | 6 | import gym 7 | import d4rl 8 | 9 | import numpy as np 10 | import torch 11 | 12 | 13 | from offlinerlkit.nets import MLP 14 | from offlinerlkit.modules import ActorProb, Critic, TanhDiagGaussian, EnsembleDynamicsModel 15 | from offlinerlkit.dynamics import EnsembleDynamics 16 | from offlinerlkit.utils.scaler import StandardScaler 17 | from offlinerlkit.utils.termination_fns import get_termination_fn 18 | from offlinerlkit.utils.load_dataset import qlearning_dataset 19 | from offlinerlkit.buffer import ReplayBuffer 20 | from offlinerlkit.utils.logger import Logger, make_log_dirs 21 | from offlinerlkit.policy_trainer import MBPolicyTrainer 22 | from offlinerlkit.policy import MOPOPolicy 23 | 24 | 25 | """ 26 | suggested hypers 27 | 28 | halfcheetah-medium-v2: rollout-length=5, penalty-coef=0.5 29 | hopper-medium-v2: rollout-length=5, penalty-coef=5.0 30 | walker2d-medium-v2: rollout-length=5, penalty-coef=0.5 31 | halfcheetah-medium-replay-v2: rollout-length=5, penalty-coef=0.5 32 | hopper-medium-replay-v2: rollout-length=5, penalty-coef=2.5 33 | walker2d-medium-replay-v2: rollout-length=1, penalty-coef=2.5 34 | halfcheetah-medium-expert-v2: rollout-length=5, penalty-coef=2.5 35 | hopper-medium-expert-v2: rollout-length=5, penalty-coef=5.0 36 | walker2d-medium-expert-v2: rollout-length=1, penalty-coef=2.5 37 | """ 38 | 39 | 40 | def get_args(): 41 | parser = argparse.ArgumentParser() 42 | parser.add_argument("--algo-name", type=str, default="mopo") 43 | parser.add_argument("--task", type=str, default="walker2d-medium-expert-v2") 44 | parser.add_argument("--seed", type=int, default=1) 45 | parser.add_argument("--actor-lr", type=float, default=1e-4) 46 | parser.add_argument("--critic-lr", type=float, default=3e-4) 47 | parser.add_argument("--hidden-dims", type=int, nargs='*', default=[256, 256]) 48 | parser.add_argument("--gamma", type=float, default=0.99) 49 | parser.add_argument("--tau", type=float, default=0.005) 50 | parser.add_argument("--alpha", type=float, default=0.2) 51 | parser.add_argument("--auto-alpha", default=True) 52 | parser.add_argument("--target-entropy", type=int, default=None) 53 | parser.add_argument("--alpha-lr", type=float, default=1e-4) 54 | 55 | parser.add_argument("--dynamics-lr", type=float, default=1e-3) 56 | parser.add_argument("--dynamics-hidden-dims", type=int, nargs='*', default=[200, 200, 200, 200]) 57 | parser.add_argument("--dynamics-weight-decay", type=float, nargs='*', default=[2.5e-5, 5e-5, 7.5e-5, 7.5e-5, 1e-4]) 58 | parser.add_argument("--n-ensemble", type=int, default=7) 59 | parser.add_argument("--n-elites", type=int, default=5) 60 | parser.add_argument("--rollout-freq", type=int, default=1000) 61 | parser.add_argument("--rollout-batch-size", type=int, default=50000) 62 | parser.add_argument("--rollout-length", type=int, default=1) 63 | parser.add_argument("--penalty-coef", type=float, default=2.5) 64 | parser.add_argument("--model-retain-epochs", type=int, default=5) 65 | parser.add_argument("--real-ratio", type=float, default=0.05) 66 | parser.add_argument("--load-dynamics-path", type=str, default=None) 67 | 68 | parser.add_argument("--epoch", type=int, default=3000) 69 | parser.add_argument("--step-per-epoch", type=int, default=1000) 70 | parser.add_argument("--eval_episodes", type=int, default=10) 71 | parser.add_argument("--batch-size", type=int, default=256) 72 | parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu") 73 | 74 | return parser.parse_args() 75 | 76 | 77 | def train(args=get_args()): 78 | # create env and dataset 79 | env = gym.make(args.task) 80 | """ 81 | Here we use our own implementation of qlearning_dataset for mbrl algos. 82 | This is because for the d4rl.qlearning_dataset, it will take the obs[i+1] as the next obs, 83 | which though has no effect for q learning but leads bug for dynamics learning. 84 | However, I can only ensure our new implementation works well on Mujoco. I don't test it on other tasks like Antmaze. 85 | Therefore, I suggest you to use the original impl if you run those tasks. 86 | """ 87 | if 'hopper' in args.task or 'halfcheetah' in args.task or 'walker2d' in args.task: 88 | dataset = qlearning_dataset(env) 89 | else: 90 | dataset = d4rl.qlearning_dataset(env) 91 | args.obs_shape = env.observation_space.shape 92 | args.action_dim = np.prod(env.action_space.shape) 93 | args.max_action = env.action_space.high[0] 94 | 95 | # seed 96 | random.seed(args.seed) 97 | np.random.seed(args.seed) 98 | torch.manual_seed(args.seed) 99 | torch.cuda.manual_seed_all(args.seed) 100 | torch.backends.cudnn.deterministic = True 101 | env.seed(args.seed) 102 | 103 | # create policy model 104 | actor_backbone = MLP(input_dim=np.prod(args.obs_shape), hidden_dims=args.hidden_dims) 105 | critic1_backbone = MLP(input_dim=np.prod(args.obs_shape) + args.action_dim, hidden_dims=args.hidden_dims) 106 | critic2_backbone = MLP(input_dim=np.prod(args.obs_shape) + args.action_dim, hidden_dims=args.hidden_dims) 107 | dist = TanhDiagGaussian( 108 | latent_dim=getattr(actor_backbone, "output_dim"), 109 | output_dim=args.action_dim, 110 | unbounded=True, 111 | conditioned_sigma=True, 112 | max_mu=args.max_action 113 | ) 114 | actor = ActorProb(actor_backbone, dist, args.device) 115 | critic1 = Critic(critic1_backbone, args.device) 116 | critic2 = Critic(critic2_backbone, args.device) 117 | actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr) 118 | critic1_optim = torch.optim.Adam(critic1.parameters(), lr=args.critic_lr) 119 | critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr) 120 | 121 | lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(actor_optim, args.epoch) 122 | 123 | if args.auto_alpha: 124 | target_entropy = args.target_entropy if args.target_entropy \ 125 | else -np.prod(env.action_space.shape) 126 | 127 | args.target_entropy = target_entropy 128 | 129 | log_alpha = torch.zeros(1, requires_grad=True, device=args.device) 130 | alpha_optim = torch.optim.Adam([log_alpha], lr=args.alpha_lr) 131 | alpha = (target_entropy, log_alpha, alpha_optim) 132 | else: 133 | alpha = args.alpha 134 | 135 | # create dynamics 136 | load_dynamics_model = True if args.load_dynamics_path else False 137 | dynamics_model = EnsembleDynamicsModel( 138 | obs_dim=np.prod(args.obs_shape), 139 | action_dim=args.action_dim, 140 | hidden_dims=args.dynamics_hidden_dims, 141 | num_ensemble=args.n_ensemble, 142 | num_elites=args.n_elites, 143 | weight_decays=args.dynamics_weight_decay, 144 | device=args.device 145 | ) 146 | dynamics_optim = torch.optim.Adam( 147 | dynamics_model.parameters(), 148 | lr=args.dynamics_lr 149 | ) 150 | scaler = StandardScaler() 151 | termination_fn = get_termination_fn(task=args.task) 152 | dynamics = EnsembleDynamics( 153 | dynamics_model, 154 | dynamics_optim, 155 | scaler, 156 | termination_fn, 157 | penalty_coef=args.penalty_coef, 158 | ) 159 | 160 | if args.load_dynamics_path: 161 | dynamics.load(args.load_dynamics_path) 162 | 163 | # create policy 164 | policy = MOPOPolicy( 165 | dynamics, 166 | actor, 167 | critic1, 168 | critic2, 169 | actor_optim, 170 | critic1_optim, 171 | critic2_optim, 172 | tau=args.tau, 173 | gamma=args.gamma, 174 | alpha=alpha 175 | ) 176 | 177 | # create buffer 178 | real_buffer = ReplayBuffer( 179 | buffer_size=len(dataset["observations"]), 180 | obs_shape=args.obs_shape, 181 | obs_dtype=np.float32, 182 | action_dim=args.action_dim, 183 | action_dtype=np.float32, 184 | device=args.device 185 | ) 186 | real_buffer.load_dataset(dataset) 187 | fake_buffer = ReplayBuffer( 188 | buffer_size=args.rollout_batch_size*args.rollout_length*args.model_retain_epochs, 189 | obs_shape=args.obs_shape, 190 | obs_dtype=np.float32, 191 | action_dim=args.action_dim, 192 | action_dtype=np.float32, 193 | device=args.device 194 | ) 195 | 196 | # log 197 | log_dirs = make_log_dirs(args.task, args.algo_name, args.seed, vars(args), record_params=["penalty_coef", "rollout_length"]) 198 | # key: output file name, value: output handler type 199 | output_config = { 200 | "consoleout_backup": "stdout", 201 | "policy_training_progress": "csv", 202 | "dynamics_training_progress": "csv", 203 | "tb": "tensorboard" 204 | } 205 | logger = Logger(log_dirs, output_config) 206 | logger.log_hyperparameters(vars(args)) 207 | 208 | # create policy trainer 209 | policy_trainer = MBPolicyTrainer( 210 | policy=policy, 211 | eval_env=env, 212 | real_buffer=real_buffer, 213 | fake_buffer=fake_buffer, 214 | logger=logger, 215 | rollout_setting=(args.rollout_freq, args.rollout_batch_size, args.rollout_length), 216 | epoch=args.epoch, 217 | step_per_epoch=args.step_per_epoch, 218 | batch_size=args.batch_size, 219 | real_ratio=args.real_ratio, 220 | eval_episodes=args.eval_episodes, 221 | lr_scheduler=lr_scheduler 222 | ) 223 | 224 | # train 225 | if not load_dynamics_model: 226 | dynamics.train(real_buffer.sample_all(), logger, max_epochs_since_update=5) 227 | 228 | policy_trainer.train() 229 | 230 | 231 | if __name__ == "__main__": 232 | train() -------------------------------------------------------------------------------- /offlinerlkit/dynamics/ensemble_dynamics.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import torch 4 | import torch.nn as nn 5 | 6 | from typing import Callable, List, Tuple, Dict, Optional 7 | from offlinerlkit.dynamics import BaseDynamics 8 | from offlinerlkit.utils.scaler import StandardScaler 9 | from offlinerlkit.utils.logger import Logger 10 | 11 | 12 | class EnsembleDynamics(BaseDynamics): 13 | def __init__( 14 | self, 15 | model: nn.Module, 16 | optim: torch.optim.Optimizer, 17 | scaler: StandardScaler, 18 | terminal_fn: Callable[[np.ndarray, np.ndarray, np.ndarray], np.ndarray], 19 | penalty_coef: float = 0.0, 20 | uncertainty_mode: str = "aleatoric" 21 | ) -> None: 22 | super().__init__(model, optim) 23 | self.scaler = scaler 24 | self.terminal_fn = terminal_fn 25 | self._penalty_coef = penalty_coef 26 | self._uncertainty_mode = uncertainty_mode 27 | 28 | @ torch.no_grad() 29 | def step( 30 | self, 31 | obs: np.ndarray, 32 | action: np.ndarray 33 | ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, Dict]: 34 | "imagine single forward step" 35 | obs_act = np.concatenate([obs, action], axis=-1) 36 | obs_act = self.scaler.transform(obs_act) 37 | mean, logvar = self.model(obs_act) 38 | mean = mean.cpu().numpy() 39 | logvar = logvar.cpu().numpy() 40 | mean[..., :-1] += obs 41 | std = np.sqrt(np.exp(logvar)) 42 | 43 | ensemble_samples = (mean + np.random.normal(size=mean.shape) * std).astype(np.float32) 44 | 45 | # choose one model from ensemble 46 | num_models, batch_size, _ = ensemble_samples.shape 47 | model_idxs = self.model.random_elite_idxs(batch_size) 48 | samples = ensemble_samples[model_idxs, np.arange(batch_size)] 49 | 50 | next_obs = samples[..., :-1] 51 | reward = samples[..., -1:] 52 | terminal = self.terminal_fn(obs, action, next_obs) 53 | info = {} 54 | info["raw_reward"] = reward 55 | 56 | if self._penalty_coef: 57 | if self._uncertainty_mode == "aleatoric": 58 | penalty = np.amax(np.linalg.norm(std, axis=2), axis=0) 59 | elif self._uncertainty_mode == "pairwise-diff": 60 | next_obses_mean = mean[..., :-1] 61 | next_obs_mean = np.mean(next_obses_mean, axis=0) 62 | diff = next_obses_mean - next_obs_mean 63 | penalty = np.amax(np.linalg.norm(diff, axis=2), axis=0) 64 | elif self._uncertainty_mode == "ensemble_std": 65 | next_obses_mean = mean[..., :-1] 66 | penalty = np.sqrt(next_obses_mean.var(0).mean(1)) 67 | else: 68 | raise ValueError 69 | penalty = np.expand_dims(penalty, 1).astype(np.float32) 70 | assert penalty.shape == reward.shape 71 | reward = reward - self._penalty_coef * penalty 72 | info["penalty"] = penalty 73 | 74 | return next_obs, reward, terminal, info 75 | 76 | @ torch.no_grad() 77 | def sample_next_obss( 78 | self, 79 | obs: torch.Tensor, 80 | action: torch.Tensor, 81 | num_samples: int 82 | ) -> torch.Tensor: 83 | obs_act = torch.cat([obs, action], dim=-1) 84 | obs_act = self.scaler.transform_tensor(obs_act) 85 | mean, logvar = self.model(obs_act) 86 | mean[..., :-1] += obs 87 | std = torch.sqrt(torch.exp(logvar)) 88 | 89 | mean = mean[self.model.elites.data.cpu().numpy()] 90 | std = std[self.model.elites.data.cpu().numpy()] 91 | 92 | samples = torch.stack([mean + torch.randn_like(std) * std for i in range(num_samples)], 0) 93 | next_obss = samples[..., :-1] 94 | return next_obss 95 | 96 | def format_samples_for_training(self, data: Dict) -> Tuple[np.ndarray, np.ndarray]: 97 | obss = data["observations"] 98 | actions = data["actions"] 99 | next_obss = data["next_observations"] 100 | rewards = data["rewards"] 101 | delta_obss = next_obss - obss 102 | inputs = np.concatenate((obss, actions), axis=-1) 103 | targets = np.concatenate((delta_obss, rewards), axis=-1) 104 | return inputs, targets 105 | 106 | def train( 107 | self, 108 | data: Dict, 109 | logger: Logger, 110 | max_epochs: Optional[float] = None, 111 | max_epochs_since_update: int = 5, 112 | batch_size: int = 256, 113 | holdout_ratio: float = 0.2, 114 | logvar_loss_coef: float = 0.01 115 | ) -> None: 116 | inputs, targets = self.format_samples_for_training(data) 117 | data_size = inputs.shape[0] 118 | holdout_size = min(int(data_size * holdout_ratio), 1000) 119 | train_size = data_size - holdout_size 120 | train_splits, holdout_splits = torch.utils.data.random_split(range(data_size), (train_size, holdout_size)) 121 | train_inputs, train_targets = inputs[train_splits.indices], targets[train_splits.indices] 122 | holdout_inputs, holdout_targets = inputs[holdout_splits.indices], targets[holdout_splits.indices] 123 | 124 | self.scaler.fit(train_inputs) 125 | train_inputs = self.scaler.transform(train_inputs) 126 | holdout_inputs = self.scaler.transform(holdout_inputs) 127 | holdout_losses = [1e10 for i in range(self.model.num_ensemble)] 128 | 129 | data_idxes = np.random.randint(train_size, size=[self.model.num_ensemble, train_size]) 130 | def shuffle_rows(arr): 131 | idxes = np.argsort(np.random.uniform(size=arr.shape), axis=-1) 132 | return arr[np.arange(arr.shape[0])[:, None], idxes] 133 | 134 | epoch = 0 135 | cnt = 0 136 | logger.log("Training dynamics:") 137 | while True: 138 | epoch += 1 139 | train_loss = self.learn(train_inputs[data_idxes], train_targets[data_idxes], batch_size, logvar_loss_coef) 140 | new_holdout_losses = self.validate(holdout_inputs, holdout_targets) 141 | holdout_loss = (np.sort(new_holdout_losses)[:self.model.num_elites]).mean() 142 | logger.logkv("loss/dynamics_train_loss", train_loss) 143 | logger.logkv("loss/dynamics_holdout_loss", holdout_loss) 144 | logger.set_timestep(epoch) 145 | logger.dumpkvs(exclude=["policy_training_progress"]) 146 | 147 | # shuffle data for each base learner 148 | data_idxes = shuffle_rows(data_idxes) 149 | 150 | indexes = [] 151 | for i, new_loss, old_loss in zip(range(len(holdout_losses)), new_holdout_losses, holdout_losses): 152 | improvement = (old_loss - new_loss) / old_loss 153 | if improvement > 0.01: 154 | indexes.append(i) 155 | holdout_losses[i] = new_loss 156 | 157 | if len(indexes) > 0: 158 | self.model.update_save(indexes) 159 | cnt = 0 160 | else: 161 | cnt += 1 162 | 163 | if (cnt >= max_epochs_since_update) or (max_epochs and (epoch >= max_epochs)): 164 | break 165 | 166 | indexes = self.select_elites(holdout_losses) 167 | self.model.set_elites(indexes) 168 | self.model.load_save() 169 | self.save(logger.model_dir) 170 | self.model.eval() 171 | logger.log("elites:{} , holdout loss: {}".format(indexes, (np.sort(holdout_losses)[:self.model.num_elites]).mean())) 172 | 173 | def learn( 174 | self, 175 | inputs: np.ndarray, 176 | targets: np.ndarray, 177 | batch_size: int = 256, 178 | logvar_loss_coef: float = 0.01 179 | ) -> float: 180 | self.model.train() 181 | train_size = inputs.shape[1] 182 | losses = [] 183 | 184 | for batch_num in range(int(np.ceil(train_size / batch_size))): 185 | inputs_batch = inputs[:, batch_num * batch_size:(batch_num + 1) * batch_size] 186 | targets_batch = targets[:, batch_num * batch_size:(batch_num + 1) * batch_size] 187 | targets_batch = torch.as_tensor(targets_batch).to(self.model.device) 188 | 189 | mean, logvar = self.model(inputs_batch) 190 | inv_var = torch.exp(-logvar) 191 | # Average over batch and dim, sum over ensembles. 192 | mse_loss_inv = (torch.pow(mean - targets_batch, 2) * inv_var).mean(dim=(1, 2)) 193 | var_loss = logvar.mean(dim=(1, 2)) 194 | loss = mse_loss_inv.sum() + var_loss.sum() 195 | loss = loss + self.model.get_decay_loss() 196 | loss = loss + logvar_loss_coef * self.model.max_logvar.sum() - logvar_loss_coef * self.model.min_logvar.sum() 197 | 198 | self.optim.zero_grad() 199 | loss.backward() 200 | self.optim.step() 201 | 202 | losses.append(loss.item()) 203 | return np.mean(losses) 204 | 205 | @ torch.no_grad() 206 | def validate(self, inputs: np.ndarray, targets: np.ndarray) -> List[float]: 207 | self.model.eval() 208 | targets = torch.as_tensor(targets).to(self.model.device) 209 | mean, _ = self.model(inputs) 210 | loss = ((mean - targets) ** 2).mean(dim=(1, 2)) 211 | val_loss = list(loss.cpu().numpy()) 212 | return val_loss 213 | 214 | def select_elites(self, metrics: List) -> List[int]: 215 | pairs = [(metric, index) for metric, index in zip(metrics, range(len(metrics)))] 216 | pairs = sorted(pairs, key=lambda x: x[0]) 217 | elites = [pairs[i][1] for i in range(self.model.num_elites)] 218 | return elites 219 | 220 | def save(self, save_path: str) -> None: 221 | torch.save(self.model.state_dict(), os.path.join(save_path, "dynamics.pth")) 222 | self.scaler.save_scaler(save_path) 223 | 224 | def load(self, load_path: str) -> None: 225 | self.model.load_state_dict(torch.load(os.path.join(load_path, "dynamics.pth"), map_location=self.model.device)) 226 | self.scaler.load_scaler(load_path) 227 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 | 3 |
4 | 5 | --- 6 | 7 | # OfflineRL-Kit: An elegant PyTorch offline reinforcement learning library. 8 | 9 | ![MIT](https://img.shields.io/badge/license-MIT-blue) 10 | 11 | OfflineRL-Kit is an offline reinforcement learning library based on pure PyTorch. This library has some features which are friendly and convenient for researchers, including: 12 | 13 | - Elegant framework, the code structure is very clear and easy to use 14 | - State-of-the-art offline RL algorithms, including model-free and model-based approaches 15 | - High scalability, you can build your new algorithm with few lines of code based on the components in our library 16 | - Support parallel tuning, very convenient for researchers 17 | - Clear and powerful log system, easy to manage experiments 18 | 19 | ## Supported algorithms 20 | - Model-free 21 | - [Conservative Q-Learning (CQL)](https://arxiv.org/abs/2006.04779) 22 | - [TD3+BC](https://arxiv.org/abs/2106.06860) 23 | - [Implicit Q-Learning (IQL)](https://arxiv.org/abs/2110.06169) 24 | - [Ensemble-Diversified Actor Critic (EDAC)](https://arxiv.org/abs/2110.01548) 25 | - [Mildly Conservative Q-Learning (MCQ)](https://arxiv.org/abs/2206.04745) 26 | - Model-based 27 | - [Model-based Offline Policy Optimization (MOPO)](https://arxiv.org/abs/2005.13239) 28 | - [Conservative Offline Model-Based Policy Optimization (COMBO)](https://arxiv.org/abs/2102.08363) 29 | - [Robust Adversarial Model-Based Offline Reinforcement Learning (RAMBO)](https://arxiv.org/abs/2204.12581) 30 | - [Model-Bellman Inconsistancy Penalized Offline Reinforcement Learning (MOBILE)](https://proceedings.mlr.press/v202/sun23q.html) 31 | 32 | ## Benchmark Results (4 seeds) (Ongoing) 33 | 34 | | | CQL | TD3+BC | EDAC | IQL | MOPO | RAMBO | COMBO | MOBILE | 35 | | ---------------------------- | --------- | --------- | --------- | --------- | --------- | --------- | --------- | --------- | 36 | | halfcheetah-medium-v2 | 49.4±0.2 | 48.2±0.5 | 66.4±1.1 | 47.4±0.5 | 72.4±4.2 | 78.7±1.1 | 71.9±8.5 | 75.8±0.8 | 37 | | hopper-medium-v2 | 59.1±4.1 | 60.8±3.4 | 101.8±0.2 | 65.7±8.1 | 62.8±38.1 | 82.1±38.0 | 84.7±9.3 | 103.6±1.0 | 38 | | walker2d-medium-v2 | 83.6±0.5 | 84.4±2.1 | 93.3±0.8 | 81.1±2.6 | 84.1±3.2 | 86.1±1.0 | 83.9±2.0 | 88.3±2.5 | 39 | | halfcheetah-medium-replay-v2 | 47.0±0.3 | 45.0±0.5 | 62.3±1.4 | 44.2±0.6 | 72.1±3.8 | 68.5±3.6 | 66.5±6.5 | 71.9±3.2 | 40 | | hopper-medium-replay-v2 | 98.6±1.5 | 67.3±13.2 | 101.5±0.1 | 94.8±6.7 | 92.7±20.7 | 93.4±11.4 | 90.1±25.2 | 105.1±1.3 | 41 | | walker2d-medium-replay-v2 | 71.3±17.9 | 83.4±7.0 | 86.2±1.2 | 77.3±11.0 | 85.9±5.3 | 73.7±6.5 | 89.4±6.4 | 90.5±1.7 | 42 | | halfcheetah-medium-expert-v2 | 93.0±2.2 | 90.7±2.7 | 101.8±8.4 | 88.0±2.8 | 83.6±12.5 | 98.8±4.3 | 98.2±0.2 | 100.9±1.5 | 43 | | hopper-medium-expert-v2 | 111.4±0.5 | 91.4±11.3 | 110.5±0.3 | 106.2±5.6 | 74.6±44.2 | 85.0±30.7 | 108.8±2.6 | 112.5±0.2 | 44 | | walker2d-medium-expert-v2 | 109.8±0.5 | 110.2±0.3 | 113.6±0.3 | 108.3±2.6 | 108.2±4.3 | 78.4±45.4 | 110.0±0.2 | 114.5±2.2 | 45 | 46 | Detailed logs can be viewed in . 47 | 48 | ## Installation 49 | First, install MuJuCo engine, which can be download from [here](https://mujoco.org/download), and install `mujoco-py` (its version depends on the version of MuJoCo engine you have installed). 50 | 51 | Second, install D4RL: 52 | ```shell 53 | git clone https://github.com/Farama-Foundation/d4rl.git 54 | cd d4rl 55 | pip install -e . 56 | ``` 57 | 58 | Finally, install our OfflineRL-Kit! 59 | ```shell 60 | git clone https://github.com/yihaosun1124/OfflineRL-Kit.git 61 | cd OfflineRL-Kit 62 | python setup.py install 63 | ``` 64 | 65 | ## Quick Start 66 | ### Train 67 | This is an example of CQL. You can also run the full script at [run_example/run_cql.py](https://github.com/yihaosun1124/OfflineRL-Kit/blob/main/run_example/run_cql.py). 68 | 69 | First, make an environment and get the offline dataset: 70 | 71 | ```python 72 | env = gym.make(args.task) 73 | dataset = qlearning_dataset(env) 74 | buffer = ReplayBuffer( 75 | buffer_size=len(dataset["observations"]), 76 | obs_shape=args.obs_shape, 77 | obs_dtype=np.float32, 78 | action_dim=args.action_dim, 79 | action_dtype=np.float32, 80 | device=args.device 81 | ) 82 | buffer.load_dataset(dataset) 83 | ``` 84 | 85 | Define the models and optimizers: 86 | 87 | ```python 88 | actor_backbone = MLP(input_dim=np.prod(args.obs_shape), hidden_dims=args.hidden_dims) 89 | critic1_backbone = MLP(input_dim=np.prod(args.obs_shape) + args.action_dim, hidden_dims=args.hidden_dims) 90 | critic2_backbone = MLP(input_dim=np.prod(args.obs_shape) + args.action_dim, hidden_dims=args.hidden_dims) 91 | dist = TanhDiagGaussian( 92 | latent_dim=getattr(actor_backbone, "output_dim"), 93 | output_dim=args.action_dim, 94 | unbounded=True, 95 | conditioned_sigma=True 96 | ) 97 | actor = ActorProb(actor_backbone, dist, args.device) 98 | critic1 = Critic(critic1_backbone, args.device) 99 | critic2 = Critic(critic2_backbone, args.device) 100 | actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr) 101 | critic1_optim = torch.optim.Adam(critic1.parameters(), lr=args.critic_lr) 102 | critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr) 103 | ``` 104 | 105 | Setup policy: 106 | 107 | ```python 108 | policy = CQLPolicy( 109 | actor, 110 | critic1, 111 | critic2, 112 | actor_optim, 113 | critic1_optim, 114 | critic2_optim, 115 | action_space=env.action_space, 116 | tau=args.tau, 117 | gamma=args.gamma, 118 | alpha=alpha, 119 | cql_weight=args.cql_weight, 120 | temperature=args.temperature, 121 | max_q_backup=args.max_q_backup, 122 | deterministic_backup=args.deterministic_backup, 123 | with_lagrange=args.with_lagrange, 124 | lagrange_threshold=args.lagrange_threshold, 125 | cql_alpha_lr=args.cql_alpha_lr, 126 | num_repeart_actions=args.num_repeat_actions 127 | ) 128 | ``` 129 | 130 | Define logger: 131 | ```python 132 | log_dirs = make_log_dirs(args.task, args.algo_name, args.seed, vars(args)) 133 | output_config = { 134 | "consoleout_backup": "stdout", 135 | "policy_training_progress": "csv", 136 | "tb": "tensorboard" 137 | } 138 | logger = Logger(log_dirs, output_config) 139 | logger.log_hyperparameters(vars(args)) 140 | ``` 141 | 142 | Load all components into the trainer and train it: 143 | ```python 144 | policy_trainer = MFPolicyTrainer( 145 | policy=policy, 146 | eval_env=env, 147 | buffer=buffer, 148 | logger=logger, 149 | epoch=args.epoch, 150 | step_per_epoch=args.step_per_epoch, 151 | batch_size=args.batch_size, 152 | eval_episodes=args.eval_episodes 153 | ) 154 | 155 | policy_trainer.train() 156 | ``` 157 | 158 | ### Tune 159 | You can easily tune your algorithm with the help of [Ray](https://github.com/ray-project/ray): 160 | ```python 161 | ray.init() 162 | # load default args 163 | args = get_args() 164 | 165 | config = {} 166 | real_ratios = [0.05, 0.5] 167 | seeds = list(range(2)) 168 | config["real_ratio"] = tune.grid_search(real_ratios) 169 | config["seed"] = tune.grid_search(seeds) 170 | 171 | analysis = tune.run( 172 | run_exp, 173 | name="tune_mopo", 174 | config=config, 175 | resources_per_trial={ 176 | "gpu": 0.5 177 | } 178 | ) 179 | ``` 180 | You can see the full script at [tune_example/tune_mopo.py](https://github.com/yihaosun1124/OfflineRL-Kit/blob/main/tune_example/tune_mopo.py). 181 | 182 | ### Log 183 | Our logger supports a variant of record file types, including .txt(backup for stdout), .csv(records loss or performance or other metrics in training progress), .tfevents (tensorboard for visualizing the training curve), .json(backup for hyper-parameters). 184 | Our logger also has a clear log structure: 185 | ``` 186 | └─log(root dir) 187 | └─task 188 | └─algo_0 189 | | └─seed_0×tamp_xxx 190 | | | ├─checkpoint 191 | | | ├─model 192 | | | ├─record 193 | | | │ ├─tb 194 | | | │ ├─consoleout_backup.txt 195 | | | │ ├─policy_training_progress.csv 196 | | | │ ├─hyper_param.json 197 | | | ├─result 198 | | └─seed_1×tamp_xxx 199 | └─algo_1 200 | ``` 201 | This is an example of logger and you can see the full script at [offlinerlkit/policy_trainer/mb_policy_trainer.py](https://github.com/yihaosun1124/OfflineRL-Kit/blob/main/offlinerlkit/policy_trainer/mb_policy_trainer.py). 202 | 203 | First, import some relevant packages: 204 | ```python 205 | from offlinerlkit.utils.logger import Logger, make_log_dirs 206 | ``` 207 | Then initialize logger: 208 | ```py 209 | log_dirs = make_log_dirs(args.task, args.algo_name, args.seed, vars(args)) 210 | # key: output file name, value: output handler type 211 | output_config = { 212 | "consoleout_backup": "stdout", 213 | "policy_training_progress": "csv", 214 | "dynamics_training_progress": "csv", 215 | "tb": "tensorboard" 216 | } 217 | logger = Logger(log_dirs, output_config) 218 | logger.log_hyperparameters(vars(args)) 219 | ``` 220 | 221 | Let's log some metrics: 222 | ```python 223 | # log 224 | logger.logkv("eval/normalized_episode_reward", norm_ep_rew_mean) 225 | logger.logkv("eval/normalized_episode_reward_std", norm_ep_rew_std) 226 | logger.logkv("eval/episode_length", ep_length_mean) 227 | logger.logkv("eval/episode_length_std", ep_length_std) 228 | # set timestep 229 | logger.set_timestep(num_timesteps) 230 | # dump results to the record files 231 | logger.dumpkvs() 232 | ``` 233 | 234 | ### Plot 235 | ```shell 236 | python run_example/plotter.py --algos "mopo" "cql" --task "hopper-medium-replay-v2" 237 | ``` 238 | 239 | ## Citing OfflineRL-Kit 240 | If you use OfflineRL-Kit in your work, please use the following bibtex 241 | ```tex 242 | @misc{offinerlkit, 243 | author = {Yihao Sun}, 244 | title = {OfflineRL-Kit: An Elegant PyTorch Offline Reinforcement Learning Library}, 245 | year = {2023}, 246 | publisher = {GitHub}, 247 | journal = {GitHub repository}, 248 | howpublished = {\url{https://github.com/yihaosun1124/OfflineRL-Kit}}, 249 | } 250 | ``` -------------------------------------------------------------------------------- /offlinerlkit/policy/model_based/combo.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import gym 5 | 6 | from torch.nn import functional as F 7 | from typing import Dict, Union, Tuple 8 | from collections import defaultdict 9 | from offlinerlkit.policy import CQLPolicy 10 | from offlinerlkit.dynamics import BaseDynamics 11 | 12 | 13 | class COMBOPolicy(CQLPolicy): 14 | """ 15 | Conservative Offline Model-Based Policy Optimization 16 | """ 17 | 18 | def __init__( 19 | self, 20 | dynamics: BaseDynamics, 21 | actor: nn.Module, 22 | critic1: nn.Module, 23 | critic2: nn.Module, 24 | actor_optim: torch.optim.Optimizer, 25 | critic1_optim: torch.optim.Optimizer, 26 | critic2_optim: torch.optim.Optimizer, 27 | action_space: gym.spaces.Space, 28 | tau: float = 0.005, 29 | gamma: float = 0.99, 30 | alpha: Union[float, Tuple[float, torch.Tensor, torch.optim.Optimizer]] = 0.2, 31 | cql_weight: float = 1.0, 32 | temperature: float = 1.0, 33 | max_q_backup: bool = False, 34 | deterministic_backup: bool = True, 35 | with_lagrange: bool = True, 36 | lagrange_threshold: float = 10.0, 37 | cql_alpha_lr: float = 1e-4, 38 | num_repeart_actions:int = 10, 39 | uniform_rollout: bool = False, 40 | rho_s: str = "mix" 41 | ) -> None: 42 | super().__init__( 43 | actor, 44 | critic1, 45 | critic2, 46 | actor_optim, 47 | critic1_optim, 48 | critic2_optim, 49 | action_space, 50 | tau=tau, 51 | gamma=gamma, 52 | alpha=alpha, 53 | cql_weight=cql_weight, 54 | temperature=temperature, 55 | max_q_backup=max_q_backup, 56 | deterministic_backup=deterministic_backup, 57 | with_lagrange=with_lagrange, 58 | lagrange_threshold=lagrange_threshold, 59 | cql_alpha_lr=cql_alpha_lr, 60 | num_repeart_actions=num_repeart_actions 61 | ) 62 | 63 | self.dynamics = dynamics 64 | self._uniform_rollout = uniform_rollout 65 | self._rho_s = rho_s 66 | 67 | def rollout( 68 | self, 69 | init_obss: np.ndarray, 70 | rollout_length: int 71 | ) -> Tuple[Dict[str, np.ndarray], Dict]: 72 | 73 | num_transitions = 0 74 | rewards_arr = np.array([]) 75 | rollout_transitions = defaultdict(list) 76 | 77 | # rollout 78 | observations = init_obss 79 | for _ in range(rollout_length): 80 | if self._uniform_rollout: 81 | actions = np.random.uniform( 82 | self.action_space.low[0], 83 | self.action_space.high[0], 84 | size=(len(observations), self.action_space.shape[0]) 85 | ) 86 | else: 87 | actions = self.select_action(observations) 88 | next_observations, rewards, terminals, info = self.dynamics.step(observations, actions) 89 | rollout_transitions["obss"].append(observations) 90 | rollout_transitions["next_obss"].append(next_observations) 91 | rollout_transitions["actions"].append(actions) 92 | rollout_transitions["rewards"].append(rewards) 93 | rollout_transitions["terminals"].append(terminals) 94 | 95 | num_transitions += len(observations) 96 | rewards_arr = np.append(rewards_arr, rewards.flatten()) 97 | 98 | nonterm_mask = (~terminals).flatten() 99 | if nonterm_mask.sum() == 0: 100 | break 101 | 102 | observations = next_observations[nonterm_mask] 103 | 104 | for k, v in rollout_transitions.items(): 105 | rollout_transitions[k] = np.concatenate(v, axis=0) 106 | 107 | return rollout_transitions, \ 108 | {"num_transitions": num_transitions, "reward_mean": rewards_arr.mean()} 109 | 110 | def learn(self, batch: Dict) -> Dict[str, float]: 111 | real_batch, fake_batch = batch["real"], batch["fake"] 112 | mix_batch = {k: torch.cat([real_batch[k], fake_batch[k]], 0) for k in real_batch.keys()} 113 | 114 | obss, actions, next_obss, rewards, terminals = mix_batch["observations"], mix_batch["actions"], \ 115 | mix_batch["next_observations"], mix_batch["rewards"], mix_batch["terminals"] 116 | batch_size = obss.shape[0] 117 | 118 | # update actor 119 | a, log_probs = self.actforward(obss) 120 | q1a, q2a = self.critic1(obss, a), self.critic2(obss, a) 121 | actor_loss = (self._alpha * log_probs - torch.min(q1a, q2a)).mean() 122 | self.actor_optim.zero_grad() 123 | actor_loss.backward() 124 | self.actor_optim.step() 125 | 126 | if self._is_auto_alpha: 127 | log_probs = log_probs.detach() + self._target_entropy 128 | alpha_loss = -(self._log_alpha * log_probs).mean() 129 | self.alpha_optim.zero_grad() 130 | alpha_loss.backward() 131 | self.alpha_optim.step() 132 | self._alpha = self._log_alpha.detach().exp() 133 | 134 | # compute td error 135 | if self._max_q_backup: 136 | with torch.no_grad(): 137 | tmp_next_obss = next_obss.unsqueeze(1) \ 138 | .repeat(1, self._num_repeat_actions, 1) \ 139 | .view(batch_size * self._num_repeat_actions, next_obss.shape[-1]) 140 | tmp_next_actions, _ = self.actforward(tmp_next_obss) 141 | tmp_next_q1 = self.critic1_old(tmp_next_obss, tmp_next_actions) \ 142 | .view(batch_size, self._num_repeat_actions, 1) \ 143 | .max(1)[0].view(-1, 1) 144 | tmp_next_q2 = self.critic2_old(tmp_next_obss, tmp_next_actions) \ 145 | .view(batch_size, self._num_repeat_actions, 1) \ 146 | .max(1)[0].view(-1, 1) 147 | next_q = torch.min(tmp_next_q1, tmp_next_q2) 148 | else: 149 | with torch.no_grad(): 150 | next_actions, next_log_probs = self.actforward(next_obss) 151 | next_q = torch.min( 152 | self.critic1_old(next_obss, next_actions), 153 | self.critic2_old(next_obss, next_actions) 154 | ) 155 | if not self._deterministic_backup: 156 | next_q -= self._alpha * next_log_probs 157 | 158 | target_q = rewards + self._gamma * (1 - terminals) * next_q 159 | q1, q2 = self.critic1(obss, actions), self.critic2(obss, actions) 160 | critic1_loss = ((q1 - target_q).pow(2)).mean() 161 | critic2_loss = ((q2 - target_q).pow(2)).mean() 162 | 163 | # compute conservative loss 164 | if self._rho_s == "model": 165 | obss, actions, next_obss = fake_batch["observations"], \ 166 | fake_batch["actions"], fake_batch["next_observations"] 167 | 168 | batch_size = len(obss) 169 | random_actions = torch.FloatTensor( 170 | batch_size * self._num_repeat_actions, actions.shape[-1] 171 | ).uniform_(self.action_space.low[0], self.action_space.high[0]).to(self.actor.device) 172 | # tmp_obss & tmp_next_obss: (batch_size * num_repeat, obs_dim) 173 | tmp_obss = obss.unsqueeze(1) \ 174 | .repeat(1, self._num_repeat_actions, 1) \ 175 | .view(batch_size * self._num_repeat_actions, obss.shape[-1]) 176 | tmp_next_obss = next_obss.unsqueeze(1) \ 177 | .repeat(1, self._num_repeat_actions, 1) \ 178 | .view(batch_size * self._num_repeat_actions, obss.shape[-1]) 179 | 180 | obs_pi_value1, obs_pi_value2 = self.calc_pi_values(tmp_obss, tmp_obss) 181 | next_obs_pi_value1, next_obs_pi_value2 = self.calc_pi_values(tmp_next_obss, tmp_obss) 182 | random_value1, random_value2 = self.calc_random_values(tmp_obss, random_actions) 183 | 184 | for value in [ 185 | obs_pi_value1, obs_pi_value2, next_obs_pi_value1, next_obs_pi_value2, 186 | random_value1, random_value2 187 | ]: 188 | value.reshape(batch_size, self._num_repeat_actions, 1) 189 | 190 | # cat_q shape: (batch_size, 3 * num_repeat, 1) 191 | cat_q1 = torch.cat([obs_pi_value1, next_obs_pi_value1, random_value1], 1) 192 | cat_q2 = torch.cat([obs_pi_value2, next_obs_pi_value2, random_value2], 1) 193 | # Samples from the original dataset 194 | real_obss, real_actions = real_batch['observations'], real_batch['actions'] 195 | q1, q2 = self.critic1(real_obss, real_actions), self.critic2(real_obss, real_actions) 196 | 197 | conservative_loss1 = \ 198 | torch.logsumexp(cat_q1 / self._temperature, dim=1).mean() * self._cql_weight * self._temperature - \ 199 | q1.mean() * self._cql_weight 200 | conservative_loss2 = \ 201 | torch.logsumexp(cat_q2 / self._temperature, dim=1).mean() * self._cql_weight * self._temperature - \ 202 | q2.mean() * self._cql_weight 203 | 204 | if self._with_lagrange: 205 | cql_alpha = torch.clamp(self.cql_log_alpha.exp(), 0.0, 1e6) 206 | conservative_loss1 = cql_alpha * (conservative_loss1 - self._lagrange_threshold) 207 | conservative_loss2 = cql_alpha * (conservative_loss2 - self._lagrange_threshold) 208 | 209 | self.cql_alpha_optim.zero_grad() 210 | cql_alpha_loss = -(conservative_loss1 + conservative_loss2) * 0.5 211 | cql_alpha_loss.backward(retain_graph=True) 212 | self.cql_alpha_optim.step() 213 | 214 | critic1_loss = critic1_loss + conservative_loss1 215 | critic2_loss = critic2_loss + conservative_loss2 216 | 217 | # update critic 218 | self.critic1_optim.zero_grad() 219 | critic1_loss.backward(retain_graph=True) 220 | self.critic1_optim.step() 221 | 222 | self.critic2_optim.zero_grad() 223 | critic2_loss.backward() 224 | self.critic2_optim.step() 225 | 226 | self._sync_weight() 227 | 228 | result = { 229 | "loss/actor": actor_loss.item(), 230 | "loss/critic1": critic1_loss.item(), 231 | "loss/critic2": critic2_loss.item() 232 | } 233 | 234 | if self._is_auto_alpha: 235 | result["loss/alpha"] = alpha_loss.item() 236 | result["alpha"] = self._alpha.item() 237 | if self._with_lagrange: 238 | result["loss/cql_alpha"] = cql_alpha_loss.item() 239 | result["cql_alpha"] = cql_alpha.item() 240 | 241 | return result --------------------------------------------------------------------------------