├── LICENSE
├── Old_Version-PPO-Continuous-with-Gym0.18.3.zip
├── PPO.py
├── README.md
├── main.py
├── model
├── PV1_actor100.pth
└── PV1_q_critic100.pth
├── ppo_result.jpg
├── render_gif
├── PV1.gif
├── lldcV2.gif
├── lldc_ppoc.png
└── pv1_ppoc.png
├── runs
├── BWv3 2021-11-04 20_14
│ └── events.out.tfevents.1636028042.localhost.localdomain.31736.0
├── HCv2 2021-11-04 19_33
│ └── events.out.tfevents.1636025607.localhost.localdomain.23247.0
├── HCv4 2023-11-17 22_42
│ └── events.out.tfevents.1700232152.kim.6672.0
├── HCv4 2023-11-18 01_58
│ └── events.out.tfevents.1700243918.kim.7102.0
├── HCv4 2023-11-18 05_14
│ └── events.out.tfevents.1700255698.kim.7476.0
├── HCv4 2023-11-18 08_30
│ └── events.out.tfevents.1700267414.kim.8013.0
├── Humanv2 2021-11-04 19_33
│ └── events.out.tfevents.1636025604.localhost.localdomain.23206.0
├── Humanv4 2023-11-17 20_13
│ └── events.out.tfevents.1700223188.kim.6379.0
├── Humanv4 2023-11-17 23_37
│ └── events.out.tfevents.1700235436.kim.6843.0
├── Humanv4 2023-11-18 02_53
│ └── events.out.tfevents.1700247217.kim.7267.0
├── Humanv4 2023-11-18 06_09
│ └── events.out.tfevents.1700258999.kim.7648.0
├── Humanv4 2023-11-18 09_25
│ └── events.out.tfevents.1700270700.kim.8175.0
├── Lch_Cv2 2021-11-04 19_32
│ └── events.out.tfevents.1636025576.localhost.localdomain.23162.0
├── Lch_Cv2 2021-11-04 19_48
│ └── events.out.tfevents.1636026535.localhost.localdomain.26058.0
└── PV0 2021-11-04 19_33
│ └── events.out.tfevents.1636025600.localhost.localdomain.23193.0
└── utils.py
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2021 XinJingHao
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/Old_Version-PPO-Continuous-with-Gym0.18.3.zip:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/XinJingHao/PPO-Continuous-Pytorch/388113545f1edf04dbd0e0f605c5ceddba3ab131/Old_Version-PPO-Continuous-with-Gym0.18.3.zip
--------------------------------------------------------------------------------
/PPO.py:
--------------------------------------------------------------------------------
1 | from utils import BetaActor, GaussianActor_musigma, GaussianActor_mu, Critic
2 | import numpy as np
3 | import copy
4 | import torch
5 | import math
6 |
7 |
8 | class PPO_agent(object):
9 | def __init__(self, **kwargs):
10 | # Init hyperparameters for PPO agent, just like "self.gamma = opt.gamma, self.lambd = opt.lambd, ..."
11 | self.__dict__.update(kwargs)
12 |
13 | # Choose distribution for the actor
14 | if self.Distribution == 'Beta':
15 | self.actor = BetaActor(self.state_dim, self.action_dim, self.net_width).to(self.dvc)
16 | elif self.Distribution == 'GS_ms':
17 | self.actor = GaussianActor_musigma(self.state_dim, self.action_dim, self.net_width).to(self.dvc)
18 | elif self.Distribution == 'GS_m':
19 | self.actor = GaussianActor_mu(self.state_dim, self.action_dim, self.net_width).to(self.dvc)
20 | else: print('Dist Error')
21 | self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), lr=self.a_lr)
22 |
23 | # Build Critic
24 | self.critic = Critic(self.state_dim, self.net_width).to(self.dvc)
25 | self.critic_optimizer = torch.optim.Adam(self.critic.parameters(), lr=self.c_lr)
26 |
27 | # Build Trajectory holder
28 | self.s_hoder = np.zeros((self.T_horizon, self.state_dim),dtype=np.float32)
29 | self.a_hoder = np.zeros((self.T_horizon, self.action_dim),dtype=np.float32)
30 | self.r_hoder = np.zeros((self.T_horizon, 1),dtype=np.float32)
31 | self.s_next_hoder = np.zeros((self.T_horizon, self.state_dim),dtype=np.float32)
32 | self.logprob_a_hoder = np.zeros((self.T_horizon, self.action_dim),dtype=np.float32)
33 | self.done_hoder = np.zeros((self.T_horizon, 1),dtype=np.bool_)
34 | self.dw_hoder = np.zeros((self.T_horizon, 1),dtype=np.bool_)
35 |
36 | def select_action(self, state, deterministic):
37 | with torch.no_grad():
38 | state = torch.FloatTensor(state.reshape(1, -1)).to(self.dvc)
39 | if deterministic:
40 | # only used when evaluate the policy.Making the performance more stable
41 | a = self.actor.deterministic_act(state)
42 | return a.cpu().numpy()[0], None # action is in shape (adim, 0)
43 | else:
44 | # only used when interact with the env
45 | dist = self.actor.get_dist(state)
46 | a = dist.sample()
47 | a = torch.clamp(a, 0, 1)
48 | logprob_a = dist.log_prob(a).cpu().numpy().flatten()
49 | return a.cpu().numpy()[0], logprob_a # both are in shape (adim, 0)
50 |
51 |
52 | def train(self):
53 | self.entropy_coef*=self.entropy_coef_decay
54 |
55 | '''Prepare PyTorch data from Numpy data'''
56 | s = torch.from_numpy(self.s_hoder).to(self.dvc)
57 | a = torch.from_numpy(self.a_hoder).to(self.dvc)
58 | r = torch.from_numpy(self.r_hoder).to(self.dvc)
59 | s_next = torch.from_numpy(self.s_next_hoder).to(self.dvc)
60 | logprob_a = torch.from_numpy(self.logprob_a_hoder).to(self.dvc)
61 | done = torch.from_numpy(self.done_hoder).to(self.dvc)
62 | dw = torch.from_numpy(self.dw_hoder).to(self.dvc)
63 |
64 | ''' Use TD+GAE+LongTrajectory to compute Advantage and TD target'''
65 | with torch.no_grad():
66 | vs = self.critic(s)
67 | vs_ = self.critic(s_next)
68 |
69 | '''dw for TD_target and Adv'''
70 | deltas = r + self.gamma * vs_ * (~dw) - vs
71 | deltas = deltas.cpu().flatten().numpy()
72 | adv = [0]
73 |
74 | '''done for GAE'''
75 | for dlt, mask in zip(deltas[::-1], done.cpu().flatten().numpy()[::-1]):
76 | advantage = dlt + self.gamma * self.lambd * adv[-1] * (~mask)
77 | adv.append(advantage)
78 | adv.reverse()
79 | adv = copy.deepcopy(adv[0:-1])
80 | adv = torch.tensor(adv).unsqueeze(1).float().to(self.dvc)
81 | td_target = adv + vs
82 | adv = (adv - adv.mean()) / ((adv.std()+1e-4)) #sometimes helps
83 |
84 |
85 | """Slice long trajectopy into short trajectory and perform mini-batch PPO update"""
86 | a_optim_iter_num = int(math.ceil(s.shape[0] / self.a_optim_batch_size))
87 | c_optim_iter_num = int(math.ceil(s.shape[0] / self.c_optim_batch_size))
88 | for i in range(self.K_epochs):
89 |
90 | #Shuffle the trajectory, Good for training
91 | perm = np.arange(s.shape[0])
92 | np.random.shuffle(perm)
93 | perm = torch.LongTensor(perm).to(self.dvc)
94 | s, a, td_target, adv, logprob_a = \
95 | s[perm].clone(), a[perm].clone(), td_target[perm].clone(), adv[perm].clone(), logprob_a[perm].clone()
96 |
97 | '''update the actor'''
98 | for i in range(a_optim_iter_num):
99 | index = slice(i * self.a_optim_batch_size, min((i + 1) * self.a_optim_batch_size, s.shape[0]))
100 | distribution = self.actor.get_dist(s[index])
101 | dist_entropy = distribution.entropy().sum(1, keepdim=True)
102 | logprob_a_now = distribution.log_prob(a[index])
103 | ratio = torch.exp(logprob_a_now.sum(1,keepdim=True) - logprob_a[index].sum(1,keepdim=True)) # a/b == exp(log(a)-log(b))
104 |
105 | surr1 = ratio * adv[index]
106 | surr2 = torch.clamp(ratio, 1 - self.clip_rate, 1 + self.clip_rate) * adv[index]
107 | a_loss = -torch.min(surr1, surr2) - self.entropy_coef * dist_entropy
108 |
109 | self.actor_optimizer.zero_grad()
110 | a_loss.mean().backward()
111 | torch.nn.utils.clip_grad_norm_(self.actor.parameters(), 40)
112 | self.actor_optimizer.step()
113 |
114 | '''update the critic'''
115 | for i in range(c_optim_iter_num):
116 | index = slice(i * self.c_optim_batch_size, min((i + 1) * self.c_optim_batch_size, s.shape[0]))
117 | c_loss = (self.critic(s[index]) - td_target[index]).pow(2).mean()
118 | for name,param in self.critic.named_parameters():
119 | if 'weight' in name:
120 | c_loss += param.pow(2).sum() * self.l2_reg
121 |
122 | self.critic_optimizer.zero_grad()
123 | c_loss.backward()
124 | self.critic_optimizer.step()
125 |
126 | def put_data(self, s, a, r, s_next, logprob_a, done, dw, idx):
127 | self.s_hoder[idx] = s
128 | self.a_hoder[idx] = a
129 | self.r_hoder[idx] = r
130 | self.s_next_hoder[idx] = s_next
131 | self.logprob_a_hoder[idx] = logprob_a
132 | self.done_hoder[idx] = done
133 | self.dw_hoder[idx] = dw
134 |
135 | def save(self,EnvName, timestep):
136 | torch.save(self.actor.state_dict(), "./model/{}_actor{}.pth".format(EnvName,timestep))
137 | torch.save(self.critic.state_dict(), "./model/{}_q_critic{}.pth".format(EnvName,timestep))
138 |
139 | def load(self,EnvName, timestep):
140 | self.actor.load_state_dict(torch.load("./model/{}_actor{}.pth".format(EnvName, timestep), map_location=self.dvc))
141 | self.critic.load_state_dict(torch.load("./model/{}_q_critic{}.pth".format(EnvName, timestep), map_location=self.dvc))
142 |
143 |
144 |
145 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # PPO-Continuous-Pytorch
2 | **A clean and robust Pytorch implementation of PPO on continuous action space**:
3 |
4 |
5 | Pendulum | LunarLanderContinuous
6 | :-----------------------:|:-----------------------:|
7 |
|
8 |
|
9 |
10 |
11 |
12 | **Other RL algorithms by Pytorch can be found [here](https://github.com/XinJingHao/RL-Algorithms-by-Pytorch).**
13 |
14 | ## Dependencies
15 | ```python
16 | gymnasium==0.29.1
17 | numpy==1.26.1
18 | pytorch==2.1.0
19 |
20 | python==3.11.5
21 | ```
22 |
23 | ## How to use my code
24 | ### Train from scratch
25 | ```bash
26 | python main.py
27 | ```
28 | where the default enviroment is 'Pendulum'.
29 |
30 | ### Play with trained model
31 | ```bash
32 | python main.py --EnvIdex 0 --render True --Loadmodel True --ModelIdex 100
33 | ```
34 | which will render the 'Pendulum'.
35 |
36 | ### Change Enviroment
37 | If you want to train on different enviroments, just run
38 | ```bash
39 | python main.py --EnvIdex 1
40 | ```
41 | The ```--EnvIdex``` can be set to be 0~5, where
42 | ```bash
43 | '--EnvIdex 0' for 'Pendulum-v1'
44 | '--EnvIdex 1' for 'LunarLanderContinuous-v2'
45 | '--EnvIdex 2' for 'Humanoid-v4'
46 | '--EnvIdex 3' for 'HalfCheetah-v4'
47 | '--EnvIdex 4' for 'BipedalWalker-v3'
48 | '--EnvIdex 5' for 'BipedalWalkerHardcore-v3'
49 | ```
50 |
51 | Note: if you want train on **BipedalWalker, BipedalWalkerHardcore, or LunarLanderContinuous**, you need to install [box2d-py](https://gymnasium.farama.org/environments/box2d/) first. You can install box2d-py via:
52 | ```bash
53 | pip install gymnasium[box2d]
54 | ```
55 |
56 | if you want train on **Humanoid or HalfCheetah**, you need to install [MuJoCo](https://gymnasium.farama.org/environments/mujoco/) first. You can install MuJoCo via:
57 | ```bash
58 | pip install mujoco
59 | pip install gymnasium[mujoco]
60 | ```
61 |
62 | ### Visualize the training curve
63 | You can use the [tensorboard](https://pytorch.org/docs/stable/tensorboard.html) to record anv visualize the training curve.
64 |
65 | - Installation (please make sure PyTorch is installed already):
66 | ```bash
67 | pip install tensorboard
68 | pip install packaging
69 | ```
70 | - Record (the training curves will be saved at '**\runs**'):
71 | ```bash
72 | python main.py --write True
73 | ```
74 |
75 | - Visualization:
76 | ```bash
77 | tensorboard --logdir runs
78 | ```
79 |
80 | ### Hyperparameter Setting
81 | For more details of Hyperparameter Setting, please check 'main.py'
82 |
83 | ### References
84 | [Proximal Policy Optimization Algorithms](https://arxiv.org/pdf/1707.06347.pdf)
85 | [Emergence of Locomotion Behaviours in Rich Environments](https://arxiv.org/pdf/1707.02286.pdf)
86 |
87 | ## Training Curves
88 | 
89 | All the experiments are trained with same hyperparameters (see main.py).
90 |
91 |
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | from datetime import datetime
2 | import os, shutil
3 | import argparse
4 | import torch
5 | import gymnasium as gym
6 |
7 | from utils import str2bool, Action_adapter, Reward_adapter, evaluate_policy
8 | from PPO import PPO_agent
9 |
10 |
11 | '''Hyperparameter Setting'''
12 | parser = argparse.ArgumentParser()
13 | parser.add_argument('--dvc', type=str, default='cuda', help='running device: cuda or cpu')
14 | parser.add_argument('--EnvIdex', type=int, default=0, help='PV1, Lch_Cv2, Humanv4, HCv4, BWv3, BWHv3')
15 | parser.add_argument('--write', type=str2bool, default=False, help='Use SummaryWriter to record the training')
16 | parser.add_argument('--render', type=str2bool, default=False, help='Render or Not')
17 | parser.add_argument('--Loadmodel', type=str2bool, default=False, help='Load pretrained model or Not')
18 | parser.add_argument('--ModelIdex', type=int, default=100, help='which model to load')
19 |
20 | parser.add_argument('--seed', type=int, default=0, help='random seed')
21 | parser.add_argument('--T_horizon', type=int, default=2048, help='lenth of long trajectory')
22 | parser.add_argument('--Distribution', type=str, default='Beta', help='Should be one of Beta ; GS_ms ; GS_m')
23 | parser.add_argument('--Max_train_steps', type=int, default=int(5e7), help='Max training steps')
24 | parser.add_argument('--save_interval', type=int, default=int(5e5), help='Model saving interval, in steps.')
25 | parser.add_argument('--eval_interval', type=int, default=int(5e3), help='Model evaluating interval, in steps.')
26 |
27 | parser.add_argument('--gamma', type=float, default=0.99, help='Discounted Factor')
28 | parser.add_argument('--lambd', type=float, default=0.95, help='GAE Factor')
29 | parser.add_argument('--clip_rate', type=float, default=0.2, help='PPO Clip rate')
30 | parser.add_argument('--K_epochs', type=int, default=10, help='PPO update times')
31 | parser.add_argument('--net_width', type=int, default=150, help='Hidden net width')
32 | parser.add_argument('--a_lr', type=float, default=2e-4, help='Learning rate of actor')
33 | parser.add_argument('--c_lr', type=float, default=2e-4, help='Learning rate of critic')
34 | parser.add_argument('--l2_reg', type=float, default=1e-3, help='L2 regulization coefficient for Critic')
35 | parser.add_argument('--a_optim_batch_size', type=int, default=64, help='lenth of sliced trajectory of actor')
36 | parser.add_argument('--c_optim_batch_size', type=int, default=64, help='lenth of sliced trajectory of critic')
37 | parser.add_argument('--entropy_coef', type=float, default=1e-3, help='Entropy coefficient of Actor')
38 | parser.add_argument('--entropy_coef_decay', type=float, default=0.99, help='Decay rate of entropy_coef')
39 | opt = parser.parse_args()
40 | opt.dvc = torch.device(opt.dvc) # from str to torch.device
41 | print(opt)
42 |
43 |
44 | def main():
45 | EnvName = ['Pendulum-v1','LunarLanderContinuous-v2','Humanoid-v4','HalfCheetah-v4','BipedalWalker-v3','BipedalWalkerHardcore-v3']
46 | BrifEnvName = ['PV1', 'LLdV2', 'Humanv4', 'HCv4','BWv3', 'BWHv3']
47 |
48 | # Build Env
49 | env = gym.make(EnvName[opt.EnvIdex], render_mode = "human" if opt.render else None)
50 | eval_env = gym.make(EnvName[opt.EnvIdex])
51 | opt.state_dim = env.observation_space.shape[0]
52 | opt.action_dim = env.action_space.shape[0]
53 | opt.max_action = float(env.action_space.high[0])
54 | opt.max_steps = env._max_episode_steps
55 | print('Env:',EnvName[opt.EnvIdex],' state_dim:',opt.state_dim,' action_dim:',opt.action_dim,
56 | ' max_a:',opt.max_action,' min_a:',env.action_space.low[0], 'max_steps', opt.max_steps)
57 |
58 | # Seed Everything
59 | env_seed = opt.seed
60 | torch.manual_seed(opt.seed)
61 | torch.cuda.manual_seed(opt.seed)
62 | torch.backends.cudnn.deterministic = True
63 | torch.backends.cudnn.benchmark = False
64 | print("Random Seed: {}".format(opt.seed))
65 |
66 | # Use tensorboard to record training curves
67 | if opt.write:
68 | from torch.utils.tensorboard import SummaryWriter
69 | timenow = str(datetime.now())[0:-10]
70 | timenow = ' ' + timenow[0:13] + '_' + timenow[-2::]
71 | writepath = 'runs/{}'.format(BrifEnvName[opt.EnvIdex]) + timenow
72 | if os.path.exists(writepath): shutil.rmtree(writepath)
73 | writer = SummaryWriter(log_dir=writepath)
74 |
75 | # Beta dist maybe need larger learning rate, Sometimes helps
76 | # if Dist[distnum] == 'Beta' :
77 | # kwargs["a_lr"] *= 2
78 | # kwargs["c_lr"] *= 4
79 |
80 | if not os.path.exists('model'): os.mkdir('model')
81 | agent = PPO_agent(**vars(opt)) # transfer opt to dictionary, and use it to init PPO_agent
82 | if opt.Loadmodel: agent.load(BrifEnvName[opt.EnvIdex], opt.ModelIdex)
83 |
84 | if opt.render:
85 | while True:
86 | ep_r = evaluate_policy(env, agent, opt.max_action, 1)
87 | print(f'Env:{EnvName[opt.EnvIdex]}, Episode Reward:{ep_r}')
88 | else:
89 | traj_lenth, total_steps = 0, 0
90 | while total_steps < opt.Max_train_steps:
91 | s, info = env.reset(seed=env_seed) # Do not use opt.seed directly, or it can overfit to opt.seed
92 | env_seed += 1
93 | done = False
94 |
95 | '''Interact & trian'''
96 | while not done:
97 | '''Interact with Env'''
98 | a, logprob_a = agent.select_action(s, deterministic=False) # use stochastic when training
99 | act = Action_adapter(a,opt.max_action) #[0,1] to [-max,max]
100 | s_next, r, dw, tr, info = env.step(act) # dw: dead&win; tr: truncated
101 | r = Reward_adapter(r, opt.EnvIdex)
102 | done = (dw or tr)
103 |
104 | '''Store the current transition'''
105 | agent.put_data(s, a, r, s_next, logprob_a, done, dw, idx = traj_lenth)
106 | s = s_next
107 |
108 | traj_lenth += 1
109 | total_steps += 1
110 |
111 | '''Update if its time'''
112 | if traj_lenth % opt.T_horizon == 0:
113 | agent.train()
114 | traj_lenth = 0
115 |
116 | '''Record & log'''
117 | if total_steps % opt.eval_interval == 0:
118 | score = evaluate_policy(eval_env, agent, opt.max_action, turns=3) # evaluate the policy for 3 times, and get averaged result
119 | if opt.write: writer.add_scalar('ep_r', score, global_step=total_steps)
120 | print('EnvName:',EnvName[opt.EnvIdex],'seed:',opt.seed,'steps: {}k'.format(int(total_steps/1000)),'score:', score)
121 |
122 | '''Save model'''
123 | if total_steps % opt.save_interval==0:
124 | agent.save(BrifEnvName[opt.EnvIdex], int(total_steps/1000))
125 |
126 | env.close()
127 | eval_env.close()
128 |
129 | if __name__ == '__main__':
130 | main()
131 |
132 |
133 |
134 |
135 |
136 |
--------------------------------------------------------------------------------
/model/PV1_actor100.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/XinJingHao/PPO-Continuous-Pytorch/388113545f1edf04dbd0e0f605c5ceddba3ab131/model/PV1_actor100.pth
--------------------------------------------------------------------------------
/model/PV1_q_critic100.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/XinJingHao/PPO-Continuous-Pytorch/388113545f1edf04dbd0e0f605c5ceddba3ab131/model/PV1_q_critic100.pth
--------------------------------------------------------------------------------
/ppo_result.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/XinJingHao/PPO-Continuous-Pytorch/388113545f1edf04dbd0e0f605c5ceddba3ab131/ppo_result.jpg
--------------------------------------------------------------------------------
/render_gif/PV1.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/XinJingHao/PPO-Continuous-Pytorch/388113545f1edf04dbd0e0f605c5ceddba3ab131/render_gif/PV1.gif
--------------------------------------------------------------------------------
/render_gif/lldcV2.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/XinJingHao/PPO-Continuous-Pytorch/388113545f1edf04dbd0e0f605c5ceddba3ab131/render_gif/lldcV2.gif
--------------------------------------------------------------------------------
/render_gif/lldc_ppoc.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/XinJingHao/PPO-Continuous-Pytorch/388113545f1edf04dbd0e0f605c5ceddba3ab131/render_gif/lldc_ppoc.png
--------------------------------------------------------------------------------
/render_gif/pv1_ppoc.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/XinJingHao/PPO-Continuous-Pytorch/388113545f1edf04dbd0e0f605c5ceddba3ab131/render_gif/pv1_ppoc.png
--------------------------------------------------------------------------------
/runs/BWv3 2021-11-04 20_14/events.out.tfevents.1636028042.localhost.localdomain.31736.0:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/XinJingHao/PPO-Continuous-Pytorch/388113545f1edf04dbd0e0f605c5ceddba3ab131/runs/BWv3 2021-11-04 20_14/events.out.tfevents.1636028042.localhost.localdomain.31736.0
--------------------------------------------------------------------------------
/runs/HCv2 2021-11-04 19_33/events.out.tfevents.1636025607.localhost.localdomain.23247.0:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/XinJingHao/PPO-Continuous-Pytorch/388113545f1edf04dbd0e0f605c5ceddba3ab131/runs/HCv2 2021-11-04 19_33/events.out.tfevents.1636025607.localhost.localdomain.23247.0
--------------------------------------------------------------------------------
/runs/HCv4 2023-11-17 22_42/events.out.tfevents.1700232152.kim.6672.0:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/XinJingHao/PPO-Continuous-Pytorch/388113545f1edf04dbd0e0f605c5ceddba3ab131/runs/HCv4 2023-11-17 22_42/events.out.tfevents.1700232152.kim.6672.0
--------------------------------------------------------------------------------
/runs/HCv4 2023-11-18 01_58/events.out.tfevents.1700243918.kim.7102.0:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/XinJingHao/PPO-Continuous-Pytorch/388113545f1edf04dbd0e0f605c5ceddba3ab131/runs/HCv4 2023-11-18 01_58/events.out.tfevents.1700243918.kim.7102.0
--------------------------------------------------------------------------------
/runs/HCv4 2023-11-18 05_14/events.out.tfevents.1700255698.kim.7476.0:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/XinJingHao/PPO-Continuous-Pytorch/388113545f1edf04dbd0e0f605c5ceddba3ab131/runs/HCv4 2023-11-18 05_14/events.out.tfevents.1700255698.kim.7476.0
--------------------------------------------------------------------------------
/runs/HCv4 2023-11-18 08_30/events.out.tfevents.1700267414.kim.8013.0:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/XinJingHao/PPO-Continuous-Pytorch/388113545f1edf04dbd0e0f605c5ceddba3ab131/runs/HCv4 2023-11-18 08_30/events.out.tfevents.1700267414.kim.8013.0
--------------------------------------------------------------------------------
/runs/Humanv2 2021-11-04 19_33/events.out.tfevents.1636025604.localhost.localdomain.23206.0:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/XinJingHao/PPO-Continuous-Pytorch/388113545f1edf04dbd0e0f605c5ceddba3ab131/runs/Humanv2 2021-11-04 19_33/events.out.tfevents.1636025604.localhost.localdomain.23206.0
--------------------------------------------------------------------------------
/runs/Humanv4 2023-11-17 20_13/events.out.tfevents.1700223188.kim.6379.0:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/XinJingHao/PPO-Continuous-Pytorch/388113545f1edf04dbd0e0f605c5ceddba3ab131/runs/Humanv4 2023-11-17 20_13/events.out.tfevents.1700223188.kim.6379.0
--------------------------------------------------------------------------------
/runs/Humanv4 2023-11-17 23_37/events.out.tfevents.1700235436.kim.6843.0:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/XinJingHao/PPO-Continuous-Pytorch/388113545f1edf04dbd0e0f605c5ceddba3ab131/runs/Humanv4 2023-11-17 23_37/events.out.tfevents.1700235436.kim.6843.0
--------------------------------------------------------------------------------
/runs/Humanv4 2023-11-18 02_53/events.out.tfevents.1700247217.kim.7267.0:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/XinJingHao/PPO-Continuous-Pytorch/388113545f1edf04dbd0e0f605c5ceddba3ab131/runs/Humanv4 2023-11-18 02_53/events.out.tfevents.1700247217.kim.7267.0
--------------------------------------------------------------------------------
/runs/Humanv4 2023-11-18 06_09/events.out.tfevents.1700258999.kim.7648.0:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/XinJingHao/PPO-Continuous-Pytorch/388113545f1edf04dbd0e0f605c5ceddba3ab131/runs/Humanv4 2023-11-18 06_09/events.out.tfevents.1700258999.kim.7648.0
--------------------------------------------------------------------------------
/runs/Humanv4 2023-11-18 09_25/events.out.tfevents.1700270700.kim.8175.0:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/XinJingHao/PPO-Continuous-Pytorch/388113545f1edf04dbd0e0f605c5ceddba3ab131/runs/Humanv4 2023-11-18 09_25/events.out.tfevents.1700270700.kim.8175.0
--------------------------------------------------------------------------------
/runs/Lch_Cv2 2021-11-04 19_32/events.out.tfevents.1636025576.localhost.localdomain.23162.0:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/XinJingHao/PPO-Continuous-Pytorch/388113545f1edf04dbd0e0f605c5ceddba3ab131/runs/Lch_Cv2 2021-11-04 19_32/events.out.tfevents.1636025576.localhost.localdomain.23162.0
--------------------------------------------------------------------------------
/runs/Lch_Cv2 2021-11-04 19_48/events.out.tfevents.1636026535.localhost.localdomain.26058.0:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/XinJingHao/PPO-Continuous-Pytorch/388113545f1edf04dbd0e0f605c5ceddba3ab131/runs/Lch_Cv2 2021-11-04 19_48/events.out.tfevents.1636026535.localhost.localdomain.26058.0
--------------------------------------------------------------------------------
/runs/PV0 2021-11-04 19_33/events.out.tfevents.1636025600.localhost.localdomain.23193.0:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/XinJingHao/PPO-Continuous-Pytorch/388113545f1edf04dbd0e0f605c5ceddba3ab131/runs/PV0 2021-11-04 19_33/events.out.tfevents.1636025600.localhost.localdomain.23193.0
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from torch.distributions import Beta,Normal
5 |
6 |
7 | class BetaActor(nn.Module):
8 | def __init__(self, state_dim, action_dim, net_width):
9 | super(BetaActor, self).__init__()
10 |
11 | self.l1 = nn.Linear(state_dim, net_width)
12 | self.l2 = nn.Linear(net_width, net_width)
13 | self.alpha_head = nn.Linear(net_width, action_dim)
14 | self.beta_head = nn.Linear(net_width, action_dim)
15 |
16 | def forward(self, state):
17 | a = torch.tanh(self.l1(state))
18 | a = torch.tanh(self.l2(a))
19 |
20 | alpha = F.softplus(self.alpha_head(a)) + 1.0
21 | beta = F.softplus(self.beta_head(a)) + 1.0
22 |
23 | return alpha,beta
24 |
25 | def get_dist(self,state):
26 | alpha,beta = self.forward(state)
27 | dist = Beta(alpha, beta)
28 | return dist
29 |
30 | def deterministic_act(self, state):
31 | alpha, beta = self.forward(state)
32 | mode = (alpha) / (alpha + beta)
33 | return mode
34 |
35 | class GaussianActor_musigma(nn.Module):
36 | def __init__(self, state_dim, action_dim, net_width):
37 | super(GaussianActor_musigma, self).__init__()
38 |
39 | self.l1 = nn.Linear(state_dim, net_width)
40 | self.l2 = nn.Linear(net_width, net_width)
41 | self.mu_head = nn.Linear(net_width, action_dim)
42 | self.sigma_head = nn.Linear(net_width, action_dim)
43 |
44 | def forward(self, state):
45 | a = torch.tanh(self.l1(state))
46 | a = torch.tanh(self.l2(a))
47 | mu = torch.sigmoid(self.mu_head(a))
48 | sigma = F.softplus( self.sigma_head(a) )
49 | return mu,sigma
50 |
51 | def get_dist(self, state):
52 | mu,sigma = self.forward(state)
53 | dist = Normal(mu,sigma)
54 | return dist
55 |
56 | def deterministic_act(self, state):
57 | mu, sigma = self.forward(state)
58 | return mu
59 |
60 |
61 | class GaussianActor_mu(nn.Module):
62 | def __init__(self, state_dim, action_dim, net_width, log_std=0):
63 | super(GaussianActor_mu, self).__init__()
64 |
65 | self.l1 = nn.Linear(state_dim, net_width)
66 | self.l2 = nn.Linear(net_width, net_width)
67 | self.mu_head = nn.Linear(net_width, action_dim)
68 | self.mu_head.weight.data.mul_(0.1)
69 | self.mu_head.bias.data.mul_(0.0)
70 |
71 | self.action_log_std = nn.Parameter(torch.ones(1, action_dim) * log_std)
72 |
73 | def forward(self, state):
74 | a = torch.relu(self.l1(state))
75 | a = torch.relu(self.l2(a))
76 | mu = torch.sigmoid(self.mu_head(a))
77 | return mu
78 |
79 | def get_dist(self,state):
80 | mu = self.forward(state)
81 | action_log_std = self.action_log_std.expand_as(mu)
82 | action_std = torch.exp(action_log_std)
83 |
84 | dist = Normal(mu, action_std)
85 | return dist
86 |
87 | def deterministic_act(self, state):
88 | return self.forward(state)
89 |
90 |
91 | class Critic(nn.Module):
92 | def __init__(self, state_dim,net_width):
93 | super(Critic, self).__init__()
94 |
95 | self.C1 = nn.Linear(state_dim, net_width)
96 | self.C2 = nn.Linear(net_width, net_width)
97 | self.C3 = nn.Linear(net_width, 1)
98 |
99 | def forward(self, state):
100 | v = torch.tanh(self.C1(state))
101 | v = torch.tanh(self.C2(v))
102 | v = self.C3(v)
103 | return v
104 |
105 | def str2bool(v):
106 | '''transfer str to bool for argparse'''
107 | if isinstance(v, bool):
108 | return v
109 | if v.lower() in ('yes', 'True','true','TRUE', 't', 'y', '1'):
110 | return True
111 | elif v.lower() in ('no', 'False','false','FALSE', 'f', 'n', '0'):
112 | return False
113 | else:
114 | print('Wrong Input.')
115 | raise
116 |
117 |
118 | def Action_adapter(a,max_action):
119 | #from [0,1] to [-max,max]
120 | return 2*(a-0.5)*max_action
121 |
122 | def Reward_adapter(r, EnvIdex):
123 | # For BipedalWalker
124 | if EnvIdex == 0 or EnvIdex == 1:
125 | if r <= -100: r = -1
126 | # For Pendulum-v0
127 | elif EnvIdex == 3:
128 | r = (r + 8) / 8
129 | return r
130 |
131 | def evaluate_policy(env, agent, max_action, turns):
132 | total_scores = 0
133 | for j in range(turns):
134 | s, info = env.reset()
135 | done = False
136 | while not done:
137 | a, logprob_a = agent.select_action(s, deterministic=True) # Take deterministic actions when evaluation
138 | act = Action_adapter(a, max_action) # [0,1] to [-max,max]
139 | s_next, r, dw, tr, info = env.step(act)
140 | done = (dw or tr)
141 |
142 | total_scores += r
143 | s = s_next
144 |
145 | return total_scores/turns
--------------------------------------------------------------------------------