├── utils ├── default_config.py ├── Replaybuffer.py └── config.py ├── Non_stationary_env ├── __init__.py ├── ant_random.py ├── humanoid_random.py ├── hopper_random.py ├── walker2d_random.py └── panda_env.py ├── README.md ├── model ├── utils.py ├── model_back.py ├── model.py ├── discriminator.py └── algo.py ├── main_stationary.py ├── main_transfer.py └── main_non_stationary.py /utils/default_config.py: -------------------------------------------------------------------------------- 1 | from utils.config import Config 2 | 3 | default_config = Config({ 4 | "seed": 0, 5 | "tag": "default", 6 | "start_steps": 5e3, 7 | "cuda": True, 8 | "num_steps": 300001, 9 | "save": True, 10 | 11 | "env_name": "HalfCheetah-v2", 12 | "eval": True, 13 | "eval_episodes": 10, 14 | "eval_times": 10, 15 | "replay_size": 1000000, 16 | "local_replay_size": 1000, # default: 1000 17 | 18 | "algo": "TOMAC", 19 | "policy": "Gaussian", # 'Policy Type: Gaussian | Deterministic (default: Gaussian)' 20 | "gamma": 0.99, 21 | "tau": 0.005, 22 | "lr": 0.0003, 23 | "alpha": 0.2, 24 | "automatic_entropy_tuning": True, 25 | "batch_size": 256, 26 | "updates_per_step": 3, 27 | "target_update_interval": 2, 28 | "hidden_size": 256, 29 | "gail_batch": 256, 30 | 31 | "exponent": 1.5, # default: 1.1 32 | "tomac_alpha": 0.001, # default: 0.001 33 | "reward_max": 1. 34 | }) 35 | -------------------------------------------------------------------------------- /utils/Replaybuffer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import os 4 | import pickle 5 | import random 6 | 7 | class ReplayMemory: 8 | def __init__(self, capacity, seed): 9 | random.seed(seed) 10 | self.capacity = capacity 11 | self.buffer = [] 12 | self.position = 0 13 | 14 | def push(self, state, action, reward, next_state, done): 15 | if len(self.buffer) < self.capacity: 16 | self.buffer.append(None) 17 | self.buffer[self.position] = (state, action, reward, next_state, done) 18 | self.position = (self.position + 1) % self.capacity 19 | 20 | def sample(self, batch_size): 21 | batch = random.sample(self.buffer, batch_size) 22 | state, action, reward, next_state, done = map(np.stack, zip(*batch)) 23 | return state, action, reward, next_state, done 24 | 25 | def __len__(self): 26 | return len(self.buffer) 27 | 28 | def save_buffer(self, path): 29 | print('Saving buffer to {}'.format(path)) 30 | 31 | with open(path, 'wb') as f: 32 | pickle.dump(self.buffer, f) 33 | 34 | def load_buffer(self, save_path): 35 | print('Loading buffer from {}'.format(save_path)) 36 | 37 | with open(save_path, "rb") as f: 38 | self.buffer = pickle.load(f) 39 | self.position = len(self.buffer) % self.capacity -------------------------------------------------------------------------------- /Non_stationary_env/__init__.py: -------------------------------------------------------------------------------- 1 | from gym.envs import register 2 | 3 | register( 4 | id='HopperRandom-v0', 5 | entry_point='Non_stationary_env.hopper_random:HopperRandomEnv', 6 | max_episode_steps=1000, 7 | reward_threshold=3800.0, 8 | ) 9 | 10 | register( 11 | id='HopperTransfer-v0', 12 | entry_point='Non_stationary_env.hopper_random:HopperTransferEnv', 13 | max_episode_steps=1000, 14 | reward_threshold=3800.0, 15 | ) 16 | 17 | register( 18 | id='Walker2dRandom-v0', 19 | entry_point='Non_stationary_env.walker2d_random:Walker2dRandomEnv', 20 | max_episode_steps=1000 21 | ) 22 | 23 | register( 24 | id='Walker2dTransfer-v0', 25 | entry_point='Non_stationary_env.walker2d_random:Walker2dTransferEnv', 26 | max_episode_steps=1000 27 | ) 28 | 29 | register( 30 | id='AntTransfer-v0', 31 | entry_point='Non_stationary_env.ant_random:AntTransferEnv', 32 | max_episode_steps=1000, 33 | reward_threshold=6000.0, 34 | ) 35 | 36 | register( 37 | id='AntRandom-v0', 38 | entry_point='Non_stationary_env.ant_random:AntRandomEnv', 39 | max_episode_steps=1000, 40 | reward_threshold=6000.0, 41 | ) 42 | 43 | register( 44 | id='HumanoidTransfer-v0', 45 | entry_point='Non_stationary_env.humanoid_random:HumanoidTransferEnv', 46 | max_episode_steps=1000 47 | ) 48 | 49 | register( 50 | id='HumanoidRandom-v0', 51 | entry_point='Non_stationary_env.humanoid_random:HumanoidRandomEnv', 52 | max_episode_steps=1000 53 | ) -------------------------------------------------------------------------------- /Non_stationary_env/ant_random.py: -------------------------------------------------------------------------------- 1 | import mujoco_py 2 | import gym 3 | import numpy as np 4 | import random 5 | 6 | from gym.envs.mujoco.ant_v3 import AntEnv 7 | 8 | class AntTransferEnv(AntEnv): 9 | def __init__(self, gravity = 9.81, wind = 0, **kwargs): 10 | super().__init__(**kwargs) 11 | # self.reset(gravity = gravity, wind = wind) 12 | self.model.opt.viscosity = 0.00002 13 | self.model.opt.density = 1.2 14 | self.model.opt.gravity[:] = np.array([0., 0., -gravity]) 15 | self.model.opt.wind[:] = np.array([-wind, 0., 0.]) 16 | 17 | class AntRandomEnv(AntEnv): 18 | 19 | def __init__(self, **kwargs): 20 | super().__init__(**kwargs) 21 | self.model.opt.viscosity = 0.00002 22 | self.model.opt.density = 1.2 23 | 24 | 25 | def step_with_random(self, action, gravity = 9.81, wind = 0): 26 | # print("Step with new gravity = ", gravity, " wind = ", wind) 27 | 28 | self.model.opt.gravity[:] = np.array([0., 0., -gravity]) 29 | self.model.opt.wind[:] = np.array([-wind, 0., 0.]) 30 | 31 | return self.step(action) 32 | 33 | def reset(self, gravity = 9.81, wind = 0): 34 | '''Must called like env.reset(body_len = XXX)''' 35 | # print("Reset with new gravity = ", gravity, " wind = ", wind) 36 | 37 | self.model.opt.gravity[:] = np.array([0., 0., -gravity]) 38 | self.model.opt.wind[:] = np.array([-wind, 0., 0.]) 39 | 40 | return super().reset() 41 | -------------------------------------------------------------------------------- /Non_stationary_env/humanoid_random.py: -------------------------------------------------------------------------------- 1 | import mujoco_py 2 | import gym 3 | import numpy as np 4 | import random 5 | 6 | from gym.envs.mujoco.humanoid_v3 import HumanoidEnv 7 | 8 | class HumanoidTransferEnv(HumanoidEnv): 9 | def __init__(self, gravity = 9.81, wind = 0, **kwargs): 10 | super().__init__(**kwargs) 11 | # self.reset(gravity = gravity, wind = wind) 12 | self.model.opt.viscosity = 0.00002 13 | self.model.opt.density = 1.2 14 | self.model.opt.gravity[:] = np.array([0., 0., -gravity]) 15 | self.model.opt.wind[:] = np.array([-wind, 0., 0.]) 16 | 17 | class HumanoidRandomEnv(HumanoidEnv): 18 | 19 | def __init__(self, **kwargs): 20 | super().__init__(**kwargs) 21 | self.model.opt.viscosity = 0.00002 22 | self.model.opt.density = 1.2 23 | 24 | 25 | def step_with_random(self, action, gravity = 9.81, wind = 0): 26 | # print("Step with new gravity = ", gravity, " wind = ", wind) 27 | 28 | self.model.opt.gravity[:] = np.array([0., 0., -gravity]) 29 | self.model.opt.wind[:] = np.array([-wind, 0., 0.]) 30 | 31 | return self.step(action) 32 | 33 | def reset(self, gravity = 9.81, wind = 0): 34 | '''Must called like env.reset(body_len = XXX)''' 35 | # print("Reset with new gravity = ", gravity, " wind = ", wind) 36 | 37 | self.model.opt.gravity[:] = np.array([0., 0., -gravity]) 38 | self.model.opt.wind[:] = np.array([-wind, 0., 0.]) 39 | 40 | return super().reset() 41 | -------------------------------------------------------------------------------- /Non_stationary_env/hopper_random.py: -------------------------------------------------------------------------------- 1 | import mujoco_py 2 | import gym 3 | import numpy as np 4 | import random 5 | 6 | from gym.envs.mujoco.hopper_v3 import HopperEnv 7 | 8 | class HopperTransferEnv(HopperEnv): 9 | def __init__(self, torso_len: float = 0.2, foot_len: float = 0.195, **kwargs): 10 | super().__init__(**kwargs) 11 | self.model.body_pos[1][2] = 1.05 + torso_len 12 | self.model.body_pos[2][2] = -torso_len 13 | self.model.geom_size[1][1] = torso_len 14 | 15 | self.model.geom_size[4][1] = foot_len 16 | 17 | 18 | 19 | class HopperRandomEnv(HopperEnv): 20 | def __init__(self, **kwargs): 21 | super().__init__(**kwargs) 22 | 23 | def reset(self, torso_len: float = 0.2, foot_len: float = 0.195): 24 | '''Must called like env.reset(body_len = XXX)''' 25 | 26 | self.model.body_pos[1][2] = 1.05 + torso_len 27 | self.model.body_pos[2][2] = -torso_len 28 | self.model.geom_size[1][1] = torso_len 29 | 30 | self.model.geom_size[4][1] = foot_len 31 | 32 | return super().reset() 33 | 34 | if __name__ == "__main__": 35 | env = gym.make("HopperRandom-v0", torso_len = 0.2, foot_len = 0.195) 36 | env.reset(torso_len = 0.2, foot_len = 0.195) 37 | 38 | for _ in range(10): 39 | for i in range(50): 40 | env.step(np.random.rand(3)) 41 | env.render() 42 | 43 | l1 = 0.2 + 0.15 * random.random() 44 | l2 = 0.195 + 0.1 * random.random() 45 | env.reset(torso_len = l1, foot_len = l2) 46 | 47 | env.close() -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

OMPO

2 | 3 | Official implementation of 4 | 5 | `OMPO: A Unified Framework for RL under Policy and Dynamics Shifts` by 6 | 7 | Yu Luo, Tianying Ji, Fuchun Sun, Jianwei Zhang, Huazhe Xu and Xianyuan Zhan 8 | 9 | ## Getting started 10 | 11 | We provide examples on how to train and evaluate **OMPO** agent. 12 | 13 | ### Training 14 | 15 | See below examples on how to train OBAC on a single task. 16 | 17 | ```python 18 | python main_stationary.py --env_name YOUR_TASK 19 | ``` 20 | 21 | We recommend using default hyperparameters. See `utils/default_config.py` for a full list of arguments. 22 | 23 | ## Citation 24 | 25 | If you find our work useful, please consider citing our paper as follows: 26 | 27 | ``` 28 | @inproceedings{Luo2024ompo, 29 | title={OMPO: A Unified Framework for RL under Policy and Dynamics Shifts}, 30 | author={Yu Luo and Tianjing Ji and Fuchun Sun and Jianwei Zhang and Huazhe Xu and Xianyuan Zhan}, 31 | booktitle={International Conference on Machine Learning}, 32 | year={2024} 33 | } 34 | ``` 35 | 36 | ---- 37 | 38 | ## Contributing 39 | 40 | Please feel free to participate in our project by opening issues or sending pull requests for any enhancements or bug reports you might have. We’re striving to develop a codebase that’s easily expandable to different settings and tasks, and your feedback on how it’s working is greatly appreciated! 41 | 42 | ---- 43 | 44 | ## License 45 | 46 | This project is licensed under the MIT License - see the `LICENSE` file for details. Note that the repository relies on third-party code, which is subject to their respective licenses. 47 | -------------------------------------------------------------------------------- /Non_stationary_env/walker2d_random.py: -------------------------------------------------------------------------------- 1 | import mujoco_py 2 | import gym 3 | import numpy as np 4 | import random 5 | 6 | from gym.envs.mujoco.walker2d_v3 import Walker2dEnv 7 | 8 | class Walker2dTransferEnv(Walker2dEnv): 9 | def __init__(self, torso_len: float = 0.2, foot_len: float = 0.1, **kwargs): 10 | super().__init__(**kwargs) 11 | self.model.body_pos[1][2] = 1.05 + torso_len 12 | self.model.body_pos[2][2] = -torso_len 13 | self.model.body_pos[5][2] = -torso_len 14 | self.model.geom_size[1][1] = torso_len 15 | 16 | self.model.geom_size[4][1] = foot_len 17 | self.model.geom_size[7][1] = foot_len 18 | 19 | 20 | class Walker2dRandomEnv(Walker2dEnv): 21 | def __init__(self, **kwargs): 22 | super().__init__(**kwargs) 23 | 24 | def reset(self, torso_len: float = 0.2, foot_len: float = 0.1): 25 | '''Must called like env.reset(body_len = XXX)''' 26 | 27 | self.model.body_pos[1][2] = 1.05 + torso_len 28 | self.model.body_pos[2][2] = -torso_len 29 | self.model.body_pos[5][2] = -torso_len 30 | self.model.geom_size[1][1] = torso_len 31 | 32 | self.model.geom_size[4][1] = foot_len 33 | self.model.geom_size[7][1] = foot_len 34 | 35 | return super().reset() 36 | 37 | if __name__ == "__main__": 38 | env = gym.make("Walker2dRandom-v0", torso_len = 0.2, foot_len = 0.1) 39 | env.reset(torso_len = 0.2, foot_len = 0.1) 40 | 41 | for _ in range(10): 42 | for i in range(50): 43 | env.step(np.random.rand(6)) 44 | env.render() 45 | 46 | l1 = 0.2 + 0.15 * random.random() 47 | l2 = 0.195 + 0.1 * random.random() 48 | env.reset(torso_len = l1, foot_len = l2) 49 | 50 | env.close() -------------------------------------------------------------------------------- /model/utils.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | 4 | def create_log_gaussian(mean, log_std, t): 5 | quadratic = -((0.5 * (t - mean) / (log_std.exp())).pow(2)) 6 | l = mean.shape 7 | log_z = log_std 8 | z = l[-1] * math.log(2 * math.pi) 9 | log_p = quadratic.sum(dim=-1) - log_z.sum(dim=-1) - 0.5 * z 10 | return log_p 11 | 12 | def logsumexp(inputs, dim=None, keepdim=False): 13 | if dim is None: 14 | inputs = inputs.view(-1) 15 | dim = 0 16 | s, _ = torch.max(inputs, dim=dim, keepdim=True) 17 | outputs = s + (inputs - s).exp().sum(dim=dim, keepdim=True).log() 18 | if not keepdim: 19 | outputs = outputs.squeeze(dim) 20 | return outputs 21 | 22 | def soft_update(target, source, tau): 23 | for target_param, param in zip(target.parameters(), source.parameters()): 24 | target_param.data.copy_(target_param.data * (1.0 - tau) + param.data * tau) 25 | 26 | def hard_update(target, source): 27 | for target_param, param in zip(target.parameters(), source.parameters()): 28 | target_param.data.copy_(param.data) 29 | 30 | def orthogonal_regularization(model, reg_coef=1e-4): 31 | """Orthogonal regularization v2. 32 | 33 | See equation (3) in https://arxiv.org/abs/1809.11096. 34 | 35 | Args: 36 | model: A PyTorch model to apply regularization for. 37 | reg_coef: Orthogonal regularization coefficient. 38 | 39 | Returns: 40 | A regularization loss term. 41 | """ 42 | 43 | reg = 0.0 44 | for module in model.modules(): 45 | if isinstance(module, torch.nn.Linear): 46 | weight = module.weight 47 | prod = torch.matmul(weight.t(), weight) 48 | eye_matrix = torch.eye(prod.shape[0], device=weight.device) 49 | reg += torch.sum(torch.square(prod * (1 - eye_matrix))) 50 | 51 | return reg * reg_coef 52 | -------------------------------------------------------------------------------- /Non_stationary_env/panda_env.py: -------------------------------------------------------------------------------- 1 | import gymnasium 2 | import panda_gym 3 | from gym.wrappers import TimeLimit 4 | import numpy as np 5 | 6 | class PandaWrapper(gymnasium.Wrapper): 7 | def __init__(self, env): 8 | gymnasium.Wrapper.__init__(self, env) 9 | def reset(self): 10 | state, info = self.env.reset() 11 | return state 12 | def step(self, action): 13 | next_state, reward, done, flag, info = self.env.step(action) 14 | return next_state, reward, done, info 15 | 16 | class PandaNoiseWrapper(gymnasium.Wrapper): 17 | def __init__(self, env): 18 | gymnasium.Wrapper.__init__(self, env) 19 | def reset(self): 20 | state, info = self.env.reset() 21 | return state 22 | 23 | def step(self, action): 24 | size = self.env.action_space.shape[0] 25 | # action = 0.02 * np.ones(size) + np.random.normal(0, np.sqrt(0.01), size=(size,)) + action 26 | next_state, reward, done, flag, info = self.env.step(action) 27 | return next_state, reward, done, info 28 | 29 | 30 | def panda_make_env(env_name, control_type="joints", reward_type="sparse", render_mode="rgb_array"): 31 | env = gymnasium.make(env_name, control_type=control_type, reward_type=reward_type, render_mode=render_mode) 32 | env = PandaWrapper(gymnasium.wrappers.FlattenObservation(env)) 33 | if 'Stack' in env_name: 34 | env = TimeLimit(env, max_episode_steps=50) 35 | else: 36 | env = TimeLimit(env, max_episode_steps=100) 37 | return env 38 | 39 | def panda_make_noise_env(env_name, control_type="joints", reward_type="sparse", render_mode="rgb_array"): 40 | env = gymnasium.make(env_name, control_type=control_type, reward_type=reward_type, render_mode=render_mode) 41 | env = PandaNoiseWrapper(gymnasium.wrappers.FlattenObservation(env)) 42 | if 'Stack' in env_name: 43 | env = TimeLimit(env, max_episode_steps=50) 44 | else: 45 | env = TimeLimit(env, max_episode_steps=100) 46 | return env -------------------------------------------------------------------------------- /utils/config.py: -------------------------------------------------------------------------------- 1 | import os 2 | import ast 3 | import argparse 4 | 5 | 6 | class Config(dict): 7 | def __init__(self, seq=None, **kwargs): 8 | if seq is None: 9 | seq = {} 10 | super(Config, self).__init__(seq, **kwargs) 11 | 12 | def __setattr__(self, key, value): 13 | self[key] = value 14 | 15 | def __getattr__(self, item): 16 | return self[item] 17 | 18 | def __str__(self): 19 | disc = [] 20 | for k in self: 21 | if k.startswith("_"): 22 | continue 23 | disc.append(f"{k}: {repr(self[k])},\n") 24 | return "".join(disc) 25 | 26 | def copy(self): 27 | return Config(self) 28 | 29 | def load_saved(self, path): 30 | if not os.path.isfile(path): 31 | raise FileNotFoundError(f"Error: file {path} not exists") 32 | lines = open(path, 'r').readlines() 33 | dic = {} 34 | for l in lines: 35 | key, value = l.strip().split(':', 1) 36 | if value == "": 37 | break 38 | key = key.strip() 39 | value = value.strip().rstrip(',') 40 | dic[key] = ast.literal_eval(value) 41 | self.update(dic) 42 | return self 43 | 44 | 45 | class ARGConfig(Config): 46 | def __init__(self, seq=None, **kwargs): 47 | seq = {} if seq is None else seq 48 | super(ARGConfig, self).__init__(seq, **kwargs) 49 | self._arg_dict = dict(seq, **kwargs) 50 | self._arg_help = dict() 51 | 52 | def add_arg(self, key, value, help_str=""): 53 | self._arg_dict[key] = value 54 | self._arg_help[key] = f"{help_str} (default: {value})" 55 | self[key] = value 56 | 57 | def parser(self, desc=""): 58 | # compiling arg-parser 59 | parser = argparse.ArgumentParser(description=desc) 60 | for k in self._arg_dict: 61 | arg_name = k.replace(' ', '_').replace('-', '_') 62 | help_msg = self._arg_help[k] if k in self._arg_help else "" 63 | parser.add_argument(f"--{arg_name}", type=str, 64 | default=self._arg_dict[k] if isinstance(self._arg_dict[k], str) else repr(self._arg_dict[k]), 65 | help=help_msg) 66 | 67 | pared_args = parser.parse_args().__dict__ 68 | 69 | for k in self._arg_dict: 70 | arg_name = k.replace(' ', '_').replace('-', '_') 71 | self[k] = self._value_from_string(pared_args[arg_name], type(self[k])) 72 | 73 | @staticmethod 74 | def _value_from_string(string: str, typeinst: type): 75 | if typeinst == str: 76 | return string 77 | elif typeinst == int: 78 | return int(string) 79 | elif typeinst == float: 80 | return float(string) 81 | elif typeinst == bool: 82 | return string.lower() == "true" 83 | elif typeinst == tuple or typeinst == list: 84 | return typeinst(ast.literal_eval(string)) 85 | else: 86 | raise TypeError(f"unknown type (str, tuple, list, int, float, bool), but get {typeinst}") 87 | -------------------------------------------------------------------------------- /model/model_back.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.distributions import Normal 5 | 6 | LOG_SIG_MAX = 2 7 | LOG_SIG_MIN = -20 8 | epsilon = 1e-6 9 | 10 | # Initialize Policy weights 11 | def weights_init_(m): 12 | if isinstance(m, nn.Linear): 13 | torch.nn.init.xavier_uniform_(m.weight, gain=1) 14 | torch.nn.init.constant_(m.bias, 0) 15 | 16 | 17 | class ValueNetwork(nn.Module): 18 | def __init__(self, num_inputs, hidden_dim): 19 | super(ValueNetwork, self).__init__() 20 | 21 | self.linear1 = nn.Linear(num_inputs, hidden_dim) 22 | self.linear2 = nn.Linear(hidden_dim, hidden_dim) 23 | self.linear3 = nn.Linear(hidden_dim, 1) 24 | 25 | self.apply(weights_init_) 26 | 27 | def forward(self, state): 28 | x = F.relu(self.linear1(state)) 29 | x = F.relu(self.linear2(x)) 30 | x = self.linear3(x) 31 | return x 32 | 33 | 34 | class QNetwork(nn.Module): 35 | def __init__(self, num_inputs, num_actions, hidden_dim): 36 | super(QNetwork, self).__init__() 37 | 38 | # Q1 architecture 39 | self.linear1 = nn.Linear(num_inputs + num_actions, hidden_dim) 40 | self.linear2 = nn.Linear(hidden_dim, hidden_dim) 41 | self.linear3 = nn.Linear(hidden_dim, 1) 42 | 43 | # Q2 architecture 44 | self.linear4 = nn.Linear(num_inputs + num_actions, hidden_dim) 45 | self.linear5 = nn.Linear(hidden_dim, hidden_dim) 46 | self.linear6 = nn.Linear(hidden_dim, 1) 47 | 48 | self.apply(weights_init_) 49 | 50 | def forward(self, state, action): 51 | xu = torch.cat([state, action], 1) 52 | 53 | x1 = F.relu(self.linear1(xu)) 54 | x1 = F.relu(self.linear2(x1)) 55 | x1 = self.linear3(x1) 56 | 57 | x2 = F.relu(self.linear4(xu)) 58 | x2 = F.relu(self.linear5(x2)) 59 | x2 = self.linear6(x2) 60 | 61 | return x1, x2 62 | 63 | 64 | class GaussianPolicy(nn.Module): 65 | def __init__(self, num_inputs, num_actions, hidden_dim, action_space=None): 66 | super(GaussianPolicy, self).__init__() 67 | 68 | self.linear1 = nn.Linear(num_inputs, hidden_dim) 69 | self.linear2 = nn.Linear(hidden_dim, hidden_dim) 70 | 71 | self.mean_linear = nn.Linear(hidden_dim, num_actions) 72 | self.log_std_linear = nn.Linear(hidden_dim, num_actions) 73 | 74 | self.apply(weights_init_) 75 | 76 | # action rescaling 77 | if action_space is None: 78 | self.action_scale = torch.tensor(1.) 79 | self.action_bias = torch.tensor(0.) 80 | else: 81 | self.action_scale = torch.FloatTensor( 82 | (action_space.high - action_space.low) / 2.) 83 | self.action_bias = torch.FloatTensor( 84 | (action_space.high + action_space.low) / 2.) 85 | 86 | def forward(self, state): 87 | x = F.relu(self.linear1(state)) 88 | x = F.relu(self.linear2(x)) 89 | mean = self.mean_linear(x) 90 | log_std = self.log_std_linear(x) 91 | log_std = torch.clamp(log_std, min=LOG_SIG_MIN, max=LOG_SIG_MAX) 92 | return mean, log_std 93 | 94 | def sample(self, state): 95 | mean, log_std = self.forward(state) 96 | std = log_std.exp() 97 | normal = Normal(mean, std) 98 | x_t = normal.rsample() # for reparameterization trick (mean + std * N(0,1)) 99 | y_t = torch.tanh(x_t) 100 | action = y_t * self.action_scale + self.action_bias 101 | log_prob = normal.log_prob(x_t) 102 | # Enforcing Action Bound 103 | log_prob -= torch.log(self.action_scale * (1 - y_t.pow(2)) + epsilon) 104 | log_prob = log_prob.sum(1, keepdim=True) 105 | mean = torch.tanh(mean) * self.action_scale + self.action_bias 106 | return action, log_prob, mean 107 | 108 | def to(self, device): 109 | self.action_scale = self.action_scale.to(device) 110 | self.action_bias = self.action_bias.to(device) 111 | return super(GaussianPolicy, self).to(device) 112 | 113 | 114 | class DeterministicPolicy(nn.Module): 115 | def __init__(self, num_inputs, num_actions, hidden_dim, action_space=None): 116 | super(DeterministicPolicy, self).__init__() 117 | self.linear1 = nn.Linear(num_inputs, hidden_dim) 118 | self.linear2 = nn.Linear(hidden_dim, hidden_dim) 119 | 120 | self.mean = nn.Linear(hidden_dim, num_actions) 121 | self.noise = torch.Tensor(num_actions) 122 | 123 | self.apply(weights_init_) 124 | 125 | # action rescaling 126 | if action_space is None: 127 | self.action_scale = 1. 128 | self.action_bias = 0. 129 | else: 130 | self.action_scale = torch.FloatTensor( 131 | (action_space.high - action_space.low) / 2.) 132 | self.action_bias = torch.FloatTensor( 133 | (action_space.high + action_space.low) / 2.) 134 | 135 | def forward(self, state): 136 | x = F.relu(self.linear1(state)) 137 | x = F.relu(self.linear2(x)) 138 | mean = torch.tanh(self.mean(x)) * self.action_scale + self.action_bias 139 | return mean 140 | 141 | def sample(self, state): 142 | mean = self.forward(state) 143 | noise = self.noise.normal_(0., std=0.1) 144 | noise = noise.clamp(-0.25, 0.25) 145 | action = mean + noise 146 | return action, torch.tensor(0.), mean 147 | 148 | def to(self, device): 149 | self.action_scale = self.action_scale.to(device) 150 | self.action_bias = self.action_bias.to(device) 151 | self.noise = self.noise.to(device) 152 | return super(DeterministicPolicy, self).to(device) 153 | -------------------------------------------------------------------------------- /model/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.distributions import Normal, TransformedDistribution, Distribution, Categorical, Bernoulli 5 | from torch.distributions.transforms import AffineTransform, SigmoidTransform 6 | 7 | LOG_SIG_MAX = 2 8 | LOG_SIG_MIN = -20 9 | epsilon = 1e-6 10 | 11 | # Initialize Policy weights 12 | def weights_init_(m): 13 | if isinstance(m, nn.Linear): 14 | torch.nn.init.xavier_uniform_(m.weight, gain=1) 15 | torch.nn.init.constant_(m.bias, 0) 16 | 17 | 18 | class ValueNetwork(nn.Module): 19 | def __init__(self, num_inputs, hidden_dim): 20 | super(ValueNetwork, self).__init__() 21 | 22 | self.linear1 = nn.Linear(num_inputs, hidden_dim) 23 | self.linear2 = nn.Linear(hidden_dim, hidden_dim) 24 | self.LayerNorm = nn.LayerNorm(hidden_dim) 25 | self.linear3 = nn.Linear(hidden_dim, 1) 26 | 27 | self.apply(weights_init_) 28 | 29 | def forward(self, state): 30 | x = torch.tanh(self.LayerNorm((self.linear1(state)))) 31 | x = F.elu(self.linear2(x)) 32 | x = self.linear3(x) 33 | return x 34 | 35 | 36 | class QNetwork(nn.Module): 37 | def __init__(self, num_inputs, num_actions, hidden_dim): 38 | super(QNetwork, self).__init__() 39 | 40 | # Q1 architecture 41 | self.linear1 = nn.Linear(num_inputs + num_actions, hidden_dim) 42 | self.linear2 = nn.Linear(hidden_dim, hidden_dim) 43 | self.LayerNorm1 = nn.LayerNorm(hidden_dim) 44 | self.linear3 = nn.Linear(hidden_dim, 1) 45 | 46 | # Q2 architecture 47 | self.linear4 = nn.Linear(num_inputs + num_actions, hidden_dim) 48 | self.linear5 = nn.Linear(hidden_dim, hidden_dim) 49 | self.LayerNorm2 = nn.LayerNorm(hidden_dim) 50 | self.linear6 = nn.Linear(hidden_dim, 1) 51 | 52 | self.apply(weights_init_) 53 | 54 | def forward(self, state, action): 55 | xu = torch.cat([state, action], 1) 56 | 57 | x1 = torch.tanh(self.LayerNorm1(self.linear1(xu))) 58 | x1 = F.elu(self.linear2(x1)) 59 | x1 = self.linear3(x1) 60 | 61 | x2 = torch.tanh(self.LayerNorm2(self.linear4(xu))) 62 | x2 = F.elu(self.linear5(x2)) 63 | x2 = self.linear6(x2) 64 | 65 | return x1, x2 66 | 67 | 68 | class GaussianPolicy(nn.Module): 69 | def __init__(self, num_inputs, num_actions, hidden_dim, action_space=None): 70 | super(GaussianPolicy, self).__init__() 71 | 72 | self.linear1 = nn.Linear(num_inputs, hidden_dim) 73 | self.linear2 = nn.Linear(hidden_dim, hidden_dim) 74 | self.LayerNorm = nn.LayerNorm(hidden_dim) 75 | 76 | self.mean_linear = nn.Linear(hidden_dim, num_actions) 77 | self.log_std_linear = nn.Linear(hidden_dim, num_actions) 78 | 79 | self.apply(weights_init_) 80 | 81 | # action rescaling 82 | if action_space is None: 83 | self.action_scale = torch.tensor(1.) 84 | self.action_bias = torch.tensor(0.) 85 | else: 86 | self.action_scale = torch.FloatTensor( 87 | (action_space.high - action_space.low) / 2.) 88 | self.action_bias = torch.FloatTensor( 89 | (action_space.high + action_space.low) / 2.) 90 | 91 | def forward(self, state): 92 | x = torch.tanh(self.LayerNorm((self.linear1(state)))) 93 | x = F.elu(self.linear2(x)) 94 | mean = self.mean_linear(x) 95 | log_std = self.log_std_linear(x) 96 | log_std = torch.clamp(log_std, min=LOG_SIG_MIN, max=LOG_SIG_MAX) 97 | return mean, log_std 98 | 99 | def sample(self, state): 100 | mean, log_std = self.forward(state) 101 | std = log_std.exp() 102 | normal = Normal(mean, std) 103 | x_t = normal.rsample() # for reparameterization trick (mean + std * N(0,1)) 104 | y_t = torch.tanh(x_t) 105 | action = y_t * self.action_scale + self.action_bias 106 | log_prob = normal.log_prob(x_t) 107 | # Enforcing Action Bound 108 | log_prob -= torch.log(self.action_scale * (1 - y_t.pow(2)) + epsilon) 109 | log_prob = log_prob.sum(1, keepdim=True) 110 | mean = torch.tanh(mean) * self.action_scale + self.action_bias 111 | return action, log_prob, mean 112 | 113 | 114 | def to(self, device): 115 | self.action_scale = self.action_scale.to(device) 116 | self.action_bias = self.action_bias.to(device) 117 | return super(GaussianPolicy, self).to(device) 118 | 119 | 120 | class DeterministicPolicy(nn.Module): 121 | def __init__(self, num_inputs, num_actions, hidden_dim, action_space=None): 122 | super(DeterministicPolicy, self).__init__() 123 | self.linear1 = nn.Linear(num_inputs, hidden_dim) 124 | self.linear2 = nn.Linear(hidden_dim, hidden_dim) 125 | 126 | self.mean = nn.Linear(hidden_dim, num_actions) 127 | self.noise = torch.Tensor(num_actions) 128 | 129 | self.apply(weights_init_) 130 | 131 | # action rescaling 132 | if action_space is None: 133 | self.action_scale = 1. 134 | self.action_bias = 0. 135 | else: 136 | self.action_scale = torch.FloatTensor( 137 | (action_space.high - action_space.low) / 2.) 138 | self.action_bias = torch.FloatTensor( 139 | (action_space.high + action_space.low) / 2.) 140 | 141 | def forward(self, state): 142 | x = F.relu(self.linear1(state)) 143 | x = F.relu(self.linear2(x)) 144 | mean = torch.tanh(self.mean(x)) * self.action_scale + self.action_bias 145 | return mean 146 | 147 | def sample(self, state): 148 | mean = self.forward(state) 149 | noise = self.noise.normal_(0., std=0.1) 150 | noise = noise.clamp(-0.25, 0.25) 151 | action = mean + noise 152 | return action, torch.tensor(0.), mean 153 | 154 | def to(self, device): 155 | self.action_scale = self.action_scale.to(device) 156 | self.action_bias = self.action_bias.to(device) 157 | self.noise = self.noise.to(device) 158 | return super(DeterministicPolicy, self).to(device) 159 | -------------------------------------------------------------------------------- /model/discriminator.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from torch.optim import Adam 6 | from torch import autograd 7 | from torch.optim.lr_scheduler import MultiStepLR 8 | 9 | class Discriminator_SAS(nn.Module): 10 | def __init__(self, state_dim, action_dim, args): 11 | super(Discriminator_SAS, self).__init__() 12 | 13 | self.device = torch.device("cuda:{}".format(str(args.device)) if args.cuda else "cpu") 14 | 15 | self.hidden_size = args.hidden_size 16 | 17 | self.trunk = nn.Sequential( 18 | nn.Linear(state_dim + action_dim + state_dim, self.hidden_size), nn.Tanh(), 19 | nn.Linear(self.hidden_size, self.hidden_size), nn.Tanh(), 20 | nn.Linear(self.hidden_size, 1)).to(self.device) 21 | 22 | self.trunk.train() 23 | 24 | self.optimizer = Adam(self.trunk.parameters(), lr = args.lr) 25 | self.scedular = MultiStepLR(self.optimizer, milestones=[20000, 40000], gamma=0.3) 26 | 27 | 28 | def compute_grad_pen(self, 29 | expert_state, 30 | expert_action, 31 | expert_next_state, 32 | policy_state, 33 | policy_action, 34 | policy_next_state, 35 | lambda_=10): 36 | # alpha = torch.rand(expert_state.size(0), 1) 37 | expert_data = torch.cat([expert_state, expert_action, expert_next_state], dim=1) 38 | policy_data = torch.cat([policy_state, policy_action, policy_next_state], dim=1) 39 | 40 | alpha = torch.rand(expert_data.size(0), 1).to(expert_data.device) 41 | 42 | mixup_data = alpha * expert_data + (1 - alpha) * policy_data 43 | mixup_data.requires_grad = True 44 | 45 | disc = self.trunk(mixup_data) 46 | ones = torch.ones(disc.size()).to(disc.device) 47 | grad = autograd.grad( 48 | outputs=disc, 49 | inputs=mixup_data, 50 | grad_outputs=ones, 51 | create_graph=True, 52 | retain_graph=True, 53 | only_inputs=True)[0] 54 | 55 | grad_pen = lambda_ * (grad.norm(2, dim=1) - 1).pow(2).mean() 56 | return grad_pen 57 | 58 | def compute_grad_pen_back(self, 59 | expert_state, 60 | expert_action, 61 | expert_next_state, 62 | policy_state, 63 | policy_action, 64 | policy_next_state, 65 | lambda_=10): 66 | # alpha = torch.rand(expert_state.size(0), 1) 67 | expert_data = torch.cat([torch.unsqueeze(expert_state, dim=0), torch.unsqueeze(expert_action, dim=0), torch.unsqueeze(expert_next_state, dim=0)], dim=1) 68 | policy_data = torch.cat([torch.unsqueeze(policy_state, dim=0), torch.unsqueeze(policy_action, dim=0), torch.unsqueeze(policy_next_state, dim=0)], dim=1) 69 | 70 | alpha = torch.rand(expert_data.size(0), 1).to(expert_data.device) 71 | 72 | mixup_data = alpha * expert_data + (1 - alpha) * policy_data 73 | mixup_data.requires_grad = True 74 | 75 | disc = self.trunk(mixup_data) 76 | ones = torch.ones(disc.size()).to(disc.device) 77 | grad = autograd.grad( 78 | outputs=disc, 79 | inputs=mixup_data, 80 | grad_outputs=ones, 81 | create_graph=True, 82 | retain_graph=True, 83 | only_inputs=True)[0] 84 | 85 | grad_pen = lambda_ * (grad.norm(2, dim=1) - 1).pow(2).mean() 86 | return grad_pen 87 | 88 | 89 | def update(self, expert_buffer, replay_buffer, batch_size): 90 | self.train() 91 | 92 | expert_state_batch, expert_action_batch, _, expert_next_state_batch, _ = expert_buffer.sample(batch_size=batch_size) 93 | expert_state_batch = torch.FloatTensor(expert_state_batch).to(self.device) 94 | expert_next_state_batch = torch.FloatTensor(expert_next_state_batch).to(self.device) 95 | expert_action_batch = torch.FloatTensor(expert_action_batch).to(self.device) 96 | 97 | state_batch, action_batch, _, next_state_batch, _ = replay_buffer.sample(batch_size=batch_size) 98 | state_batch = torch.FloatTensor(state_batch).to(self.device) 99 | next_state_batch = torch.FloatTensor(next_state_batch).to(self.device) 100 | action_batch = torch.FloatTensor(action_batch).to(self.device) 101 | 102 | policy_d = self.trunk(torch.cat([state_batch, action_batch, next_state_batch], dim=1)) 103 | expert_d = self.trunk(torch.cat([expert_state_batch, expert_action_batch, expert_next_state_batch], dim=1)) 104 | 105 | policy_loss = F.binary_cross_entropy_with_logits( 106 | policy_d, 107 | torch.zeros(policy_d.size()).to(self.device)) 108 | 109 | expert_loss = F.binary_cross_entropy_with_logits( 110 | expert_d, 111 | torch.ones(expert_d.size()).to(self.device)) 112 | 113 | gail_loss = expert_loss + policy_loss 114 | grad_pen = self.compute_grad_pen(expert_state_batch, expert_action_batch, expert_next_state_batch, 115 | state_batch, action_batch,next_state_batch) 116 | 117 | self.optimizer.zero_grad() 118 | (gail_loss + grad_pen).backward() 119 | self.optimizer.step() 120 | 121 | # self.scedular.step() 122 | 123 | loss = (gail_loss + grad_pen).item() 124 | 125 | return loss 126 | 127 | def update_back(self, expert_buffer, replay_buffer, batch_size): 128 | self.train() 129 | 130 | expert_state_batch, expert_action_batch, _, expert_next_state_batch, _ = expert_buffer.sample(batch_size=batch_size) 131 | expert_state_batch = torch.FloatTensor(expert_state_batch).to(self.device) 132 | expert_next_state_batch = torch.FloatTensor(expert_next_state_batch).to(self.device) 133 | expert_action_batch = torch.FloatTensor(expert_action_batch).to(self.device) 134 | 135 | state_batch, action_batch, _, next_state_batch, _ = replay_buffer.sample(batch_size=batch_size) 136 | state_batch = torch.FloatTensor(state_batch).to(self.device) 137 | next_state_batch = torch.FloatTensor(next_state_batch).to(self.device) 138 | action_batch = torch.FloatTensor(action_batch).to(self.device) 139 | 140 | loss = 0 141 | n = 0 142 | 143 | for i in range(batch_size): 144 | policy_d = self.trunk(torch.cat([torch.unsqueeze(state_batch[i], dim=0), torch.unsqueeze(action_batch[i], dim=0), torch.unsqueeze(next_state_batch[i], dim=0)], dim = 1)) 145 | expert_d = self.trunk(torch.cat([torch.unsqueeze(expert_state_batch[i],dim=0), torch.unsqueeze(expert_action_batch[i], dim=0), torch.unsqueeze(expert_next_state_batch[i], dim=0)], dim = 1)) 146 | 147 | policy_loss = F.binary_cross_entropy_with_logits( 148 | policy_d, 149 | torch.zeros(policy_d.size()).to(self.device)) 150 | 151 | expert_loss = F.binary_cross_entropy_with_logits( 152 | expert_d, 153 | torch.ones(expert_d.size()).to(self.device)) 154 | 155 | gail_loss = expert_loss + policy_loss 156 | grad_pen = self.compute_grad_pen(expert_state_batch[i], expert_action_batch[i], expert_next_state_batch[i], 157 | state_batch[i], action_batch[i],next_state_batch[i]) 158 | 159 | loss += (gail_loss + grad_pen).item() 160 | n += 1 161 | 162 | self.optimizer.zero_grad() 163 | (gail_loss + grad_pen).backward() 164 | self.optimizer.step() 165 | 166 | return loss / n 167 | 168 | def predict_reward(self, state, action, next_state): 169 | with torch.no_grad(): 170 | self.eval() 171 | d = self.trunk(torch.cat([state, action, next_state], dim=1)) 172 | s = torch.sigmoid(d) 173 | reward = s.log() - (1 - s).log() 174 | return reward -------------------------------------------------------------------------------- /main_stationary.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import gym 4 | import numpy as np 5 | from utils.config import ARGConfig 6 | from utils.default_config import default_config 7 | from model.algo import OMPO 8 | from model.discriminator import Discriminator_SAS 9 | from utils.Replaybuffer import ReplayMemory 10 | import datetime 11 | import itertools 12 | from copy import copy 13 | from torch.utils.tensorboard import SummaryWriter 14 | import shutil 15 | 16 | 17 | def train_loop(config, msg = "default"): 18 | # set seed 19 | env = gym.make(config.env_name) 20 | env.seed(config.seed) 21 | env.action_space.seed(config.seed) 22 | 23 | torch.manual_seed(config.seed) 24 | np.random.seed(config.seed) 25 | 26 | discriminator = Discriminator_SAS(env.observation_space.shape[0], env.action_space.shape[0], config) 27 | 28 | agent = OMPO(env.observation_space.shape[0], env.action_space, config) 29 | 30 | result_path = './results_mujoco_utd/{}/{}/{}/{}_{}_{}_{}'.format(config.env_name, config.algo, msg, 31 | datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S"), 32 | 'OMPO', config.seed, 33 | "autotune" if config.automatic_entropy_tuning else "") 34 | 35 | checkpoint_path = result_path + '/' + 'checkpoint' 36 | 37 | # training logs 38 | if not os.path.exists(result_path): 39 | os.makedirs(result_path) 40 | if not os.path.exists(checkpoint_path): 41 | os.makedirs(checkpoint_path) 42 | with open(os.path.join(result_path, "config.log"), 'w') as f: 43 | f.write(str(config)) 44 | 45 | writer = SummaryWriter(result_path) 46 | current_path = os.path.dirname(os.path.abspath(__file__)) 47 | files = os.listdir(current_path) 48 | files_to_save = ['main.py', 'envs', 'model', 'utilis'] 49 | ignore_files = [x for x in files if x not in files_to_save] 50 | shutil.copytree('.', result_path + '/code', ignore=shutil.ignore_patterns(*ignore_files)) 51 | 52 | # all memory 53 | memory = ReplayMemory(config.replay_size, config.seed) 54 | initial_state_memory = ReplayMemory(config.replay_size, config.seed) 55 | # expert memory 56 | local_memory = ReplayMemory(config.local_replay_size, config.seed) 57 | # sample from all memory for training 58 | temp_memory = ReplayMemory(config.local_replay_size, config.seed) 59 | 60 | for _ in range(config.batch_size): 61 | state = env.reset() 62 | initial_state_memory.push(state, 0, 0, 0, 0) # Note the initial buffer only contains state 63 | 64 | # Training Loop 65 | total_numsteps = 0 66 | updates_discriminator = 0 67 | updates_agent = 0 68 | best_reward = -1e6 69 | for i_episode in itertools.count(1): 70 | episode_reward = 0 71 | episode_steps = 0 72 | done = False 73 | state = env.reset() 74 | initial_state_memory.push(state, 0, 0, 0, 0) # Note the initial buffer only contains state 75 | 76 | while not done: 77 | if config.start_steps > total_numsteps: 78 | action = env.action_space.sample() # Sample random action 79 | else: 80 | action = agent.select_action(state) # Sample action from policy 81 | 82 | if config.start_steps <= total_numsteps: 83 | # use the same history buffer to update the discriminator 84 | if local_memory.__len__() == config.local_replay_size: 85 | for _ in range(10): # default: 10 86 | state_batch, action_batch, reward_batch, next_state_batch, mask_batch = memory.sample(batch_size=config.local_replay_size) 87 | for local_buffer_idx in range(config.local_replay_size): 88 | temp_memory.push(state_batch[local_buffer_idx], action_batch[local_buffer_idx], reward_batch[local_buffer_idx], next_state_batch[local_buffer_idx], mask_batch[local_buffer_idx]) 89 | for _ in range(20): # default: 20 90 | discriminator_loss = discriminator.update(local_memory, temp_memory, config.gail_batch) 91 | writer.add_scalar('loss/discriminator_loss', discriminator_loss, updates_discriminator) 92 | updates_discriminator += 1 93 | 94 | temp_memory = ReplayMemory(config.local_replay_size, config.seed) 95 | # reset the local memory 96 | local_memory = ReplayMemory(config.local_replay_size, config.seed) 97 | 98 | # train the agent 99 | for _ in range(config.updates_per_step): 100 | # Update parameters of all the networks 101 | critic_loss, policy_loss, ent_loss, alpha = agent.update_parameters(initial_state_memory, memory, discriminator, config.batch_size, updates_agent, writer) 102 | 103 | writer.add_scalar('loss/critic', critic_loss, updates_agent) 104 | writer.add_scalar('loss/policy', policy_loss, updates_agent) 105 | writer.add_scalar('loss/entropy_loss', ent_loss, updates_agent) 106 | writer.add_scalar('entropy_temprature/alpha', alpha, updates_agent) 107 | updates_agent += 1 108 | 109 | next_state, reward, done, _ = env.step(action) # Step 110 | episode_steps += 1 111 | total_numsteps += 1 112 | episode_reward += reward 113 | 114 | # Ignore the "done" signal if it comes from hitting the time horizon. 115 | # (https://github.com/openai/spinningup/blob/master/spinup/algos/sac/sac.py) 116 | mask = 1 if episode_steps == env._max_episode_steps else float(not done) 117 | 118 | memory.push(state, action, reward + config.reward_max, next_state, mask) # Append transition to global memory 119 | local_memory.push(state, action, reward + config.reward_max, next_state, mask) # Append transition to local memory 120 | 121 | state = next_state 122 | 123 | if total_numsteps > config.num_steps: 124 | break 125 | 126 | writer.add_scalar('train/reward', episode_reward, total_numsteps) 127 | print("Episode: {}, total numsteps: {}, episode steps: {}, reward: {}".format(i_episode, total_numsteps, episode_steps, round(episode_reward, 2))) 128 | 129 | # test agent 130 | if i_episode % config.eval_episodes == 0 and config.eval is True: 131 | avg_reward = 0. 132 | # avg_success = 0. 133 | for _ in range(config.eval_episodes): 134 | state = env.reset() 135 | episode_reward = 0 136 | done = False 137 | while not done: 138 | action = agent.select_action(state, evaluate=True) 139 | 140 | next_state, reward, done, info = env.step(action) 141 | episode_reward += reward 142 | 143 | state = next_state 144 | avg_reward += episode_reward 145 | # avg_success += float(info['is_success']) 146 | avg_reward /= config.eval_episodes 147 | # avg_success /= config.eval_episodes 148 | if avg_reward >= best_reward and config.save is True: 149 | best_reward = avg_reward 150 | agent.save_checkpoint(checkpoint_path, 'best') 151 | 152 | writer.add_scalar('test/avg_reward', avg_reward, total_numsteps) 153 | # writer.add_scalar('test/avg_success', avg_success, total_numsteps) 154 | 155 | print("----------------------------------------") 156 | print("Env: {}, Test Episodes: {}, Avg. Reward: {}".format(config.env_name, config.eval_episodes, round(avg_reward, 2))) 157 | print("----------------------------------------") 158 | 159 | env.close() 160 | 161 | # python main.py --device 2 162 | 163 | if __name__ == "__main__": 164 | arg = ARGConfig() 165 | arg.add_arg("env_name", "Humanoid-v3", "Environment name") 166 | arg.add_arg("device", 0, "Computing device") 167 | arg.add_arg("policy", "Gaussian", "Policy Type: Gaussian | Deterministic (default: Gaussian)") 168 | arg.add_arg("tag", "default", "Experiment tag") 169 | arg.add_arg("start_steps", 5000, "Number of start steps") 170 | arg.add_arg("automatic_entropy_tuning", True, "Automaically adjust α (default: True)") 171 | arg.add_arg("seed", np.random.randint(0, 1000), "experiment seed") 172 | arg.parser() 173 | 174 | config = default_config 175 | config.update(arg) 176 | 177 | print(f">>>> Training OMPO on {config.env_name} environment, on {config.device}") 178 | train_loop(config, msg=config.tag) -------------------------------------------------------------------------------- /main_transfer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import gym 4 | import numpy as np 5 | from utils.config import ARGConfig 6 | from utils.default_config import default_config 7 | from model.algo import OMPO 8 | from model.discriminator import Discriminator_SAS 9 | from utils.Replaybuffer import ReplayMemory 10 | import datetime 11 | import itertools 12 | from copy import copy 13 | from torch.utils.tensorboard import SummaryWriter 14 | import shutil 15 | import Non_stationary_env 16 | 17 | def train_loop(config, msg = "default"): 18 | # set seed 19 | sim_env = gym.make(config.env_name) 20 | sim_env.seed(config.seed) 21 | sim_env.action_space.seed(config.seed) 22 | 23 | if "Hopper" in config.env_name: 24 | real_env = gym.make(config.env_name, torso_len = 0.4, foot_len = 0.39) 25 | real_env.seed(config.seed) 26 | real_env.action_space.seed(config.seed) 27 | elif "Walker" in config.env_name: 28 | real_env = gym.make(config.env_name, torso_len = 0.4, foot_len = 0.2) 29 | real_env.seed(config.seed) 30 | real_env.action_space.seed(config.seed) 31 | elif "Ant" in config.env_name: 32 | real_env = gym.make(config.env_name, gravity = 19.62, wind = 1) 33 | real_env.seed(config.seed) 34 | real_env.action_space.seed(config.seed) 35 | elif "Humanoid" in config.env_name: 36 | real_env = gym.make(config.env_name, gravity = 19.62, wind = 1) 37 | real_env.seed(config.seed) 38 | real_env.action_space.seed(config.seed) 39 | 40 | torch.manual_seed(config.seed) 41 | np.random.seed(config.seed) 42 | 43 | discriminator = Discriminator_SAS(sim_env.observation_space.shape[0], sim_env.action_space.shape[0], config) 44 | 45 | agent = OMPO(sim_env.observation_space.shape[0], sim_env.action_space, config) 46 | 47 | result_path = './results/{}/{}/{}/{}_{}_{}_{}'.format(config.env_name, config.algo, msg, 48 | datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S"), 49 | 'OMPO', config.seed, 50 | "autotune" if config.automatic_entropy_tuning else "") 51 | 52 | checkpoint_path = result_path + '/' + 'checkpoint' 53 | 54 | # training logs 55 | if not os.path.exists(result_path): 56 | os.makedirs(result_path) 57 | if not os.path.exists(checkpoint_path): 58 | os.makedirs(checkpoint_path) 59 | with open(os.path.join(result_path, "config.log"), 'w') as f: 60 | f.write(str(config)) 61 | 62 | writer = SummaryWriter(result_path) 63 | shutil.copytree('.', result_path + '/code', ignore=shutil.ignore_patterns('results', 'results_stationary_test_1', 'results_stationary_test_2')) 64 | 65 | # all memory 66 | memory = ReplayMemory(config.replay_size, config.seed) 67 | initial_state_memory = ReplayMemory(config.replay_size, config.seed) 68 | # expert memory 69 | local_memory = ReplayMemory(config.local_replay_size, config.seed) 70 | # sample from all memory for training 71 | temp_memory = ReplayMemory(config.local_replay_size, config.seed) 72 | 73 | for _ in range(config.batch_size): 74 | state = real_env.reset() 75 | initial_state_memory.push(state, 0, 0, 0, 0) # Note the initial buffer only contains state 76 | 77 | # Training Loop 78 | total_numsteps = 0 79 | updates_discriminator = 0 80 | updates_agent = 0 81 | best_reward = -1e6 82 | for i_episode in itertools.count(1): 83 | episode_reward = 0 84 | episode_steps = 0 85 | done = False 86 | state = real_env.reset() 87 | initial_state_memory.push(state, 0, 0, 0, 0) # Note the initial buffer only contains state 88 | # sample in the real_env 89 | while not done: 90 | if config.start_steps > total_numsteps: 91 | action = real_env.action_space.sample() # Sample random action 92 | else: 93 | action = agent.select_action(state) # Sample action from policy 94 | 95 | if config.start_steps <= total_numsteps: 96 | # use the same history buffer to update the discriminator 97 | if local_memory.__len__() == config.local_replay_size: 98 | for _ in range(10): # default: 10 99 | state_batch, action_batch, reward_batch, next_state_batch, mask_batch = memory.sample(batch_size=config.local_replay_size) 100 | for local_buffer_idx in range(config.local_replay_size): 101 | temp_memory.push(state_batch[local_buffer_idx], action_batch[local_buffer_idx], reward_batch[local_buffer_idx], next_state_batch[local_buffer_idx], mask_batch[local_buffer_idx]) 102 | for _ in range(20): # default: 20 103 | discriminator_loss = discriminator.update(local_memory, temp_memory, config.gail_batch) 104 | writer.add_scalar('loss/discriminator_loss', discriminator_loss, updates_discriminator) 105 | updates_discriminator += 1 106 | 107 | temp_memory = ReplayMemory(config.local_replay_size, config.seed) 108 | # reset the local memory 109 | local_memory = ReplayMemory(config.local_replay_size, config.seed) 110 | 111 | # train the agent 112 | for _ in range(config.updates_per_step*10): 113 | # Update parameters of all the networks 114 | critic_loss, policy_loss, ent_loss, alpha = agent.update_parameters(initial_state_memory, memory, discriminator, config.batch_size, updates_agent) 115 | 116 | writer.add_scalar('loss/critic', critic_loss, updates_agent) 117 | writer.add_scalar('loss/policy', policy_loss, updates_agent) 118 | writer.add_scalar('loss/entropy_loss', ent_loss, updates_agent) 119 | writer.add_scalar('entropy_temprature/alpha', alpha, updates_agent) 120 | updates_agent += 1 121 | 122 | next_state, reward, done, _ = real_env.step(action) # Step 123 | episode_steps += 1 124 | total_numsteps += 1 125 | episode_reward += reward 126 | 127 | # Ignore the "done" signal if it comes from hitting the time horizon. 128 | # (https://github.com/openai/spinningup/blob/master/spinup/algos/sac/sac.py) 129 | mask = 1 if episode_steps == real_env._max_episode_steps else float(not done) 130 | 131 | memory.push(state, action, reward + config.reward_max, next_state, mask) # Append transition to global memory 132 | local_memory.push(state, action, reward + config.reward_max, next_state, mask) # Append transition to local memory 133 | 134 | state = next_state 135 | 136 | 137 | # sample in the sim_env 138 | for _ in range(10): 139 | state = sim_env.reset() 140 | done = False 141 | while not done: 142 | if config.start_steps > total_numsteps: 143 | action = sim_env.action_space.sample() # Sample random action 144 | else: 145 | action = agent.select_action(state) 146 | next_state, reward, done, _ = sim_env.step(action) 147 | mask = 1 if episode_steps == real_env._max_episode_steps else float(not done) 148 | memory.push(state, action, reward + config.reward_max, next_state, mask) 149 | state = next_state 150 | 151 | 152 | if total_numsteps > config.num_steps: 153 | break 154 | 155 | writer.add_scalar('train/reward', episode_reward, total_numsteps) 156 | print("Episode: {}, total numsteps: {}, episode steps: {}, reward: {}".format(i_episode, total_numsteps, episode_steps, round(episode_reward, 2))) 157 | 158 | # test agent 159 | if i_episode % config.eval_episodes == 0 and config.eval is True: 160 | avg_reward = 0. 161 | # avg_success = 0. 162 | for _ in range(config.eval_episodes): 163 | state = real_env.reset() 164 | episode_reward = 0 165 | done = False 166 | while not done: 167 | action = agent.select_action(state, evaluate=True) 168 | 169 | next_state, reward, done, info = real_env.step(action) 170 | episode_reward += reward 171 | 172 | state = next_state 173 | avg_reward += episode_reward 174 | # avg_success += float(info['is_success']) 175 | avg_reward /= config.eval_episodes 176 | # avg_success /= config.eval_episodes 177 | if avg_reward >= best_reward and config.save is True: 178 | best_reward = avg_reward 179 | agent.save_checkpoint(checkpoint_path, 'best') 180 | 181 | writer.add_scalar('test/avg_reward', avg_reward, total_numsteps) 182 | # writer.add_scalar('test/avg_success', avg_success, total_numsteps) 183 | 184 | print("----------------------------------------") 185 | print("Env: {}, Test Episodes: {}, Avg. Reward: {}".format(config.env_name, config.eval_episodes, round(avg_reward, 2))) 186 | print("----------------------------------------") 187 | 188 | # env.close() 189 | 190 | # python main.py --device 2 191 | 192 | if __name__ == "__main__": 193 | arg = ARGConfig() 194 | arg.add_arg("env_name", "AntTransfer-v0", "Environment name") 195 | arg.add_arg("device", 0, "Computing device") 196 | arg.add_arg("policy", "Gaussian", "Policy Type: Gaussian | Deterministic (default: Gaussian)") 197 | arg.add_arg("tag", "default", "Experiment tag") 198 | arg.add_arg("start_steps", 1000, "Number of start steps") 199 | arg.add_arg("automatic_entropy_tuning", True, "Automaically adjust α (default: True)") 200 | arg.add_arg("seed", np.random.randint(0, 1000), "experiment seed") 201 | arg.parser() 202 | 203 | config = default_config 204 | config.update(arg) 205 | 206 | print(f">>>> Training OLBO on {config.env_name} environment, on {config.device}") 207 | train_loop(config, msg=config.tag) -------------------------------------------------------------------------------- /main_non_stationary.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import gym 4 | import numpy as np 5 | from utils.config import ARGConfig 6 | from utils.default_config import default_config 7 | from model.algo import OMPO 8 | from model.discriminator import Discriminator_SAS 9 | from utils.Replaybuffer import ReplayMemory 10 | import datetime 11 | import itertools 12 | from copy import copy 13 | from torch.utils.tensorboard import SummaryWriter 14 | import shutil 15 | import Non_stationary_env 16 | 17 | 18 | 19 | 20 | def train_loop(config, msg = "default"): 21 | # set seed 22 | env = gym.make(config.env_name) 23 | env.seed(config.seed) 24 | env.action_space.seed(config.seed) 25 | 26 | torch.manual_seed(config.seed) 27 | np.random.seed(config.seed) 28 | 29 | discriminator = Discriminator_SAS(env.observation_space.shape[0], env.action_space.shape[0], config) 30 | 31 | agent = OMPO(env.observation_space.shape[0], env.action_space, config) 32 | 33 | result_path = './results/{}/{}/{}/{}_{}_{}_{}'.format(config.env_name, config.algo, msg, 34 | datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S"), 35 | 'OMPO', config.seed, 36 | "autotune" if config.automatic_entropy_tuning else "") 37 | 38 | checkpoint_path = result_path + '/' + 'checkpoint' 39 | 40 | # training logs 41 | if not os.path.exists(result_path): 42 | os.makedirs(result_path) 43 | if not os.path.exists(checkpoint_path): 44 | os.makedirs(checkpoint_path) 45 | with open(os.path.join(result_path, "config.log"), 'w') as f: 46 | f.write(str(config)) 47 | writer = SummaryWriter(result_path) 48 | shutil.copytree('.', result_path + '/code', ignore=shutil.ignore_patterns('results', 'results_stationary_test_1', 'results_stationary_test_2')) 49 | 50 | # all memory 51 | memory = ReplayMemory(config.replay_size, config.seed) 52 | initial_state_memory = ReplayMemory(config.replay_size, config.seed) 53 | # expert memory 54 | local_memory = ReplayMemory(config.local_replay_size, config.seed) 55 | # sample from all memory for training 56 | temp_memory = ReplayMemory(config.local_replay_size, config.seed) 57 | 58 | for _ in range(config.batch_size): 59 | if "Hopper" in config.env_name: 60 | torso_len = np.random.uniform(0.3, 0.5) 61 | foot_len = np.random.uniform(0.29, 0.49) 62 | state = env.reset() 63 | elif "Walker" in config.env_name: 64 | torso_len = np.random.uniform(0.1, 0.3) 65 | foot_len = np.random.uniform(0.05, 0.15) 66 | state = env.reset(torso_len = torso_len, foot_len = foot_len) 67 | elif "Ant" in config.env_name: 68 | gravity = np.random.uniform(9.81, 19.82) 69 | wind = np.random.uniform(0.8, 1.2) 70 | state = env.reset(gravity = gravity, wind = wind) 71 | elif "Humanoid" in config.env_name: 72 | gravity = np.random.uniform(9.81, 19.82) 73 | wind = np.random.uniform(0.5, 1.5) 74 | state = env.reset(gravity = gravity, wind = wind) 75 | else: 76 | state = env.reset() 77 | initial_state_memory.push(state, 0, 0, 0, 0) # Note the initial buffer only contains state 78 | 79 | # Training Loop 80 | total_numsteps = 0 81 | updates_discriminator = 0 82 | updates_agent = 0 83 | best_reward = -1e6 84 | for i_episode in itertools.count(1): 85 | if "Hopper" in config.env_name: 86 | torso_len = 0.4 + 0.1 * np.sin(0.2 * i_episode) 87 | foot_len = 0.39 + 0.1 * np.sin(0.2 * i_episode) 88 | state = env.reset(torso_len = torso_len, foot_len = foot_len) 89 | elif "Walker" in config.env_name: 90 | torso_len = 0.2 + 0.1 * np.sin(0.3 * i_episode) 91 | foot_len = 0.1 + 0.05 * np.sin(0.3 * i_episode) 92 | state = env.reset(torso_len = torso_len, foot_len = foot_len) 93 | elif "Ant" in config.env_name: 94 | gravity = 14.715 + 4.905 * np.sin(0.5 * i_episode) 95 | wind = 1. + 0.2 * np.sin(0.5 * i_episode) 96 | state = env.reset(gravity = gravity, wind = wind) 97 | elif "Humanoid" in config.env_name: 98 | gravity = 14.715 + 4.905 * np.sin(0.5 * i_episode) 99 | wind = 1. + 0.5 * np.sin(0.5 * i_episode) 100 | state = env.reset(gravity = gravity, wind = wind) 101 | else: 102 | state = env.reset() 103 | episode_reward = 0 104 | episode_steps = 0 105 | done = False 106 | initial_state_memory.push(state, 0, 0, 0, 0) # Note the initial buffer only contains state 107 | 108 | while not done: 109 | if config.start_steps > total_numsteps: 110 | action = env.action_space.sample() # Sample random action 111 | else: 112 | action = agent.select_action(state) # Sample action from policy 113 | 114 | if config.start_steps <= total_numsteps: 115 | # use the same history buffer to update the discriminator 116 | if local_memory.__len__() == config.local_replay_size: 117 | for _ in range(10): # default: 10 118 | state_batch, action_batch, reward_batch, next_state_batch, mask_batch = memory.sample(batch_size=config.local_replay_size) 119 | for local_buffer_idx in range(config.local_replay_size): 120 | temp_memory.push(state_batch[local_buffer_idx], action_batch[local_buffer_idx], reward_batch[local_buffer_idx], next_state_batch[local_buffer_idx], mask_batch[local_buffer_idx]) 121 | for _ in range(20): # default: 20 122 | discriminator_loss = discriminator.update(local_memory, temp_memory, config.gail_batch) 123 | writer.add_scalar('loss/discriminator_loss', discriminator_loss, updates_discriminator) 124 | updates_discriminator += 1 125 | 126 | temp_memory = ReplayMemory(config.local_replay_size, config.seed) 127 | # reset the local memory 128 | local_memory = ReplayMemory(config.local_replay_size, config.seed) 129 | 130 | # train the agent 131 | for _ in range(config.updates_per_step): 132 | # Update parameters of all the networks 133 | critic_loss, policy_loss, ent_loss, alpha = agent.update_parameters(initial_state_memory, memory, discriminator, config.batch_size, updates_agent, writer) 134 | 135 | writer.add_scalar('loss/critic', critic_loss, updates_agent) 136 | writer.add_scalar('loss/policy', policy_loss, updates_agent) 137 | writer.add_scalar('loss/entropy_loss', ent_loss, updates_agent) 138 | writer.add_scalar('entropy_temprature/alpha', alpha, updates_agent) 139 | updates_agent += 1 140 | 141 | next_state, reward, done, _ = env.step(action) # Step 142 | episode_steps += 1 143 | total_numsteps += 1 144 | episode_reward += reward 145 | 146 | # Ignore the "done" signal if it comes from hitting the time horizon. 147 | # (https://github.com/openai/spinningup/blob/master/spinup/algos/sac/sac.py) 148 | mask = 1 if episode_steps == env._max_episode_steps else float(not done) 149 | 150 | memory.push(state, action, reward + config.reward_max, next_state, mask) # Append transition to global memory 151 | local_memory.push(state, action, reward + config.reward_max, next_state, mask) # Append transition to local memory 152 | 153 | state = next_state 154 | 155 | if total_numsteps > config.num_steps: 156 | break 157 | 158 | writer.add_scalar('train/reward', episode_reward, total_numsteps) 159 | print("Episode: {}, total numsteps: {}, episode steps: {}, reward: {}".format(i_episode, total_numsteps, episode_steps, round(episode_reward, 2))) 160 | 161 | # test agent 162 | if i_episode % config.eval_episodes == 0 and config.eval is True: 163 | avg_reward = 0. 164 | # avg_success = 0. 165 | for _ in range(config.eval_episodes): 166 | if "Hopper" in config.env_name: 167 | torso_len = 0.4 + 0.1 * np.sin(0.2 * i_episode) 168 | foot_len = 0.39 + 0.1 * np.sin(0.2 * i_episode) 169 | state = env.reset() 170 | elif "Walker" in config.env_name: 171 | torso_len = 0.2 + 0.1 * np.sin(0.3 * i_episode) 172 | foot_len = 0.1 + 0.05 * np.sin(0.3 * i_episode) 173 | state = env.reset(torso_len = torso_len, foot_len = foot_len) 174 | elif "Ant" in config.env_name: 175 | gravity = 14.715 + 4.905 * np.sin(0.5 * i_episode) 176 | wind = 1. + 0.2 * np.sin(0.5 * i_episode) 177 | state = env.reset(gravity = gravity, wind = wind) 178 | elif "Humanoid" in config.env_name: 179 | gravity = 14.715 + 4.905 * np.sin(0.5 * i_episode) 180 | wind = 1. + 0.5 * np.sin(0.5 * i_episode) 181 | state = env.reset(gravity = gravity, wind = wind) 182 | else: 183 | state = env.reset() 184 | episode_reward = 0 185 | done = False 186 | while not done: 187 | action = agent.select_action(state, evaluate=True) 188 | 189 | next_state, reward, done, info = env.step(action) 190 | episode_reward += reward 191 | 192 | state = next_state 193 | avg_reward += episode_reward 194 | # avg_success += float(info['is_success']) 195 | avg_reward /= config.eval_episodes 196 | # avg_success /= config.eval_episodes 197 | if avg_reward >= best_reward and config.save is True: 198 | best_reward = avg_reward 199 | agent.save_checkpoint(checkpoint_path, 'best') 200 | 201 | writer.add_scalar('test/avg_reward', avg_reward, total_numsteps) 202 | # writer.add_scalar('test/avg_success', avg_success, total_numsteps) 203 | 204 | print("----------------------------------------") 205 | print("Env: {}, Test Episodes: {}, Avg. Reward: {}".format(config.env_name, config.eval_episodes, round(avg_reward, 2))) 206 | print("----------------------------------------") 207 | 208 | env.close() 209 | 210 | # python main.py --device 2 211 | 212 | if __name__ == "__main__": 213 | arg = ARGConfig() 214 | arg.add_arg("env_name", "AntRandom-v0", "Environment name") 215 | arg.add_arg("device", 0, "Computing device") 216 | arg.add_arg("policy", "Gaussian", "Policy Type: Gaussian | Deterministic (default: Gaussian)") 217 | arg.add_arg("tag", "default", "Experiment tag") 218 | arg.add_arg("start_steps", 5000, "Number of start steps") 219 | arg.add_arg("automatic_entropy_tuning", True, "Automaically adjust α (default: True)") 220 | arg.add_arg("seed", np.random.randint(0, 1000), "experiment seed") 221 | arg.parser() 222 | 223 | config = default_config 224 | config.update(arg) 225 | 226 | print(f">>>> Training OMPO on {config.env_name} environment, on {config.device}") 227 | train_loop(config, msg=config.tag) -------------------------------------------------------------------------------- /model/algo.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.nn.functional as F 4 | from torch.optim import Adam 5 | from torch.optim.lr_scheduler import MultiStepLR 6 | import copy 7 | from .utils import soft_update, hard_update, orthogonal_regularization 8 | from .model import GaussianPolicy, QNetwork, DeterministicPolicy 9 | 10 | epsilon = 1e-6 11 | 12 | class OMPO(object): 13 | def __init__(self, num_inputs, action_space, args): 14 | 15 | self.gamma = args.gamma 16 | self.tau = args.tau 17 | self.alpha = args.alpha 18 | 19 | self.policy_type = args.policy 20 | self.target_update_interval = args.target_update_interval 21 | self.automatic_entropy_tuning = args.automatic_entropy_tuning 22 | 23 | self.exponent = args.exponent 24 | 25 | self.actor_loss = 0 26 | self.alpha_loss = 0 27 | self.alpha_tlogs = 0 28 | 29 | if self.exponent <= 1: 30 | raise ValueError('Exponent must be greather than 1, but received %f.' % 31 | self.exponent) 32 | 33 | self.tomac_alpha = args.tomac_alpha 34 | 35 | self.f = lambda resid: torch.pow(torch.abs(resid), self.exponent) / self.exponent 36 | clip_resid = lambda resid: torch.clamp(resid, 0.0, 1e6) 37 | self.fgrad = lambda resid: torch.pow(clip_resid(resid), self.exponent - 1) 38 | 39 | self.device = torch.device("cuda:{}".format(str(args.device)) if args.cuda else "cpu") 40 | 41 | self.critic = QNetwork(num_inputs, action_space.shape[0], args.hidden_size).to(device=self.device) 42 | self.critic_optim = Adam(self.critic.parameters(), lr=args.lr) 43 | self.critic_optim_scheduler = MultiStepLR(self.critic_optim, milestones=[200000, 400000], gamma=0.2) 44 | 45 | self.critic_target = QNetwork(num_inputs, action_space.shape[0], args.hidden_size).to(self.device) 46 | hard_update(self.critic_target, self.critic) 47 | 48 | if self.policy_type == "Gaussian": 49 | # Target Entropy = −dim(A) (e.g. , -6 for HalfCheetah-v2) as given in the paper 50 | if self.automatic_entropy_tuning is True: 51 | self.target_entropy = -torch.prod(torch.Tensor(action_space.shape).to(self.device)).item() 52 | self.log_alpha = torch.zeros(1, requires_grad=True, device=self.device) # initial alpha = 1.0 53 | self.alpha_optim = Adam([self.log_alpha], lr=args.lr) 54 | self.alpha_optim_scheduler = MultiStepLR(self.alpha_optim, milestones=[100000, 200000], gamma=0.2) 55 | 56 | self.policy = GaussianPolicy(num_inputs, action_space.shape[0], args.hidden_size, action_space).to(self.device) 57 | self.policy_optim = Adam(self.policy.parameters(), lr=args.lr) 58 | self.policy_optim_scheduler = MultiStepLR(self.policy_optim, milestones=[100000, 200000], gamma=0.2) 59 | 60 | else: 61 | self.alpha = 0 62 | self.automatic_entropy_tuning = False 63 | self.policy = DeterministicPolicy(num_inputs, action_space.shape[0], args.hidden_size, action_space).to(self.device) 64 | self.policy_optim = Adam(self.policy.parameters(), lr=args.lr) 65 | 66 | def select_action(self, state, evaluate=False): 67 | state = torch.FloatTensor(state).to(self.device).unsqueeze(0) 68 | if evaluate is False: 69 | action, _, _ = self.policy.sample(state) 70 | else: 71 | _, _, action = self.policy.sample(state) 72 | return action.detach().cpu().numpy()[0] 73 | 74 | def critic_mix(self, s, a): 75 | target_q1, target_q2 = self.critic_target(s, a) 76 | target_q = torch.min(target_q1, target_q2) 77 | q1, q2 = self.critic(s, a) 78 | return q1 * 0.05 + target_q * 0.95, q2 * 0.05 + target_q * 0.95 79 | 80 | def update_critic(self, discriminator, states, actions, next_states, rewards, masks, init_states, updates, writer): 81 | init_actions, _, _ = self.policy.sample(init_states) 82 | next_actions, next_log_probs, _ = self.policy.sample(next_states) 83 | 84 | # rewards = torch.clamp(rewards, 0, torch.inf) 85 | rewards = torch.clamp(rewards, epsilon, torch.inf) 86 | 87 | d_sas_rewards = discriminator.predict_reward(states, actions, next_states) 88 | 89 | writer.add_scalar('para/d_sas_reward', torch.mean(d_sas_rewards).item(), updates) 90 | 91 | # compute the reward 92 | # rewards = torch.log(rewards + epsilon * torch.ones(rewards.shape[0]).to(self.device)) - self.tomac_alpha * d_sas_rewards 93 | # rewards = torch.log(rewards) - self.tomac_alpha * d_sas_rewards 94 | rewards = rewards - self.tomac_alpha * d_sas_rewards 95 | 96 | # rewards -= self.tomac_alpha * d_sas_rewards 97 | 98 | with torch.no_grad(): 99 | target_q1, target_q2 = self.critic_mix(next_states, next_actions) 100 | 101 | target_q1 = target_q1 - self.alpha * next_log_probs 102 | target_q2 = target_q2 - self.alpha * next_log_probs 103 | 104 | target_q1 = rewards + self.gamma * masks * target_q1 105 | target_q2 = rewards + self.gamma * masks * target_q2 106 | 107 | q1, q2 = self.critic(states, actions) 108 | init_q1, init_q2 = self.critic(init_states, init_actions) 109 | 110 | critic_loss1 = torch.mean(self.f(target_q1 - q1) + (1 - self.gamma) * init_q1 * self.tomac_alpha) 111 | critic_loss2 = torch.mean(self.f(target_q2 - q2) + (1 - self.gamma) * init_q2 * self.tomac_alpha) 112 | 113 | critic_loss = (critic_loss1 + critic_loss2) 114 | 115 | self.critic_optim.zero_grad() 116 | critic_loss.backward() 117 | self.critic_optim.step() 118 | 119 | return critic_loss.item() 120 | 121 | def update_actor(self, discriminator, states, actions, next_states, rewards, masks, init_states): 122 | init_actions, _, _ = self.policy.sample(init_states) 123 | next_actions, next_log_probs, _ = self.policy.sample(next_states) 124 | 125 | # rewards = torch.clamp(rewards, 0, torch.inf) 126 | rewards = torch.clamp(rewards, epsilon, torch.inf) 127 | 128 | d_sas_rewards = discriminator.predict_reward(states, actions, next_states) 129 | 130 | # compute the reward 131 | # rewards = torch.log(rewards + epsilon * torch.ones(rewards.shape[0]).to(self.device)) - self.tomac_alpha * d_sas_rewards 132 | # rewards = torch.log(rewards) - self.tomac_alpha * d_sas_rewards 133 | rewards = rewards - self.tomac_alpha * d_sas_rewards 134 | 135 | # rewards -= self.tomac_alpha * d_sas_rewards 136 | 137 | target_q1, target_q2 = self.critic_mix(next_states, next_actions) 138 | 139 | target_q1 = target_q1 - self.alpha * next_log_probs 140 | target_q2 = target_q2 - self.alpha * next_log_probs 141 | 142 | target_q1 = rewards + self.gamma * masks * target_q1 143 | target_q2 = rewards + self.gamma * masks * target_q2 144 | 145 | q1, q2 = self.critic(states, actions) 146 | init_q1, init_q2 = self.critic(init_states, init_actions) 147 | 148 | actor_loss1 = -torch.mean(self.fgrad(target_q1 - q1).detach() * (target_q1 - q1) + (1 - self.gamma) * init_q1 * self.tomac_alpha) 149 | actor_loss2 = -torch.mean(self.fgrad(target_q2 - q2).detach() * (target_q2 - q2) + (1 - self.gamma) * init_q2 * self.tomac_alpha) 150 | 151 | actor_loss = (actor_loss1 + actor_loss2) / 2.0 152 | # actor_loss += orthogonal_regularization(self.policy) 153 | 154 | self.policy_optim.zero_grad() 155 | actor_loss.backward() 156 | self.policy_optim.step() 157 | 158 | _, log_probs, _ = self.policy.sample(states) 159 | 160 | if self.automatic_entropy_tuning: 161 | alpha_loss = -(self.log_alpha * (log_probs + self.target_entropy).detach()).mean() 162 | 163 | self.alpha_optim.zero_grad() 164 | alpha_loss.backward() 165 | self.alpha_optim.step() 166 | 167 | # self.alpha_optim_scheduler.step() 168 | 169 | self.alpha = self.log_alpha.exp() 170 | alpha_tlogs = self.alpha.clone() # For TensorboardX logs 171 | else: 172 | alpha_loss = torch.tensor(0.).to(self.device) 173 | alpha_tlogs = torch.tensor(self.alpha) # For TensorboardX logs 174 | 175 | return actor_loss.item(), alpha_loss.item(), alpha_tlogs.item() 176 | 177 | def update_parameters(self, initial_state_memory, memory, discriminator, batch_size, updates, writer): 178 | # Sample a batch from initial_state_memory 179 | initial_state_batch, _, _, _, _ = initial_state_memory.sample(batch_size=batch_size) 180 | initial_state_batch = torch.FloatTensor(initial_state_batch).to(self.device) 181 | 182 | # Sample a batch from memory 183 | state_batch, action_batch, reward_batch, next_state_batch, mask_batch = memory.sample(batch_size=batch_size) 184 | 185 | state_batch = torch.FloatTensor(state_batch).to(self.device) 186 | next_state_batch = torch.FloatTensor(next_state_batch).to(self.device) 187 | action_batch = torch.FloatTensor(action_batch).to(self.device) 188 | reward_batch = torch.FloatTensor(reward_batch).to(self.device).unsqueeze(1) 189 | mask_batch = torch.FloatTensor(mask_batch).to(self.device).unsqueeze(1) 190 | 191 | critic_loss = self.update_critic(discriminator, state_batch, action_batch, next_state_batch, reward_batch, mask_batch, initial_state_batch, updates, writer) 192 | # self.critic_optim_scheduler.step() 193 | 194 | if updates % self.target_update_interval == 0: 195 | self.actor_loss, self.alpha_loss, self.alpha_tlogs = self.update_actor(discriminator, state_batch, action_batch, next_state_batch, reward_batch, mask_batch, initial_state_batch) 196 | soft_update(self.critic_target, self.critic, self.tau) 197 | # self.policy_optim_scheduler.step() 198 | 199 | return critic_loss, self.actor_loss, self.alpha_loss, self.alpha_tlogs 200 | 201 | # Save model parameters 202 | def save_checkpoint(self, path, i_episode): 203 | ckpt_path = path + '/' + '{}.torch'.format(i_episode) 204 | print('Saving models to {}'.format(ckpt_path)) 205 | torch.save({'policy_state_dict': self.policy.state_dict(), 206 | 'critic_state_dict': self.critic.state_dict(), 207 | 'critic_target_state_dict': self.critic_target.state_dict(), 208 | 'critic_optimizer_state_dict': self.critic_optim.state_dict(), 209 | 'policy_optimizer_state_dict': self.policy_optim.state_dict()}, ckpt_path) 210 | 211 | # Load model parameters 212 | def load_checkpoint(self, path, i_episode, evaluate=False): 213 | ckpt_path = path + '/' + '{}.torch'.format(i_episode) 214 | print('Loading models from {}'.format(ckpt_path)) 215 | if ckpt_path is not None: 216 | checkpoint = torch.load(ckpt_path) 217 | self.policy.load_state_dict(checkpoint['policy_state_dict']) 218 | self.critic.load_state_dict(checkpoint['critic_state_dict']) 219 | self.critic_target.load_state_dict(checkpoint['critic_target_state_dict']) 220 | self.critic_optim.load_state_dict(checkpoint['critic_optimizer_state_dict']) 221 | self.policy_optim.load_state_dict(checkpoint['policy_optimizer_state_dict']) 222 | 223 | if evaluate: 224 | self.policy.eval() 225 | self.critic.eval() 226 | self.critic_target.eval() 227 | else: 228 | self.policy.train() 229 | self.critic.train() 230 | self.critic_target.train() 231 | 232 | --------------------------------------------------------------------------------