├── LICENSE ├── Old_Version-PPO-Continuous-with-Gym0.18.3.zip ├── PPO.py ├── README.md ├── main.py ├── model ├── PV1_actor100.pth └── PV1_q_critic100.pth ├── ppo_result.jpg ├── render_gif ├── PV1.gif ├── lldcV2.gif ├── lldc_ppoc.png └── pv1_ppoc.png ├── runs ├── BWv3 2021-11-04 20_14 │ └── events.out.tfevents.1636028042.localhost.localdomain.31736.0 ├── HCv2 2021-11-04 19_33 │ └── events.out.tfevents.1636025607.localhost.localdomain.23247.0 ├── HCv4 2023-11-17 22_42 │ └── events.out.tfevents.1700232152.kim.6672.0 ├── HCv4 2023-11-18 01_58 │ └── events.out.tfevents.1700243918.kim.7102.0 ├── HCv4 2023-11-18 05_14 │ └── events.out.tfevents.1700255698.kim.7476.0 ├── HCv4 2023-11-18 08_30 │ └── events.out.tfevents.1700267414.kim.8013.0 ├── Humanv2 2021-11-04 19_33 │ └── events.out.tfevents.1636025604.localhost.localdomain.23206.0 ├── Humanv4 2023-11-17 20_13 │ └── events.out.tfevents.1700223188.kim.6379.0 ├── Humanv4 2023-11-17 23_37 │ └── events.out.tfevents.1700235436.kim.6843.0 ├── Humanv4 2023-11-18 02_53 │ └── events.out.tfevents.1700247217.kim.7267.0 ├── Humanv4 2023-11-18 06_09 │ └── events.out.tfevents.1700258999.kim.7648.0 ├── Humanv4 2023-11-18 09_25 │ └── events.out.tfevents.1700270700.kim.8175.0 ├── Lch_Cv2 2021-11-04 19_32 │ └── events.out.tfevents.1636025576.localhost.localdomain.23162.0 ├── Lch_Cv2 2021-11-04 19_48 │ └── events.out.tfevents.1636026535.localhost.localdomain.26058.0 └── PV0 2021-11-04 19_33 │ └── events.out.tfevents.1636025600.localhost.localdomain.23193.0 └── utils.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 XinJingHao 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /Old_Version-PPO-Continuous-with-Gym0.18.3.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XinJingHao/PPO-Continuous-Pytorch/388113545f1edf04dbd0e0f605c5ceddba3ab131/Old_Version-PPO-Continuous-with-Gym0.18.3.zip -------------------------------------------------------------------------------- /PPO.py: -------------------------------------------------------------------------------- 1 | from utils import BetaActor, GaussianActor_musigma, GaussianActor_mu, Critic 2 | import numpy as np 3 | import copy 4 | import torch 5 | import math 6 | 7 | 8 | class PPO_agent(object): 9 | def __init__(self, **kwargs): 10 | # Init hyperparameters for PPO agent, just like "self.gamma = opt.gamma, self.lambd = opt.lambd, ..." 11 | self.__dict__.update(kwargs) 12 | 13 | # Choose distribution for the actor 14 | if self.Distribution == 'Beta': 15 | self.actor = BetaActor(self.state_dim, self.action_dim, self.net_width).to(self.dvc) 16 | elif self.Distribution == 'GS_ms': 17 | self.actor = GaussianActor_musigma(self.state_dim, self.action_dim, self.net_width).to(self.dvc) 18 | elif self.Distribution == 'GS_m': 19 | self.actor = GaussianActor_mu(self.state_dim, self.action_dim, self.net_width).to(self.dvc) 20 | else: print('Dist Error') 21 | self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), lr=self.a_lr) 22 | 23 | # Build Critic 24 | self.critic = Critic(self.state_dim, self.net_width).to(self.dvc) 25 | self.critic_optimizer = torch.optim.Adam(self.critic.parameters(), lr=self.c_lr) 26 | 27 | # Build Trajectory holder 28 | self.s_hoder = np.zeros((self.T_horizon, self.state_dim),dtype=np.float32) 29 | self.a_hoder = np.zeros((self.T_horizon, self.action_dim),dtype=np.float32) 30 | self.r_hoder = np.zeros((self.T_horizon, 1),dtype=np.float32) 31 | self.s_next_hoder = np.zeros((self.T_horizon, self.state_dim),dtype=np.float32) 32 | self.logprob_a_hoder = np.zeros((self.T_horizon, self.action_dim),dtype=np.float32) 33 | self.done_hoder = np.zeros((self.T_horizon, 1),dtype=np.bool_) 34 | self.dw_hoder = np.zeros((self.T_horizon, 1),dtype=np.bool_) 35 | 36 | def select_action(self, state, deterministic): 37 | with torch.no_grad(): 38 | state = torch.FloatTensor(state.reshape(1, -1)).to(self.dvc) 39 | if deterministic: 40 | # only used when evaluate the policy.Making the performance more stable 41 | a = self.actor.deterministic_act(state) 42 | return a.cpu().numpy()[0], None # action is in shape (adim, 0) 43 | else: 44 | # only used when interact with the env 45 | dist = self.actor.get_dist(state) 46 | a = dist.sample() 47 | a = torch.clamp(a, 0, 1) 48 | logprob_a = dist.log_prob(a).cpu().numpy().flatten() 49 | return a.cpu().numpy()[0], logprob_a # both are in shape (adim, 0) 50 | 51 | 52 | def train(self): 53 | self.entropy_coef*=self.entropy_coef_decay 54 | 55 | '''Prepare PyTorch data from Numpy data''' 56 | s = torch.from_numpy(self.s_hoder).to(self.dvc) 57 | a = torch.from_numpy(self.a_hoder).to(self.dvc) 58 | r = torch.from_numpy(self.r_hoder).to(self.dvc) 59 | s_next = torch.from_numpy(self.s_next_hoder).to(self.dvc) 60 | logprob_a = torch.from_numpy(self.logprob_a_hoder).to(self.dvc) 61 | done = torch.from_numpy(self.done_hoder).to(self.dvc) 62 | dw = torch.from_numpy(self.dw_hoder).to(self.dvc) 63 | 64 | ''' Use TD+GAE+LongTrajectory to compute Advantage and TD target''' 65 | with torch.no_grad(): 66 | vs = self.critic(s) 67 | vs_ = self.critic(s_next) 68 | 69 | '''dw for TD_target and Adv''' 70 | deltas = r + self.gamma * vs_ * (~dw) - vs 71 | deltas = deltas.cpu().flatten().numpy() 72 | adv = [0] 73 | 74 | '''done for GAE''' 75 | for dlt, mask in zip(deltas[::-1], done.cpu().flatten().numpy()[::-1]): 76 | advantage = dlt + self.gamma * self.lambd * adv[-1] * (~mask) 77 | adv.append(advantage) 78 | adv.reverse() 79 | adv = copy.deepcopy(adv[0:-1]) 80 | adv = torch.tensor(adv).unsqueeze(1).float().to(self.dvc) 81 | td_target = adv + vs 82 | adv = (adv - adv.mean()) / ((adv.std()+1e-4)) #sometimes helps 83 | 84 | 85 | """Slice long trajectopy into short trajectory and perform mini-batch PPO update""" 86 | a_optim_iter_num = int(math.ceil(s.shape[0] / self.a_optim_batch_size)) 87 | c_optim_iter_num = int(math.ceil(s.shape[0] / self.c_optim_batch_size)) 88 | for i in range(self.K_epochs): 89 | 90 | #Shuffle the trajectory, Good for training 91 | perm = np.arange(s.shape[0]) 92 | np.random.shuffle(perm) 93 | perm = torch.LongTensor(perm).to(self.dvc) 94 | s, a, td_target, adv, logprob_a = \ 95 | s[perm].clone(), a[perm].clone(), td_target[perm].clone(), adv[perm].clone(), logprob_a[perm].clone() 96 | 97 | '''update the actor''' 98 | for i in range(a_optim_iter_num): 99 | index = slice(i * self.a_optim_batch_size, min((i + 1) * self.a_optim_batch_size, s.shape[0])) 100 | distribution = self.actor.get_dist(s[index]) 101 | dist_entropy = distribution.entropy().sum(1, keepdim=True) 102 | logprob_a_now = distribution.log_prob(a[index]) 103 | ratio = torch.exp(logprob_a_now.sum(1,keepdim=True) - logprob_a[index].sum(1,keepdim=True)) # a/b == exp(log(a)-log(b)) 104 | 105 | surr1 = ratio * adv[index] 106 | surr2 = torch.clamp(ratio, 1 - self.clip_rate, 1 + self.clip_rate) * adv[index] 107 | a_loss = -torch.min(surr1, surr2) - self.entropy_coef * dist_entropy 108 | 109 | self.actor_optimizer.zero_grad() 110 | a_loss.mean().backward() 111 | torch.nn.utils.clip_grad_norm_(self.actor.parameters(), 40) 112 | self.actor_optimizer.step() 113 | 114 | '''update the critic''' 115 | for i in range(c_optim_iter_num): 116 | index = slice(i * self.c_optim_batch_size, min((i + 1) * self.c_optim_batch_size, s.shape[0])) 117 | c_loss = (self.critic(s[index]) - td_target[index]).pow(2).mean() 118 | for name,param in self.critic.named_parameters(): 119 | if 'weight' in name: 120 | c_loss += param.pow(2).sum() * self.l2_reg 121 | 122 | self.critic_optimizer.zero_grad() 123 | c_loss.backward() 124 | self.critic_optimizer.step() 125 | 126 | def put_data(self, s, a, r, s_next, logprob_a, done, dw, idx): 127 | self.s_hoder[idx] = s 128 | self.a_hoder[idx] = a 129 | self.r_hoder[idx] = r 130 | self.s_next_hoder[idx] = s_next 131 | self.logprob_a_hoder[idx] = logprob_a 132 | self.done_hoder[idx] = done 133 | self.dw_hoder[idx] = dw 134 | 135 | def save(self,EnvName, timestep): 136 | torch.save(self.actor.state_dict(), "./model/{}_actor{}.pth".format(EnvName,timestep)) 137 | torch.save(self.critic.state_dict(), "./model/{}_q_critic{}.pth".format(EnvName,timestep)) 138 | 139 | def load(self,EnvName, timestep): 140 | self.actor.load_state_dict(torch.load("./model/{}_actor{}.pth".format(EnvName, timestep), map_location=self.dvc)) 141 | self.critic.load_state_dict(torch.load("./model/{}_q_critic{}.pth".format(EnvName, timestep), map_location=self.dvc)) 142 | 143 | 144 | 145 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # PPO-Continuous-Pytorch 2 | **A clean and robust Pytorch implementation of PPO on continuous action space**: 3 | 4 | 5 | Pendulum | LunarLanderContinuous 6 | :-----------------------:|:-----------------------:| 7 | | 8 | | 9 | 10 | 11 | 12 | **Other RL algorithms by Pytorch can be found [here](https://github.com/XinJingHao/RL-Algorithms-by-Pytorch).** 13 | 14 | ## Dependencies 15 | ```python 16 | gymnasium==0.29.1 17 | numpy==1.26.1 18 | pytorch==2.1.0 19 | 20 | python==3.11.5 21 | ``` 22 | 23 | ## How to use my code 24 | ### Train from scratch 25 | ```bash 26 | python main.py 27 | ``` 28 | where the default enviroment is 'Pendulum'. 29 | 30 | ### Play with trained model 31 | ```bash 32 | python main.py --EnvIdex 0 --render True --Loadmodel True --ModelIdex 100 33 | ``` 34 | which will render the 'Pendulum'. 35 | 36 | ### Change Enviroment 37 | If you want to train on different enviroments, just run 38 | ```bash 39 | python main.py --EnvIdex 1 40 | ``` 41 | The ```--EnvIdex``` can be set to be 0~5, where 42 | ```bash 43 | '--EnvIdex 0' for 'Pendulum-v1' 44 | '--EnvIdex 1' for 'LunarLanderContinuous-v2' 45 | '--EnvIdex 2' for 'Humanoid-v4' 46 | '--EnvIdex 3' for 'HalfCheetah-v4' 47 | '--EnvIdex 4' for 'BipedalWalker-v3' 48 | '--EnvIdex 5' for 'BipedalWalkerHardcore-v3' 49 | ``` 50 | 51 | Note: if you want train on **BipedalWalker, BipedalWalkerHardcore, or LunarLanderContinuous**, you need to install [box2d-py](https://gymnasium.farama.org/environments/box2d/) first. You can install box2d-py via: 52 | ```bash 53 | pip install gymnasium[box2d] 54 | ``` 55 | 56 | if you want train on **Humanoid or HalfCheetah**, you need to install [MuJoCo](https://gymnasium.farama.org/environments/mujoco/) first. You can install MuJoCo via: 57 | ```bash 58 | pip install mujoco 59 | pip install gymnasium[mujoco] 60 | ``` 61 | 62 | ### Visualize the training curve 63 | You can use the [tensorboard](https://pytorch.org/docs/stable/tensorboard.html) to record anv visualize the training curve. 64 | 65 | - Installation (please make sure PyTorch is installed already): 66 | ```bash 67 | pip install tensorboard 68 | pip install packaging 69 | ``` 70 | - Record (the training curves will be saved at '**\runs**'): 71 | ```bash 72 | python main.py --write True 73 | ``` 74 | 75 | - Visualization: 76 | ```bash 77 | tensorboard --logdir runs 78 | ``` 79 | 80 | ### Hyperparameter Setting 81 | For more details of Hyperparameter Setting, please check 'main.py' 82 | 83 | ### References 84 | [Proximal Policy Optimization Algorithms](https://arxiv.org/pdf/1707.06347.pdf) 85 | [Emergence of Locomotion Behaviours in Rich Environments](https://arxiv.org/pdf/1707.02286.pdf) 86 | 87 | ## Training Curves 88 | ![avatar](https://github.com/XinJingHao/PPO-Continuous-Pytorch/blob/main/ppo_result.jpg) 89 | All the experiments are trained with same hyperparameters (see main.py). 90 | 91 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime 2 | import os, shutil 3 | import argparse 4 | import torch 5 | import gymnasium as gym 6 | 7 | from utils import str2bool, Action_adapter, Reward_adapter, evaluate_policy 8 | from PPO import PPO_agent 9 | 10 | 11 | '''Hyperparameter Setting''' 12 | parser = argparse.ArgumentParser() 13 | parser.add_argument('--dvc', type=str, default='cuda', help='running device: cuda or cpu') 14 | parser.add_argument('--EnvIdex', type=int, default=0, help='PV1, Lch_Cv2, Humanv4, HCv4, BWv3, BWHv3') 15 | parser.add_argument('--write', type=str2bool, default=False, help='Use SummaryWriter to record the training') 16 | parser.add_argument('--render', type=str2bool, default=False, help='Render or Not') 17 | parser.add_argument('--Loadmodel', type=str2bool, default=False, help='Load pretrained model or Not') 18 | parser.add_argument('--ModelIdex', type=int, default=100, help='which model to load') 19 | 20 | parser.add_argument('--seed', type=int, default=0, help='random seed') 21 | parser.add_argument('--T_horizon', type=int, default=2048, help='lenth of long trajectory') 22 | parser.add_argument('--Distribution', type=str, default='Beta', help='Should be one of Beta ; GS_ms ; GS_m') 23 | parser.add_argument('--Max_train_steps', type=int, default=int(5e7), help='Max training steps') 24 | parser.add_argument('--save_interval', type=int, default=int(5e5), help='Model saving interval, in steps.') 25 | parser.add_argument('--eval_interval', type=int, default=int(5e3), help='Model evaluating interval, in steps.') 26 | 27 | parser.add_argument('--gamma', type=float, default=0.99, help='Discounted Factor') 28 | parser.add_argument('--lambd', type=float, default=0.95, help='GAE Factor') 29 | parser.add_argument('--clip_rate', type=float, default=0.2, help='PPO Clip rate') 30 | parser.add_argument('--K_epochs', type=int, default=10, help='PPO update times') 31 | parser.add_argument('--net_width', type=int, default=150, help='Hidden net width') 32 | parser.add_argument('--a_lr', type=float, default=2e-4, help='Learning rate of actor') 33 | parser.add_argument('--c_lr', type=float, default=2e-4, help='Learning rate of critic') 34 | parser.add_argument('--l2_reg', type=float, default=1e-3, help='L2 regulization coefficient for Critic') 35 | parser.add_argument('--a_optim_batch_size', type=int, default=64, help='lenth of sliced trajectory of actor') 36 | parser.add_argument('--c_optim_batch_size', type=int, default=64, help='lenth of sliced trajectory of critic') 37 | parser.add_argument('--entropy_coef', type=float, default=1e-3, help='Entropy coefficient of Actor') 38 | parser.add_argument('--entropy_coef_decay', type=float, default=0.99, help='Decay rate of entropy_coef') 39 | opt = parser.parse_args() 40 | opt.dvc = torch.device(opt.dvc) # from str to torch.device 41 | print(opt) 42 | 43 | 44 | def main(): 45 | EnvName = ['Pendulum-v1','LunarLanderContinuous-v2','Humanoid-v4','HalfCheetah-v4','BipedalWalker-v3','BipedalWalkerHardcore-v3'] 46 | BrifEnvName = ['PV1', 'LLdV2', 'Humanv4', 'HCv4','BWv3', 'BWHv3'] 47 | 48 | # Build Env 49 | env = gym.make(EnvName[opt.EnvIdex], render_mode = "human" if opt.render else None) 50 | eval_env = gym.make(EnvName[opt.EnvIdex]) 51 | opt.state_dim = env.observation_space.shape[0] 52 | opt.action_dim = env.action_space.shape[0] 53 | opt.max_action = float(env.action_space.high[0]) 54 | opt.max_steps = env._max_episode_steps 55 | print('Env:',EnvName[opt.EnvIdex],' state_dim:',opt.state_dim,' action_dim:',opt.action_dim, 56 | ' max_a:',opt.max_action,' min_a:',env.action_space.low[0], 'max_steps', opt.max_steps) 57 | 58 | # Seed Everything 59 | env_seed = opt.seed 60 | torch.manual_seed(opt.seed) 61 | torch.cuda.manual_seed(opt.seed) 62 | torch.backends.cudnn.deterministic = True 63 | torch.backends.cudnn.benchmark = False 64 | print("Random Seed: {}".format(opt.seed)) 65 | 66 | # Use tensorboard to record training curves 67 | if opt.write: 68 | from torch.utils.tensorboard import SummaryWriter 69 | timenow = str(datetime.now())[0:-10] 70 | timenow = ' ' + timenow[0:13] + '_' + timenow[-2::] 71 | writepath = 'runs/{}'.format(BrifEnvName[opt.EnvIdex]) + timenow 72 | if os.path.exists(writepath): shutil.rmtree(writepath) 73 | writer = SummaryWriter(log_dir=writepath) 74 | 75 | # Beta dist maybe need larger learning rate, Sometimes helps 76 | # if Dist[distnum] == 'Beta' : 77 | # kwargs["a_lr"] *= 2 78 | # kwargs["c_lr"] *= 4 79 | 80 | if not os.path.exists('model'): os.mkdir('model') 81 | agent = PPO_agent(**vars(opt)) # transfer opt to dictionary, and use it to init PPO_agent 82 | if opt.Loadmodel: agent.load(BrifEnvName[opt.EnvIdex], opt.ModelIdex) 83 | 84 | if opt.render: 85 | while True: 86 | ep_r = evaluate_policy(env, agent, opt.max_action, 1) 87 | print(f'Env:{EnvName[opt.EnvIdex]}, Episode Reward:{ep_r}') 88 | else: 89 | traj_lenth, total_steps = 0, 0 90 | while total_steps < opt.Max_train_steps: 91 | s, info = env.reset(seed=env_seed) # Do not use opt.seed directly, or it can overfit to opt.seed 92 | env_seed += 1 93 | done = False 94 | 95 | '''Interact & trian''' 96 | while not done: 97 | '''Interact with Env''' 98 | a, logprob_a = agent.select_action(s, deterministic=False) # use stochastic when training 99 | act = Action_adapter(a,opt.max_action) #[0,1] to [-max,max] 100 | s_next, r, dw, tr, info = env.step(act) # dw: dead&win; tr: truncated 101 | r = Reward_adapter(r, opt.EnvIdex) 102 | done = (dw or tr) 103 | 104 | '''Store the current transition''' 105 | agent.put_data(s, a, r, s_next, logprob_a, done, dw, idx = traj_lenth) 106 | s = s_next 107 | 108 | traj_lenth += 1 109 | total_steps += 1 110 | 111 | '''Update if its time''' 112 | if traj_lenth % opt.T_horizon == 0: 113 | agent.train() 114 | traj_lenth = 0 115 | 116 | '''Record & log''' 117 | if total_steps % opt.eval_interval == 0: 118 | score = evaluate_policy(eval_env, agent, opt.max_action, turns=3) # evaluate the policy for 3 times, and get averaged result 119 | if opt.write: writer.add_scalar('ep_r', score, global_step=total_steps) 120 | print('EnvName:',EnvName[opt.EnvIdex],'seed:',opt.seed,'steps: {}k'.format(int(total_steps/1000)),'score:', score) 121 | 122 | '''Save model''' 123 | if total_steps % opt.save_interval==0: 124 | agent.save(BrifEnvName[opt.EnvIdex], int(total_steps/1000)) 125 | 126 | env.close() 127 | eval_env.close() 128 | 129 | if __name__ == '__main__': 130 | main() 131 | 132 | 133 | 134 | 135 | 136 | -------------------------------------------------------------------------------- /model/PV1_actor100.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XinJingHao/PPO-Continuous-Pytorch/388113545f1edf04dbd0e0f605c5ceddba3ab131/model/PV1_actor100.pth -------------------------------------------------------------------------------- /model/PV1_q_critic100.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XinJingHao/PPO-Continuous-Pytorch/388113545f1edf04dbd0e0f605c5ceddba3ab131/model/PV1_q_critic100.pth -------------------------------------------------------------------------------- /ppo_result.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XinJingHao/PPO-Continuous-Pytorch/388113545f1edf04dbd0e0f605c5ceddba3ab131/ppo_result.jpg -------------------------------------------------------------------------------- /render_gif/PV1.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XinJingHao/PPO-Continuous-Pytorch/388113545f1edf04dbd0e0f605c5ceddba3ab131/render_gif/PV1.gif -------------------------------------------------------------------------------- /render_gif/lldcV2.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XinJingHao/PPO-Continuous-Pytorch/388113545f1edf04dbd0e0f605c5ceddba3ab131/render_gif/lldcV2.gif -------------------------------------------------------------------------------- /render_gif/lldc_ppoc.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XinJingHao/PPO-Continuous-Pytorch/388113545f1edf04dbd0e0f605c5ceddba3ab131/render_gif/lldc_ppoc.png -------------------------------------------------------------------------------- /render_gif/pv1_ppoc.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XinJingHao/PPO-Continuous-Pytorch/388113545f1edf04dbd0e0f605c5ceddba3ab131/render_gif/pv1_ppoc.png -------------------------------------------------------------------------------- /runs/BWv3 2021-11-04 20_14/events.out.tfevents.1636028042.localhost.localdomain.31736.0: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XinJingHao/PPO-Continuous-Pytorch/388113545f1edf04dbd0e0f605c5ceddba3ab131/runs/BWv3 2021-11-04 20_14/events.out.tfevents.1636028042.localhost.localdomain.31736.0 -------------------------------------------------------------------------------- /runs/HCv2 2021-11-04 19_33/events.out.tfevents.1636025607.localhost.localdomain.23247.0: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XinJingHao/PPO-Continuous-Pytorch/388113545f1edf04dbd0e0f605c5ceddba3ab131/runs/HCv2 2021-11-04 19_33/events.out.tfevents.1636025607.localhost.localdomain.23247.0 -------------------------------------------------------------------------------- /runs/HCv4 2023-11-17 22_42/events.out.tfevents.1700232152.kim.6672.0: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XinJingHao/PPO-Continuous-Pytorch/388113545f1edf04dbd0e0f605c5ceddba3ab131/runs/HCv4 2023-11-17 22_42/events.out.tfevents.1700232152.kim.6672.0 -------------------------------------------------------------------------------- /runs/HCv4 2023-11-18 01_58/events.out.tfevents.1700243918.kim.7102.0: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XinJingHao/PPO-Continuous-Pytorch/388113545f1edf04dbd0e0f605c5ceddba3ab131/runs/HCv4 2023-11-18 01_58/events.out.tfevents.1700243918.kim.7102.0 -------------------------------------------------------------------------------- /runs/HCv4 2023-11-18 05_14/events.out.tfevents.1700255698.kim.7476.0: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XinJingHao/PPO-Continuous-Pytorch/388113545f1edf04dbd0e0f605c5ceddba3ab131/runs/HCv4 2023-11-18 05_14/events.out.tfevents.1700255698.kim.7476.0 -------------------------------------------------------------------------------- /runs/HCv4 2023-11-18 08_30/events.out.tfevents.1700267414.kim.8013.0: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XinJingHao/PPO-Continuous-Pytorch/388113545f1edf04dbd0e0f605c5ceddba3ab131/runs/HCv4 2023-11-18 08_30/events.out.tfevents.1700267414.kim.8013.0 -------------------------------------------------------------------------------- /runs/Humanv2 2021-11-04 19_33/events.out.tfevents.1636025604.localhost.localdomain.23206.0: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XinJingHao/PPO-Continuous-Pytorch/388113545f1edf04dbd0e0f605c5ceddba3ab131/runs/Humanv2 2021-11-04 19_33/events.out.tfevents.1636025604.localhost.localdomain.23206.0 -------------------------------------------------------------------------------- /runs/Humanv4 2023-11-17 20_13/events.out.tfevents.1700223188.kim.6379.0: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XinJingHao/PPO-Continuous-Pytorch/388113545f1edf04dbd0e0f605c5ceddba3ab131/runs/Humanv4 2023-11-17 20_13/events.out.tfevents.1700223188.kim.6379.0 -------------------------------------------------------------------------------- /runs/Humanv4 2023-11-17 23_37/events.out.tfevents.1700235436.kim.6843.0: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XinJingHao/PPO-Continuous-Pytorch/388113545f1edf04dbd0e0f605c5ceddba3ab131/runs/Humanv4 2023-11-17 23_37/events.out.tfevents.1700235436.kim.6843.0 -------------------------------------------------------------------------------- /runs/Humanv4 2023-11-18 02_53/events.out.tfevents.1700247217.kim.7267.0: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XinJingHao/PPO-Continuous-Pytorch/388113545f1edf04dbd0e0f605c5ceddba3ab131/runs/Humanv4 2023-11-18 02_53/events.out.tfevents.1700247217.kim.7267.0 -------------------------------------------------------------------------------- /runs/Humanv4 2023-11-18 06_09/events.out.tfevents.1700258999.kim.7648.0: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XinJingHao/PPO-Continuous-Pytorch/388113545f1edf04dbd0e0f605c5ceddba3ab131/runs/Humanv4 2023-11-18 06_09/events.out.tfevents.1700258999.kim.7648.0 -------------------------------------------------------------------------------- /runs/Humanv4 2023-11-18 09_25/events.out.tfevents.1700270700.kim.8175.0: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XinJingHao/PPO-Continuous-Pytorch/388113545f1edf04dbd0e0f605c5ceddba3ab131/runs/Humanv4 2023-11-18 09_25/events.out.tfevents.1700270700.kim.8175.0 -------------------------------------------------------------------------------- /runs/Lch_Cv2 2021-11-04 19_32/events.out.tfevents.1636025576.localhost.localdomain.23162.0: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XinJingHao/PPO-Continuous-Pytorch/388113545f1edf04dbd0e0f605c5ceddba3ab131/runs/Lch_Cv2 2021-11-04 19_32/events.out.tfevents.1636025576.localhost.localdomain.23162.0 -------------------------------------------------------------------------------- /runs/Lch_Cv2 2021-11-04 19_48/events.out.tfevents.1636026535.localhost.localdomain.26058.0: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XinJingHao/PPO-Continuous-Pytorch/388113545f1edf04dbd0e0f605c5ceddba3ab131/runs/Lch_Cv2 2021-11-04 19_48/events.out.tfevents.1636026535.localhost.localdomain.26058.0 -------------------------------------------------------------------------------- /runs/PV0 2021-11-04 19_33/events.out.tfevents.1636025600.localhost.localdomain.23193.0: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XinJingHao/PPO-Continuous-Pytorch/388113545f1edf04dbd0e0f605c5ceddba3ab131/runs/PV0 2021-11-04 19_33/events.out.tfevents.1636025600.localhost.localdomain.23193.0 -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.distributions import Beta,Normal 5 | 6 | 7 | class BetaActor(nn.Module): 8 | def __init__(self, state_dim, action_dim, net_width): 9 | super(BetaActor, self).__init__() 10 | 11 | self.l1 = nn.Linear(state_dim, net_width) 12 | self.l2 = nn.Linear(net_width, net_width) 13 | self.alpha_head = nn.Linear(net_width, action_dim) 14 | self.beta_head = nn.Linear(net_width, action_dim) 15 | 16 | def forward(self, state): 17 | a = torch.tanh(self.l1(state)) 18 | a = torch.tanh(self.l2(a)) 19 | 20 | alpha = F.softplus(self.alpha_head(a)) + 1.0 21 | beta = F.softplus(self.beta_head(a)) + 1.0 22 | 23 | return alpha,beta 24 | 25 | def get_dist(self,state): 26 | alpha,beta = self.forward(state) 27 | dist = Beta(alpha, beta) 28 | return dist 29 | 30 | def deterministic_act(self, state): 31 | alpha, beta = self.forward(state) 32 | mode = (alpha) / (alpha + beta) 33 | return mode 34 | 35 | class GaussianActor_musigma(nn.Module): 36 | def __init__(self, state_dim, action_dim, net_width): 37 | super(GaussianActor_musigma, self).__init__() 38 | 39 | self.l1 = nn.Linear(state_dim, net_width) 40 | self.l2 = nn.Linear(net_width, net_width) 41 | self.mu_head = nn.Linear(net_width, action_dim) 42 | self.sigma_head = nn.Linear(net_width, action_dim) 43 | 44 | def forward(self, state): 45 | a = torch.tanh(self.l1(state)) 46 | a = torch.tanh(self.l2(a)) 47 | mu = torch.sigmoid(self.mu_head(a)) 48 | sigma = F.softplus( self.sigma_head(a) ) 49 | return mu,sigma 50 | 51 | def get_dist(self, state): 52 | mu,sigma = self.forward(state) 53 | dist = Normal(mu,sigma) 54 | return dist 55 | 56 | def deterministic_act(self, state): 57 | mu, sigma = self.forward(state) 58 | return mu 59 | 60 | 61 | class GaussianActor_mu(nn.Module): 62 | def __init__(self, state_dim, action_dim, net_width, log_std=0): 63 | super(GaussianActor_mu, self).__init__() 64 | 65 | self.l1 = nn.Linear(state_dim, net_width) 66 | self.l2 = nn.Linear(net_width, net_width) 67 | self.mu_head = nn.Linear(net_width, action_dim) 68 | self.mu_head.weight.data.mul_(0.1) 69 | self.mu_head.bias.data.mul_(0.0) 70 | 71 | self.action_log_std = nn.Parameter(torch.ones(1, action_dim) * log_std) 72 | 73 | def forward(self, state): 74 | a = torch.relu(self.l1(state)) 75 | a = torch.relu(self.l2(a)) 76 | mu = torch.sigmoid(self.mu_head(a)) 77 | return mu 78 | 79 | def get_dist(self,state): 80 | mu = self.forward(state) 81 | action_log_std = self.action_log_std.expand_as(mu) 82 | action_std = torch.exp(action_log_std) 83 | 84 | dist = Normal(mu, action_std) 85 | return dist 86 | 87 | def deterministic_act(self, state): 88 | return self.forward(state) 89 | 90 | 91 | class Critic(nn.Module): 92 | def __init__(self, state_dim,net_width): 93 | super(Critic, self).__init__() 94 | 95 | self.C1 = nn.Linear(state_dim, net_width) 96 | self.C2 = nn.Linear(net_width, net_width) 97 | self.C3 = nn.Linear(net_width, 1) 98 | 99 | def forward(self, state): 100 | v = torch.tanh(self.C1(state)) 101 | v = torch.tanh(self.C2(v)) 102 | v = self.C3(v) 103 | return v 104 | 105 | def str2bool(v): 106 | '''transfer str to bool for argparse''' 107 | if isinstance(v, bool): 108 | return v 109 | if v.lower() in ('yes', 'True','true','TRUE', 't', 'y', '1'): 110 | return True 111 | elif v.lower() in ('no', 'False','false','FALSE', 'f', 'n', '0'): 112 | return False 113 | else: 114 | print('Wrong Input.') 115 | raise 116 | 117 | 118 | def Action_adapter(a,max_action): 119 | #from [0,1] to [-max,max] 120 | return 2*(a-0.5)*max_action 121 | 122 | def Reward_adapter(r, EnvIdex): 123 | # For BipedalWalker 124 | if EnvIdex == 0 or EnvIdex == 1: 125 | if r <= -100: r = -1 126 | # For Pendulum-v0 127 | elif EnvIdex == 3: 128 | r = (r + 8) / 8 129 | return r 130 | 131 | def evaluate_policy(env, agent, max_action, turns): 132 | total_scores = 0 133 | for j in range(turns): 134 | s, info = env.reset() 135 | done = False 136 | while not done: 137 | a, logprob_a = agent.select_action(s, deterministic=True) # Take deterministic actions when evaluation 138 | act = Action_adapter(a, max_action) # [0,1] to [-max,max] 139 | s_next, r, dw, tr, info = env.step(act) 140 | done = (dw or tr) 141 | 142 | total_scores += r 143 | s = s_next 144 | 145 | return total_scores/turns --------------------------------------------------------------------------------