├── utils
├── default_config.py
├── Replaybuffer.py
└── config.py
├── Non_stationary_env
├── __init__.py
├── ant_random.py
├── humanoid_random.py
├── hopper_random.py
├── walker2d_random.py
└── panda_env.py
├── README.md
├── model
├── utils.py
├── model_back.py
├── model.py
├── discriminator.py
└── algo.py
├── main_stationary.py
├── main_transfer.py
└── main_non_stationary.py
/utils/default_config.py:
--------------------------------------------------------------------------------
1 | from utils.config import Config
2 |
3 | default_config = Config({
4 | "seed": 0,
5 | "tag": "default",
6 | "start_steps": 5e3,
7 | "cuda": True,
8 | "num_steps": 300001,
9 | "save": True,
10 |
11 | "env_name": "HalfCheetah-v2",
12 | "eval": True,
13 | "eval_episodes": 10,
14 | "eval_times": 10,
15 | "replay_size": 1000000,
16 | "local_replay_size": 1000, # default: 1000
17 |
18 | "algo": "TOMAC",
19 | "policy": "Gaussian", # 'Policy Type: Gaussian | Deterministic (default: Gaussian)'
20 | "gamma": 0.99,
21 | "tau": 0.005,
22 | "lr": 0.0003,
23 | "alpha": 0.2,
24 | "automatic_entropy_tuning": True,
25 | "batch_size": 256,
26 | "updates_per_step": 3,
27 | "target_update_interval": 2,
28 | "hidden_size": 256,
29 | "gail_batch": 256,
30 |
31 | "exponent": 1.5, # default: 1.1
32 | "tomac_alpha": 0.001, # default: 0.001
33 | "reward_max": 1.
34 | })
35 |
--------------------------------------------------------------------------------
/utils/Replaybuffer.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import os
4 | import pickle
5 | import random
6 |
7 | class ReplayMemory:
8 | def __init__(self, capacity, seed):
9 | random.seed(seed)
10 | self.capacity = capacity
11 | self.buffer = []
12 | self.position = 0
13 |
14 | def push(self, state, action, reward, next_state, done):
15 | if len(self.buffer) < self.capacity:
16 | self.buffer.append(None)
17 | self.buffer[self.position] = (state, action, reward, next_state, done)
18 | self.position = (self.position + 1) % self.capacity
19 |
20 | def sample(self, batch_size):
21 | batch = random.sample(self.buffer, batch_size)
22 | state, action, reward, next_state, done = map(np.stack, zip(*batch))
23 | return state, action, reward, next_state, done
24 |
25 | def __len__(self):
26 | return len(self.buffer)
27 |
28 | def save_buffer(self, path):
29 | print('Saving buffer to {}'.format(path))
30 |
31 | with open(path, 'wb') as f:
32 | pickle.dump(self.buffer, f)
33 |
34 | def load_buffer(self, save_path):
35 | print('Loading buffer from {}'.format(save_path))
36 |
37 | with open(save_path, "rb") as f:
38 | self.buffer = pickle.load(f)
39 | self.position = len(self.buffer) % self.capacity
--------------------------------------------------------------------------------
/Non_stationary_env/__init__.py:
--------------------------------------------------------------------------------
1 | from gym.envs import register
2 |
3 | register(
4 | id='HopperRandom-v0',
5 | entry_point='Non_stationary_env.hopper_random:HopperRandomEnv',
6 | max_episode_steps=1000,
7 | reward_threshold=3800.0,
8 | )
9 |
10 | register(
11 | id='HopperTransfer-v0',
12 | entry_point='Non_stationary_env.hopper_random:HopperTransferEnv',
13 | max_episode_steps=1000,
14 | reward_threshold=3800.0,
15 | )
16 |
17 | register(
18 | id='Walker2dRandom-v0',
19 | entry_point='Non_stationary_env.walker2d_random:Walker2dRandomEnv',
20 | max_episode_steps=1000
21 | )
22 |
23 | register(
24 | id='Walker2dTransfer-v0',
25 | entry_point='Non_stationary_env.walker2d_random:Walker2dTransferEnv',
26 | max_episode_steps=1000
27 | )
28 |
29 | register(
30 | id='AntTransfer-v0',
31 | entry_point='Non_stationary_env.ant_random:AntTransferEnv',
32 | max_episode_steps=1000,
33 | reward_threshold=6000.0,
34 | )
35 |
36 | register(
37 | id='AntRandom-v0',
38 | entry_point='Non_stationary_env.ant_random:AntRandomEnv',
39 | max_episode_steps=1000,
40 | reward_threshold=6000.0,
41 | )
42 |
43 | register(
44 | id='HumanoidTransfer-v0',
45 | entry_point='Non_stationary_env.humanoid_random:HumanoidTransferEnv',
46 | max_episode_steps=1000
47 | )
48 |
49 | register(
50 | id='HumanoidRandom-v0',
51 | entry_point='Non_stationary_env.humanoid_random:HumanoidRandomEnv',
52 | max_episode_steps=1000
53 | )
--------------------------------------------------------------------------------
/Non_stationary_env/ant_random.py:
--------------------------------------------------------------------------------
1 | import mujoco_py
2 | import gym
3 | import numpy as np
4 | import random
5 |
6 | from gym.envs.mujoco.ant_v3 import AntEnv
7 |
8 | class AntTransferEnv(AntEnv):
9 | def __init__(self, gravity = 9.81, wind = 0, **kwargs):
10 | super().__init__(**kwargs)
11 | # self.reset(gravity = gravity, wind = wind)
12 | self.model.opt.viscosity = 0.00002
13 | self.model.opt.density = 1.2
14 | self.model.opt.gravity[:] = np.array([0., 0., -gravity])
15 | self.model.opt.wind[:] = np.array([-wind, 0., 0.])
16 |
17 | class AntRandomEnv(AntEnv):
18 |
19 | def __init__(self, **kwargs):
20 | super().__init__(**kwargs)
21 | self.model.opt.viscosity = 0.00002
22 | self.model.opt.density = 1.2
23 |
24 |
25 | def step_with_random(self, action, gravity = 9.81, wind = 0):
26 | # print("Step with new gravity = ", gravity, " wind = ", wind)
27 |
28 | self.model.opt.gravity[:] = np.array([0., 0., -gravity])
29 | self.model.opt.wind[:] = np.array([-wind, 0., 0.])
30 |
31 | return self.step(action)
32 |
33 | def reset(self, gravity = 9.81, wind = 0):
34 | '''Must called like env.reset(body_len = XXX)'''
35 | # print("Reset with new gravity = ", gravity, " wind = ", wind)
36 |
37 | self.model.opt.gravity[:] = np.array([0., 0., -gravity])
38 | self.model.opt.wind[:] = np.array([-wind, 0., 0.])
39 |
40 | return super().reset()
41 |
--------------------------------------------------------------------------------
/Non_stationary_env/humanoid_random.py:
--------------------------------------------------------------------------------
1 | import mujoco_py
2 | import gym
3 | import numpy as np
4 | import random
5 |
6 | from gym.envs.mujoco.humanoid_v3 import HumanoidEnv
7 |
8 | class HumanoidTransferEnv(HumanoidEnv):
9 | def __init__(self, gravity = 9.81, wind = 0, **kwargs):
10 | super().__init__(**kwargs)
11 | # self.reset(gravity = gravity, wind = wind)
12 | self.model.opt.viscosity = 0.00002
13 | self.model.opt.density = 1.2
14 | self.model.opt.gravity[:] = np.array([0., 0., -gravity])
15 | self.model.opt.wind[:] = np.array([-wind, 0., 0.])
16 |
17 | class HumanoidRandomEnv(HumanoidEnv):
18 |
19 | def __init__(self, **kwargs):
20 | super().__init__(**kwargs)
21 | self.model.opt.viscosity = 0.00002
22 | self.model.opt.density = 1.2
23 |
24 |
25 | def step_with_random(self, action, gravity = 9.81, wind = 0):
26 | # print("Step with new gravity = ", gravity, " wind = ", wind)
27 |
28 | self.model.opt.gravity[:] = np.array([0., 0., -gravity])
29 | self.model.opt.wind[:] = np.array([-wind, 0., 0.])
30 |
31 | return self.step(action)
32 |
33 | def reset(self, gravity = 9.81, wind = 0):
34 | '''Must called like env.reset(body_len = XXX)'''
35 | # print("Reset with new gravity = ", gravity, " wind = ", wind)
36 |
37 | self.model.opt.gravity[:] = np.array([0., 0., -gravity])
38 | self.model.opt.wind[:] = np.array([-wind, 0., 0.])
39 |
40 | return super().reset()
41 |
--------------------------------------------------------------------------------
/Non_stationary_env/hopper_random.py:
--------------------------------------------------------------------------------
1 | import mujoco_py
2 | import gym
3 | import numpy as np
4 | import random
5 |
6 | from gym.envs.mujoco.hopper_v3 import HopperEnv
7 |
8 | class HopperTransferEnv(HopperEnv):
9 | def __init__(self, torso_len: float = 0.2, foot_len: float = 0.195, **kwargs):
10 | super().__init__(**kwargs)
11 | self.model.body_pos[1][2] = 1.05 + torso_len
12 | self.model.body_pos[2][2] = -torso_len
13 | self.model.geom_size[1][1] = torso_len
14 |
15 | self.model.geom_size[4][1] = foot_len
16 |
17 |
18 |
19 | class HopperRandomEnv(HopperEnv):
20 | def __init__(self, **kwargs):
21 | super().__init__(**kwargs)
22 |
23 | def reset(self, torso_len: float = 0.2, foot_len: float = 0.195):
24 | '''Must called like env.reset(body_len = XXX)'''
25 |
26 | self.model.body_pos[1][2] = 1.05 + torso_len
27 | self.model.body_pos[2][2] = -torso_len
28 | self.model.geom_size[1][1] = torso_len
29 |
30 | self.model.geom_size[4][1] = foot_len
31 |
32 | return super().reset()
33 |
34 | if __name__ == "__main__":
35 | env = gym.make("HopperRandom-v0", torso_len = 0.2, foot_len = 0.195)
36 | env.reset(torso_len = 0.2, foot_len = 0.195)
37 |
38 | for _ in range(10):
39 | for i in range(50):
40 | env.step(np.random.rand(3))
41 | env.render()
42 |
43 | l1 = 0.2 + 0.15 * random.random()
44 | l2 = 0.195 + 0.1 * random.random()
45 | env.reset(torso_len = l1, foot_len = l2)
46 |
47 | env.close()
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
OMPO
2 |
3 | Official implementation of
4 |
5 | `OMPO: A Unified Framework for RL under Policy and Dynamics Shifts` by
6 |
7 | Yu Luo, Tianying Ji, Fuchun Sun, Jianwei Zhang, Huazhe Xu and Xianyuan Zhan
8 |
9 | ## Getting started
10 |
11 | We provide examples on how to train and evaluate **OMPO** agent.
12 |
13 | ### Training
14 |
15 | See below examples on how to train OBAC on a single task.
16 |
17 | ```python
18 | python main_stationary.py --env_name YOUR_TASK
19 | ```
20 |
21 | We recommend using default hyperparameters. See `utils/default_config.py` for a full list of arguments.
22 |
23 | ## Citation
24 |
25 | If you find our work useful, please consider citing our paper as follows:
26 |
27 | ```
28 | @inproceedings{Luo2024ompo,
29 | title={OMPO: A Unified Framework for RL under Policy and Dynamics Shifts},
30 | author={Yu Luo and Tianjing Ji and Fuchun Sun and Jianwei Zhang and Huazhe Xu and Xianyuan Zhan},
31 | booktitle={International Conference on Machine Learning},
32 | year={2024}
33 | }
34 | ```
35 |
36 | ----
37 |
38 | ## Contributing
39 |
40 | Please feel free to participate in our project by opening issues or sending pull requests for any enhancements or bug reports you might have. We’re striving to develop a codebase that’s easily expandable to different settings and tasks, and your feedback on how it’s working is greatly appreciated!
41 |
42 | ----
43 |
44 | ## License
45 |
46 | This project is licensed under the MIT License - see the `LICENSE` file for details. Note that the repository relies on third-party code, which is subject to their respective licenses.
47 |
--------------------------------------------------------------------------------
/Non_stationary_env/walker2d_random.py:
--------------------------------------------------------------------------------
1 | import mujoco_py
2 | import gym
3 | import numpy as np
4 | import random
5 |
6 | from gym.envs.mujoco.walker2d_v3 import Walker2dEnv
7 |
8 | class Walker2dTransferEnv(Walker2dEnv):
9 | def __init__(self, torso_len: float = 0.2, foot_len: float = 0.1, **kwargs):
10 | super().__init__(**kwargs)
11 | self.model.body_pos[1][2] = 1.05 + torso_len
12 | self.model.body_pos[2][2] = -torso_len
13 | self.model.body_pos[5][2] = -torso_len
14 | self.model.geom_size[1][1] = torso_len
15 |
16 | self.model.geom_size[4][1] = foot_len
17 | self.model.geom_size[7][1] = foot_len
18 |
19 |
20 | class Walker2dRandomEnv(Walker2dEnv):
21 | def __init__(self, **kwargs):
22 | super().__init__(**kwargs)
23 |
24 | def reset(self, torso_len: float = 0.2, foot_len: float = 0.1):
25 | '''Must called like env.reset(body_len = XXX)'''
26 |
27 | self.model.body_pos[1][2] = 1.05 + torso_len
28 | self.model.body_pos[2][2] = -torso_len
29 | self.model.body_pos[5][2] = -torso_len
30 | self.model.geom_size[1][1] = torso_len
31 |
32 | self.model.geom_size[4][1] = foot_len
33 | self.model.geom_size[7][1] = foot_len
34 |
35 | return super().reset()
36 |
37 | if __name__ == "__main__":
38 | env = gym.make("Walker2dRandom-v0", torso_len = 0.2, foot_len = 0.1)
39 | env.reset(torso_len = 0.2, foot_len = 0.1)
40 |
41 | for _ in range(10):
42 | for i in range(50):
43 | env.step(np.random.rand(6))
44 | env.render()
45 |
46 | l1 = 0.2 + 0.15 * random.random()
47 | l2 = 0.195 + 0.1 * random.random()
48 | env.reset(torso_len = l1, foot_len = l2)
49 |
50 | env.close()
--------------------------------------------------------------------------------
/model/utils.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 |
4 | def create_log_gaussian(mean, log_std, t):
5 | quadratic = -((0.5 * (t - mean) / (log_std.exp())).pow(2))
6 | l = mean.shape
7 | log_z = log_std
8 | z = l[-1] * math.log(2 * math.pi)
9 | log_p = quadratic.sum(dim=-1) - log_z.sum(dim=-1) - 0.5 * z
10 | return log_p
11 |
12 | def logsumexp(inputs, dim=None, keepdim=False):
13 | if dim is None:
14 | inputs = inputs.view(-1)
15 | dim = 0
16 | s, _ = torch.max(inputs, dim=dim, keepdim=True)
17 | outputs = s + (inputs - s).exp().sum(dim=dim, keepdim=True).log()
18 | if not keepdim:
19 | outputs = outputs.squeeze(dim)
20 | return outputs
21 |
22 | def soft_update(target, source, tau):
23 | for target_param, param in zip(target.parameters(), source.parameters()):
24 | target_param.data.copy_(target_param.data * (1.0 - tau) + param.data * tau)
25 |
26 | def hard_update(target, source):
27 | for target_param, param in zip(target.parameters(), source.parameters()):
28 | target_param.data.copy_(param.data)
29 |
30 | def orthogonal_regularization(model, reg_coef=1e-4):
31 | """Orthogonal regularization v2.
32 |
33 | See equation (3) in https://arxiv.org/abs/1809.11096.
34 |
35 | Args:
36 | model: A PyTorch model to apply regularization for.
37 | reg_coef: Orthogonal regularization coefficient.
38 |
39 | Returns:
40 | A regularization loss term.
41 | """
42 |
43 | reg = 0.0
44 | for module in model.modules():
45 | if isinstance(module, torch.nn.Linear):
46 | weight = module.weight
47 | prod = torch.matmul(weight.t(), weight)
48 | eye_matrix = torch.eye(prod.shape[0], device=weight.device)
49 | reg += torch.sum(torch.square(prod * (1 - eye_matrix)))
50 |
51 | return reg * reg_coef
52 |
--------------------------------------------------------------------------------
/Non_stationary_env/panda_env.py:
--------------------------------------------------------------------------------
1 | import gymnasium
2 | import panda_gym
3 | from gym.wrappers import TimeLimit
4 | import numpy as np
5 |
6 | class PandaWrapper(gymnasium.Wrapper):
7 | def __init__(self, env):
8 | gymnasium.Wrapper.__init__(self, env)
9 | def reset(self):
10 | state, info = self.env.reset()
11 | return state
12 | def step(self, action):
13 | next_state, reward, done, flag, info = self.env.step(action)
14 | return next_state, reward, done, info
15 |
16 | class PandaNoiseWrapper(gymnasium.Wrapper):
17 | def __init__(self, env):
18 | gymnasium.Wrapper.__init__(self, env)
19 | def reset(self):
20 | state, info = self.env.reset()
21 | return state
22 |
23 | def step(self, action):
24 | size = self.env.action_space.shape[0]
25 | # action = 0.02 * np.ones(size) + np.random.normal(0, np.sqrt(0.01), size=(size,)) + action
26 | next_state, reward, done, flag, info = self.env.step(action)
27 | return next_state, reward, done, info
28 |
29 |
30 | def panda_make_env(env_name, control_type="joints", reward_type="sparse", render_mode="rgb_array"):
31 | env = gymnasium.make(env_name, control_type=control_type, reward_type=reward_type, render_mode=render_mode)
32 | env = PandaWrapper(gymnasium.wrappers.FlattenObservation(env))
33 | if 'Stack' in env_name:
34 | env = TimeLimit(env, max_episode_steps=50)
35 | else:
36 | env = TimeLimit(env, max_episode_steps=100)
37 | return env
38 |
39 | def panda_make_noise_env(env_name, control_type="joints", reward_type="sparse", render_mode="rgb_array"):
40 | env = gymnasium.make(env_name, control_type=control_type, reward_type=reward_type, render_mode=render_mode)
41 | env = PandaNoiseWrapper(gymnasium.wrappers.FlattenObservation(env))
42 | if 'Stack' in env_name:
43 | env = TimeLimit(env, max_episode_steps=50)
44 | else:
45 | env = TimeLimit(env, max_episode_steps=100)
46 | return env
--------------------------------------------------------------------------------
/utils/config.py:
--------------------------------------------------------------------------------
1 | import os
2 | import ast
3 | import argparse
4 |
5 |
6 | class Config(dict):
7 | def __init__(self, seq=None, **kwargs):
8 | if seq is None:
9 | seq = {}
10 | super(Config, self).__init__(seq, **kwargs)
11 |
12 | def __setattr__(self, key, value):
13 | self[key] = value
14 |
15 | def __getattr__(self, item):
16 | return self[item]
17 |
18 | def __str__(self):
19 | disc = []
20 | for k in self:
21 | if k.startswith("_"):
22 | continue
23 | disc.append(f"{k}: {repr(self[k])},\n")
24 | return "".join(disc)
25 |
26 | def copy(self):
27 | return Config(self)
28 |
29 | def load_saved(self, path):
30 | if not os.path.isfile(path):
31 | raise FileNotFoundError(f"Error: file {path} not exists")
32 | lines = open(path, 'r').readlines()
33 | dic = {}
34 | for l in lines:
35 | key, value = l.strip().split(':', 1)
36 | if value == "":
37 | break
38 | key = key.strip()
39 | value = value.strip().rstrip(',')
40 | dic[key] = ast.literal_eval(value)
41 | self.update(dic)
42 | return self
43 |
44 |
45 | class ARGConfig(Config):
46 | def __init__(self, seq=None, **kwargs):
47 | seq = {} if seq is None else seq
48 | super(ARGConfig, self).__init__(seq, **kwargs)
49 | self._arg_dict = dict(seq, **kwargs)
50 | self._arg_help = dict()
51 |
52 | def add_arg(self, key, value, help_str=""):
53 | self._arg_dict[key] = value
54 | self._arg_help[key] = f"{help_str} (default: {value})"
55 | self[key] = value
56 |
57 | def parser(self, desc=""):
58 | # compiling arg-parser
59 | parser = argparse.ArgumentParser(description=desc)
60 | for k in self._arg_dict:
61 | arg_name = k.replace(' ', '_').replace('-', '_')
62 | help_msg = self._arg_help[k] if k in self._arg_help else ""
63 | parser.add_argument(f"--{arg_name}", type=str,
64 | default=self._arg_dict[k] if isinstance(self._arg_dict[k], str) else repr(self._arg_dict[k]),
65 | help=help_msg)
66 |
67 | pared_args = parser.parse_args().__dict__
68 |
69 | for k in self._arg_dict:
70 | arg_name = k.replace(' ', '_').replace('-', '_')
71 | self[k] = self._value_from_string(pared_args[arg_name], type(self[k]))
72 |
73 | @staticmethod
74 | def _value_from_string(string: str, typeinst: type):
75 | if typeinst == str:
76 | return string
77 | elif typeinst == int:
78 | return int(string)
79 | elif typeinst == float:
80 | return float(string)
81 | elif typeinst == bool:
82 | return string.lower() == "true"
83 | elif typeinst == tuple or typeinst == list:
84 | return typeinst(ast.literal_eval(string))
85 | else:
86 | raise TypeError(f"unknown type (str, tuple, list, int, float, bool), but get {typeinst}")
87 |
--------------------------------------------------------------------------------
/model/model_back.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from torch.distributions import Normal
5 |
6 | LOG_SIG_MAX = 2
7 | LOG_SIG_MIN = -20
8 | epsilon = 1e-6
9 |
10 | # Initialize Policy weights
11 | def weights_init_(m):
12 | if isinstance(m, nn.Linear):
13 | torch.nn.init.xavier_uniform_(m.weight, gain=1)
14 | torch.nn.init.constant_(m.bias, 0)
15 |
16 |
17 | class ValueNetwork(nn.Module):
18 | def __init__(self, num_inputs, hidden_dim):
19 | super(ValueNetwork, self).__init__()
20 |
21 | self.linear1 = nn.Linear(num_inputs, hidden_dim)
22 | self.linear2 = nn.Linear(hidden_dim, hidden_dim)
23 | self.linear3 = nn.Linear(hidden_dim, 1)
24 |
25 | self.apply(weights_init_)
26 |
27 | def forward(self, state):
28 | x = F.relu(self.linear1(state))
29 | x = F.relu(self.linear2(x))
30 | x = self.linear3(x)
31 | return x
32 |
33 |
34 | class QNetwork(nn.Module):
35 | def __init__(self, num_inputs, num_actions, hidden_dim):
36 | super(QNetwork, self).__init__()
37 |
38 | # Q1 architecture
39 | self.linear1 = nn.Linear(num_inputs + num_actions, hidden_dim)
40 | self.linear2 = nn.Linear(hidden_dim, hidden_dim)
41 | self.linear3 = nn.Linear(hidden_dim, 1)
42 |
43 | # Q2 architecture
44 | self.linear4 = nn.Linear(num_inputs + num_actions, hidden_dim)
45 | self.linear5 = nn.Linear(hidden_dim, hidden_dim)
46 | self.linear6 = nn.Linear(hidden_dim, 1)
47 |
48 | self.apply(weights_init_)
49 |
50 | def forward(self, state, action):
51 | xu = torch.cat([state, action], 1)
52 |
53 | x1 = F.relu(self.linear1(xu))
54 | x1 = F.relu(self.linear2(x1))
55 | x1 = self.linear3(x1)
56 |
57 | x2 = F.relu(self.linear4(xu))
58 | x2 = F.relu(self.linear5(x2))
59 | x2 = self.linear6(x2)
60 |
61 | return x1, x2
62 |
63 |
64 | class GaussianPolicy(nn.Module):
65 | def __init__(self, num_inputs, num_actions, hidden_dim, action_space=None):
66 | super(GaussianPolicy, self).__init__()
67 |
68 | self.linear1 = nn.Linear(num_inputs, hidden_dim)
69 | self.linear2 = nn.Linear(hidden_dim, hidden_dim)
70 |
71 | self.mean_linear = nn.Linear(hidden_dim, num_actions)
72 | self.log_std_linear = nn.Linear(hidden_dim, num_actions)
73 |
74 | self.apply(weights_init_)
75 |
76 | # action rescaling
77 | if action_space is None:
78 | self.action_scale = torch.tensor(1.)
79 | self.action_bias = torch.tensor(0.)
80 | else:
81 | self.action_scale = torch.FloatTensor(
82 | (action_space.high - action_space.low) / 2.)
83 | self.action_bias = torch.FloatTensor(
84 | (action_space.high + action_space.low) / 2.)
85 |
86 | def forward(self, state):
87 | x = F.relu(self.linear1(state))
88 | x = F.relu(self.linear2(x))
89 | mean = self.mean_linear(x)
90 | log_std = self.log_std_linear(x)
91 | log_std = torch.clamp(log_std, min=LOG_SIG_MIN, max=LOG_SIG_MAX)
92 | return mean, log_std
93 |
94 | def sample(self, state):
95 | mean, log_std = self.forward(state)
96 | std = log_std.exp()
97 | normal = Normal(mean, std)
98 | x_t = normal.rsample() # for reparameterization trick (mean + std * N(0,1))
99 | y_t = torch.tanh(x_t)
100 | action = y_t * self.action_scale + self.action_bias
101 | log_prob = normal.log_prob(x_t)
102 | # Enforcing Action Bound
103 | log_prob -= torch.log(self.action_scale * (1 - y_t.pow(2)) + epsilon)
104 | log_prob = log_prob.sum(1, keepdim=True)
105 | mean = torch.tanh(mean) * self.action_scale + self.action_bias
106 | return action, log_prob, mean
107 |
108 | def to(self, device):
109 | self.action_scale = self.action_scale.to(device)
110 | self.action_bias = self.action_bias.to(device)
111 | return super(GaussianPolicy, self).to(device)
112 |
113 |
114 | class DeterministicPolicy(nn.Module):
115 | def __init__(self, num_inputs, num_actions, hidden_dim, action_space=None):
116 | super(DeterministicPolicy, self).__init__()
117 | self.linear1 = nn.Linear(num_inputs, hidden_dim)
118 | self.linear2 = nn.Linear(hidden_dim, hidden_dim)
119 |
120 | self.mean = nn.Linear(hidden_dim, num_actions)
121 | self.noise = torch.Tensor(num_actions)
122 |
123 | self.apply(weights_init_)
124 |
125 | # action rescaling
126 | if action_space is None:
127 | self.action_scale = 1.
128 | self.action_bias = 0.
129 | else:
130 | self.action_scale = torch.FloatTensor(
131 | (action_space.high - action_space.low) / 2.)
132 | self.action_bias = torch.FloatTensor(
133 | (action_space.high + action_space.low) / 2.)
134 |
135 | def forward(self, state):
136 | x = F.relu(self.linear1(state))
137 | x = F.relu(self.linear2(x))
138 | mean = torch.tanh(self.mean(x)) * self.action_scale + self.action_bias
139 | return mean
140 |
141 | def sample(self, state):
142 | mean = self.forward(state)
143 | noise = self.noise.normal_(0., std=0.1)
144 | noise = noise.clamp(-0.25, 0.25)
145 | action = mean + noise
146 | return action, torch.tensor(0.), mean
147 |
148 | def to(self, device):
149 | self.action_scale = self.action_scale.to(device)
150 | self.action_bias = self.action_bias.to(device)
151 | self.noise = self.noise.to(device)
152 | return super(DeterministicPolicy, self).to(device)
153 |
--------------------------------------------------------------------------------
/model/model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from torch.distributions import Normal, TransformedDistribution, Distribution, Categorical, Bernoulli
5 | from torch.distributions.transforms import AffineTransform, SigmoidTransform
6 |
7 | LOG_SIG_MAX = 2
8 | LOG_SIG_MIN = -20
9 | epsilon = 1e-6
10 |
11 | # Initialize Policy weights
12 | def weights_init_(m):
13 | if isinstance(m, nn.Linear):
14 | torch.nn.init.xavier_uniform_(m.weight, gain=1)
15 | torch.nn.init.constant_(m.bias, 0)
16 |
17 |
18 | class ValueNetwork(nn.Module):
19 | def __init__(self, num_inputs, hidden_dim):
20 | super(ValueNetwork, self).__init__()
21 |
22 | self.linear1 = nn.Linear(num_inputs, hidden_dim)
23 | self.linear2 = nn.Linear(hidden_dim, hidden_dim)
24 | self.LayerNorm = nn.LayerNorm(hidden_dim)
25 | self.linear3 = nn.Linear(hidden_dim, 1)
26 |
27 | self.apply(weights_init_)
28 |
29 | def forward(self, state):
30 | x = torch.tanh(self.LayerNorm((self.linear1(state))))
31 | x = F.elu(self.linear2(x))
32 | x = self.linear3(x)
33 | return x
34 |
35 |
36 | class QNetwork(nn.Module):
37 | def __init__(self, num_inputs, num_actions, hidden_dim):
38 | super(QNetwork, self).__init__()
39 |
40 | # Q1 architecture
41 | self.linear1 = nn.Linear(num_inputs + num_actions, hidden_dim)
42 | self.linear2 = nn.Linear(hidden_dim, hidden_dim)
43 | self.LayerNorm1 = nn.LayerNorm(hidden_dim)
44 | self.linear3 = nn.Linear(hidden_dim, 1)
45 |
46 | # Q2 architecture
47 | self.linear4 = nn.Linear(num_inputs + num_actions, hidden_dim)
48 | self.linear5 = nn.Linear(hidden_dim, hidden_dim)
49 | self.LayerNorm2 = nn.LayerNorm(hidden_dim)
50 | self.linear6 = nn.Linear(hidden_dim, 1)
51 |
52 | self.apply(weights_init_)
53 |
54 | def forward(self, state, action):
55 | xu = torch.cat([state, action], 1)
56 |
57 | x1 = torch.tanh(self.LayerNorm1(self.linear1(xu)))
58 | x1 = F.elu(self.linear2(x1))
59 | x1 = self.linear3(x1)
60 |
61 | x2 = torch.tanh(self.LayerNorm2(self.linear4(xu)))
62 | x2 = F.elu(self.linear5(x2))
63 | x2 = self.linear6(x2)
64 |
65 | return x1, x2
66 |
67 |
68 | class GaussianPolicy(nn.Module):
69 | def __init__(self, num_inputs, num_actions, hidden_dim, action_space=None):
70 | super(GaussianPolicy, self).__init__()
71 |
72 | self.linear1 = nn.Linear(num_inputs, hidden_dim)
73 | self.linear2 = nn.Linear(hidden_dim, hidden_dim)
74 | self.LayerNorm = nn.LayerNorm(hidden_dim)
75 |
76 | self.mean_linear = nn.Linear(hidden_dim, num_actions)
77 | self.log_std_linear = nn.Linear(hidden_dim, num_actions)
78 |
79 | self.apply(weights_init_)
80 |
81 | # action rescaling
82 | if action_space is None:
83 | self.action_scale = torch.tensor(1.)
84 | self.action_bias = torch.tensor(0.)
85 | else:
86 | self.action_scale = torch.FloatTensor(
87 | (action_space.high - action_space.low) / 2.)
88 | self.action_bias = torch.FloatTensor(
89 | (action_space.high + action_space.low) / 2.)
90 |
91 | def forward(self, state):
92 | x = torch.tanh(self.LayerNorm((self.linear1(state))))
93 | x = F.elu(self.linear2(x))
94 | mean = self.mean_linear(x)
95 | log_std = self.log_std_linear(x)
96 | log_std = torch.clamp(log_std, min=LOG_SIG_MIN, max=LOG_SIG_MAX)
97 | return mean, log_std
98 |
99 | def sample(self, state):
100 | mean, log_std = self.forward(state)
101 | std = log_std.exp()
102 | normal = Normal(mean, std)
103 | x_t = normal.rsample() # for reparameterization trick (mean + std * N(0,1))
104 | y_t = torch.tanh(x_t)
105 | action = y_t * self.action_scale + self.action_bias
106 | log_prob = normal.log_prob(x_t)
107 | # Enforcing Action Bound
108 | log_prob -= torch.log(self.action_scale * (1 - y_t.pow(2)) + epsilon)
109 | log_prob = log_prob.sum(1, keepdim=True)
110 | mean = torch.tanh(mean) * self.action_scale + self.action_bias
111 | return action, log_prob, mean
112 |
113 |
114 | def to(self, device):
115 | self.action_scale = self.action_scale.to(device)
116 | self.action_bias = self.action_bias.to(device)
117 | return super(GaussianPolicy, self).to(device)
118 |
119 |
120 | class DeterministicPolicy(nn.Module):
121 | def __init__(self, num_inputs, num_actions, hidden_dim, action_space=None):
122 | super(DeterministicPolicy, self).__init__()
123 | self.linear1 = nn.Linear(num_inputs, hidden_dim)
124 | self.linear2 = nn.Linear(hidden_dim, hidden_dim)
125 |
126 | self.mean = nn.Linear(hidden_dim, num_actions)
127 | self.noise = torch.Tensor(num_actions)
128 |
129 | self.apply(weights_init_)
130 |
131 | # action rescaling
132 | if action_space is None:
133 | self.action_scale = 1.
134 | self.action_bias = 0.
135 | else:
136 | self.action_scale = torch.FloatTensor(
137 | (action_space.high - action_space.low) / 2.)
138 | self.action_bias = torch.FloatTensor(
139 | (action_space.high + action_space.low) / 2.)
140 |
141 | def forward(self, state):
142 | x = F.relu(self.linear1(state))
143 | x = F.relu(self.linear2(x))
144 | mean = torch.tanh(self.mean(x)) * self.action_scale + self.action_bias
145 | return mean
146 |
147 | def sample(self, state):
148 | mean = self.forward(state)
149 | noise = self.noise.normal_(0., std=0.1)
150 | noise = noise.clamp(-0.25, 0.25)
151 | action = mean + noise
152 | return action, torch.tensor(0.), mean
153 |
154 | def to(self, device):
155 | self.action_scale = self.action_scale.to(device)
156 | self.action_bias = self.action_bias.to(device)
157 | self.noise = self.noise.to(device)
158 | return super(DeterministicPolicy, self).to(device)
159 |
--------------------------------------------------------------------------------
/model/discriminator.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | from torch.optim import Adam
6 | from torch import autograd
7 | from torch.optim.lr_scheduler import MultiStepLR
8 |
9 | class Discriminator_SAS(nn.Module):
10 | def __init__(self, state_dim, action_dim, args):
11 | super(Discriminator_SAS, self).__init__()
12 |
13 | self.device = torch.device("cuda:{}".format(str(args.device)) if args.cuda else "cpu")
14 |
15 | self.hidden_size = args.hidden_size
16 |
17 | self.trunk = nn.Sequential(
18 | nn.Linear(state_dim + action_dim + state_dim, self.hidden_size), nn.Tanh(),
19 | nn.Linear(self.hidden_size, self.hidden_size), nn.Tanh(),
20 | nn.Linear(self.hidden_size, 1)).to(self.device)
21 |
22 | self.trunk.train()
23 |
24 | self.optimizer = Adam(self.trunk.parameters(), lr = args.lr)
25 | self.scedular = MultiStepLR(self.optimizer, milestones=[20000, 40000], gamma=0.3)
26 |
27 |
28 | def compute_grad_pen(self,
29 | expert_state,
30 | expert_action,
31 | expert_next_state,
32 | policy_state,
33 | policy_action,
34 | policy_next_state,
35 | lambda_=10):
36 | # alpha = torch.rand(expert_state.size(0), 1)
37 | expert_data = torch.cat([expert_state, expert_action, expert_next_state], dim=1)
38 | policy_data = torch.cat([policy_state, policy_action, policy_next_state], dim=1)
39 |
40 | alpha = torch.rand(expert_data.size(0), 1).to(expert_data.device)
41 |
42 | mixup_data = alpha * expert_data + (1 - alpha) * policy_data
43 | mixup_data.requires_grad = True
44 |
45 | disc = self.trunk(mixup_data)
46 | ones = torch.ones(disc.size()).to(disc.device)
47 | grad = autograd.grad(
48 | outputs=disc,
49 | inputs=mixup_data,
50 | grad_outputs=ones,
51 | create_graph=True,
52 | retain_graph=True,
53 | only_inputs=True)[0]
54 |
55 | grad_pen = lambda_ * (grad.norm(2, dim=1) - 1).pow(2).mean()
56 | return grad_pen
57 |
58 | def compute_grad_pen_back(self,
59 | expert_state,
60 | expert_action,
61 | expert_next_state,
62 | policy_state,
63 | policy_action,
64 | policy_next_state,
65 | lambda_=10):
66 | # alpha = torch.rand(expert_state.size(0), 1)
67 | expert_data = torch.cat([torch.unsqueeze(expert_state, dim=0), torch.unsqueeze(expert_action, dim=0), torch.unsqueeze(expert_next_state, dim=0)], dim=1)
68 | policy_data = torch.cat([torch.unsqueeze(policy_state, dim=0), torch.unsqueeze(policy_action, dim=0), torch.unsqueeze(policy_next_state, dim=0)], dim=1)
69 |
70 | alpha = torch.rand(expert_data.size(0), 1).to(expert_data.device)
71 |
72 | mixup_data = alpha * expert_data + (1 - alpha) * policy_data
73 | mixup_data.requires_grad = True
74 |
75 | disc = self.trunk(mixup_data)
76 | ones = torch.ones(disc.size()).to(disc.device)
77 | grad = autograd.grad(
78 | outputs=disc,
79 | inputs=mixup_data,
80 | grad_outputs=ones,
81 | create_graph=True,
82 | retain_graph=True,
83 | only_inputs=True)[0]
84 |
85 | grad_pen = lambda_ * (grad.norm(2, dim=1) - 1).pow(2).mean()
86 | return grad_pen
87 |
88 |
89 | def update(self, expert_buffer, replay_buffer, batch_size):
90 | self.train()
91 |
92 | expert_state_batch, expert_action_batch, _, expert_next_state_batch, _ = expert_buffer.sample(batch_size=batch_size)
93 | expert_state_batch = torch.FloatTensor(expert_state_batch).to(self.device)
94 | expert_next_state_batch = torch.FloatTensor(expert_next_state_batch).to(self.device)
95 | expert_action_batch = torch.FloatTensor(expert_action_batch).to(self.device)
96 |
97 | state_batch, action_batch, _, next_state_batch, _ = replay_buffer.sample(batch_size=batch_size)
98 | state_batch = torch.FloatTensor(state_batch).to(self.device)
99 | next_state_batch = torch.FloatTensor(next_state_batch).to(self.device)
100 | action_batch = torch.FloatTensor(action_batch).to(self.device)
101 |
102 | policy_d = self.trunk(torch.cat([state_batch, action_batch, next_state_batch], dim=1))
103 | expert_d = self.trunk(torch.cat([expert_state_batch, expert_action_batch, expert_next_state_batch], dim=1))
104 |
105 | policy_loss = F.binary_cross_entropy_with_logits(
106 | policy_d,
107 | torch.zeros(policy_d.size()).to(self.device))
108 |
109 | expert_loss = F.binary_cross_entropy_with_logits(
110 | expert_d,
111 | torch.ones(expert_d.size()).to(self.device))
112 |
113 | gail_loss = expert_loss + policy_loss
114 | grad_pen = self.compute_grad_pen(expert_state_batch, expert_action_batch, expert_next_state_batch,
115 | state_batch, action_batch,next_state_batch)
116 |
117 | self.optimizer.zero_grad()
118 | (gail_loss + grad_pen).backward()
119 | self.optimizer.step()
120 |
121 | # self.scedular.step()
122 |
123 | loss = (gail_loss + grad_pen).item()
124 |
125 | return loss
126 |
127 | def update_back(self, expert_buffer, replay_buffer, batch_size):
128 | self.train()
129 |
130 | expert_state_batch, expert_action_batch, _, expert_next_state_batch, _ = expert_buffer.sample(batch_size=batch_size)
131 | expert_state_batch = torch.FloatTensor(expert_state_batch).to(self.device)
132 | expert_next_state_batch = torch.FloatTensor(expert_next_state_batch).to(self.device)
133 | expert_action_batch = torch.FloatTensor(expert_action_batch).to(self.device)
134 |
135 | state_batch, action_batch, _, next_state_batch, _ = replay_buffer.sample(batch_size=batch_size)
136 | state_batch = torch.FloatTensor(state_batch).to(self.device)
137 | next_state_batch = torch.FloatTensor(next_state_batch).to(self.device)
138 | action_batch = torch.FloatTensor(action_batch).to(self.device)
139 |
140 | loss = 0
141 | n = 0
142 |
143 | for i in range(batch_size):
144 | policy_d = self.trunk(torch.cat([torch.unsqueeze(state_batch[i], dim=0), torch.unsqueeze(action_batch[i], dim=0), torch.unsqueeze(next_state_batch[i], dim=0)], dim = 1))
145 | expert_d = self.trunk(torch.cat([torch.unsqueeze(expert_state_batch[i],dim=0), torch.unsqueeze(expert_action_batch[i], dim=0), torch.unsqueeze(expert_next_state_batch[i], dim=0)], dim = 1))
146 |
147 | policy_loss = F.binary_cross_entropy_with_logits(
148 | policy_d,
149 | torch.zeros(policy_d.size()).to(self.device))
150 |
151 | expert_loss = F.binary_cross_entropy_with_logits(
152 | expert_d,
153 | torch.ones(expert_d.size()).to(self.device))
154 |
155 | gail_loss = expert_loss + policy_loss
156 | grad_pen = self.compute_grad_pen(expert_state_batch[i], expert_action_batch[i], expert_next_state_batch[i],
157 | state_batch[i], action_batch[i],next_state_batch[i])
158 |
159 | loss += (gail_loss + grad_pen).item()
160 | n += 1
161 |
162 | self.optimizer.zero_grad()
163 | (gail_loss + grad_pen).backward()
164 | self.optimizer.step()
165 |
166 | return loss / n
167 |
168 | def predict_reward(self, state, action, next_state):
169 | with torch.no_grad():
170 | self.eval()
171 | d = self.trunk(torch.cat([state, action, next_state], dim=1))
172 | s = torch.sigmoid(d)
173 | reward = s.log() - (1 - s).log()
174 | return reward
--------------------------------------------------------------------------------
/main_stationary.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import gym
4 | import numpy as np
5 | from utils.config import ARGConfig
6 | from utils.default_config import default_config
7 | from model.algo import OMPO
8 | from model.discriminator import Discriminator_SAS
9 | from utils.Replaybuffer import ReplayMemory
10 | import datetime
11 | import itertools
12 | from copy import copy
13 | from torch.utils.tensorboard import SummaryWriter
14 | import shutil
15 |
16 |
17 | def train_loop(config, msg = "default"):
18 | # set seed
19 | env = gym.make(config.env_name)
20 | env.seed(config.seed)
21 | env.action_space.seed(config.seed)
22 |
23 | torch.manual_seed(config.seed)
24 | np.random.seed(config.seed)
25 |
26 | discriminator = Discriminator_SAS(env.observation_space.shape[0], env.action_space.shape[0], config)
27 |
28 | agent = OMPO(env.observation_space.shape[0], env.action_space, config)
29 |
30 | result_path = './results_mujoco_utd/{}/{}/{}/{}_{}_{}_{}'.format(config.env_name, config.algo, msg,
31 | datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S"),
32 | 'OMPO', config.seed,
33 | "autotune" if config.automatic_entropy_tuning else "")
34 |
35 | checkpoint_path = result_path + '/' + 'checkpoint'
36 |
37 | # training logs
38 | if not os.path.exists(result_path):
39 | os.makedirs(result_path)
40 | if not os.path.exists(checkpoint_path):
41 | os.makedirs(checkpoint_path)
42 | with open(os.path.join(result_path, "config.log"), 'w') as f:
43 | f.write(str(config))
44 |
45 | writer = SummaryWriter(result_path)
46 | current_path = os.path.dirname(os.path.abspath(__file__))
47 | files = os.listdir(current_path)
48 | files_to_save = ['main.py', 'envs', 'model', 'utilis']
49 | ignore_files = [x for x in files if x not in files_to_save]
50 | shutil.copytree('.', result_path + '/code', ignore=shutil.ignore_patterns(*ignore_files))
51 |
52 | # all memory
53 | memory = ReplayMemory(config.replay_size, config.seed)
54 | initial_state_memory = ReplayMemory(config.replay_size, config.seed)
55 | # expert memory
56 | local_memory = ReplayMemory(config.local_replay_size, config.seed)
57 | # sample from all memory for training
58 | temp_memory = ReplayMemory(config.local_replay_size, config.seed)
59 |
60 | for _ in range(config.batch_size):
61 | state = env.reset()
62 | initial_state_memory.push(state, 0, 0, 0, 0) # Note the initial buffer only contains state
63 |
64 | # Training Loop
65 | total_numsteps = 0
66 | updates_discriminator = 0
67 | updates_agent = 0
68 | best_reward = -1e6
69 | for i_episode in itertools.count(1):
70 | episode_reward = 0
71 | episode_steps = 0
72 | done = False
73 | state = env.reset()
74 | initial_state_memory.push(state, 0, 0, 0, 0) # Note the initial buffer only contains state
75 |
76 | while not done:
77 | if config.start_steps > total_numsteps:
78 | action = env.action_space.sample() # Sample random action
79 | else:
80 | action = agent.select_action(state) # Sample action from policy
81 |
82 | if config.start_steps <= total_numsteps:
83 | # use the same history buffer to update the discriminator
84 | if local_memory.__len__() == config.local_replay_size:
85 | for _ in range(10): # default: 10
86 | state_batch, action_batch, reward_batch, next_state_batch, mask_batch = memory.sample(batch_size=config.local_replay_size)
87 | for local_buffer_idx in range(config.local_replay_size):
88 | temp_memory.push(state_batch[local_buffer_idx], action_batch[local_buffer_idx], reward_batch[local_buffer_idx], next_state_batch[local_buffer_idx], mask_batch[local_buffer_idx])
89 | for _ in range(20): # default: 20
90 | discriminator_loss = discriminator.update(local_memory, temp_memory, config.gail_batch)
91 | writer.add_scalar('loss/discriminator_loss', discriminator_loss, updates_discriminator)
92 | updates_discriminator += 1
93 |
94 | temp_memory = ReplayMemory(config.local_replay_size, config.seed)
95 | # reset the local memory
96 | local_memory = ReplayMemory(config.local_replay_size, config.seed)
97 |
98 | # train the agent
99 | for _ in range(config.updates_per_step):
100 | # Update parameters of all the networks
101 | critic_loss, policy_loss, ent_loss, alpha = agent.update_parameters(initial_state_memory, memory, discriminator, config.batch_size, updates_agent, writer)
102 |
103 | writer.add_scalar('loss/critic', critic_loss, updates_agent)
104 | writer.add_scalar('loss/policy', policy_loss, updates_agent)
105 | writer.add_scalar('loss/entropy_loss', ent_loss, updates_agent)
106 | writer.add_scalar('entropy_temprature/alpha', alpha, updates_agent)
107 | updates_agent += 1
108 |
109 | next_state, reward, done, _ = env.step(action) # Step
110 | episode_steps += 1
111 | total_numsteps += 1
112 | episode_reward += reward
113 |
114 | # Ignore the "done" signal if it comes from hitting the time horizon.
115 | # (https://github.com/openai/spinningup/blob/master/spinup/algos/sac/sac.py)
116 | mask = 1 if episode_steps == env._max_episode_steps else float(not done)
117 |
118 | memory.push(state, action, reward + config.reward_max, next_state, mask) # Append transition to global memory
119 | local_memory.push(state, action, reward + config.reward_max, next_state, mask) # Append transition to local memory
120 |
121 | state = next_state
122 |
123 | if total_numsteps > config.num_steps:
124 | break
125 |
126 | writer.add_scalar('train/reward', episode_reward, total_numsteps)
127 | print("Episode: {}, total numsteps: {}, episode steps: {}, reward: {}".format(i_episode, total_numsteps, episode_steps, round(episode_reward, 2)))
128 |
129 | # test agent
130 | if i_episode % config.eval_episodes == 0 and config.eval is True:
131 | avg_reward = 0.
132 | # avg_success = 0.
133 | for _ in range(config.eval_episodes):
134 | state = env.reset()
135 | episode_reward = 0
136 | done = False
137 | while not done:
138 | action = agent.select_action(state, evaluate=True)
139 |
140 | next_state, reward, done, info = env.step(action)
141 | episode_reward += reward
142 |
143 | state = next_state
144 | avg_reward += episode_reward
145 | # avg_success += float(info['is_success'])
146 | avg_reward /= config.eval_episodes
147 | # avg_success /= config.eval_episodes
148 | if avg_reward >= best_reward and config.save is True:
149 | best_reward = avg_reward
150 | agent.save_checkpoint(checkpoint_path, 'best')
151 |
152 | writer.add_scalar('test/avg_reward', avg_reward, total_numsteps)
153 | # writer.add_scalar('test/avg_success', avg_success, total_numsteps)
154 |
155 | print("----------------------------------------")
156 | print("Env: {}, Test Episodes: {}, Avg. Reward: {}".format(config.env_name, config.eval_episodes, round(avg_reward, 2)))
157 | print("----------------------------------------")
158 |
159 | env.close()
160 |
161 | # python main.py --device 2
162 |
163 | if __name__ == "__main__":
164 | arg = ARGConfig()
165 | arg.add_arg("env_name", "Humanoid-v3", "Environment name")
166 | arg.add_arg("device", 0, "Computing device")
167 | arg.add_arg("policy", "Gaussian", "Policy Type: Gaussian | Deterministic (default: Gaussian)")
168 | arg.add_arg("tag", "default", "Experiment tag")
169 | arg.add_arg("start_steps", 5000, "Number of start steps")
170 | arg.add_arg("automatic_entropy_tuning", True, "Automaically adjust α (default: True)")
171 | arg.add_arg("seed", np.random.randint(0, 1000), "experiment seed")
172 | arg.parser()
173 |
174 | config = default_config
175 | config.update(arg)
176 |
177 | print(f">>>> Training OMPO on {config.env_name} environment, on {config.device}")
178 | train_loop(config, msg=config.tag)
--------------------------------------------------------------------------------
/main_transfer.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import gym
4 | import numpy as np
5 | from utils.config import ARGConfig
6 | from utils.default_config import default_config
7 | from model.algo import OMPO
8 | from model.discriminator import Discriminator_SAS
9 | from utils.Replaybuffer import ReplayMemory
10 | import datetime
11 | import itertools
12 | from copy import copy
13 | from torch.utils.tensorboard import SummaryWriter
14 | import shutil
15 | import Non_stationary_env
16 |
17 | def train_loop(config, msg = "default"):
18 | # set seed
19 | sim_env = gym.make(config.env_name)
20 | sim_env.seed(config.seed)
21 | sim_env.action_space.seed(config.seed)
22 |
23 | if "Hopper" in config.env_name:
24 | real_env = gym.make(config.env_name, torso_len = 0.4, foot_len = 0.39)
25 | real_env.seed(config.seed)
26 | real_env.action_space.seed(config.seed)
27 | elif "Walker" in config.env_name:
28 | real_env = gym.make(config.env_name, torso_len = 0.4, foot_len = 0.2)
29 | real_env.seed(config.seed)
30 | real_env.action_space.seed(config.seed)
31 | elif "Ant" in config.env_name:
32 | real_env = gym.make(config.env_name, gravity = 19.62, wind = 1)
33 | real_env.seed(config.seed)
34 | real_env.action_space.seed(config.seed)
35 | elif "Humanoid" in config.env_name:
36 | real_env = gym.make(config.env_name, gravity = 19.62, wind = 1)
37 | real_env.seed(config.seed)
38 | real_env.action_space.seed(config.seed)
39 |
40 | torch.manual_seed(config.seed)
41 | np.random.seed(config.seed)
42 |
43 | discriminator = Discriminator_SAS(sim_env.observation_space.shape[0], sim_env.action_space.shape[0], config)
44 |
45 | agent = OMPO(sim_env.observation_space.shape[0], sim_env.action_space, config)
46 |
47 | result_path = './results/{}/{}/{}/{}_{}_{}_{}'.format(config.env_name, config.algo, msg,
48 | datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S"),
49 | 'OMPO', config.seed,
50 | "autotune" if config.automatic_entropy_tuning else "")
51 |
52 | checkpoint_path = result_path + '/' + 'checkpoint'
53 |
54 | # training logs
55 | if not os.path.exists(result_path):
56 | os.makedirs(result_path)
57 | if not os.path.exists(checkpoint_path):
58 | os.makedirs(checkpoint_path)
59 | with open(os.path.join(result_path, "config.log"), 'w') as f:
60 | f.write(str(config))
61 |
62 | writer = SummaryWriter(result_path)
63 | shutil.copytree('.', result_path + '/code', ignore=shutil.ignore_patterns('results', 'results_stationary_test_1', 'results_stationary_test_2'))
64 |
65 | # all memory
66 | memory = ReplayMemory(config.replay_size, config.seed)
67 | initial_state_memory = ReplayMemory(config.replay_size, config.seed)
68 | # expert memory
69 | local_memory = ReplayMemory(config.local_replay_size, config.seed)
70 | # sample from all memory for training
71 | temp_memory = ReplayMemory(config.local_replay_size, config.seed)
72 |
73 | for _ in range(config.batch_size):
74 | state = real_env.reset()
75 | initial_state_memory.push(state, 0, 0, 0, 0) # Note the initial buffer only contains state
76 |
77 | # Training Loop
78 | total_numsteps = 0
79 | updates_discriminator = 0
80 | updates_agent = 0
81 | best_reward = -1e6
82 | for i_episode in itertools.count(1):
83 | episode_reward = 0
84 | episode_steps = 0
85 | done = False
86 | state = real_env.reset()
87 | initial_state_memory.push(state, 0, 0, 0, 0) # Note the initial buffer only contains state
88 | # sample in the real_env
89 | while not done:
90 | if config.start_steps > total_numsteps:
91 | action = real_env.action_space.sample() # Sample random action
92 | else:
93 | action = agent.select_action(state) # Sample action from policy
94 |
95 | if config.start_steps <= total_numsteps:
96 | # use the same history buffer to update the discriminator
97 | if local_memory.__len__() == config.local_replay_size:
98 | for _ in range(10): # default: 10
99 | state_batch, action_batch, reward_batch, next_state_batch, mask_batch = memory.sample(batch_size=config.local_replay_size)
100 | for local_buffer_idx in range(config.local_replay_size):
101 | temp_memory.push(state_batch[local_buffer_idx], action_batch[local_buffer_idx], reward_batch[local_buffer_idx], next_state_batch[local_buffer_idx], mask_batch[local_buffer_idx])
102 | for _ in range(20): # default: 20
103 | discriminator_loss = discriminator.update(local_memory, temp_memory, config.gail_batch)
104 | writer.add_scalar('loss/discriminator_loss', discriminator_loss, updates_discriminator)
105 | updates_discriminator += 1
106 |
107 | temp_memory = ReplayMemory(config.local_replay_size, config.seed)
108 | # reset the local memory
109 | local_memory = ReplayMemory(config.local_replay_size, config.seed)
110 |
111 | # train the agent
112 | for _ in range(config.updates_per_step*10):
113 | # Update parameters of all the networks
114 | critic_loss, policy_loss, ent_loss, alpha = agent.update_parameters(initial_state_memory, memory, discriminator, config.batch_size, updates_agent)
115 |
116 | writer.add_scalar('loss/critic', critic_loss, updates_agent)
117 | writer.add_scalar('loss/policy', policy_loss, updates_agent)
118 | writer.add_scalar('loss/entropy_loss', ent_loss, updates_agent)
119 | writer.add_scalar('entropy_temprature/alpha', alpha, updates_agent)
120 | updates_agent += 1
121 |
122 | next_state, reward, done, _ = real_env.step(action) # Step
123 | episode_steps += 1
124 | total_numsteps += 1
125 | episode_reward += reward
126 |
127 | # Ignore the "done" signal if it comes from hitting the time horizon.
128 | # (https://github.com/openai/spinningup/blob/master/spinup/algos/sac/sac.py)
129 | mask = 1 if episode_steps == real_env._max_episode_steps else float(not done)
130 |
131 | memory.push(state, action, reward + config.reward_max, next_state, mask) # Append transition to global memory
132 | local_memory.push(state, action, reward + config.reward_max, next_state, mask) # Append transition to local memory
133 |
134 | state = next_state
135 |
136 |
137 | # sample in the sim_env
138 | for _ in range(10):
139 | state = sim_env.reset()
140 | done = False
141 | while not done:
142 | if config.start_steps > total_numsteps:
143 | action = sim_env.action_space.sample() # Sample random action
144 | else:
145 | action = agent.select_action(state)
146 | next_state, reward, done, _ = sim_env.step(action)
147 | mask = 1 if episode_steps == real_env._max_episode_steps else float(not done)
148 | memory.push(state, action, reward + config.reward_max, next_state, mask)
149 | state = next_state
150 |
151 |
152 | if total_numsteps > config.num_steps:
153 | break
154 |
155 | writer.add_scalar('train/reward', episode_reward, total_numsteps)
156 | print("Episode: {}, total numsteps: {}, episode steps: {}, reward: {}".format(i_episode, total_numsteps, episode_steps, round(episode_reward, 2)))
157 |
158 | # test agent
159 | if i_episode % config.eval_episodes == 0 and config.eval is True:
160 | avg_reward = 0.
161 | # avg_success = 0.
162 | for _ in range(config.eval_episodes):
163 | state = real_env.reset()
164 | episode_reward = 0
165 | done = False
166 | while not done:
167 | action = agent.select_action(state, evaluate=True)
168 |
169 | next_state, reward, done, info = real_env.step(action)
170 | episode_reward += reward
171 |
172 | state = next_state
173 | avg_reward += episode_reward
174 | # avg_success += float(info['is_success'])
175 | avg_reward /= config.eval_episodes
176 | # avg_success /= config.eval_episodes
177 | if avg_reward >= best_reward and config.save is True:
178 | best_reward = avg_reward
179 | agent.save_checkpoint(checkpoint_path, 'best')
180 |
181 | writer.add_scalar('test/avg_reward', avg_reward, total_numsteps)
182 | # writer.add_scalar('test/avg_success', avg_success, total_numsteps)
183 |
184 | print("----------------------------------------")
185 | print("Env: {}, Test Episodes: {}, Avg. Reward: {}".format(config.env_name, config.eval_episodes, round(avg_reward, 2)))
186 | print("----------------------------------------")
187 |
188 | # env.close()
189 |
190 | # python main.py --device 2
191 |
192 | if __name__ == "__main__":
193 | arg = ARGConfig()
194 | arg.add_arg("env_name", "AntTransfer-v0", "Environment name")
195 | arg.add_arg("device", 0, "Computing device")
196 | arg.add_arg("policy", "Gaussian", "Policy Type: Gaussian | Deterministic (default: Gaussian)")
197 | arg.add_arg("tag", "default", "Experiment tag")
198 | arg.add_arg("start_steps", 1000, "Number of start steps")
199 | arg.add_arg("automatic_entropy_tuning", True, "Automaically adjust α (default: True)")
200 | arg.add_arg("seed", np.random.randint(0, 1000), "experiment seed")
201 | arg.parser()
202 |
203 | config = default_config
204 | config.update(arg)
205 |
206 | print(f">>>> Training OLBO on {config.env_name} environment, on {config.device}")
207 | train_loop(config, msg=config.tag)
--------------------------------------------------------------------------------
/main_non_stationary.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import gym
4 | import numpy as np
5 | from utils.config import ARGConfig
6 | from utils.default_config import default_config
7 | from model.algo import OMPO
8 | from model.discriminator import Discriminator_SAS
9 | from utils.Replaybuffer import ReplayMemory
10 | import datetime
11 | import itertools
12 | from copy import copy
13 | from torch.utils.tensorboard import SummaryWriter
14 | import shutil
15 | import Non_stationary_env
16 |
17 |
18 |
19 |
20 | def train_loop(config, msg = "default"):
21 | # set seed
22 | env = gym.make(config.env_name)
23 | env.seed(config.seed)
24 | env.action_space.seed(config.seed)
25 |
26 | torch.manual_seed(config.seed)
27 | np.random.seed(config.seed)
28 |
29 | discriminator = Discriminator_SAS(env.observation_space.shape[0], env.action_space.shape[0], config)
30 |
31 | agent = OMPO(env.observation_space.shape[0], env.action_space, config)
32 |
33 | result_path = './results/{}/{}/{}/{}_{}_{}_{}'.format(config.env_name, config.algo, msg,
34 | datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S"),
35 | 'OMPO', config.seed,
36 | "autotune" if config.automatic_entropy_tuning else "")
37 |
38 | checkpoint_path = result_path + '/' + 'checkpoint'
39 |
40 | # training logs
41 | if not os.path.exists(result_path):
42 | os.makedirs(result_path)
43 | if not os.path.exists(checkpoint_path):
44 | os.makedirs(checkpoint_path)
45 | with open(os.path.join(result_path, "config.log"), 'w') as f:
46 | f.write(str(config))
47 | writer = SummaryWriter(result_path)
48 | shutil.copytree('.', result_path + '/code', ignore=shutil.ignore_patterns('results', 'results_stationary_test_1', 'results_stationary_test_2'))
49 |
50 | # all memory
51 | memory = ReplayMemory(config.replay_size, config.seed)
52 | initial_state_memory = ReplayMemory(config.replay_size, config.seed)
53 | # expert memory
54 | local_memory = ReplayMemory(config.local_replay_size, config.seed)
55 | # sample from all memory for training
56 | temp_memory = ReplayMemory(config.local_replay_size, config.seed)
57 |
58 | for _ in range(config.batch_size):
59 | if "Hopper" in config.env_name:
60 | torso_len = np.random.uniform(0.3, 0.5)
61 | foot_len = np.random.uniform(0.29, 0.49)
62 | state = env.reset()
63 | elif "Walker" in config.env_name:
64 | torso_len = np.random.uniform(0.1, 0.3)
65 | foot_len = np.random.uniform(0.05, 0.15)
66 | state = env.reset(torso_len = torso_len, foot_len = foot_len)
67 | elif "Ant" in config.env_name:
68 | gravity = np.random.uniform(9.81, 19.82)
69 | wind = np.random.uniform(0.8, 1.2)
70 | state = env.reset(gravity = gravity, wind = wind)
71 | elif "Humanoid" in config.env_name:
72 | gravity = np.random.uniform(9.81, 19.82)
73 | wind = np.random.uniform(0.5, 1.5)
74 | state = env.reset(gravity = gravity, wind = wind)
75 | else:
76 | state = env.reset()
77 | initial_state_memory.push(state, 0, 0, 0, 0) # Note the initial buffer only contains state
78 |
79 | # Training Loop
80 | total_numsteps = 0
81 | updates_discriminator = 0
82 | updates_agent = 0
83 | best_reward = -1e6
84 | for i_episode in itertools.count(1):
85 | if "Hopper" in config.env_name:
86 | torso_len = 0.4 + 0.1 * np.sin(0.2 * i_episode)
87 | foot_len = 0.39 + 0.1 * np.sin(0.2 * i_episode)
88 | state = env.reset(torso_len = torso_len, foot_len = foot_len)
89 | elif "Walker" in config.env_name:
90 | torso_len = 0.2 + 0.1 * np.sin(0.3 * i_episode)
91 | foot_len = 0.1 + 0.05 * np.sin(0.3 * i_episode)
92 | state = env.reset(torso_len = torso_len, foot_len = foot_len)
93 | elif "Ant" in config.env_name:
94 | gravity = 14.715 + 4.905 * np.sin(0.5 * i_episode)
95 | wind = 1. + 0.2 * np.sin(0.5 * i_episode)
96 | state = env.reset(gravity = gravity, wind = wind)
97 | elif "Humanoid" in config.env_name:
98 | gravity = 14.715 + 4.905 * np.sin(0.5 * i_episode)
99 | wind = 1. + 0.5 * np.sin(0.5 * i_episode)
100 | state = env.reset(gravity = gravity, wind = wind)
101 | else:
102 | state = env.reset()
103 | episode_reward = 0
104 | episode_steps = 0
105 | done = False
106 | initial_state_memory.push(state, 0, 0, 0, 0) # Note the initial buffer only contains state
107 |
108 | while not done:
109 | if config.start_steps > total_numsteps:
110 | action = env.action_space.sample() # Sample random action
111 | else:
112 | action = agent.select_action(state) # Sample action from policy
113 |
114 | if config.start_steps <= total_numsteps:
115 | # use the same history buffer to update the discriminator
116 | if local_memory.__len__() == config.local_replay_size:
117 | for _ in range(10): # default: 10
118 | state_batch, action_batch, reward_batch, next_state_batch, mask_batch = memory.sample(batch_size=config.local_replay_size)
119 | for local_buffer_idx in range(config.local_replay_size):
120 | temp_memory.push(state_batch[local_buffer_idx], action_batch[local_buffer_idx], reward_batch[local_buffer_idx], next_state_batch[local_buffer_idx], mask_batch[local_buffer_idx])
121 | for _ in range(20): # default: 20
122 | discriminator_loss = discriminator.update(local_memory, temp_memory, config.gail_batch)
123 | writer.add_scalar('loss/discriminator_loss', discriminator_loss, updates_discriminator)
124 | updates_discriminator += 1
125 |
126 | temp_memory = ReplayMemory(config.local_replay_size, config.seed)
127 | # reset the local memory
128 | local_memory = ReplayMemory(config.local_replay_size, config.seed)
129 |
130 | # train the agent
131 | for _ in range(config.updates_per_step):
132 | # Update parameters of all the networks
133 | critic_loss, policy_loss, ent_loss, alpha = agent.update_parameters(initial_state_memory, memory, discriminator, config.batch_size, updates_agent, writer)
134 |
135 | writer.add_scalar('loss/critic', critic_loss, updates_agent)
136 | writer.add_scalar('loss/policy', policy_loss, updates_agent)
137 | writer.add_scalar('loss/entropy_loss', ent_loss, updates_agent)
138 | writer.add_scalar('entropy_temprature/alpha', alpha, updates_agent)
139 | updates_agent += 1
140 |
141 | next_state, reward, done, _ = env.step(action) # Step
142 | episode_steps += 1
143 | total_numsteps += 1
144 | episode_reward += reward
145 |
146 | # Ignore the "done" signal if it comes from hitting the time horizon.
147 | # (https://github.com/openai/spinningup/blob/master/spinup/algos/sac/sac.py)
148 | mask = 1 if episode_steps == env._max_episode_steps else float(not done)
149 |
150 | memory.push(state, action, reward + config.reward_max, next_state, mask) # Append transition to global memory
151 | local_memory.push(state, action, reward + config.reward_max, next_state, mask) # Append transition to local memory
152 |
153 | state = next_state
154 |
155 | if total_numsteps > config.num_steps:
156 | break
157 |
158 | writer.add_scalar('train/reward', episode_reward, total_numsteps)
159 | print("Episode: {}, total numsteps: {}, episode steps: {}, reward: {}".format(i_episode, total_numsteps, episode_steps, round(episode_reward, 2)))
160 |
161 | # test agent
162 | if i_episode % config.eval_episodes == 0 and config.eval is True:
163 | avg_reward = 0.
164 | # avg_success = 0.
165 | for _ in range(config.eval_episodes):
166 | if "Hopper" in config.env_name:
167 | torso_len = 0.4 + 0.1 * np.sin(0.2 * i_episode)
168 | foot_len = 0.39 + 0.1 * np.sin(0.2 * i_episode)
169 | state = env.reset()
170 | elif "Walker" in config.env_name:
171 | torso_len = 0.2 + 0.1 * np.sin(0.3 * i_episode)
172 | foot_len = 0.1 + 0.05 * np.sin(0.3 * i_episode)
173 | state = env.reset(torso_len = torso_len, foot_len = foot_len)
174 | elif "Ant" in config.env_name:
175 | gravity = 14.715 + 4.905 * np.sin(0.5 * i_episode)
176 | wind = 1. + 0.2 * np.sin(0.5 * i_episode)
177 | state = env.reset(gravity = gravity, wind = wind)
178 | elif "Humanoid" in config.env_name:
179 | gravity = 14.715 + 4.905 * np.sin(0.5 * i_episode)
180 | wind = 1. + 0.5 * np.sin(0.5 * i_episode)
181 | state = env.reset(gravity = gravity, wind = wind)
182 | else:
183 | state = env.reset()
184 | episode_reward = 0
185 | done = False
186 | while not done:
187 | action = agent.select_action(state, evaluate=True)
188 |
189 | next_state, reward, done, info = env.step(action)
190 | episode_reward += reward
191 |
192 | state = next_state
193 | avg_reward += episode_reward
194 | # avg_success += float(info['is_success'])
195 | avg_reward /= config.eval_episodes
196 | # avg_success /= config.eval_episodes
197 | if avg_reward >= best_reward and config.save is True:
198 | best_reward = avg_reward
199 | agent.save_checkpoint(checkpoint_path, 'best')
200 |
201 | writer.add_scalar('test/avg_reward', avg_reward, total_numsteps)
202 | # writer.add_scalar('test/avg_success', avg_success, total_numsteps)
203 |
204 | print("----------------------------------------")
205 | print("Env: {}, Test Episodes: {}, Avg. Reward: {}".format(config.env_name, config.eval_episodes, round(avg_reward, 2)))
206 | print("----------------------------------------")
207 |
208 | env.close()
209 |
210 | # python main.py --device 2
211 |
212 | if __name__ == "__main__":
213 | arg = ARGConfig()
214 | arg.add_arg("env_name", "AntRandom-v0", "Environment name")
215 | arg.add_arg("device", 0, "Computing device")
216 | arg.add_arg("policy", "Gaussian", "Policy Type: Gaussian | Deterministic (default: Gaussian)")
217 | arg.add_arg("tag", "default", "Experiment tag")
218 | arg.add_arg("start_steps", 5000, "Number of start steps")
219 | arg.add_arg("automatic_entropy_tuning", True, "Automaically adjust α (default: True)")
220 | arg.add_arg("seed", np.random.randint(0, 1000), "experiment seed")
221 | arg.parser()
222 |
223 | config = default_config
224 | config.update(arg)
225 |
226 | print(f">>>> Training OMPO on {config.env_name} environment, on {config.device}")
227 | train_loop(config, msg=config.tag)
--------------------------------------------------------------------------------
/model/algo.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import torch.nn.functional as F
4 | from torch.optim import Adam
5 | from torch.optim.lr_scheduler import MultiStepLR
6 | import copy
7 | from .utils import soft_update, hard_update, orthogonal_regularization
8 | from .model import GaussianPolicy, QNetwork, DeterministicPolicy
9 |
10 | epsilon = 1e-6
11 |
12 | class OMPO(object):
13 | def __init__(self, num_inputs, action_space, args):
14 |
15 | self.gamma = args.gamma
16 | self.tau = args.tau
17 | self.alpha = args.alpha
18 |
19 | self.policy_type = args.policy
20 | self.target_update_interval = args.target_update_interval
21 | self.automatic_entropy_tuning = args.automatic_entropy_tuning
22 |
23 | self.exponent = args.exponent
24 |
25 | self.actor_loss = 0
26 | self.alpha_loss = 0
27 | self.alpha_tlogs = 0
28 |
29 | if self.exponent <= 1:
30 | raise ValueError('Exponent must be greather than 1, but received %f.' %
31 | self.exponent)
32 |
33 | self.tomac_alpha = args.tomac_alpha
34 |
35 | self.f = lambda resid: torch.pow(torch.abs(resid), self.exponent) / self.exponent
36 | clip_resid = lambda resid: torch.clamp(resid, 0.0, 1e6)
37 | self.fgrad = lambda resid: torch.pow(clip_resid(resid), self.exponent - 1)
38 |
39 | self.device = torch.device("cuda:{}".format(str(args.device)) if args.cuda else "cpu")
40 |
41 | self.critic = QNetwork(num_inputs, action_space.shape[0], args.hidden_size).to(device=self.device)
42 | self.critic_optim = Adam(self.critic.parameters(), lr=args.lr)
43 | self.critic_optim_scheduler = MultiStepLR(self.critic_optim, milestones=[200000, 400000], gamma=0.2)
44 |
45 | self.critic_target = QNetwork(num_inputs, action_space.shape[0], args.hidden_size).to(self.device)
46 | hard_update(self.critic_target, self.critic)
47 |
48 | if self.policy_type == "Gaussian":
49 | # Target Entropy = −dim(A) (e.g. , -6 for HalfCheetah-v2) as given in the paper
50 | if self.automatic_entropy_tuning is True:
51 | self.target_entropy = -torch.prod(torch.Tensor(action_space.shape).to(self.device)).item()
52 | self.log_alpha = torch.zeros(1, requires_grad=True, device=self.device) # initial alpha = 1.0
53 | self.alpha_optim = Adam([self.log_alpha], lr=args.lr)
54 | self.alpha_optim_scheduler = MultiStepLR(self.alpha_optim, milestones=[100000, 200000], gamma=0.2)
55 |
56 | self.policy = GaussianPolicy(num_inputs, action_space.shape[0], args.hidden_size, action_space).to(self.device)
57 | self.policy_optim = Adam(self.policy.parameters(), lr=args.lr)
58 | self.policy_optim_scheduler = MultiStepLR(self.policy_optim, milestones=[100000, 200000], gamma=0.2)
59 |
60 | else:
61 | self.alpha = 0
62 | self.automatic_entropy_tuning = False
63 | self.policy = DeterministicPolicy(num_inputs, action_space.shape[0], args.hidden_size, action_space).to(self.device)
64 | self.policy_optim = Adam(self.policy.parameters(), lr=args.lr)
65 |
66 | def select_action(self, state, evaluate=False):
67 | state = torch.FloatTensor(state).to(self.device).unsqueeze(0)
68 | if evaluate is False:
69 | action, _, _ = self.policy.sample(state)
70 | else:
71 | _, _, action = self.policy.sample(state)
72 | return action.detach().cpu().numpy()[0]
73 |
74 | def critic_mix(self, s, a):
75 | target_q1, target_q2 = self.critic_target(s, a)
76 | target_q = torch.min(target_q1, target_q2)
77 | q1, q2 = self.critic(s, a)
78 | return q1 * 0.05 + target_q * 0.95, q2 * 0.05 + target_q * 0.95
79 |
80 | def update_critic(self, discriminator, states, actions, next_states, rewards, masks, init_states, updates, writer):
81 | init_actions, _, _ = self.policy.sample(init_states)
82 | next_actions, next_log_probs, _ = self.policy.sample(next_states)
83 |
84 | # rewards = torch.clamp(rewards, 0, torch.inf)
85 | rewards = torch.clamp(rewards, epsilon, torch.inf)
86 |
87 | d_sas_rewards = discriminator.predict_reward(states, actions, next_states)
88 |
89 | writer.add_scalar('para/d_sas_reward', torch.mean(d_sas_rewards).item(), updates)
90 |
91 | # compute the reward
92 | # rewards = torch.log(rewards + epsilon * torch.ones(rewards.shape[0]).to(self.device)) - self.tomac_alpha * d_sas_rewards
93 | # rewards = torch.log(rewards) - self.tomac_alpha * d_sas_rewards
94 | rewards = rewards - self.tomac_alpha * d_sas_rewards
95 |
96 | # rewards -= self.tomac_alpha * d_sas_rewards
97 |
98 | with torch.no_grad():
99 | target_q1, target_q2 = self.critic_mix(next_states, next_actions)
100 |
101 | target_q1 = target_q1 - self.alpha * next_log_probs
102 | target_q2 = target_q2 - self.alpha * next_log_probs
103 |
104 | target_q1 = rewards + self.gamma * masks * target_q1
105 | target_q2 = rewards + self.gamma * masks * target_q2
106 |
107 | q1, q2 = self.critic(states, actions)
108 | init_q1, init_q2 = self.critic(init_states, init_actions)
109 |
110 | critic_loss1 = torch.mean(self.f(target_q1 - q1) + (1 - self.gamma) * init_q1 * self.tomac_alpha)
111 | critic_loss2 = torch.mean(self.f(target_q2 - q2) + (1 - self.gamma) * init_q2 * self.tomac_alpha)
112 |
113 | critic_loss = (critic_loss1 + critic_loss2)
114 |
115 | self.critic_optim.zero_grad()
116 | critic_loss.backward()
117 | self.critic_optim.step()
118 |
119 | return critic_loss.item()
120 |
121 | def update_actor(self, discriminator, states, actions, next_states, rewards, masks, init_states):
122 | init_actions, _, _ = self.policy.sample(init_states)
123 | next_actions, next_log_probs, _ = self.policy.sample(next_states)
124 |
125 | # rewards = torch.clamp(rewards, 0, torch.inf)
126 | rewards = torch.clamp(rewards, epsilon, torch.inf)
127 |
128 | d_sas_rewards = discriminator.predict_reward(states, actions, next_states)
129 |
130 | # compute the reward
131 | # rewards = torch.log(rewards + epsilon * torch.ones(rewards.shape[0]).to(self.device)) - self.tomac_alpha * d_sas_rewards
132 | # rewards = torch.log(rewards) - self.tomac_alpha * d_sas_rewards
133 | rewards = rewards - self.tomac_alpha * d_sas_rewards
134 |
135 | # rewards -= self.tomac_alpha * d_sas_rewards
136 |
137 | target_q1, target_q2 = self.critic_mix(next_states, next_actions)
138 |
139 | target_q1 = target_q1 - self.alpha * next_log_probs
140 | target_q2 = target_q2 - self.alpha * next_log_probs
141 |
142 | target_q1 = rewards + self.gamma * masks * target_q1
143 | target_q2 = rewards + self.gamma * masks * target_q2
144 |
145 | q1, q2 = self.critic(states, actions)
146 | init_q1, init_q2 = self.critic(init_states, init_actions)
147 |
148 | actor_loss1 = -torch.mean(self.fgrad(target_q1 - q1).detach() * (target_q1 - q1) + (1 - self.gamma) * init_q1 * self.tomac_alpha)
149 | actor_loss2 = -torch.mean(self.fgrad(target_q2 - q2).detach() * (target_q2 - q2) + (1 - self.gamma) * init_q2 * self.tomac_alpha)
150 |
151 | actor_loss = (actor_loss1 + actor_loss2) / 2.0
152 | # actor_loss += orthogonal_regularization(self.policy)
153 |
154 | self.policy_optim.zero_grad()
155 | actor_loss.backward()
156 | self.policy_optim.step()
157 |
158 | _, log_probs, _ = self.policy.sample(states)
159 |
160 | if self.automatic_entropy_tuning:
161 | alpha_loss = -(self.log_alpha * (log_probs + self.target_entropy).detach()).mean()
162 |
163 | self.alpha_optim.zero_grad()
164 | alpha_loss.backward()
165 | self.alpha_optim.step()
166 |
167 | # self.alpha_optim_scheduler.step()
168 |
169 | self.alpha = self.log_alpha.exp()
170 | alpha_tlogs = self.alpha.clone() # For TensorboardX logs
171 | else:
172 | alpha_loss = torch.tensor(0.).to(self.device)
173 | alpha_tlogs = torch.tensor(self.alpha) # For TensorboardX logs
174 |
175 | return actor_loss.item(), alpha_loss.item(), alpha_tlogs.item()
176 |
177 | def update_parameters(self, initial_state_memory, memory, discriminator, batch_size, updates, writer):
178 | # Sample a batch from initial_state_memory
179 | initial_state_batch, _, _, _, _ = initial_state_memory.sample(batch_size=batch_size)
180 | initial_state_batch = torch.FloatTensor(initial_state_batch).to(self.device)
181 |
182 | # Sample a batch from memory
183 | state_batch, action_batch, reward_batch, next_state_batch, mask_batch = memory.sample(batch_size=batch_size)
184 |
185 | state_batch = torch.FloatTensor(state_batch).to(self.device)
186 | next_state_batch = torch.FloatTensor(next_state_batch).to(self.device)
187 | action_batch = torch.FloatTensor(action_batch).to(self.device)
188 | reward_batch = torch.FloatTensor(reward_batch).to(self.device).unsqueeze(1)
189 | mask_batch = torch.FloatTensor(mask_batch).to(self.device).unsqueeze(1)
190 |
191 | critic_loss = self.update_critic(discriminator, state_batch, action_batch, next_state_batch, reward_batch, mask_batch, initial_state_batch, updates, writer)
192 | # self.critic_optim_scheduler.step()
193 |
194 | if updates % self.target_update_interval == 0:
195 | self.actor_loss, self.alpha_loss, self.alpha_tlogs = self.update_actor(discriminator, state_batch, action_batch, next_state_batch, reward_batch, mask_batch, initial_state_batch)
196 | soft_update(self.critic_target, self.critic, self.tau)
197 | # self.policy_optim_scheduler.step()
198 |
199 | return critic_loss, self.actor_loss, self.alpha_loss, self.alpha_tlogs
200 |
201 | # Save model parameters
202 | def save_checkpoint(self, path, i_episode):
203 | ckpt_path = path + '/' + '{}.torch'.format(i_episode)
204 | print('Saving models to {}'.format(ckpt_path))
205 | torch.save({'policy_state_dict': self.policy.state_dict(),
206 | 'critic_state_dict': self.critic.state_dict(),
207 | 'critic_target_state_dict': self.critic_target.state_dict(),
208 | 'critic_optimizer_state_dict': self.critic_optim.state_dict(),
209 | 'policy_optimizer_state_dict': self.policy_optim.state_dict()}, ckpt_path)
210 |
211 | # Load model parameters
212 | def load_checkpoint(self, path, i_episode, evaluate=False):
213 | ckpt_path = path + '/' + '{}.torch'.format(i_episode)
214 | print('Loading models from {}'.format(ckpt_path))
215 | if ckpt_path is not None:
216 | checkpoint = torch.load(ckpt_path)
217 | self.policy.load_state_dict(checkpoint['policy_state_dict'])
218 | self.critic.load_state_dict(checkpoint['critic_state_dict'])
219 | self.critic_target.load_state_dict(checkpoint['critic_target_state_dict'])
220 | self.critic_optim.load_state_dict(checkpoint['critic_optimizer_state_dict'])
221 | self.policy_optim.load_state_dict(checkpoint['policy_optimizer_state_dict'])
222 |
223 | if evaluate:
224 | self.policy.eval()
225 | self.critic.eval()
226 | self.critic_target.eval()
227 | else:
228 | self.policy.train()
229 | self.critic.train()
230 | self.critic_target.train()
231 |
232 |
--------------------------------------------------------------------------------