├── LICENSE ├── README.md ├── bppo.py ├── buffer.py ├── critic.py ├── main.py ├── net.py ├── ppo.py └── utils.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Dragon-Zhuang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Behavior Proximal Policy Optimization 2 | 3 | Author's Pytorch implementation of [ICLR 2023 paper](https://openreview.net/forum?id=3c13LptpIph&referrer=%5Bthe%20profile%20of%20Kun%20LEI%5D(%2Fprofile%3Fid%3D~Kun_LEI1)) **B**ehavior **P**roximal **P**olicy **O**ptimization (BPPO). BPPO uses the loss function from Proximal Policy Optimization (PPO) to improve the behavior policy estimated by behavior cloning. 4 | 5 | ## The difference between BPPO and PPO 6 | 7 | Compared to the loss function of PPO, BPPO does not introduce any extra constraint or regularization. The only difference is the advantage approximation, corresponding to the code difference between `ppo.py` line 88-89 and `bppo.py` line 151-155. 8 | 9 | 10 | ## Overview of the Code 11 | The code consists of 7 Python scripts and the file `main.py` contains various parameter settings which are interpreted and described in our paper. 12 | ### Requirements 13 | - `torch 1.12.0` 14 | - `mujoco 2.2.1` 15 | - `mujoco-py 2.1.2.14` 16 | - `d4rl 1.1` 17 | ### Running the code 18 | - `python main.py`: trains the network, storing checkpoints along the way. 19 | - `Example`: 20 | ```bash 21 | python main.py --env hopper-medium-v2 22 | ``` 23 | ## Citation 24 | If you use BPPO, please cite our paper as follows: 25 | ``` 26 | @article{zhuang2023behavior, 27 | title={Behavior proximal policy optimization}, 28 | author={Zhuang, Zifeng and Lei, Kun and Liu, Jinxin and Wang, Donglin and Guo, Yilang}, 29 | journal={arXiv preprint arXiv:2302.11312}, 30 | year={2023} 31 | } 32 | ``` 33 | -------------------------------------------------------------------------------- /bppo.py: -------------------------------------------------------------------------------- 1 | import gym 2 | import torch 3 | import numpy as np 4 | from buffer import OnlineReplayBuffer 5 | from net import GaussPolicyMLP 6 | from critic import ValueLearner, QLearner 7 | from ppo import ProximalPolicyOptimization 8 | from utils import CONST_EPS, log_prob_func, orthogonal_initWeights 9 | 10 | 11 | class BehaviorCloning: 12 | _device: torch.device 13 | _policy: GaussPolicyMLP 14 | _optimizer: torch.optim 15 | _policy_lr: float 16 | _batch_size: int 17 | def __init__( 18 | self, 19 | device: torch.device, 20 | state_dim: int, 21 | hidden_dim: int, 22 | depth: int, 23 | action_dim: int, 24 | policy_lr: float, 25 | batch_size: int 26 | ) -> None: 27 | super().__init__() 28 | self._device = device 29 | self._policy = GaussPolicyMLP(state_dim, hidden_dim, depth, action_dim).to(device) 30 | orthogonal_initWeights(self._policy) 31 | self._optimizer = torch.optim.Adam( 32 | self._policy.parameters(), 33 | lr = policy_lr 34 | ) 35 | self._lr = policy_lr 36 | self._batch_size = batch_size 37 | 38 | 39 | def loss( 40 | self, replay_buffer: OnlineReplayBuffer, 41 | ) -> torch.Tensor: 42 | s, a, _, _, _, _, _, _ = replay_buffer.sample(self._batch_size) 43 | dist = self._policy(s) 44 | log_prob = log_prob_func(dist, a) 45 | loss = (-log_prob).mean() 46 | 47 | return loss 48 | 49 | 50 | def update( 51 | self, replay_buffer: OnlineReplayBuffer, 52 | ) -> float: 53 | policy_loss = self.loss(replay_buffer) 54 | 55 | self._optimizer.zero_grad() 56 | policy_loss.backward() 57 | self._optimizer.step() 58 | 59 | return policy_loss.item() 60 | 61 | 62 | def select_action( 63 | self, s: torch.Tensor, is_sample: bool 64 | ) -> torch.Tensor: 65 | dist = self._policy(s) 66 | if is_sample: 67 | action = dist.sample() 68 | else: 69 | action = dist.mean 70 | # clip 71 | action = action.clamp(-1., 1.) 72 | return action 73 | 74 | 75 | def offline_evaluate( 76 | self, 77 | env_name: str, 78 | seed: int, 79 | mean: np.ndarray, 80 | std: np.ndarray, 81 | eval_episodes: int=10 82 | ) -> float: 83 | env = gym.make(env_name) 84 | env.seed(seed) 85 | 86 | total_reward = 0 87 | for _ in range(eval_episodes): 88 | s, done = env.reset(), False 89 | while not done: 90 | s = torch.FloatTensor((np.array(s).reshape(1, -1) - mean) / std).to(self._device) 91 | a = self.select_action(s, is_sample=False).cpu().data.numpy().flatten() 92 | s, r, done, _ = env.step(a) 93 | total_reward += r 94 | 95 | avg_reward = total_reward / eval_episodes 96 | d4rl_score = env.get_normalized_score(avg_reward) * 100 97 | return d4rl_score 98 | 99 | 100 | def save( 101 | self, path: str 102 | ) -> None: 103 | torch.save(self._policy.state_dict(), path) 104 | print('Behavior policy parameters saved in {}'.format(path)) 105 | 106 | 107 | def load( 108 | self, path: str 109 | ) -> None: 110 | self._policy.load_state_dict(torch.load(path, map_location=self._device)) 111 | print('Behavior policy parameters loaded') 112 | 113 | 114 | 115 | class BehaviorProximalPolicyOptimization(ProximalPolicyOptimization): 116 | 117 | def __init__( 118 | self, 119 | device: torch.device, 120 | state_dim: int, 121 | hidden_dim: int, 122 | depth: int, 123 | action_dim: int, 124 | policy_lr: float, 125 | clip_ratio: float, 126 | entropy_weight: float, 127 | decay: float, 128 | omega: float, 129 | batch_size: int 130 | ) -> None: 131 | super().__init__( 132 | device = device, 133 | state_dim = state_dim, 134 | hidden_dim = hidden_dim, 135 | depth = depth, 136 | action_dim = action_dim, 137 | policy_lr = policy_lr, 138 | clip_ratio = clip_ratio, 139 | entropy_weight = entropy_weight, 140 | decay = decay, 141 | omega = omega, 142 | batch_size = batch_size) 143 | 144 | 145 | def loss( 146 | self, 147 | replay_buffer: OnlineReplayBuffer, 148 | Q: QLearner, 149 | value: ValueLearner, 150 | is_clip_decay: bool, 151 | ) -> torch.Tensor: 152 | # -------------------------------------Advantage------------------------------------- 153 | s, _, _, _, _, _, _, _ = replay_buffer.sample(self._batch_size) 154 | old_dist = self._old_policy(s) 155 | a = old_dist.rsample() 156 | advantage = Q(s, a) - value(s) 157 | advantage = (advantage - advantage.mean()) / (advantage.std() + CONST_EPS) 158 | # -------------------------------------Advantage------------------------------------- 159 | new_dist = self._policy(s) 160 | 161 | new_log_prob = log_prob_func(new_dist, a) 162 | old_log_prob = log_prob_func(old_dist, a) 163 | ratio = (new_log_prob - old_log_prob).exp() 164 | 165 | advantage = self.weighted_advantage(advantage) 166 | 167 | loss1 = ratio * advantage 168 | 169 | if is_clip_decay: 170 | self._clip_ratio = self._clip_ratio * self._decay 171 | else: 172 | self._clip_ratio = self._clip_ratio 173 | 174 | loss2 = torch.clamp(ratio, 1 - self._clip_ratio, 1 + self._clip_ratio) * advantage 175 | 176 | entropy_loss = new_dist.entropy().sum(-1, keepdim=True) * self._entropy_weight 177 | 178 | loss = -(torch.min(loss1, loss2) + entropy_loss).mean() 179 | 180 | return loss 181 | 182 | 183 | def offline_evaluate( 184 | self, 185 | env_name: str, 186 | seed: int, 187 | mean: np.ndarray, 188 | std: np.ndarray, 189 | eval_episodes: int=10 190 | ) -> float: 191 | env = gym.make(env_name) 192 | avg_reward = self.evaluate(env_name, seed, mean, std, eval_episodes) 193 | d4rl_score = env.get_normalized_score(avg_reward) * 100 194 | return d4rl_score 195 | -------------------------------------------------------------------------------- /buffer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import numpy as np 4 | from tqdm import tqdm 5 | from utils import CONST_EPS 6 | 7 | 8 | 9 | class OnlineReplayBuffer: 10 | _device: torch.device 11 | _state: np.ndarray 12 | _action: np.ndarray 13 | _reward: np.ndarray 14 | _next_state: np.ndarray 15 | _next_action: np.ndarray 16 | _not_done: np.ndarray 17 | _return: np.ndarray 18 | _size: int 19 | 20 | 21 | def __init__( 22 | self, 23 | device: torch.device, 24 | state_dim: int, action_dim: int, max_size: int 25 | ) -> None: 26 | 27 | self._device = device 28 | 29 | self._state = np.zeros((max_size, state_dim)) 30 | self._action = np.zeros((max_size, action_dim)) 31 | self._reward = np.zeros((max_size, 1)) 32 | self._next_state = np.zeros((max_size, state_dim)) 33 | self._next_action = np.zeros((max_size, action_dim)) 34 | self._not_done = np.zeros((max_size, 1)) 35 | self._return = np.zeros((max_size, 1)) 36 | self._advantage = np.zeros((max_size, 1)) 37 | 38 | self._size = 0 39 | 40 | 41 | def store( 42 | self, 43 | s: np.ndarray, 44 | a: np.ndarray, 45 | r: np.ndarray, 46 | s_p: np.ndarray, 47 | a_p: np.ndarray, 48 | not_done: bool 49 | ) -> None: 50 | 51 | self._state[self._size] = s 52 | self._action[self._size] = a 53 | self._reward[self._size] = r 54 | self._next_state[self._size] = s_p 55 | self._next_action[self._size] = a_p 56 | self._not_done[self._size] = not_done 57 | self._size += 1 58 | 59 | 60 | def compute_return( 61 | self, gamma: float 62 | ) -> None: 63 | 64 | pre_return = 0 65 | for i in tqdm(reversed(range(self._size)), desc='Computing the returns'): 66 | self._return[i] = self._reward[i] + gamma * pre_return * self._not_done[i] 67 | pre_return = self._return[i] 68 | 69 | 70 | def compute_advantage( 71 | self, gamma:float, lamda: float, value 72 | ) -> None: 73 | delta = np.zeros_like(self._reward) 74 | 75 | pre_value = 0 76 | pre_advantage = 0 77 | 78 | for i in tqdm(reversed(range(self._size)), 'Computing the advantage'): 79 | current_state = torch.FloatTensor(self._state[i]).to(self._device) 80 | current_value = value(current_state).cpu().data.numpy().flatten() 81 | 82 | delta[i] = self._reward[i] + gamma * pre_value * self._not_done[i] - current_value 83 | self._advantage[i] = delta[i] + gamma * lamda * pre_advantage * self._not_done[i] 84 | 85 | pre_value = current_value 86 | pre_advantage = self._advantage[i] 87 | 88 | self._advantage = (self._advantage - self._advantage.mean()) / (self._advantage.std() + CONST_EPS) 89 | 90 | 91 | def sample( 92 | self, batch_size: int 93 | ) -> tuple: 94 | 95 | ind = np.random.randint(0, self._size, size=batch_size) 96 | 97 | return ( 98 | torch.FloatTensor(self._state[ind]).to(self._device), 99 | torch.FloatTensor(self._action[ind]).to(self._device), 100 | torch.FloatTensor(self._reward[ind]).to(self._device), 101 | torch.FloatTensor(self._next_state[ind]).to(self._device), 102 | torch.FloatTensor(self._next_action[ind]).to(self._device), 103 | torch.FloatTensor(self._not_done[ind]).to(self._device), 104 | torch.FloatTensor(self._return[ind]).to(self._device), 105 | torch.FloatTensor(self._advantage[ind]).to(self._device) 106 | ) 107 | 108 | 109 | 110 | class OfflineReplayBuffer(OnlineReplayBuffer): 111 | 112 | def __init__( 113 | self, device: torch.device, 114 | state_dim: int, action_dim: int, max_size: int 115 | ) -> None: 116 | super().__init__(device, state_dim, action_dim, max_size) 117 | 118 | 119 | def load_dataset( 120 | self, dataset: dict 121 | ) -> None: 122 | 123 | self._state = dataset['observations'][:-1, :] 124 | self._action = dataset['actions'][:-1, :] 125 | self._reward = dataset['rewards'].reshape(-1, 1)[:-1, :] 126 | self._next_state = dataset['observations'][1:, :] 127 | self._next_action = dataset['actions'][1:, :] 128 | self._not_done = 1. - (dataset['terminals'].reshape(-1, 1)[:-1, :] | dataset['timeouts'].reshape(-1, 1)[:-1, :]) 129 | 130 | self._size = len(dataset['actions']) - 1 131 | 132 | 133 | def normalize_state( 134 | self 135 | ) -> tuple: 136 | 137 | mean = self._state.mean(0, keepdims=True) 138 | std = self._state.std(0, keepdims=True) + CONST_EPS 139 | self._state = (self._state - mean) / std 140 | self._next_state = (self._next_state - mean) / std 141 | return (mean, std) 142 | -------------------------------------------------------------------------------- /critic.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | from net import ValueMLP, QMLP 5 | from buffer import OnlineReplayBuffer 6 | 7 | 8 | class ValueLearner: 9 | _device: torch.device 10 | _value: ValueMLP 11 | _optimizer: torch.optim 12 | _batch_size: int 13 | 14 | def __init__( 15 | self, 16 | device: torch.device, 17 | state_dim: int, 18 | hidden_dim: int, 19 | depth: int, 20 | value_lr: float, 21 | batch_size: int 22 | ) -> None: 23 | super().__init__() 24 | self._device = device 25 | self._value = ValueMLP(state_dim, hidden_dim, depth).to(device) 26 | self._optimizer = torch.optim.Adam( 27 | self._value.parameters(), 28 | lr=value_lr, 29 | ) 30 | self._batch_size = batch_size 31 | 32 | 33 | def __call__( 34 | self, s: torch.Tensor 35 | ) -> torch.Tensor: 36 | return self._value(s) 37 | 38 | 39 | def update( 40 | self, replay_buffer: OnlineReplayBuffer 41 | ) -> float: 42 | s, _, _, _, _, _, Return, _ = replay_buffer.sample(self._batch_size) 43 | value_loss = F.mse_loss(self._value(s), Return) 44 | 45 | self._optimizer.zero_grad() 46 | value_loss.backward() 47 | self._optimizer.step() 48 | 49 | return value_loss.item() 50 | 51 | 52 | def save( 53 | self, path: str 54 | ) -> None: 55 | torch.save(self._value.state_dict(), path) 56 | print('Value parameters saved in {}'.format(path)) 57 | 58 | 59 | def load( 60 | self, path: str 61 | ) -> None: 62 | self._value.load_state_dict(torch.load(path, map_location=self._device)) 63 | print('Value parameters loaded') 64 | 65 | 66 | 67 | class QLearner: 68 | _device: torch.device 69 | _Q: QMLP 70 | _optimizer: torch.optim 71 | _target_Q: QMLP 72 | _total_update_step: int 73 | _target_update_freq: int 74 | _tau: float 75 | _gamma: float 76 | _batch_size: int 77 | 78 | def __init__( 79 | self, 80 | device: torch.device, 81 | state_dim: int, 82 | action_dim: int, 83 | hidden_dim: int, 84 | depth: int, 85 | Q_lr: float, 86 | target_update_freq: int, 87 | tau: float, 88 | gamma: float, 89 | batch_size: int 90 | ) -> None: 91 | super().__init__() 92 | self._device = device 93 | self._Q = QMLP(state_dim, action_dim, hidden_dim, depth).to(device) 94 | self._optimizer = torch.optim.Adam( 95 | self._Q.parameters(), 96 | lr=Q_lr, 97 | ) 98 | 99 | self._target_Q = QMLP(state_dim, action_dim, hidden_dim, depth).to(device) 100 | self._target_Q.load_state_dict(self._Q.state_dict()) 101 | self._total_update_step = 0 102 | self._target_update_freq = target_update_freq 103 | self._tau = tau 104 | 105 | self._gamma = gamma 106 | self._batch_size = batch_size 107 | 108 | 109 | def __call__( 110 | self, s: torch.Tensor, a: torch.Tensor 111 | ) -> torch.Tensor: 112 | return self._Q(s, a) 113 | 114 | 115 | def loss( 116 | self, replay_buffer: OnlineReplayBuffer, pi 117 | ) -> torch.Tensor: 118 | raise NotImplementedError 119 | 120 | 121 | def update( 122 | self, replay_buffer: OnlineReplayBuffer, pi 123 | ) -> float: 124 | Q_loss = self.loss(replay_buffer, pi) 125 | self._optimizer.zero_grad() 126 | Q_loss.backward() 127 | self._optimizer.step() 128 | 129 | self._total_update_step += 1 130 | if self._total_update_step % self._target_update_freq == 0: 131 | for param, target_param in zip(self._Q.parameters(), self._target_Q.parameters()): 132 | target_param.data.copy_(self._tau * param.data + (1 - self._tau) * target_param.data) 133 | 134 | return Q_loss.item() 135 | 136 | 137 | def save( 138 | self, path: str 139 | ) -> None: 140 | torch.save(self._Q.state_dict(), path) 141 | print('Q function parameters saved in {}'.format(path)) 142 | 143 | 144 | def load( 145 | self, path: str 146 | ) -> None: 147 | self._Q.load_state_dict(torch.load(path, map_location=self._device)) 148 | self._target_Q.load_state_dict(self._Q.state_dict()) 149 | print('Q function parameters loaded') 150 | 151 | 152 | 153 | class QSarsaLearner(QLearner): 154 | def __init__( 155 | self, 156 | device: torch.device, 157 | state_dim: int, 158 | action_dim: int, 159 | hidden_dim: int, 160 | depth: int, 161 | Q_lr: float, 162 | target_update_freq: int, 163 | tau: float, 164 | gamma: float, 165 | batch_size: int 166 | ) -> None: 167 | super().__init__( 168 | device = device, 169 | state_dim = state_dim, 170 | action_dim = action_dim, 171 | hidden_dim = hidden_dim, 172 | depth = depth, 173 | Q_lr = Q_lr, 174 | target_update_freq = target_update_freq, 175 | tau = tau, 176 | gamma = gamma, 177 | batch_size = batch_size 178 | ) 179 | 180 | 181 | def loss( 182 | self, replay_buffer: OnlineReplayBuffer, pi 183 | ) -> torch.Tensor: 184 | s, a, r, s_p, a_p, not_done, _, _ = replay_buffer.sample(self._batch_size) 185 | with torch.no_grad(): 186 | target_Q_value = r + not_done * self._gamma * self._target_Q(s_p, a_p) 187 | 188 | Q = self._Q(s, a) 189 | loss = F.mse_loss(Q, target_Q_value) 190 | 191 | return loss 192 | 193 | 194 | 195 | class QPiLearner(QLearner): 196 | def __init__( 197 | self, 198 | device: torch.device, 199 | state_dim: int, 200 | action_dim: int, 201 | hidden_dim: int, 202 | depth: int, 203 | Q_lr: float, 204 | target_update_freq: int, 205 | tau: float, 206 | gamma: float, 207 | batch_size: int 208 | ) -> None: 209 | super().__init__( 210 | device = device, 211 | state_dim = state_dim, 212 | action_dim = action_dim, 213 | hidden_dim = hidden_dim, 214 | depth = depth, 215 | Q_lr = Q_lr, 216 | target_update_freq = target_update_freq, 217 | tau = tau, 218 | gamma = gamma, 219 | batch_size = batch_size 220 | ) 221 | 222 | 223 | def loss( 224 | self, replay_buffer: OnlineReplayBuffer, pi 225 | ) -> torch.Tensor: 226 | s, a, r, s_p, _, not_done, _, _ = replay_buffer.sample(self._batch_size) 227 | a_p = pi.select_action(s_p, is_sample=True) 228 | with torch.no_grad(): 229 | target_Q_value = r + not_done * self._gamma * self._target_Q(s_p, a_p) 230 | 231 | Q = self._Q(s, a) 232 | loss = F.mse_loss(Q, target_Q_value) 233 | 234 | return loss -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import gym 2 | import d4rl 3 | import torch 4 | import numpy as np 5 | import os 6 | import time 7 | from tqdm import tqdm 8 | import argparse 9 | from tensorboardX import SummaryWriter 10 | 11 | from buffer import OfflineReplayBuffer 12 | from critic import ValueLearner, QPiLearner, QSarsaLearner 13 | from bppo import BehaviorCloning, BehaviorProximalPolicyOptimization 14 | 15 | #===========================================================Welcome to use BPPO================================================================== 16 | #Tips 17 | #for hopper-medium-v2 and walker2d-meidum-replay-v2, run 5e-4/2e-5/2e-5 for bc/q/v. 5e-5/2e-6/2e-6 for others, see the scale of dataset in d4rl. 18 | #for hopper-medium-v2, donnot use state normalization (state normalization is a trick in PPO). 19 | #===========================================================Welcome to use BPPO================================================================== 20 | if __name__ == "__main__": 21 | parser = argparse.ArgumentParser() 22 | # Experiment 23 | parser.add_argument("--env", default="hopper-medium-v2") 24 | parser.add_argument("--seed", default=8, type=int) 25 | parser.add_argument("--gpu", default=0, type=int) 26 | parser.add_argument("--log_freq", default=int(2e3), type=int) 27 | parser.add_argument("--path", default="logs", type=str) 28 | # For Value 29 | parser.add_argument("--v_steps", default=int(2e6), type=int) 30 | parser.add_argument("--v_hidden_dim", default=512, type=int) 31 | parser.add_argument("--v_depth", default=3, type=int) 32 | parser.add_argument("--v_lr", default=1e-4, type=float) 33 | parser.add_argument("--v_batch_size", default=512, type=int) 34 | # For Q 35 | parser.add_argument("--q_bc_steps", default=int(2e6), type=int) 36 | parser.add_argument("--q_pi_steps", default=10, type=int) 37 | parser.add_argument("--q_hidden_dim", default=1024, type=int) 38 | parser.add_argument("--q_depth", default=2, type=int) 39 | parser.add_argument("--q_lr", default=1e-4, type=float) 40 | parser.add_argument("--q_batch_size", default=512, type=int) 41 | parser.add_argument("--target_update_freq", default=2, type=int) 42 | parser.add_argument("--tau", default=0.005, type=float) 43 | parser.add_argument("--gamma", default=0.99, type=float) 44 | parser.add_argument("--is_offpolicy_update", default=False, type=bool) 45 | # For BehaviorCloning 46 | parser.add_argument("--bc_steps", default=int(5e5), type=int) # try to reduce the bc/q/v step if it works poorly, 5e-4/2e-5/2e-5 for bc/q/v, for example 47 | parser.add_argument("--bc_hidden_dim", default=1024, type=int) 48 | parser.add_argument("--bc_depth", default=2, type=int) 49 | parser.add_argument("--bc_lr", default=1e-4, type=float) 50 | parser.add_argument("--bc_batch_size", default=512, type=int) 51 | # For BPPO 52 | parser.add_argument("--bppo_steps", default=int(1e3), type=int) 53 | parser.add_argument("--bppo_hidden_dim", default=1024, type=int) 54 | parser.add_argument("--bppo_depth", default=2, type=int) 55 | parser.add_argument("--bppo_lr", default=1e-4, type=float) 56 | parser.add_argument("--bppo_batch_size", default=512, type=int) 57 | parser.add_argument("--clip_ratio", default=0.25, type=float) 58 | parser.add_argument("--entropy_weight", default=0.0, type=float) # for ()-medium-() tasks, try to use the entropy loss, weight == 0.01 59 | parser.add_argument("--decay", default=0.96, type=float) 60 | parser.add_argument("--omega", default=0.9, type=float) 61 | parser.add_argument("--is_clip_decay", default=True, type=bool) 62 | parser.add_argument("--is_bppo_lr_decay", default=True, type=bool) 63 | parser.add_argument("--is_update_old_policy", default=True, type=bool) 64 | parser.add_argument("--is_state_norm", default=False, type=bool) 65 | 66 | args = parser.parse_args() 67 | print(f'------current env {args.env} and current seed {args.seed}------') 68 | # path 69 | current_time = time.strftime("%Y_%m_%d__%H_%M_%S", time.localtime()) 70 | path = os.path.join(args.path, args.env, str(args.seed)) 71 | os.makedirs(os.path.join(path, current_time)) 72 | # save args 73 | config_path = os.path.join(path, current_time, 'config.txt') 74 | config = vars(args) 75 | with open(config_path, 'w') as f: 76 | for k, v in config.items(): 77 | f.writelines(f"{k:20} : {v} \n") 78 | 79 | 80 | env = gym.make(args.env) 81 | # seed 82 | env.seed(args.seed) 83 | env.action_space.seed(args.seed) 84 | torch.manual_seed(args.seed) 85 | torch.cuda.manual_seed(args.seed) 86 | np.random.seed(args.seed) 87 | # dim of state and action 88 | state_dim = env.observation_space.shape[0] 89 | action_dim = env.action_space.shape[0] 90 | # device 91 | device = torch.device(f"cuda:{args.gpu}" if torch.cuda.is_available() else "cpu") 92 | 93 | 94 | # offline dataset to replay buffer 95 | dataset = env.get_dataset() 96 | replay_buffer = OfflineReplayBuffer(device, state_dim, action_dim, len(dataset['actions'])) 97 | replay_buffer.load_dataset(dataset=dataset) 98 | replay_buffer.compute_return(args.gamma) 99 | 100 | #for hopper-medium-v2 task, don't use state normalize 101 | if args.is_state_norm: 102 | mean, std = replay_buffer.normalize_state() 103 | else: 104 | mean, std = 0., 1. 105 | 106 | # summarywriter logger 107 | comment = args.env + '_' + str(args.seed) 108 | logger_path = os.path.join(path, current_time) 109 | logger = SummaryWriter(log_dir=logger_path, comment=comment) 110 | 111 | 112 | # initilize 113 | value = ValueLearner(device, state_dim, args.v_hidden_dim, args.v_depth, args.v_lr, args.v_batch_size) 114 | Q_bc = QSarsaLearner(device, state_dim, action_dim, args.q_hidden_dim, args.q_depth, args.q_lr, args.target_update_freq, args.tau, args.gamma, args.q_batch_size) 115 | if args.is_offpolicy_update: 116 | Q_pi = QPiLearner(device, state_dim, action_dim, args.q_hidden_dim, args.q_depth, args.q_lr, args.target_update_freq, args.tau, args.gamma, args.q_batch_size) 117 | bc = BehaviorCloning(device, state_dim, args.bc_hidden_dim, args.bc_depth, action_dim, args.bc_lr, args.bc_batch_size) 118 | bppo = BehaviorProximalPolicyOptimization(device, state_dim, args.bppo_hidden_dim, args.bppo_depth, action_dim, args.bppo_lr, args.clip_ratio, args.entropy_weight, args.decay, args.omega, args.bppo_batch_size) 119 | 120 | 121 | # value training 122 | value_path = os.path.join(path, 'value.pt') 123 | if os.path.exists(value_path): 124 | value.load(value_path) 125 | else: 126 | for step in tqdm(range(int(args.v_steps)), desc='value updating ......'): 127 | value_loss = value.update(replay_buffer) 128 | 129 | if step % int(args.log_freq) == 0: 130 | print(f"Step: {step}, Loss: {value_loss:.4f}") 131 | logger.add_scalar('value_loss', value_loss, global_step=(step+1)) 132 | 133 | value.save(value_path) 134 | 135 | # Q_bc training 136 | Q_bc_path = os.path.join(path, 'Q_bc.pt') 137 | if os.path.exists(Q_bc_path): 138 | Q_bc.load(Q_bc_path) 139 | else: 140 | for step in tqdm(range(int(args.q_bc_steps)), desc='Q_bc updating ......'): 141 | Q_bc_loss = Q_bc.update(replay_buffer, pi=None) 142 | 143 | if step % int(args.log_freq) == 0: 144 | print(f"Step: {step}, Loss: {Q_bc_loss:.4f}") 145 | logger.add_scalar('Q_bc_loss', Q_bc_loss, global_step=(step+1)) 146 | 147 | Q_bc.save(Q_bc_path) 148 | 149 | if args.is_offpolicy_update: 150 | Q_pi.load(Q_bc_path) 151 | 152 | # bc training 153 | best_bc_path = os.path.join(path, 'bc_best.pt') 154 | if os.path.exists(best_bc_path): 155 | bc.load(best_bc_path) 156 | else: 157 | best_bc_score = 0 158 | for step in tqdm(range(int(args.bc_steps)), desc='bc updating ......'): 159 | bc_loss = bc.update(replay_buffer) 160 | 161 | if step % int(args.log_freq) == 0: 162 | current_bc_score = bc.offline_evaluate(args.env, args.seed, mean, std) 163 | if current_bc_score > best_bc_score: 164 | best_bc_score = current_bc_score 165 | bc.save(best_bc_path) 166 | np.savetxt(os.path.join(path, 'best_bc.csv'), [best_bc_score], fmt='%f', delimiter=',') 167 | print(f"Step: {step}, Loss: {bc_loss:.4f}, Score: {current_bc_score:.4f}") 168 | logger.add_scalar('bc_loss', bc_loss, global_step=(step+1)) 169 | logger.add_scalar('bc_score', current_bc_score, global_step=(step+1)) 170 | 171 | bc.save(os.path.join(path, 'bc_last.pt')) 172 | bc.load(best_bc_path) 173 | 174 | 175 | # bppo training 176 | bppo.load(best_bc_path) 177 | best_bppo_path = os.path.join(path, current_time, 'bppo_best.pt') 178 | Q = Q_bc 179 | 180 | best_bppo_score = bppo.offline_evaluate(args.env, args.seed, mean, std) 181 | print('best_bppo_score:',best_bppo_score,'-------------------------') 182 | 183 | for step in tqdm(range(int(args.bppo_steps)), desc='bppo updating ......'): 184 | if step > 200: 185 | args.is_clip_decay = False 186 | args.is_bppo_lr_decay = False 187 | bppo_loss = bppo.update(replay_buffer, Q, value, args.is_clip_decay, args.is_bppo_lr_decay) 188 | current_bppo_score = bppo.offline_evaluate(args.env, args.seed, mean, std) 189 | 190 | if current_bppo_score > best_bppo_score: 191 | best_bppo_score = current_bppo_score 192 | print('best_bppo_score:',best_bppo_score,'-------------------------') 193 | bppo.save(best_bppo_path) 194 | np.savetxt(os.path.join(path, current_time, 'best_bppo.csv'), [best_bppo_score], fmt='%f', delimiter=',') 195 | 196 | if args.is_update_old_policy: 197 | bppo.set_old_policy() 198 | 199 | if args.is_offpolicy_update: 200 | for _ in tqdm(range(int(args.q_pi_steps)), desc='Q_pi updating ......'): 201 | Q_pi_loss = Q_pi.update(replay_buffer, bppo) 202 | 203 | Q = Q_pi 204 | 205 | print(f"Step: {step}, Loss: {bppo_loss:.4f}, Score: {current_bppo_score:.4f}") 206 | logger.add_scalar('bppo_loss', bppo_loss, global_step=(step+1)) 207 | logger.add_scalar('bppo_score', current_bppo_score, global_step=(step+1)) 208 | 209 | logger.close() 210 | -------------------------------------------------------------------------------- /net.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.distributions import Normal 4 | from typing import Tuple 5 | 6 | 7 | def soft_clamp( 8 | x: torch.Tensor, bound: tuple 9 | ) -> torch.Tensor: 10 | low, high = bound 11 | #x = torch.tanh(x) 12 | x = low + 0.5 * (high - low) * (x + 1) 13 | return x 14 | 15 | 16 | def MLP( 17 | input_dim: int, 18 | hidden_dim: int, 19 | depth: int, 20 | output_dim: int, 21 | final_activation: str 22 | ) -> torch.nn.modules.container.Sequential: 23 | 24 | layers = [nn.Linear(input_dim, hidden_dim), nn.ReLU()] 25 | for _ in range(depth -1): 26 | layers.append(nn.Linear(hidden_dim, hidden_dim)) 27 | layers.append(nn.ReLU()) 28 | layers.append(nn.Linear(hidden_dim, output_dim)) 29 | if final_activation == 'relu': 30 | layers.append(nn.ReLU()) 31 | elif final_activation == 'tanh': 32 | layers.append(nn.Tanh()) 33 | 34 | return nn.Sequential(*layers) 35 | 36 | 37 | 38 | class ValueMLP(nn.Module): 39 | _net: torch.nn.modules.container.Sequential 40 | 41 | def __init__( 42 | self, state_dim: int, hidden_dim: int, depth: int 43 | ) -> None: 44 | super().__init__() 45 | self._net = MLP(state_dim, hidden_dim, depth, 1, 'relu') 46 | 47 | def forward( 48 | self, s: torch.Tensor 49 | ) -> torch.Tensor: 50 | return self._net(s) 51 | 52 | 53 | 54 | class QMLP(nn.Module): 55 | _net: torch.nn.modules.container.Sequential 56 | 57 | def __init__( 58 | self, 59 | state_dim: int, action_dim: int, hidden_dim: int, depth:int 60 | ) -> None: 61 | super().__init__() 62 | self._net = MLP((state_dim + action_dim), hidden_dim, depth, 1, 'relu') 63 | 64 | def forward( 65 | self, s: torch.Tensor, a: torch.Tensor 66 | ) -> Tuple[torch.Tensor, torch.Tensor]: 67 | sa = torch.cat([s, a], dim=1) 68 | return self._net(sa) 69 | 70 | 71 | 72 | class GaussPolicyMLP(nn.Module): 73 | _net: torch.nn.modules.container.Sequential 74 | _log_std_bound: tuple 75 | 76 | def __init__( 77 | self, 78 | state_dim: int, hidden_dim: int, depth: int, action_dim: int, 79 | ) -> None: 80 | super().__init__() 81 | self._net = MLP(state_dim, hidden_dim, depth, (2 * action_dim), 'tanh') 82 | self._log_std_bound = (-5., 0.) 83 | 84 | 85 | def forward( 86 | self, s: torch.Tensor 87 | ) -> torch.distributions: 88 | 89 | mu, log_std = self._net(s).chunk(2, dim=-1) 90 | log_std = soft_clamp(log_std, self._log_std_bound) 91 | std = log_std.exp() 92 | 93 | dist = Normal(mu, std) 94 | return dist 95 | -------------------------------------------------------------------------------- /ppo.py: -------------------------------------------------------------------------------- 1 | import gym 2 | import torch 3 | import numpy as np 4 | from copy import deepcopy 5 | 6 | from buffer import OnlineReplayBuffer 7 | from net import GaussPolicyMLP 8 | from critic import ValueLearner, QLearner 9 | from utils import orthogonal_initWeights, log_prob_func 10 | 11 | 12 | 13 | class ProximalPolicyOptimization: 14 | _device: torch.device 15 | _policy: GaussPolicyMLP 16 | _optimizer: torch.optim 17 | _policy_lr: float 18 | _old_policy: GaussPolicyMLP 19 | _scheduler: torch.optim 20 | _clip_ratio: float 21 | _entropy_weight: float 22 | _decay: float 23 | _omega: float 24 | _batch_size: int 25 | 26 | 27 | def __init__( 28 | self, 29 | device: torch.device, 30 | state_dim: int, 31 | hidden_dim: int, 32 | depth: int, 33 | action_dim: int, 34 | policy_lr: float, 35 | clip_ratio: float, 36 | entropy_weight: float, 37 | decay: float, 38 | omega: float, 39 | batch_size: int 40 | ) -> None: 41 | super().__init__() 42 | self._device = device 43 | self._policy = GaussPolicyMLP(state_dim, hidden_dim, depth, action_dim).to(device) 44 | orthogonal_initWeights(self._policy) 45 | self._optimizer = torch.optim.Adam( 46 | self._policy.parameters(), 47 | lr=policy_lr 48 | ) 49 | self._policy_lr = policy_lr 50 | self._old_policy = deepcopy(self._policy) 51 | self._scheduler = torch.optim.lr_scheduler.StepLR( 52 | self._optimizer, 53 | step_size=2, 54 | gamma=0.98 55 | ) 56 | 57 | self._clip_ratio = clip_ratio 58 | self._entropy_weight = entropy_weight 59 | self._decay = decay 60 | self._omega = omega 61 | self._batch_size = batch_size 62 | 63 | 64 | def weighted_advantage( 65 | self, 66 | advantage: torch.Tensor 67 | ) -> torch.Tensor: 68 | if self._omega == 0.5: 69 | return advantage 70 | else: 71 | weight = torch.zeros_like(advantage) 72 | index = torch.where(advantage > 0)[0] 73 | weight[index] = self._omega 74 | weight[torch.where(weight == 0)[0]] = 1 - self._omega 75 | weight.to(self._device) 76 | return weight * advantage 77 | 78 | 79 | def loss( 80 | self, 81 | replay_buffer: OnlineReplayBuffer, 82 | Q: QLearner, 83 | value: ValueLearner, 84 | is_clip_decay: bool, 85 | ) -> torch.Tensor: 86 | # -------------------------------------Advantage------------------------------------- 87 | s, a, _, _, _, _, _, advantage = replay_buffer.sample(self._batch_size) 88 | old_dist = self._old_policy(s) 89 | # -------------------------------------Advantage------------------------------------- 90 | new_dist = self._policy(s) 91 | 92 | new_log_prob = log_prob_func(new_dist, a) 93 | old_log_prob = log_prob_func(old_dist, a) 94 | ratio = (new_log_prob - old_log_prob).exp() 95 | 96 | advantage = self.weighted_advantage(advantage) 97 | 98 | loss1 = ratio * advantage 99 | 100 | if is_clip_decay: 101 | self._clip_ratio = self._clip_ratio * self._decay 102 | else: 103 | self._clip_ratio = self._clip_ratio 104 | 105 | loss2 = torch.clamp(ratio, 1 - self._clip_ratio, 1 + self._clip_ratio) * advantage 106 | 107 | entropy_loss = new_dist.entropy().sum(-1, keepdim=True) * self._entropy_weight 108 | 109 | loss = -(torch.min(loss1, loss2) + entropy_loss).mean() 110 | 111 | return loss 112 | 113 | 114 | def update( 115 | self, 116 | replay_buffer: OnlineReplayBuffer, 117 | Q: QLearner, 118 | value: ValueLearner, 119 | is_clip_decay: bool, 120 | is_lr_decay: bool 121 | ) -> float: 122 | policy_loss = self.loss(replay_buffer, Q, value, is_clip_decay) 123 | 124 | self._optimizer.zero_grad() 125 | policy_loss.backward() 126 | torch.nn.utils.clip_grad_norm_(self._policy.parameters(), 0.5) 127 | self._optimizer.step() 128 | 129 | if is_lr_decay: 130 | self._scheduler.step() 131 | return policy_loss.item() 132 | 133 | 134 | def select_action( 135 | self, s: torch.Tensor, is_sample: bool 136 | ) -> torch.Tensor: 137 | dist = self._policy(s) 138 | if is_sample: 139 | action = dist.sample() 140 | else: 141 | action = dist.mean 142 | # clip 143 | action = action.clamp(-1., 1.) 144 | return action 145 | 146 | 147 | def evaluate( 148 | self, 149 | env_name: str, 150 | seed: int, 151 | mean: np.ndarray, 152 | std: np.ndarray, 153 | eval_episodes: int=10 154 | ) -> float: 155 | env = gym.make(env_name) 156 | env.seed(seed) 157 | 158 | total_reward = 0 159 | for _ in range(eval_episodes): 160 | s, done = env.reset(), False 161 | while not done: 162 | s = torch.FloatTensor((np.array(s).reshape(1, -1) - mean) / std).to(self._device) 163 | a = self.select_action(s, is_sample=False).cpu().data.numpy().flatten() 164 | s, r, done, _ = env.step(a) 165 | total_reward += r 166 | 167 | avg_reward = total_reward / eval_episodes 168 | return avg_reward 169 | 170 | 171 | def save( 172 | self, path: str 173 | ) -> None: 174 | torch.save(self._policy.state_dict(), path) 175 | print('Policy parameters saved in {}'.format(path)) 176 | 177 | 178 | def load( 179 | self, path: str 180 | ) -> None: 181 | self._policy.load_state_dict(torch.load(path, map_location=self._device)) 182 | self._old_policy.load_state_dict(self._policy.state_dict()) 183 | #self._optimizer = torch.optim.Adam(self._policy.parameters(), lr=self._policy_lr) 184 | print('Policy parameters loaded') 185 | 186 | def set_old_policy( 187 | self, 188 | ) -> None: 189 | self._old_policy.load_state_dict(self._policy.state_dict()) 190 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.distributions import Distribution 4 | 5 | CONST_EPS = 1e-10 6 | 7 | 8 | def orthogonal_initWeights( 9 | net: nn.Module, 10 | ) -> None: 11 | for e in net.parameters(): 12 | if len(e.size()) >= 2: 13 | nn.init.orthogonal_(e) 14 | 15 | 16 | def log_prob_func( 17 | dist: Distribution, action: torch.Tensor 18 | ) -> torch.Tensor: 19 | log_prob = dist.log_prob(action) 20 | if len(log_prob.shape) == 1: 21 | return log_prob 22 | else: 23 | return log_prob.sum(-1, keepdim=True) --------------------------------------------------------------------------------