├── PPO_continuous.py ├── PPO_discrete.py └── README.md /PPO_continuous.py: -------------------------------------------------------------------------------- 1 | # This is a PPO algorithm for multi-dimension continuous action 2 | 3 | import torch 4 | import torch.nn as nn 5 | from torch.distributions import MultivariateNormal 6 | import gym 7 | import numpy as np 8 | import os 9 | import argparse 10 | from tensorboardX import SummaryWriter 11 | 12 | 13 | parser = argparse.ArgumentParser(description='PyTorch PPO for continuous controlling') 14 | parser.add_argument('--gpus', default=1, type=int, help='number of gpu') 15 | parser.add_argument('--env', type=str, default='BipedalWalker-v2', help='continuous env') 16 | parser.add_argument('--render', default=False, action='store_true', help='Render?') 17 | parser.add_argument('--solved_reward', type=float, default=300, help='stop training if avg_reward > solved_reward') 18 | parser.add_argument('--print_interval', type=int, default=10, help='how many episodes to print the results out') 19 | parser.add_argument('--save_interval', type=int, default=100, help='how many episodes to save a checkpoint') 20 | parser.add_argument('--max_episodes', type=int, default=100000) 21 | parser.add_argument('--max_timesteps', type=int, default=1500) 22 | parser.add_argument('--update_timesteps', type=int, default=4000, help='how many timesteps to update the policy') 23 | parser.add_argument('--action_std', type=float, default=0.5, help='constant std for action distribution (Multivariate Normal)') 24 | parser.add_argument('--K_epochs', type=int, default=80, help='update the policy for how long time everytime') 25 | parser.add_argument('--eps_clip', type=float, default=0.2, help='epsilon for p/q clipped') 26 | parser.add_argument('--gamma', type=float, default=0.99, help='discount factor') 27 | parser.add_argument('--lr', type=float, default=0.0003) 28 | parser.add_argument('--seed', type=int, default=123, help='random seed to use') 29 | parser.add_argument('--ckpt_folder', default='./checkpoints', help='Location to save checkpoint models') 30 | parser.add_argument('--tb', default=False, action='store_true', help='Use tensorboardX?') 31 | parser.add_argument('--log_folder', default='./logs', help='Location to save logs') 32 | parser.add_argument('--mode', default='train', help='choose train or test') 33 | parser.add_argument('--restore', default=False, action='store_true', help='Restore and go on training?') 34 | opt = parser.parse_args() 35 | 36 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 37 | 38 | 39 | class Memory: # collected from old policy 40 | def __init__(self): 41 | self.states = [] 42 | self.actions = [] 43 | self.rewards = [] 44 | self.is_terminals = [] 45 | self.logprobs = [] 46 | 47 | def clear_memory(self): 48 | del self.states[:] 49 | del self.actions[:] 50 | del self.rewards[:] 51 | del self.is_terminals[:] 52 | del self.logprobs[:] 53 | 54 | 55 | class ActorCritic(nn.Module): 56 | def __init__(self, state_dim, action_dim, action_std): 57 | super(ActorCritic, self).__init__() 58 | 59 | self.actor = nn.Sequential( 60 | nn.Linear(state_dim, 64), 61 | nn.Tanh(), 62 | nn.Linear(64, 32), 63 | nn.Tanh(), 64 | nn.Linear(32, action_dim), 65 | nn.Tanh() 66 | ) 67 | 68 | self.critic = nn.Sequential( 69 | nn.Linear(state_dim, 64), 70 | nn.Tanh(), 71 | nn.Linear(64, 32), 72 | nn.Tanh(), 73 | nn.Linear(32, 1) 74 | ) 75 | 76 | self.action_var = torch.full((action_dim, ), action_std * action_std).to(device) #(4, ) 77 | 78 | def act(self, state, memory): # state (1,24) 79 | action_mean = self.actor(state) # (1,4) 80 | cov_mat = torch.diag(self.action_var).to(device) # (4,4) 81 | dist = MultivariateNormal(action_mean, cov_mat) 82 | action = dist.sample() # (1,4) 83 | action_logprob = dist.log_prob(action) 84 | 85 | memory.states.append(state) 86 | memory.actions.append(action) 87 | memory.logprobs.append(action_logprob) 88 | 89 | return action.detach() 90 | 91 | def evaluate(self, state, action): # state (4000, 24); action (4000, 4) 92 | state_value = self.critic(state) # (4000, 1) 93 | 94 | # to calculate action score(logprobs) and distribution entropy 95 | action_mean = self.actor(state) # (4000,4) 96 | action_var = self.action_var.expand_as(action_mean) # (4000,4) 97 | cov_mat = torch.diag_embed(action_var).to(device) # (4000,4,4) 98 | dist = MultivariateNormal(action_mean, cov_mat) 99 | action_logprobs = dist.log_prob(action) 100 | dist_entropy = dist.entropy() 101 | 102 | return action_logprobs, torch.squeeze(state_value), dist_entropy 103 | 104 | 105 | class PPO: 106 | def __init__(self, state_dim, action_dim, action_std, lr, betas, gamma, K_epochs, eps_clip, restore=False, ckpt=None): 107 | self.lr = lr 108 | self.betas = betas 109 | self.gamma = gamma 110 | self.eps_clip = eps_clip 111 | self.K_epochs = K_epochs 112 | 113 | # current policy 114 | self.policy = ActorCritic(state_dim, action_dim, action_std).to(device) 115 | if restore: 116 | pretained_model = torch.load(ckpt, map_location=lambda storage, loc: storage) 117 | self.policy.load_state_dict(pretained_model) 118 | self.optimizer = torch.optim.Adam(self.policy.parameters(), lr=lr, betas=betas) 119 | 120 | # old policy: initialize old policy with current policy's parameter 121 | self.old_policy = ActorCritic(state_dim, action_dim, action_std).to(device) 122 | self.old_policy.load_state_dict(self.policy.state_dict()) 123 | 124 | self.MSE_loss = nn.MSELoss() 125 | 126 | def select_action(self, state, memory): 127 | state = torch.FloatTensor(state.reshape(1, -1)).to(device) # flatten the state 128 | return self.old_policy.act(state, memory).cpu().numpy().flatten() 129 | 130 | def update(self, memory): 131 | # Monte Carlo estimation of rewards 132 | rewards = [] 133 | discounted_reward = 0 134 | for reward, is_terminal in zip(reversed(memory.rewards), reversed(memory.is_terminals)): 135 | if is_terminal: 136 | discounted_reward = 0 137 | discounted_reward = reward + self.gamma * discounted_reward 138 | rewards.insert(0, discounted_reward) 139 | 140 | # Normalize rewards 141 | rewards = torch.tensor(rewards).to(device) 142 | rewards = (rewards - rewards.mean()) / (rewards.std() + 1e-5) 143 | 144 | # convert list to tensor 145 | old_states = torch.squeeze(torch.stack(memory.states).to(device)).detach() 146 | old_actions = torch.squeeze(torch.stack(memory.actions).to(device)).detach() 147 | old_logprobs = torch.squeeze(torch.stack(memory.logprobs)).to(device).detach() 148 | 149 | # Train policy for K epochs: sampling and updating 150 | for _ in range(self.K_epochs): 151 | # Evaluate old actions and values using current policy 152 | logprobs, state_values, dist_entropy = self.policy.evaluate(old_states, old_actions) 153 | 154 | # Importance ratio: p/q 155 | ratios = torch.exp(logprobs - old_logprobs.detach()) 156 | 157 | # Advantages 158 | advantages = rewards - state_values.detach() 159 | 160 | # Actor loss using Surrogate loss 161 | surr1 = ratios * advantages 162 | surr2 = torch.clamp(ratios, 1 - self.eps_clip, 1 + self.eps_clip) * advantages 163 | actor_loss = - torch.min(surr1, surr2) 164 | 165 | # Critic loss: critic loss - entropy 166 | critic_loss = 0.5 * self.MSE_loss(rewards, state_values) - 0.01 * dist_entropy 167 | 168 | # Total loss 169 | loss = actor_loss + critic_loss 170 | 171 | # Backward gradients 172 | self.optimizer.zero_grad() 173 | loss.mean().backward() 174 | self.optimizer.step() 175 | 176 | # Copy new weights to old_policy 177 | self.old_policy.load_state_dict(self.policy.state_dict()) 178 | 179 | 180 | def train(env_name, env, state_dim, action_dim, render, solved_reward, 181 | max_episodes, max_timesteps, update_timestep, action_std, K_epochs, eps_clip, 182 | gamma, lr, betas, ckpt_folder, restore, tb=False, print_interval=10, save_interval=100): 183 | 184 | ckpt = ckpt_folder+'/PPO_continuous_'+env_name+'.pth' 185 | if restore: 186 | print('Load checkpoint from {}'.format(ckpt)) 187 | 188 | memory = Memory() 189 | 190 | ppo = PPO(state_dim, action_dim, action_std, lr, betas, gamma, K_epochs, eps_clip, restore=restore, ckpt=ckpt) 191 | 192 | running_reward, avg_length, time_step = 0, 0, 0 193 | 194 | # training loop 195 | for i_episode in range(1, max_episodes+1): 196 | state = env.reset() 197 | for t in range(max_timesteps): 198 | time_step += 1 199 | 200 | # Run old policy 201 | action = ppo.select_action(state, memory) 202 | 203 | state, reward, done, _ = env.step(action) 204 | 205 | memory.rewards.append(reward) 206 | memory.is_terminals.append(done) 207 | 208 | if time_step % update_timestep == 0: 209 | ppo.update(memory) 210 | memory.clear_memory() 211 | time_step = 0 212 | 213 | running_reward += reward 214 | if render: 215 | env.render() 216 | 217 | if done: 218 | break 219 | avg_length += t 220 | 221 | if running_reward > (print_interval * solved_reward): 222 | print("########## Solved! ##########") 223 | torch.save(ppo.policy.state_dict(), ckpt_folder + '/PPO_continuous_{}.pth'.format(env_name)) 224 | print('Save a checkpoint!') 225 | break 226 | 227 | if i_episode % save_interval == 0: 228 | torch.save(ppo.policy.state_dict(), ckpt_folder + '/PPO_continuous_{}.pth'.format(env_name)) 229 | print('Save a checkpoint!') 230 | 231 | if i_episode % print_interval == 0: 232 | avg_length = int(avg_length / print_interval) 233 | running_reward = int((running_reward / print_interval)) 234 | 235 | print('Episode {} \t Avg length: {} \t Avg reward: {}'.format(i_episode, avg_length, running_reward)) 236 | 237 | if tb: 238 | writer.add_scalar('scalar/reward', running_reward, i_episode) 239 | writer.add_scalar('scalar/length', avg_length, i_episode) 240 | 241 | running_reward, avg_length = 0, 0 242 | 243 | def test(env_name, env, state_dim, action_dim, render, action_std, K_epochs, eps_clip, gamma, lr, betas, ckpt_folder, test_episodes): 244 | 245 | ckpt = ckpt_folder+'/PPO_continuous_'+env_name+'.pth' 246 | print('Load checkpoint from {}'.format(ckpt)) 247 | 248 | memory = Memory() 249 | 250 | ppo = PPO(state_dim, action_dim, action_std, lr, betas, gamma, K_epochs, eps_clip, restore=True, ckpt=ckpt) 251 | 252 | episode_reward, time_step = 0, 0 253 | avg_episode_reward, avg_length = 0, 0 254 | 255 | # test 256 | for i_episode in range(1, test_episodes+1): 257 | state = env.reset() 258 | while True: 259 | time_step += 1 260 | 261 | # Run old policy 262 | action = ppo.select_action(state, memory) 263 | 264 | state, reward, done, _ = env.step(action) 265 | 266 | episode_reward += reward 267 | 268 | if render: 269 | env.render() 270 | 271 | if done: 272 | print('Episode {} \t Length: {} \t Reward: {}'.format(i_episode, time_step, episode_reward)) 273 | avg_episode_reward += episode_reward 274 | avg_length += time_step 275 | memory.clear_memory() 276 | time_step, episode_reward = 0, 0 277 | break 278 | 279 | print('Test {} episodes DONE!'.format(test_episodes)) 280 | print('Avg episode reward: {} | Avg length: {}'.format(avg_episode_reward/test_episodes, avg_length/test_episodes)) 281 | 282 | 283 | if __name__ == '__main__': 284 | if opt.tb: 285 | writer = SummaryWriter() 286 | 287 | if not os.path.exists(opt.ckpt_folder): 288 | os.mkdir(opt.ckpt_folder) 289 | 290 | print("Random Seed: {}".format(opt.seed)) 291 | torch.manual_seed(opt.seed) 292 | np.random.seed(opt.seed) 293 | 294 | env_name = opt.env 295 | env = gym.make(env_name) 296 | env.seed(opt.seed) 297 | state_dim = env.observation_space.shape[0] 298 | action_dim = env.action_space.shape[0] 299 | print('Environment: {}\nState Size: {}\nAction Size: {}\n'.format(env_name, state_dim, action_dim)) 300 | 301 | if opt.mode == 'train': 302 | train(env_name, env, state_dim, action_dim, 303 | render=opt.render, solved_reward=opt.solved_reward, 304 | max_episodes=opt.max_episodes, max_timesteps=opt.max_timesteps, update_timestep=opt.update_timesteps, 305 | action_std=opt.action_std, K_epochs=opt.K_epochs, eps_clip=opt.eps_clip, 306 | gamma=opt.gamma, lr=opt.lr, betas=[0.9, 0.990], ckpt_folder=opt.ckpt_folder, 307 | restore=opt.restore, tb=opt.tb, print_interval=opt.print_interval, save_interval=opt.save_interval) 308 | elif opt.mode == 'test': 309 | test(env_name, env, state_dim, action_dim, 310 | render=opt.render, action_std=opt.action_std, K_epochs=opt.K_epochs, eps_clip=opt.eps_clip, 311 | gamma=opt.gamma, lr=opt.lr, betas=[0.9, 0.990], ckpt_folder=opt.ckpt_folder, test_episodes=100) 312 | else: 313 | raise Exception("Wrong Mode!") 314 | 315 | if opt.tb: 316 | writer.close() 317 | -------------------------------------------------------------------------------- /PPO_discrete.py: -------------------------------------------------------------------------------- 1 | # This is a PPO algorithm for multi-dimension continuous action 2 | 3 | import torch 4 | import torch.nn as nn 5 | from torch.distributions import Categorical 6 | import gym 7 | import numpy as np 8 | import os 9 | import argparse 10 | from tensorboardX import SummaryWriter 11 | 12 | 13 | parser = argparse.ArgumentParser(description='PyTorch PPO for continuous controlling') 14 | parser.add_argument('--gpus', default=1, type=int, help='number of gpu') 15 | parser.add_argument('--env', type=str, default='LunarLander-v2', help='continuous env') 16 | parser.add_argument('--render', default=False, action='store_true', help='Render?') 17 | parser.add_argument('--solved_reward', type=float, default=200, help='stop training if avg_reward > solved_reward') 18 | parser.add_argument('--print_interval', type=int, default=10, help='how many episodes to print the results out') 19 | parser.add_argument('--save_interval', type=int, default=100, help='how many episodes to save a checkpoint') 20 | parser.add_argument('--max_episodes', type=int, default=100000) 21 | parser.add_argument('--max_timesteps', type=int, default=300, help='maxium timesteps in one episode') 22 | parser.add_argument('--update_timesteps', type=int, default=2000, help='how many timesteps to update the policy') 23 | parser.add_argument('--K_epochs', type=int, default=4, help='update the policy for how long time everytime') 24 | parser.add_argument('--eps_clip', type=float, default=0.2, help='epsilon for p/q clipped') 25 | parser.add_argument('--gamma', type=float, default=0.99, help='discount factor') 26 | parser.add_argument('--lr', type=float, default=0.002) 27 | parser.add_argument('--seed', type=int, default=123, help='random seed to use') 28 | parser.add_argument('--ckpt_folder', default='./checkpoints', help='Location to save checkpoint models') 29 | parser.add_argument('--tb', default=False, action='store_true', help='Use tensorboardX?') 30 | parser.add_argument('--log_folder', default='./logs', help='Location to save logs') 31 | parser.add_argument('--mode', default='train', help='choose train or test') 32 | parser.add_argument('--restore', default=False, action='store_true', help='Restore and go on training?') 33 | opt = parser.parse_args() 34 | 35 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 36 | 37 | 38 | class Memory: # collected from old policy 39 | def __init__(self): 40 | self.states = [] 41 | self.actions = [] 42 | self.rewards = [] 43 | self.is_terminals = [] 44 | self.logprobs = [] 45 | 46 | def clear_memory(self): 47 | del self.states[:] 48 | del self.actions[:] 49 | del self.rewards[:] 50 | del self.is_terminals[:] 51 | del self.logprobs[:] 52 | 53 | 54 | class ActorCritic(nn.Module): 55 | def __init__(self, state_dim, action_dim): 56 | super(ActorCritic, self).__init__() 57 | 58 | self.actor = nn.Sequential( 59 | nn.Linear(state_dim, 64), 60 | nn.Tanh(), 61 | nn.Linear(64, 32), 62 | nn.Tanh(), 63 | nn.Linear(32, action_dim), 64 | nn.Softmax(dim=-1) # For discrete actions, we use softmax policy 65 | ) 66 | 67 | self.critic = nn.Sequential( 68 | nn.Linear(state_dim, 64), 69 | nn.Tanh(), 70 | nn.Linear(64, 32), 71 | nn.Tanh(), 72 | nn.Linear(32, 1) 73 | ) 74 | 75 | 76 | def act(self, state, memory): # state (1,8) 77 | action_probs = self.actor(state) # (1,4) 78 | dist = Categorical(action_probs) # distribution func: sample an action (return the corresponding index) according to the probs 79 | action = dist.sample() 80 | action_logprob = dist.log_prob(action) # (1,) 81 | 82 | memory.states.append(state) 83 | memory.actions.append(action) 84 | memory.logprobs.append(action_logprob) 85 | # print(action_probs.size(), action_logprob.size(), action.size()) 86 | return action.item() # convert to scalar 87 | 88 | def evaluate(self, state, action): # state (2000, 8); action (2000, 4) 89 | state_value = self.critic(state) # (2000, 1) 90 | 91 | # to calculate action score(logprobs) and distribution entropy 92 | action_probs = self.actor(state) # (2000,4) 93 | dist = Categorical(action_probs) 94 | action_logprobs = dist.log_prob(action) # (2000, 1) 95 | dist_entropy = dist.entropy() 96 | 97 | return action_logprobs, torch.squeeze(state_value), dist_entropy 98 | 99 | 100 | class PPO: 101 | def __init__(self, state_dim, action_dim, lr, betas, gamma, K_epochs, eps_clip, restore=False, ckpt=None): 102 | self.lr = lr 103 | self.betas = betas 104 | self.gamma = gamma 105 | self.eps_clip = eps_clip 106 | self.K_epochs = K_epochs 107 | 108 | # current policy 109 | self.policy = ActorCritic(state_dim, action_dim).to(device) 110 | if restore: 111 | pretained_model = torch.load(ckpt, map_location=lambda storage, loc: storage) 112 | self.policy.load_state_dict(pretained_model) 113 | self.optimizer = torch.optim.Adam(self.policy.parameters(), lr=lr, betas=betas) 114 | 115 | # old policy: initialize old policy with current policy's parameter 116 | self.old_policy = ActorCritic(state_dim, action_dim).to(device) 117 | self.old_policy.load_state_dict(self.policy.state_dict()) 118 | 119 | self.MSE_loss = nn.MSELoss() # to calculate critic loss 120 | 121 | def select_action(self, state, memory): 122 | state = torch.FloatTensor(state.reshape(1, -1)).to(device) # flatten the state 123 | return self.old_policy.act(state, memory) 124 | 125 | def update(self, memory): 126 | # Monte Carlo estimation of rewards 127 | rewards = [] 128 | discounted_reward = 0 129 | for reward, is_terminal in zip(reversed(memory.rewards), reversed(memory.is_terminals)): 130 | if is_terminal: 131 | discounted_reward = 0 132 | discounted_reward = reward + self.gamma * discounted_reward 133 | rewards.insert(0, discounted_reward) 134 | 135 | # Normalize rewards 136 | rewards = torch.tensor(rewards).to(device) 137 | rewards = (rewards - rewards.mean()) / (rewards.std() + 1e-5) 138 | 139 | # convert list to tensor 140 | old_states = torch.squeeze(torch.stack(memory.states).to(device)).detach() 141 | old_actions = torch.squeeze(torch.stack(memory.actions).to(device)).detach() 142 | old_logprobs = torch.squeeze(torch.stack(memory.logprobs)).to(device).detach() 143 | 144 | # Train policy for K epochs: sampling and updating 145 | for _ in range(self.K_epochs): 146 | # Evaluate old actions and values using current policy 147 | logprobs, state_values, dist_entropy = self.policy.evaluate(old_states, old_actions) 148 | 149 | # Importance ratio: p/q 150 | ratios = torch.exp(logprobs - old_logprobs.detach()) 151 | 152 | # Advantages 153 | advantages = rewards - state_values.detach() # old states' rewards - old states' value( evaluated by current policy) 154 | 155 | # Actor loss using Surrogate loss 156 | surr1 = ratios * advantages 157 | surr2 = torch.clamp(ratios, 1 - self.eps_clip, 1 + self.eps_clip) * advantages 158 | actor_loss = - torch.min(surr1, surr2) 159 | 160 | # Critic loss: critic loss - entropy 161 | critic_loss = 0.5 * self.MSE_loss(rewards, state_values) - 0.01 * dist_entropy 162 | 163 | # Total loss 164 | loss = actor_loss + critic_loss 165 | 166 | # Backward gradients 167 | self.optimizer.zero_grad() 168 | loss.mean().backward() 169 | self.optimizer.step() 170 | 171 | # Copy new weights to old_policy 172 | self.old_policy.load_state_dict(self.policy.state_dict()) 173 | 174 | 175 | def train(env_name, env, state_dim, action_dim, render, solved_reward, 176 | max_episodes, max_timesteps, update_timestep, K_epochs, eps_clip, 177 | gamma, lr, betas, ckpt_folder, restore, tb=False, print_interval=10, save_interval=100): 178 | 179 | ckpt = ckpt_folder+'/PPO_discrete_'+env_name+'.pth' 180 | if restore: 181 | print('Load checkpoint from {}'.format(ckpt)) 182 | 183 | memory = Memory() 184 | 185 | ppo = PPO(state_dim, action_dim, lr, betas, gamma, K_epochs, eps_clip, restore=restore, ckpt=ckpt) 186 | 187 | running_reward, avg_length, time_step = 0, 0, 0 188 | 189 | # training loop 190 | for i_episode in range(1, max_episodes+1): 191 | state = env.reset() 192 | for t in range(max_timesteps): 193 | time_step += 1 194 | 195 | # Run old policy 196 | action = ppo.select_action(state, memory) 197 | 198 | state, reward, done, _ = env.step(action) 199 | 200 | memory.rewards.append(reward) 201 | memory.is_terminals.append(done) 202 | 203 | if time_step % update_timestep == 0: 204 | ppo.update(memory) 205 | memory.clear_memory() 206 | time_step = 0 207 | 208 | running_reward += reward 209 | if render: 210 | env.render() 211 | 212 | if done: 213 | break 214 | avg_length += t 215 | 216 | if running_reward > (print_interval * solved_reward): 217 | print("########## Solved! ##########") 218 | torch.save(ppo.policy.state_dict(), ckpt_folder + '/PPO_discrete_{}.pth'.format(env_name)) 219 | print('Save a checkpoint!') 220 | break 221 | 222 | if i_episode % save_interval == 0: 223 | torch.save(ppo.policy.state_dict(), ckpt_folder + '/PPO_discrete_{}.pth'.format(env_name)) 224 | print('Save a checkpoint!') 225 | 226 | if i_episode % print_interval == 0: 227 | avg_length = int(avg_length / print_interval) 228 | running_reward = int((running_reward / print_interval)) 229 | 230 | print('Episode {} \t Avg length: {} \t Avg reward: {}'.format(i_episode, avg_length, running_reward)) 231 | 232 | if tb: 233 | writer.add_scalar('scalar/reward', running_reward, i_episode) 234 | writer.add_scalar('scalar/length', avg_length, i_episode) 235 | 236 | running_reward, avg_length = 0, 0 237 | 238 | def test(env_name, env, state_dim, action_dim, render, K_epochs, eps_clip, gamma, lr, betas, ckpt_folder, test_episodes): 239 | 240 | ckpt = ckpt_folder+'/PPO_discrete_'+env_name+'.pth' 241 | print('Load checkpoint from {}'.format(ckpt)) 242 | 243 | memory = Memory() 244 | 245 | ppo = PPO(state_dim, action_dim, lr, betas, gamma, K_epochs, eps_clip, restore=True, ckpt=ckpt) 246 | 247 | episode_reward, time_step = 0, 0 248 | avg_episode_reward, avg_length = 0, 0 249 | 250 | # test 251 | for i_episode in range(1, test_episodes+1): 252 | state = env.reset() 253 | while True: 254 | time_step += 1 255 | 256 | # Run old policy 257 | action = ppo.select_action(state, memory) 258 | 259 | state, reward, done, _ = env.step(action) 260 | 261 | episode_reward += reward 262 | 263 | if render: 264 | env.render() 265 | 266 | if done: 267 | print('Episode {} \t Length: {} \t Reward: {}'.format(i_episode, time_step, episode_reward)) 268 | avg_episode_reward += episode_reward 269 | avg_length += time_step 270 | memory.clear_memory() 271 | time_step, episode_reward = 0, 0 272 | break 273 | 274 | print('Test {} episodes DONE!'.format(test_episodes)) 275 | print('Avg episode reward: {} | Avg length: {}'.format(avg_episode_reward/test_episodes, avg_length/test_episodes)) 276 | 277 | 278 | if __name__ == '__main__': 279 | if opt.tb: 280 | writer = SummaryWriter() 281 | 282 | if not os.path.exists(opt.ckpt_folder): 283 | os.mkdir(opt.ckpt_folder) 284 | 285 | print("Random Seed: {}".format(opt.seed)) 286 | torch.manual_seed(opt.seed) 287 | np.random.seed(opt.seed) 288 | 289 | env_name = opt.env 290 | env = gym.make(env_name) 291 | env.seed(opt.seed) 292 | state_dim = env.observation_space.shape[0] 293 | action_dim = env.action_space.n 294 | print('Environment: {}\nState Size: {}\nAction Size: {}\n'.format(env_name, state_dim, action_dim)) 295 | 296 | if opt.mode == 'train': 297 | train(env_name, env, state_dim, action_dim, 298 | render=opt.render, solved_reward=opt.solved_reward, 299 | max_episodes=opt.max_episodes, max_timesteps=opt.max_timesteps, 300 | update_timestep=opt.update_timesteps, K_epochs=opt.K_epochs, 301 | eps_clip=opt.eps_clip, gamma=opt.gamma, lr=opt.lr, 302 | betas=[0.9, 0.990], ckpt_folder=opt.ckpt_folder, restore=opt.restore, 303 | tb=opt.tb, print_interval=opt.print_interval, save_interval=opt.save_interval) 304 | elif opt.mode == 'test': 305 | test(env_name, env, state_dim, action_dim, 306 | render=opt.render, K_epochs=opt.K_epochs, eps_clip=opt.eps_clip, 307 | gamma=opt.gamma, lr=opt.lr, betas=[0.9, 0.990], ckpt_folder=opt.ckpt_folder, test_episodes=100) 308 | else: 309 | raise Exception("Wrong Mode!") 310 | 311 | if opt.tb: 312 | writer.close() 313 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # PyTorch-PPO 2 | 3 | ## Continuous Action Control 4 | Train: 5 | ``` 6 | python PPO_continuous.py 7 | ``` 8 | Test: 9 | ``` 10 | python PPO_continuous.py --mode test 11 | ``` 12 | 13 | ## Discrete Action Control 14 | Train: 15 | ``` 16 | python PPO_discrete.py 17 | ``` 18 | Test: 19 | ``` 20 | python PPO_discrete.py --mode test 21 | ``` 22 | --------------------------------------------------------------------------------