├── .gitignore ├── pyproject.toml ├── README.md ├── ppo.py ├── ppo_atari.py └── ppo_continuous_action.py /.gitignore: -------------------------------------------------------------------------------- 1 | runs 2 | videos 3 | wandb 4 | .DS_Store -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.poetry] 2 | name = "ppo_implementation_deep_dive" 3 | version = "0.1.0" 4 | description = "" 5 | authors = ["Costa Huang "] 6 | 7 | [tool.poetry.dependencies] 8 | python = ">=3.7.1,<3.10" 9 | torch = "^1.7.1" 10 | stable-baselines3 = "^1.1.0" 11 | tensorboard = "^2.5.0" 12 | wandb = "0.12.1" 13 | pyglet = "^1.5.19" 14 | opencv-python = "^4.5.3" 15 | gym = "^0.21.0" 16 | 17 | # Atari-related dependencies 18 | atari-py = {version = "0.2.6", optional = true} 19 | ale-py = {version = "^0.7", optional = true} 20 | AutoROM = {version = "^0.4.2", optional = true, extras = ["accept-rom-license"]} 21 | 22 | # Robotics env dependencies 23 | pybullet = {version = "^3.1.8", optional = true} 24 | free-mujoco-py = {version = "^2.1.6", optional = true} 25 | 26 | [tool.poetry.dev-dependencies] 27 | spyder = {version = "^5.1.1", optional = true} 28 | 29 | [build-system] 30 | requires = ["poetry-core>=1.0.0"] 31 | build-backend = "poetry.core.masonry.api" 32 | 33 | [tool.poetry.extras] 34 | spyder = ["spyder"] 35 | atari = ["ale-py", "AutoROM"] 36 | pybullet = ["pybullet"] 37 | mujoco = ["free-mujoco-py"] 38 | 39 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Deprecation Notice 2 | 3 | This repo is deprecated - please visit our new repo https://github.com/vwxyzjn/ppo-implementation-details and the improved ICLR 2022 blog post on PPO https://iclr-blog-track.github.io/2022/03/25/ppo-implementation-details/ 4 | 5 | ## PPO-Implementation-Deep-Dive 6 | 7 | This repo contains the source code for the PPO Implementation Deep Dive tutorial series. 8 | 9 | 1. Proximal Policy Optimization Implementation Deep Dive | 11 Core Implementation Details ([youtu.be/MEt6rrxH8W4](https://youtu.be/MEt6rrxH8W4)) 10 | 2. Proximal Policy Optimization Implementation Deep Dive | 9 Atari-specific Details ([youtu.be/05RMTj-2K_Y](https://youtu.be/05RMTj-2K_Y)) 11 | 3. Proximal Policy Optimization Implementation Deep Dive | 8 Details for Continuous Actions ([youtu.be/BvZvx7ENZBw](https://youtu.be/BvZvx7ENZBw)) 12 | 13 | ![image](https://user-images.githubusercontent.com/5555347/144305162-435cf10f-780a-4681-bb7e-95b84f4e0146.png) 14 | 15 | 16 | 17 | You can find out where theses implementation details come from by visiting 18 | my [blog post](https://costa.sh/blog-the-32-implementation-details-of-ppo.html), which contains 19 | github permanent links of the details to the original implementation. 20 | 21 | If you like this repo, consider also checking out [CleanRL](https://github.com/vwxyzjn/cleanrl), my RL library based on single-file implementations. 22 | 23 | 24 | ## Get started 25 | 26 | Prerequisites: 27 | * Python 3.8+ 28 | * [Poetry](https://python-poetry.org) 29 | 30 | Install dependencies: 31 | ``` 32 | poetry install 33 | ``` 34 | 35 | Train agents: 36 | ``` 37 | poetry run python ppo.py 38 | ``` 39 | 40 | Train agents with experiment tracking: 41 | ``` 42 | poetry run python ppo.py --track --capture-video 43 | ``` 44 | 45 | ### Atari 46 | Install dependencies: 47 | ``` 48 | poetry install -E atari 49 | ``` 50 | Train agents: 51 | ``` 52 | poetry run python ppo_atari.py 53 | ``` 54 | Train agents with experiment tracking: 55 | ``` 56 | poetry run python ppo_atari.py --track --capture-video 57 | ``` 58 | 59 | 60 | ### Pybullet 61 | Install dependencies: 62 | ``` 63 | poetry install -E pybullet 64 | ``` 65 | Train agents: 66 | ``` 67 | poetry run python ppo_continuous_action.py 68 | ``` 69 | Train agents with experiment tracking: 70 | ``` 71 | poetry run python ppo_continuous_action.py --track --capture-video 72 | ``` 73 | 74 | ### MuJoCo 75 | 76 | !! Note this installation method only works in Linux 77 | 78 | Install dependencies: 79 | ``` 80 | poetry install -E mujoco 81 | poetry run python -c "import mujoco_py" 82 | ``` 83 | Train agents: 84 | ``` 85 | poetry run python ppo_continuous_action.py --gym-id Hopper-v2 86 | ``` 87 | Train agents with experiment tracking: 88 | ``` 89 | poetry run python ppo_continuous_action.py --gym-id Hopper-v2 --track --capture-video 90 | ``` 91 | -------------------------------------------------------------------------------- /ppo.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import random 4 | import time 5 | from distutils.util import strtobool 6 | 7 | import gym 8 | import numpy as np 9 | import torch 10 | import torch.nn as nn 11 | import torch.optim as optim 12 | from torch.distributions.categorical import Categorical 13 | from torch.utils.tensorboard import SummaryWriter 14 | 15 | 16 | def parse_args(): 17 | # fmt: off 18 | parser = argparse.ArgumentParser() 19 | parser.add_argument('--exp-name', type=str, default=os.path.basename(__file__).rstrip(".py"), 20 | help='the name of this experiment') 21 | parser.add_argument('--gym-id', type=str, default="CartPole-v1", 22 | help='the id of the gym environment') 23 | parser.add_argument('--learning-rate', type=float, default=2.5e-4, 24 | help='the learning rate of the optimizer') 25 | parser.add_argument('--seed', type=int, default=1, 26 | help='seed of the experiment') 27 | parser.add_argument('--total-timesteps', type=int, default=25000, 28 | help='total timesteps of the experiments') 29 | parser.add_argument('--torch-deterministic', type=lambda x:bool(strtobool(x)), default=True, nargs='?', const=True, 30 | help='if toggled, `torch.backends.cudnn.deterministic=False`') 31 | parser.add_argument('--cuda', type=lambda x:bool(strtobool(x)), default=True, nargs='?', const=True, 32 | help='if toggled, cuda will be enabled by default') 33 | parser.add_argument('--track', type=lambda x:bool(strtobool(x)), default=False, nargs='?', const=True, 34 | help='if toggled, this experiment will be tracked with Weights and Biases') 35 | parser.add_argument('--wandb-project-name', type=str, default="cleanRL", 36 | help="the wandb's project name") 37 | parser.add_argument('--wandb-entity', type=str, default=None, 38 | help="the entity (team) of wandb's project") 39 | parser.add_argument('--capture-video', type=lambda x:bool(strtobool(x)), default=False, nargs='?', const=True, 40 | help='weather to capture videos of the agent performances (check out `videos` folder)') 41 | 42 | # Algorithm specific arguments 43 | parser.add_argument('--num-envs', type=int, default=4, 44 | help='the number of parallel game environments') 45 | parser.add_argument('--num-steps', type=int, default=128, 46 | help='the number of steps to run in each environment per policy rollout') 47 | parser.add_argument('--anneal-lr', type=lambda x:bool(strtobool(x)), default=True, nargs='?', const=True, 48 | help="Toggle learning rate annealing for policy and value networks") 49 | parser.add_argument('--gae', type=lambda x:bool(strtobool(x)), default=True, nargs='?', const=True, 50 | help='Use GAE for advantage computation') 51 | parser.add_argument('--gamma', type=float, default=0.99, 52 | help='the discount factor gamma') 53 | parser.add_argument('--gae-lambda', type=float, default=0.95, 54 | help='the lambda for the general advantage estimation') 55 | parser.add_argument('--num-minibatches', type=int, default=4, 56 | help='the number of mini-batches') 57 | parser.add_argument('--update-epochs', type=int, default=4, 58 | help="the K epochs to update the policy") 59 | parser.add_argument('--norm-adv', type=lambda x:bool(strtobool(x)), default=True, nargs='?', const=True, 60 | help="Toggles advantages normalization") 61 | parser.add_argument('--clip-coef', type=float, default=0.2, 62 | help="the surrogate clipping coefficient") 63 | parser.add_argument('--clip-vloss', type=lambda x:bool(strtobool(x)), default=True, nargs='?', const=True, 64 | help='Toggles wheter or not to use a clipped loss for the value function, as per the paper.') 65 | parser.add_argument('--ent-coef', type=float, default=0.01, 66 | help="coefficient of the entropy") 67 | parser.add_argument('--vf-coef', type=float, default=0.5, 68 | help="coefficient of the value function") 69 | parser.add_argument('--max-grad-norm', type=float, default=0.5, 70 | help='the maximum norm for the gradient clipping') 71 | parser.add_argument('--target-kl', type=float, default=None, 72 | help='the target KL divergence threshold') 73 | args = parser.parse_args() 74 | args.batch_size = int(args.num_envs * args.num_steps) 75 | args.minibatch_size = int(args.batch_size // args.num_minibatches) 76 | # fmt: on 77 | return args 78 | 79 | 80 | def make_env(gym_id, seed, idx, capture_video, run_name): 81 | def thunk(): 82 | env = gym.make(gym_id) 83 | env = gym.wrappers.RecordEpisodeStatistics(env) 84 | if capture_video: 85 | if idx == 0: 86 | env = gym.wrappers.RecordVideo(env, f"videos/{run_name}") 87 | env.seed(seed) 88 | env.action_space.seed(seed) 89 | env.observation_space.seed(seed) 90 | return env 91 | 92 | return thunk 93 | 94 | 95 | def layer_init(layer, std=np.sqrt(2), bias_const=0.0): 96 | torch.nn.init.orthogonal_(layer.weight, std) 97 | torch.nn.init.constant_(layer.bias, bias_const) 98 | return layer 99 | 100 | 101 | class Agent(nn.Module): 102 | def __init__(self, envs): 103 | super(Agent, self).__init__() 104 | self.critic = nn.Sequential( 105 | layer_init(nn.Linear(np.array(envs.single_observation_space.shape).prod(), 64)), 106 | nn.Tanh(), 107 | layer_init(nn.Linear(64, 64)), 108 | nn.Tanh(), 109 | layer_init(nn.Linear(64, 1), std=1.0), 110 | ) 111 | self.actor = nn.Sequential( 112 | layer_init(nn.Linear(np.array(envs.single_observation_space.shape).prod(), 64)), 113 | nn.Tanh(), 114 | layer_init(nn.Linear(64, 64)), 115 | nn.Tanh(), 116 | layer_init(nn.Linear(64, envs.single_action_space.n), std=0.01), 117 | ) 118 | 119 | def get_value(self, x): 120 | return self.critic(x) 121 | 122 | def get_action_and_value(self, x, action=None): 123 | logits = self.actor(x) 124 | probs = Categorical(logits=logits) 125 | if action is None: 126 | action = probs.sample() 127 | return action, probs.log_prob(action), probs.entropy(), self.critic(x) 128 | 129 | 130 | if __name__ == "__main__": 131 | args = parse_args() 132 | run_name = f"{args.gym_id}__{args.exp_name}__{args.seed}__{int(time.time())}" 133 | if args.track: 134 | import wandb 135 | 136 | wandb.init( 137 | project=args.wandb_project_name, 138 | entity=args.wandb_entity, 139 | sync_tensorboard=True, 140 | config=vars(args), 141 | name=run_name, 142 | monitor_gym=True, 143 | save_code=True, 144 | ) 145 | writer = SummaryWriter(f"runs/{run_name}") 146 | writer.add_text( 147 | "hyperparameters", 148 | "|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])), 149 | ) 150 | 151 | # TRY NOT TO MODIFY: seeding 152 | random.seed(args.seed) 153 | np.random.seed(args.seed) 154 | torch.manual_seed(args.seed) 155 | torch.backends.cudnn.deterministic = args.torch_deterministic 156 | 157 | device = torch.device("cuda" if torch.cuda.is_available() and args.cuda else "cpu") 158 | 159 | # env setup 160 | envs = gym.vector.SyncVectorEnv( 161 | [make_env(args.gym_id, args.seed + i, i, args.capture_video, run_name) for i in range(args.num_envs)] 162 | ) 163 | assert isinstance(envs.single_action_space, gym.spaces.Discrete), "only discrete action space is supported" 164 | 165 | agent = Agent(envs).to(device) 166 | optimizer = optim.Adam(agent.parameters(), lr=args.learning_rate, eps=1e-5) 167 | 168 | # ALGO Logic: Storage setup 169 | obs = torch.zeros((args.num_steps, args.num_envs) + envs.single_observation_space.shape).to(device) 170 | actions = torch.zeros((args.num_steps, args.num_envs) + envs.single_action_space.shape).to(device) 171 | logprobs = torch.zeros((args.num_steps, args.num_envs)).to(device) 172 | rewards = torch.zeros((args.num_steps, args.num_envs)).to(device) 173 | dones = torch.zeros((args.num_steps, args.num_envs)).to(device) 174 | values = torch.zeros((args.num_steps, args.num_envs)).to(device) 175 | 176 | # TRY NOT TO MODIFY: start the game 177 | global_step = 0 178 | start_time = time.time() 179 | next_obs = torch.Tensor(envs.reset()).to(device) 180 | next_done = torch.zeros(args.num_envs).to(device) 181 | num_updates = args.total_timesteps // args.batch_size 182 | 183 | for update in range(1, num_updates + 1): 184 | # Annealing the rate if instructed to do so. 185 | if args.anneal_lr: 186 | frac = 1.0 - (update - 1.0) / num_updates 187 | lrnow = frac * args.learning_rate 188 | optimizer.param_groups[0]["lr"] = lrnow 189 | 190 | for step in range(0, args.num_steps): 191 | global_step += 1 * args.num_envs 192 | obs[step] = next_obs 193 | dones[step] = next_done 194 | 195 | # ALGO LOGIC: action logic 196 | with torch.no_grad(): 197 | action, logprob, _, value = agent.get_action_and_value(next_obs) 198 | values[step] = value.flatten() 199 | actions[step] = action 200 | logprobs[step] = logprob 201 | 202 | # TRY NOT TO MODIFY: execute the game and log data. 203 | next_obs, reward, done, info = envs.step(action.cpu().numpy()) 204 | rewards[step] = torch.tensor(reward).to(device).view(-1) 205 | next_obs, next_done = torch.Tensor(next_obs).to(device), torch.Tensor(done).to(device) 206 | 207 | for item in info: 208 | if "episode" in item.keys(): 209 | print(f"global_step={global_step}, episodic_return={item['episode']['r']}") 210 | writer.add_scalar("charts/episodic_return", item["episode"]["r"], global_step) 211 | writer.add_scalar("charts/episodic_length", item["episode"]["l"], global_step) 212 | break 213 | 214 | # bootstrap value if not done 215 | with torch.no_grad(): 216 | next_value = agent.get_value(next_obs).reshape(1, -1) 217 | if args.gae: 218 | advantages = torch.zeros_like(rewards).to(device) 219 | lastgaelam = 0 220 | for t in reversed(range(args.num_steps)): 221 | if t == args.num_steps - 1: 222 | nextnonterminal = 1.0 - next_done 223 | nextvalues = next_value 224 | else: 225 | nextnonterminal = 1.0 - dones[t + 1] 226 | nextvalues = values[t + 1] 227 | delta = rewards[t] + args.gamma * nextvalues * nextnonterminal - values[t] 228 | advantages[t] = lastgaelam = delta + args.gamma * args.gae_lambda * nextnonterminal * lastgaelam 229 | returns = advantages + values 230 | else: 231 | returns = torch.zeros_like(rewards).to(device) 232 | for t in reversed(range(args.num_steps)): 233 | if t == args.num_steps - 1: 234 | nextnonterminal = 1.0 - next_done 235 | next_return = next_value 236 | else: 237 | nextnonterminal = 1.0 - dones[t + 1] 238 | next_return = returns[t + 1] 239 | returns[t] = rewards[t] + args.gamma * nextnonterminal * next_return 240 | advantages = returns - values 241 | 242 | # flatten the batch 243 | b_obs = obs.reshape((-1,) + envs.single_observation_space.shape) 244 | b_logprobs = logprobs.reshape(-1) 245 | b_actions = actions.reshape((-1,) + envs.single_action_space.shape) 246 | b_advantages = advantages.reshape(-1) 247 | b_returns = returns.reshape(-1) 248 | b_values = values.reshape(-1) 249 | 250 | # Optimizaing the policy and value network 251 | b_inds = np.arange(args.batch_size) 252 | clipfracs = [] 253 | for epoch in range(args.update_epochs): 254 | np.random.shuffle(b_inds) 255 | for start in range(0, args.batch_size, args.minibatch_size): 256 | end = start + args.minibatch_size 257 | mb_inds = b_inds[start:end] 258 | 259 | _, newlogprob, entropy, newvalue = agent.get_action_and_value(b_obs[mb_inds], b_actions.long()[mb_inds]) 260 | logratio = newlogprob - b_logprobs[mb_inds] 261 | ratio = logratio.exp() 262 | 263 | with torch.no_grad(): 264 | # calculate approx_kl http://joschu.net/blog/kl-approx.html 265 | # old_approx_kl = (-logratio).mean() 266 | approx_kl = ((ratio - 1) - logratio).mean() 267 | clipfracs += [((ratio - 1.0).abs() > args.clip_coef).float().mean().item()] 268 | 269 | mb_advantages = b_advantages[mb_inds] 270 | if args.norm_adv: 271 | mb_advantages = (mb_advantages - mb_advantages.mean()) / (mb_advantages.std() + 1e-8) 272 | 273 | # Policy loss 274 | pg_loss1 = -mb_advantages * ratio 275 | pg_loss2 = -mb_advantages * torch.clamp(ratio, 1 - args.clip_coef, 1 + args.clip_coef) 276 | pg_loss = torch.max(pg_loss1, pg_loss2).mean() 277 | 278 | # Value loss 279 | newvalue = newvalue.view(-1) 280 | if args.clip_vloss: 281 | v_loss_unclipped = (newvalue - b_returns[mb_inds]) ** 2 282 | v_clipped = b_values[mb_inds] + torch.clamp( 283 | newvalue - b_values[mb_inds], 284 | -args.clip_coef, 285 | args.clip_coef, 286 | ) 287 | v_loss_clipped = (v_clipped - b_returns[mb_inds]) ** 2 288 | v_loss_max = torch.max(v_loss_unclipped, v_loss_clipped) 289 | v_loss = 0.5 * v_loss_max.mean() 290 | else: 291 | v_loss = 0.5 * ((newvalue - b_returns[mb_inds]) ** 2).mean() 292 | 293 | entropy_loss = entropy.mean() 294 | loss = pg_loss - args.ent_coef * entropy_loss + v_loss * args.vf_coef 295 | 296 | optimizer.zero_grad() 297 | loss.backward() 298 | nn.utils.clip_grad_norm_(agent.parameters(), args.max_grad_norm) 299 | optimizer.step() 300 | 301 | if args.target_kl is not None: 302 | if approx_kl > args.target_kl: 303 | break 304 | 305 | y_pred, y_true = b_values.cpu().numpy(), b_returns.cpu().numpy() 306 | var_y = np.var(y_true) 307 | explained_var = np.nan if var_y == 0 else 1 - np.var(y_true - y_pred) / var_y 308 | 309 | # TRY NOT TO MODIFY: record rewards for plotting purposes 310 | writer.add_scalar("charts/learning_rate", optimizer.param_groups[0]["lr"], global_step) 311 | writer.add_scalar("losses/value_loss", v_loss.item(), global_step) 312 | writer.add_scalar("losses/policy_loss", pg_loss.item(), global_step) 313 | writer.add_scalar("losses/entropy", entropy_loss.item(), global_step) 314 | writer.add_scalar("losses/approx_kl", approx_kl.item(), global_step) 315 | writer.add_scalar("losses/clipfrac", np.mean(clipfracs), global_step) 316 | writer.add_scalar("losses/explained_variance", explained_var, global_step) 317 | print("SPS:", int(global_step / (time.time() - start_time))) 318 | writer.add_scalar("charts/SPS", int(global_step / (time.time() - start_time)), global_step) 319 | 320 | envs.close() 321 | writer.close() 322 | -------------------------------------------------------------------------------- /ppo_atari.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import random 4 | import time 5 | from distutils.util import strtobool 6 | 7 | import gym 8 | import numpy as np 9 | import torch 10 | import torch.nn as nn 11 | import torch.optim as optim 12 | from stable_baselines3.common.atari_wrappers import ( 13 | ClipRewardEnv, 14 | EpisodicLifeEnv, 15 | FireResetEnv, 16 | MaxAndSkipEnv, 17 | NoopResetEnv, 18 | ) 19 | from torch.distributions.categorical import Categorical 20 | from torch.utils.tensorboard import SummaryWriter 21 | 22 | 23 | def parse_args(): 24 | # fmt: off 25 | parser = argparse.ArgumentParser() 26 | parser.add_argument('--exp-name', type=str, default=os.path.basename(__file__).rstrip(".py"), 27 | help='the name of this experiment') 28 | parser.add_argument('--gym-id', type=str, default="BreakoutNoFrameskip-v4", 29 | help='the id of the gym environment') 30 | parser.add_argument('--learning-rate', type=float, default=2.5e-4, 31 | help='the learning rate of the optimizer') 32 | parser.add_argument('--seed', type=int, default=1, 33 | help='seed of the experiment') 34 | parser.add_argument('--total-timesteps', type=int, default=10000000, 35 | help='total timesteps of the experiments') 36 | parser.add_argument('--torch-deterministic', type=lambda x:bool(strtobool(x)), default=True, nargs='?', const=True, 37 | help='if toggled, `torch.backends.cudnn.deterministic=False`') 38 | parser.add_argument('--cuda', type=lambda x:bool(strtobool(x)), default=True, nargs='?', const=True, 39 | help='if toggled, cuda will be enabled by default') 40 | parser.add_argument('--track', type=lambda x:bool(strtobool(x)), default=False, nargs='?', const=True, 41 | help='if toggled, this experiment will be tracked with Weights and Biases') 42 | parser.add_argument('--wandb-project-name', type=str, default="cleanRL", 43 | help="the wandb's project name") 44 | parser.add_argument('--wandb-entity', type=str, default=None, 45 | help="the entity (team) of wandb's project") 46 | parser.add_argument('--capture-video', type=lambda x:bool(strtobool(x)), default=False, nargs='?', const=True, 47 | help='weather to capture videos of the agent performances (check out `videos` folder)') 48 | 49 | # Algorithm specific arguments 50 | parser.add_argument('--num-envs', type=int, default=8, 51 | help='the number of parallel game environments') 52 | parser.add_argument('--num-steps', type=int, default=128, 53 | help='the number of steps to run in each environment per policy rollout') 54 | parser.add_argument('--anneal-lr', type=lambda x:bool(strtobool(x)), default=True, nargs='?', const=True, 55 | help="Toggle learning rate annealing for policy and value networks") 56 | parser.add_argument('--gae', type=lambda x:bool(strtobool(x)), default=True, nargs='?', const=True, 57 | help='Use GAE for advantage computation') 58 | parser.add_argument('--gamma', type=float, default=0.99, 59 | help='the discount factor gamma') 60 | parser.add_argument('--gae-lambda', type=float, default=0.95, 61 | help='the lambda for the general advantage estimation') 62 | parser.add_argument('--num-minibatches', type=int, default=4, 63 | help='the number of mini-batches') 64 | parser.add_argument('--update-epochs', type=int, default=4, 65 | help="the K epochs to update the policy") 66 | parser.add_argument('--norm-adv', type=lambda x:bool(strtobool(x)), default=True, nargs='?', const=True, 67 | help="Toggles advantages normalization") 68 | parser.add_argument('--clip-coef', type=float, default=0.1, 69 | help="the surrogate clipping coefficient") 70 | parser.add_argument('--clip-vloss', type=lambda x:bool(strtobool(x)), default=True, nargs='?', const=True, 71 | help='Toggles wheter or not to use a clipped loss for the value function, as per the paper.') 72 | parser.add_argument('--ent-coef', type=float, default=0.01, 73 | help="coefficient of the entropy") 74 | parser.add_argument('--vf-coef', type=float, default=0.5, 75 | help="coefficient of the value function") 76 | parser.add_argument('--max-grad-norm', type=float, default=0.5, 77 | help='the maximum norm for the gradient clipping') 78 | parser.add_argument('--target-kl', type=float, default=None, 79 | help='the target KL divergence threshold') 80 | args = parser.parse_args() 81 | args.batch_size = int(args.num_envs * args.num_steps) 82 | args.minibatch_size = int(args.batch_size // args.num_minibatches) 83 | # fmt: on 84 | return args 85 | 86 | 87 | def make_env(gym_id, seed, idx, capture_video, run_name): 88 | def thunk(): 89 | env = gym.make(gym_id) 90 | env = gym.wrappers.RecordEpisodeStatistics(env) 91 | if capture_video: 92 | if idx == 0: 93 | env = gym.wrappers.RecordVideo(env, f"videos/{run_name}") 94 | env = NoopResetEnv(env, noop_max=30) 95 | env = MaxAndSkipEnv(env, skip=4) 96 | env = EpisodicLifeEnv(env) 97 | if "FIRE" in env.unwrapped.get_action_meanings(): 98 | env = FireResetEnv(env) 99 | env = ClipRewardEnv(env) 100 | env = gym.wrappers.ResizeObservation(env, (84, 84)) 101 | env = gym.wrappers.GrayScaleObservation(env) 102 | env = gym.wrappers.FrameStack(env, 4) 103 | env.seed(seed) 104 | env.action_space.seed(seed) 105 | env.observation_space.seed(seed) 106 | return env 107 | 108 | return thunk 109 | 110 | 111 | def layer_init(layer, std=np.sqrt(2), bias_const=0.0): 112 | torch.nn.init.orthogonal_(layer.weight, std) 113 | torch.nn.init.constant_(layer.bias, bias_const) 114 | return layer 115 | 116 | 117 | class Agent(nn.Module): 118 | def __init__(self, envs): 119 | super(Agent, self).__init__() 120 | self.network = nn.Sequential( 121 | layer_init(nn.Conv2d(4, 32, 8, stride=4)), 122 | nn.ReLU(), 123 | layer_init(nn.Conv2d(32, 64, 4, stride=2)), 124 | nn.ReLU(), 125 | layer_init(nn.Conv2d(64, 64, 3, stride=1)), 126 | nn.ReLU(), 127 | nn.Flatten(), 128 | layer_init(nn.Linear(64 * 7 * 7, 512)), 129 | nn.ReLU(), 130 | ) 131 | self.actor = layer_init(nn.Linear(512, envs.single_action_space.n), std=0.01) 132 | self.critic = layer_init(nn.Linear(512, 1), std=1) 133 | 134 | def get_value(self, x): 135 | return self.critic(self.network(x / 255.0)) 136 | 137 | def get_action_and_value(self, x, action=None): 138 | hidden = self.network(x / 255.0) 139 | logits = self.actor(hidden) 140 | probs = Categorical(logits=logits) 141 | if action is None: 142 | action = probs.sample() 143 | return action, probs.log_prob(action), probs.entropy(), self.critic(hidden) 144 | 145 | 146 | if __name__ == "__main__": 147 | args = parse_args() 148 | run_name = f"{args.gym_id}__{args.exp_name}__{args.seed}__{int(time.time())}" 149 | if args.track: 150 | import wandb 151 | 152 | wandb.init( 153 | project=args.wandb_project_name, 154 | entity=args.wandb_entity, 155 | sync_tensorboard=True, 156 | config=vars(args), 157 | name=run_name, 158 | monitor_gym=True, 159 | save_code=True, 160 | ) 161 | writer = SummaryWriter(f"runs/{run_name}") 162 | writer.add_text( 163 | "hyperparameters", 164 | "|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])), 165 | ) 166 | 167 | # TRY NOT TO MODIFY: seeding 168 | random.seed(args.seed) 169 | np.random.seed(args.seed) 170 | torch.manual_seed(args.seed) 171 | torch.backends.cudnn.deterministic = args.torch_deterministic 172 | 173 | device = torch.device("cuda" if torch.cuda.is_available() and args.cuda else "cpu") 174 | 175 | # env setup 176 | envs = gym.vector.SyncVectorEnv( 177 | [make_env(args.gym_id, args.seed + i, i, args.capture_video, run_name) for i in range(args.num_envs)] 178 | ) 179 | assert isinstance(envs.single_action_space, gym.spaces.Discrete), "only discrete action space is supported" 180 | 181 | agent = Agent(envs).to(device) 182 | optimizer = optim.Adam(agent.parameters(), lr=args.learning_rate, eps=1e-5) 183 | 184 | # ALGO Logic: Storage setup 185 | obs = torch.zeros((args.num_steps, args.num_envs) + envs.single_observation_space.shape).to(device) 186 | actions = torch.zeros((args.num_steps, args.num_envs) + envs.single_action_space.shape).to(device) 187 | logprobs = torch.zeros((args.num_steps, args.num_envs)).to(device) 188 | rewards = torch.zeros((args.num_steps, args.num_envs)).to(device) 189 | dones = torch.zeros((args.num_steps, args.num_envs)).to(device) 190 | values = torch.zeros((args.num_steps, args.num_envs)).to(device) 191 | 192 | # TRY NOT TO MODIFY: start the game 193 | global_step = 0 194 | start_time = time.time() 195 | next_obs = torch.Tensor(envs.reset()).to(device) 196 | next_done = torch.zeros(args.num_envs).to(device) 197 | num_updates = args.total_timesteps // args.batch_size 198 | 199 | for update in range(1, num_updates + 1): 200 | # Annealing the rate if instructed to do so. 201 | if args.anneal_lr: 202 | frac = 1.0 - (update - 1.0) / num_updates 203 | lrnow = frac * args.learning_rate 204 | optimizer.param_groups[0]["lr"] = lrnow 205 | 206 | for step in range(0, args.num_steps): 207 | global_step += 1 * args.num_envs 208 | obs[step] = next_obs 209 | dones[step] = next_done 210 | 211 | # ALGO LOGIC: action logic 212 | with torch.no_grad(): 213 | action, logprob, _, value = agent.get_action_and_value(next_obs) 214 | values[step] = value.flatten() 215 | actions[step] = action 216 | logprobs[step] = logprob 217 | 218 | # TRY NOT TO MODIFY: execute the game and log data. 219 | next_obs, reward, done, info = envs.step(action.cpu().numpy()) 220 | rewards[step] = torch.tensor(reward).to(device).view(-1) 221 | next_obs, next_done = torch.Tensor(next_obs).to(device), torch.Tensor(done).to(device) 222 | 223 | for item in info: 224 | if "episode" in item.keys(): 225 | print(f"global_step={global_step}, episodic_return={item['episode']['r']}") 226 | writer.add_scalar("charts/episodic_return", item["episode"]["r"], global_step) 227 | writer.add_scalar("charts/episodic_length", item["episode"]["l"], global_step) 228 | break 229 | 230 | # bootstrap value if not done 231 | with torch.no_grad(): 232 | next_value = agent.get_value(next_obs).reshape(1, -1) 233 | if args.gae: 234 | advantages = torch.zeros_like(rewards).to(device) 235 | lastgaelam = 0 236 | for t in reversed(range(args.num_steps)): 237 | if t == args.num_steps - 1: 238 | nextnonterminal = 1.0 - next_done 239 | nextvalues = next_value 240 | else: 241 | nextnonterminal = 1.0 - dones[t + 1] 242 | nextvalues = values[t + 1] 243 | delta = rewards[t] + args.gamma * nextvalues * nextnonterminal - values[t] 244 | advantages[t] = lastgaelam = delta + args.gamma * args.gae_lambda * nextnonterminal * lastgaelam 245 | returns = advantages + values 246 | else: 247 | returns = torch.zeros_like(rewards).to(device) 248 | for t in reversed(range(args.num_steps)): 249 | if t == args.num_steps - 1: 250 | nextnonterminal = 1.0 - next_done 251 | next_return = next_value 252 | else: 253 | nextnonterminal = 1.0 - dones[t + 1] 254 | next_return = returns[t + 1] 255 | returns[t] = rewards[t] + args.gamma * nextnonterminal * next_return 256 | advantages = returns - values 257 | 258 | # flatten the batch 259 | b_obs = obs.reshape((-1,) + envs.single_observation_space.shape) 260 | b_logprobs = logprobs.reshape(-1) 261 | b_actions = actions.reshape((-1,) + envs.single_action_space.shape) 262 | b_advantages = advantages.reshape(-1) 263 | b_returns = returns.reshape(-1) 264 | b_values = values.reshape(-1) 265 | 266 | # Optimizaing the policy and value network 267 | b_inds = np.arange(args.batch_size) 268 | clipfracs = [] 269 | for epoch in range(args.update_epochs): 270 | np.random.shuffle(b_inds) 271 | for start in range(0, args.batch_size, args.minibatch_size): 272 | end = start + args.minibatch_size 273 | mb_inds = b_inds[start:end] 274 | 275 | _, newlogprob, entropy, newvalue = agent.get_action_and_value(b_obs[mb_inds], b_actions.long()[mb_inds]) 276 | logratio = newlogprob - b_logprobs[mb_inds] 277 | ratio = logratio.exp() 278 | 279 | with torch.no_grad(): 280 | # calculate approx_kl http://joschu.net/blog/kl-approx.html 281 | # old_approx_kl = (-logratio).mean() 282 | approx_kl = ((ratio - 1) - logratio).mean() 283 | clipfracs += [((ratio - 1.0).abs() > args.clip_coef).float().mean().item()] 284 | 285 | mb_advantages = b_advantages[mb_inds] 286 | if args.norm_adv: 287 | mb_advantages = (mb_advantages - mb_advantages.mean()) / (mb_advantages.std() + 1e-8) 288 | 289 | # Policy loss 290 | pg_loss1 = -mb_advantages * ratio 291 | pg_loss2 = -mb_advantages * torch.clamp(ratio, 1 - args.clip_coef, 1 + args.clip_coef) 292 | pg_loss = torch.max(pg_loss1, pg_loss2).mean() 293 | 294 | # Value loss 295 | newvalue = newvalue.view(-1) 296 | if args.clip_vloss: 297 | v_loss_unclipped = (newvalue - b_returns[mb_inds]) ** 2 298 | v_clipped = b_values[mb_inds] + torch.clamp( 299 | newvalue - b_values[mb_inds], 300 | -args.clip_coef, 301 | args.clip_coef, 302 | ) 303 | v_loss_clipped = (v_clipped - b_returns[mb_inds]) ** 2 304 | v_loss_max = torch.max(v_loss_unclipped, v_loss_clipped) 305 | v_loss = 0.5 * v_loss_max.mean() 306 | else: 307 | v_loss = 0.5 * ((newvalue - b_returns[mb_inds]) ** 2).mean() 308 | 309 | entropy_loss = entropy.mean() 310 | loss = pg_loss - args.ent_coef * entropy_loss + v_loss * args.vf_coef 311 | 312 | optimizer.zero_grad() 313 | loss.backward() 314 | nn.utils.clip_grad_norm_(agent.parameters(), args.max_grad_norm) 315 | optimizer.step() 316 | 317 | if args.target_kl is not None: 318 | if approx_kl > args.target_kl: 319 | break 320 | 321 | y_pred, y_true = b_values.cpu().numpy(), b_returns.cpu().numpy() 322 | var_y = np.var(y_true) 323 | explained_var = np.nan if var_y == 0 else 1 - np.var(y_true - y_pred) / var_y 324 | 325 | # TRY NOT TO MODIFY: record rewards for plotting purposes 326 | writer.add_scalar("charts/learning_rate", optimizer.param_groups[0]["lr"], global_step) 327 | writer.add_scalar("losses/value_loss", v_loss.item(), global_step) 328 | writer.add_scalar("losses/policy_loss", pg_loss.item(), global_step) 329 | writer.add_scalar("losses/entropy", entropy_loss.item(), global_step) 330 | writer.add_scalar("losses/approx_kl", approx_kl.item(), global_step) 331 | writer.add_scalar("losses/clipfrac", np.mean(clipfracs), global_step) 332 | writer.add_scalar("losses/explained_variance", explained_var, global_step) 333 | print("SPS:", int(global_step / (time.time() - start_time))) 334 | writer.add_scalar("charts/SPS", int(global_step / (time.time() - start_time)), global_step) 335 | 336 | envs.close() 337 | writer.close() 338 | -------------------------------------------------------------------------------- /ppo_continuous_action.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import random 4 | import time 5 | from distutils.util import strtobool 6 | 7 | import gym 8 | import pybullet_envs 9 | import numpy as np 10 | import torch 11 | import torch.nn as nn 12 | import torch.optim as optim 13 | from torch.distributions.categorical import Categorical 14 | from torch.distributions.normal import Normal 15 | from torch.utils.tensorboard import SummaryWriter 16 | 17 | 18 | def parse_args(): 19 | # fmt: off 20 | parser = argparse.ArgumentParser() 21 | parser.add_argument('--exp-name', type=str, default=os.path.basename(__file__).rstrip(".py"), 22 | help='the name of this experiment') 23 | parser.add_argument('--gym-id', type=str, default="HalfCheetahBulletEnv-v0", 24 | help='the id of the gym environment') 25 | parser.add_argument('--learning-rate', type=float, default=3e-4, 26 | help='the learning rate of the optimizer') 27 | parser.add_argument('--seed', type=int, default=1, 28 | help='seed of the experiment') 29 | parser.add_argument('--total-timesteps', type=int, default=2000000, 30 | help='total timesteps of the experiments') 31 | parser.add_argument('--torch-deterministic', type=lambda x:bool(strtobool(x)), default=True, nargs='?', const=True, 32 | help='if toggled, `torch.backends.cudnn.deterministic=False`') 33 | parser.add_argument('--cuda', type=lambda x:bool(strtobool(x)), default=True, nargs='?', const=True, 34 | help='if toggled, cuda will be enabled by default') 35 | parser.add_argument('--track', type=lambda x:bool(strtobool(x)), default=False, nargs='?', const=True, 36 | help='if toggled, this experiment will be tracked with Weights and Biases') 37 | parser.add_argument('--wandb-project-name', type=str, default="cleanRL", 38 | help="the wandb's project name") 39 | parser.add_argument('--wandb-entity', type=str, default=None, 40 | help="the entity (team) of wandb's project") 41 | parser.add_argument('--capture-video', type=lambda x:bool(strtobool(x)), default=False, nargs='?', const=True, 42 | help='weather to capture videos of the agent performances (check out `videos` folder)') 43 | 44 | # Algorithm specific arguments 45 | parser.add_argument('--num-envs', type=int, default=1, 46 | help='the number of parallel game environments') 47 | parser.add_argument('--num-steps', type=int, default=2048, 48 | help='the number of steps to run in each environment per policy rollout') 49 | parser.add_argument('--anneal-lr', type=lambda x:bool(strtobool(x)), default=True, nargs='?', const=True, 50 | help="Toggle learning rate annealing for policy and value networks") 51 | parser.add_argument('--gae', type=lambda x:bool(strtobool(x)), default=True, nargs='?', const=True, 52 | help='Use GAE for advantage computation') 53 | parser.add_argument('--gamma', type=float, default=0.99, 54 | help='the discount factor gamma') 55 | parser.add_argument('--gae-lambda', type=float, default=0.95, 56 | help='the lambda for the general advantage estimation') 57 | parser.add_argument('--num-minibatches', type=int, default=32, 58 | help='the number of mini-batches') 59 | parser.add_argument('--update-epochs', type=int, default=10, 60 | help="the K epochs to update the policy") 61 | parser.add_argument('--norm-adv', type=lambda x:bool(strtobool(x)), default=True, nargs='?', const=True, 62 | help="Toggles advantages normalization") 63 | parser.add_argument('--clip-coef', type=float, default=0.2, 64 | help="the surrogate clipping coefficient") 65 | parser.add_argument('--clip-vloss', type=lambda x:bool(strtobool(x)), default=True, nargs='?', const=True, 66 | help='Toggles wheter or not to use a clipped loss for the value function, as per the paper.') 67 | parser.add_argument('--ent-coef', type=float, default=0.0, 68 | help="coefficient of the entropy") 69 | parser.add_argument('--vf-coef', type=float, default=0.5, 70 | help="coefficient of the value function") 71 | parser.add_argument('--max-grad-norm', type=float, default=0.5, 72 | help='the maximum norm for the gradient clipping') 73 | parser.add_argument('--target-kl', type=float, default=None, 74 | help='the target KL divergence threshold') 75 | args = parser.parse_args() 76 | args.batch_size = int(args.num_envs * args.num_steps) 77 | args.minibatch_size = int(args.batch_size // args.num_minibatches) 78 | # fmt: on 79 | return args 80 | 81 | 82 | def make_env(gym_id, seed, idx, capture_video, run_name): 83 | def thunk(): 84 | env = gym.make(gym_id) 85 | env = gym.wrappers.RecordEpisodeStatistics(env) 86 | if capture_video: 87 | if idx == 0: 88 | env = gym.wrappers.RecordVideo(env, f"videos/{run_name}") 89 | env = gym.wrappers.ClipAction(env) 90 | env = gym.wrappers.NormalizeObservation(env) 91 | env = gym.wrappers.TransformObservation(env, lambda obs: np.clip(obs, -10, 10)) 92 | env = gym.wrappers.NormalizeReward(env) 93 | env = gym.wrappers.TransformReward(env, lambda reward: np.clip(reward, -10, 10)) 94 | env.seed(seed) 95 | env.action_space.seed(seed) 96 | env.observation_space.seed(seed) 97 | return env 98 | 99 | return thunk 100 | 101 | 102 | def layer_init(layer, std=np.sqrt(2), bias_const=0.0): 103 | torch.nn.init.orthogonal_(layer.weight, std) 104 | torch.nn.init.constant_(layer.bias, bias_const) 105 | return layer 106 | 107 | 108 | class Agent(nn.Module): 109 | def __init__(self, envs): 110 | super(Agent, self).__init__() 111 | self.critic = nn.Sequential( 112 | layer_init(nn.Linear(np.array(envs.single_observation_space.shape).prod(), 64)), 113 | nn.Tanh(), 114 | layer_init(nn.Linear(64, 64)), 115 | nn.Tanh(), 116 | layer_init(nn.Linear(64, 1), std=1.0), 117 | ) 118 | self.actor_mean = nn.Sequential( 119 | layer_init(nn.Linear(np.array(envs.single_observation_space.shape).prod(), 64)), 120 | nn.Tanh(), 121 | layer_init(nn.Linear(64, 64)), 122 | nn.Tanh(), 123 | layer_init(nn.Linear(64, np.prod(envs.single_action_space.shape)), std=0.01), 124 | ) 125 | self.actor_logstd = nn.Parameter(torch.zeros(1, np.prod(envs.single_action_space.shape))) 126 | 127 | def get_value(self, x): 128 | return self.critic(x) 129 | 130 | def get_action_and_value(self, x, action=None): 131 | action_mean = self.actor_mean(x) 132 | action_logstd = self.actor_logstd.expand_as(action_mean) 133 | action_std = torch.exp(action_logstd) 134 | probs = Normal(action_mean, action_std) 135 | if action is None: 136 | action = probs.sample() 137 | return action, probs.log_prob(action).sum(1), probs.entropy().sum(1), self.critic(x) 138 | 139 | 140 | if __name__ == "__main__": 141 | args = parse_args() 142 | run_name = f"{args.gym_id}__{args.exp_name}__{args.seed}__{int(time.time())}" 143 | if args.track: 144 | import wandb 145 | 146 | wandb.init( 147 | project=args.wandb_project_name, 148 | entity=args.wandb_entity, 149 | sync_tensorboard=True, 150 | config=vars(args), 151 | name=run_name, 152 | monitor_gym=True, 153 | save_code=True, 154 | ) 155 | writer = SummaryWriter(f"runs/{run_name}") 156 | writer.add_text( 157 | "hyperparameters", 158 | "|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])), 159 | ) 160 | 161 | # TRY NOT TO MODIFY: seeding 162 | random.seed(args.seed) 163 | np.random.seed(args.seed) 164 | torch.manual_seed(args.seed) 165 | torch.backends.cudnn.deterministic = args.torch_deterministic 166 | 167 | device = torch.device("cuda" if torch.cuda.is_available() and args.cuda else "cpu") 168 | 169 | # env setup 170 | envs = gym.vector.SyncVectorEnv( 171 | [make_env(args.gym_id, args.seed + i, i, args.capture_video, run_name) for i in range(args.num_envs)] 172 | ) 173 | assert isinstance(envs.single_action_space, gym.spaces.Box), "only continuous action space is supported" 174 | 175 | agent = Agent(envs).to(device) 176 | optimizer = optim.Adam(agent.parameters(), lr=args.learning_rate, eps=1e-5) 177 | 178 | # ALGO Logic: Storage setup 179 | obs = torch.zeros((args.num_steps, args.num_envs) + envs.single_observation_space.shape).to(device) 180 | actions = torch.zeros((args.num_steps, args.num_envs) + envs.single_action_space.shape).to(device) 181 | logprobs = torch.zeros((args.num_steps, args.num_envs)).to(device) 182 | rewards = torch.zeros((args.num_steps, args.num_envs)).to(device) 183 | dones = torch.zeros((args.num_steps, args.num_envs)).to(device) 184 | values = torch.zeros((args.num_steps, args.num_envs)).to(device) 185 | 186 | # TRY NOT TO MODIFY: start the game 187 | global_step = 0 188 | start_time = time.time() 189 | next_obs = torch.Tensor(envs.reset()).to(device) 190 | next_done = torch.zeros(args.num_envs).to(device) 191 | num_updates = args.total_timesteps // args.batch_size 192 | 193 | for update in range(1, num_updates + 1): 194 | # Annealing the rate if instructed to do so. 195 | if args.anneal_lr: 196 | frac = 1.0 - (update - 1.0) / num_updates 197 | lrnow = frac * args.learning_rate 198 | optimizer.param_groups[0]["lr"] = lrnow 199 | 200 | for step in range(0, args.num_steps): 201 | global_step += 1 * args.num_envs 202 | obs[step] = next_obs 203 | dones[step] = next_done 204 | 205 | # ALGO LOGIC: action logic 206 | with torch.no_grad(): 207 | action, logprob, _, value = agent.get_action_and_value(next_obs) 208 | values[step] = value.flatten() 209 | actions[step] = action 210 | logprobs[step] = logprob 211 | 212 | # TRY NOT TO MODIFY: execute the game and log data. 213 | next_obs, reward, done, info = envs.step(action.cpu().numpy()) 214 | rewards[step] = torch.tensor(reward).to(device).view(-1) 215 | next_obs, next_done = torch.Tensor(next_obs).to(device), torch.Tensor(done).to(device) 216 | 217 | for item in info: 218 | if "episode" in item.keys(): 219 | print(f"global_step={global_step}, episodic_return={item['episode']['r']}") 220 | writer.add_scalar("charts/episodic_return", item["episode"]["r"], global_step) 221 | writer.add_scalar("charts/episodic_length", item["episode"]["l"], global_step) 222 | break 223 | 224 | # bootstrap value if not done 225 | with torch.no_grad(): 226 | next_value = agent.get_value(next_obs).reshape(1, -1) 227 | if args.gae: 228 | advantages = torch.zeros_like(rewards).to(device) 229 | lastgaelam = 0 230 | for t in reversed(range(args.num_steps)): 231 | if t == args.num_steps - 1: 232 | nextnonterminal = 1.0 - next_done 233 | nextvalues = next_value 234 | else: 235 | nextnonterminal = 1.0 - dones[t + 1] 236 | nextvalues = values[t + 1] 237 | delta = rewards[t] + args.gamma * nextvalues * nextnonterminal - values[t] 238 | advantages[t] = lastgaelam = delta + args.gamma * args.gae_lambda * nextnonterminal * lastgaelam 239 | returns = advantages + values 240 | else: 241 | returns = torch.zeros_like(rewards).to(device) 242 | for t in reversed(range(args.num_steps)): 243 | if t == args.num_steps - 1: 244 | nextnonterminal = 1.0 - next_done 245 | next_return = next_value 246 | else: 247 | nextnonterminal = 1.0 - dones[t + 1] 248 | next_return = returns[t + 1] 249 | returns[t] = rewards[t] + args.gamma * nextnonterminal * next_return 250 | advantages = returns - values 251 | 252 | # flatten the batch 253 | b_obs = obs.reshape((-1,) + envs.single_observation_space.shape) 254 | b_logprobs = logprobs.reshape(-1) 255 | b_actions = actions.reshape((-1,) + envs.single_action_space.shape) 256 | b_advantages = advantages.reshape(-1) 257 | b_returns = returns.reshape(-1) 258 | b_values = values.reshape(-1) 259 | 260 | # Optimizaing the policy and value network 261 | b_inds = np.arange(args.batch_size) 262 | clipfracs = [] 263 | for epoch in range(args.update_epochs): 264 | np.random.shuffle(b_inds) 265 | for start in range(0, args.batch_size, args.minibatch_size): 266 | end = start + args.minibatch_size 267 | mb_inds = b_inds[start:end] 268 | 269 | _, newlogprob, entropy, newvalue = agent.get_action_and_value(b_obs[mb_inds], b_actions[mb_inds]) 270 | logratio = newlogprob - b_logprobs[mb_inds] 271 | ratio = logratio.exp() 272 | 273 | with torch.no_grad(): 274 | # calculate approx_kl http://joschu.net/blog/kl-approx.html 275 | # old_approx_kl = (-logratio).mean() 276 | approx_kl = ((ratio - 1) - logratio).mean() 277 | clipfracs += [((ratio - 1.0).abs() > args.clip_coef).float().mean().item()] 278 | 279 | mb_advantages = b_advantages[mb_inds] 280 | if args.norm_adv: 281 | mb_advantages = (mb_advantages - mb_advantages.mean()) / (mb_advantages.std() + 1e-8) 282 | 283 | # Policy loss 284 | pg_loss1 = -mb_advantages * ratio 285 | pg_loss2 = -mb_advantages * torch.clamp(ratio, 1 - args.clip_coef, 1 + args.clip_coef) 286 | pg_loss = torch.max(pg_loss1, pg_loss2).mean() 287 | 288 | # Value loss 289 | newvalue = newvalue.view(-1) 290 | if args.clip_vloss: 291 | v_loss_unclipped = (newvalue - b_returns[mb_inds]) ** 2 292 | v_clipped = b_values[mb_inds] + torch.clamp( 293 | newvalue - b_values[mb_inds], 294 | -args.clip_coef, 295 | args.clip_coef, 296 | ) 297 | v_loss_clipped = (v_clipped - b_returns[mb_inds]) ** 2 298 | v_loss_max = torch.max(v_loss_unclipped, v_loss_clipped) 299 | v_loss = 0.5 * v_loss_max.mean() 300 | else: 301 | v_loss = 0.5 * ((newvalue - b_returns[mb_inds]) ** 2).mean() 302 | 303 | entropy_loss = entropy.mean() 304 | loss = pg_loss - args.ent_coef * entropy_loss + v_loss * args.vf_coef 305 | 306 | optimizer.zero_grad() 307 | loss.backward() 308 | nn.utils.clip_grad_norm_(agent.parameters(), args.max_grad_norm) 309 | optimizer.step() 310 | 311 | if args.target_kl is not None: 312 | if approx_kl > args.target_kl: 313 | break 314 | 315 | y_pred, y_true = b_values.cpu().numpy(), b_returns.cpu().numpy() 316 | var_y = np.var(y_true) 317 | explained_var = np.nan if var_y == 0 else 1 - np.var(y_true - y_pred) / var_y 318 | 319 | # TRY NOT TO MODIFY: record rewards for plotting purposes 320 | writer.add_scalar("charts/learning_rate", optimizer.param_groups[0]["lr"], global_step) 321 | writer.add_scalar("losses/value_loss", v_loss.item(), global_step) 322 | writer.add_scalar("losses/policy_loss", pg_loss.item(), global_step) 323 | writer.add_scalar("losses/entropy", entropy_loss.item(), global_step) 324 | writer.add_scalar("losses/approx_kl", approx_kl.item(), global_step) 325 | writer.add_scalar("losses/clipfrac", np.mean(clipfracs), global_step) 326 | writer.add_scalar("losses/explained_variance", explained_var, global_step) 327 | print("SPS:", int(global_step / (time.time() - start_time))) 328 | writer.add_scalar("charts/SPS", int(global_step / (time.time() - start_time)), global_step) 329 | 330 | envs.close() 331 | writer.close() 332 | --------------------------------------------------------------------------------