├── .gitignore ├── LICENSE ├── README.md ├── environment.yml ├── gif ├── CartPole-v0.gif ├── CartPole-v1.gif └── LunarLander-v2.gif ├── parallel_PPO.py ├── print_custom.py └── test.py /.gitignore: -------------------------------------------------------------------------------- 1 | .vscode 2 | runs/ 3 | *.pyc 4 | __pycache__ 5 | archive/ 6 | *.pth 7 | *.csv 8 | *.TXT 9 | .ipynb_checkpoints 10 | original/ 11 | bug_test/ 12 | gif/*.jpg -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Nikhil Barhate 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Parallel PPO-PyTorch 2 | 3 | A parallel agent training version of Proximal Policy Optimization with clipped objective. 4 | 5 | ## Usage 6 | 7 | - To test a pre-trained network : run `test.py` 8 | - To train a new network : run `parallel_PPO.py` 9 | - All the hyperparameters are in the file, main function 10 | 11 | ## Results 12 | 13 | | CartPole-v1 | LunarLander-v2 | 14 | | :--------------------------------: | :---------------------------------: | 15 | | ![cartpole](./gif/CartPole-v1.gif) | ![lander](./gif/LunarLander-v2.gif) | 16 | 17 | ## Dependencies 18 | 19 | Trained and tested on: 20 | 21 | ``` 22 | Python 3.6 23 | PyTorch 1.3 24 | NumPy 1.15.3 25 | gym 0.10.8 26 | Pillow 5.3.0 27 | ``` 28 | 29 | ## TODO 30 | 31 | - [ ] implement Conv net based training 32 | 33 | ## Setting up Conda Environment 34 | 35 | - `conda env export | grep -v "^prefix: " > environment.yml` to export the file `environment.yml` 36 | - `conda create -f environment.yml` to create the conda environment used for training 37 | 38 | ## References 39 | 40 | - PPO [paper](https://arxiv.org/abs/1707.06347) 41 | - [PPO-PyTorch github](https://github.com/nikhilbarhate99/PPO-PyTorch) 42 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: ppo_drl 2 | channels: 3 | - pytorch 4 | - defaults 5 | dependencies: 6 | - _libgcc_mutex=0.1=main 7 | - blas=1.0=mkl 8 | - ca-certificates=2020.1.1=0 9 | - certifi=2019.11.28=py37_0 10 | - cudatoolkit=10.1.243=h6bb024c_0 11 | - entrypoints=0.3=py37_0 12 | - flake8=3.7.9=py37_0 13 | - freetype=2.9.1=h8a8886c_1 14 | - intel-openmp=2020.0=166 15 | - jpeg=9b=h024ee3a_2 16 | - ld_impl_linux-64=2.33.1=h53a641e_7 17 | - libedit=3.1.20181209=hc058e9b_0 18 | - libffi=3.2.1=hd88cf55_4 19 | - libgcc-ng=9.1.0=hdf63c60_0 20 | - libgfortran-ng=7.3.0=hdf63c60_0 21 | - libpng=1.6.37=hbc83047_0 22 | - libstdcxx-ng=9.1.0=hdf63c60_0 23 | - libtiff=4.1.0=h2733197_0 24 | - mccabe=0.6.1=py37_1 25 | - mkl=2020.0=166 26 | - mkl-service=2.3.0=py37he904b0f_0 27 | - mkl_fft=1.0.15=py37ha843d7b_0 28 | - mkl_random=1.1.0=py37hd6b4f25_0 29 | - ncurses=6.1=he6710b0_1 30 | - ninja=1.9.0=py37hfd86e86_0 31 | - numpy=1.18.1=py37h4f9e942_0 32 | - numpy-base=1.18.1=py37hde5b4d6_1 33 | - olefile=0.46=py37_0 34 | - openssl=1.1.1d=h7b6447c_3 35 | - pillow=7.0.0=py37hb39fc2d_0 36 | - pip=20.0.2=py37_1 37 | - pycodestyle=2.5.0=py37_0 38 | - pyflakes=2.1.1=py37_0 39 | - python=3.7.6=h0371630_2 40 | - pytorch=1.4.0=py3.7_cuda10.1.243_cudnn7.6.3_0 41 | - readline=7.0=h7b6447c_5 42 | - setuptools=45.1.0=py37_0 43 | - six=1.14.0=py37_0 44 | - sqlite=3.31.1=h7b6447c_0 45 | - tk=8.6.8=hbc83047_0 46 | - torchvision=0.5.0=py37_cu101 47 | - wheel=0.34.2=py37_0 48 | - xz=5.2.4=h14c3975_4 49 | - zlib=1.2.11=h7b6447c_3 50 | - zstd=1.3.7=h0b5b093_0 51 | - pip: 52 | - absl-py==0.9.0 53 | - atari-py==0.2.6 54 | - box2d-py==2.3.8 55 | - cachetools==4.0.0 56 | - cffi==1.13.2 57 | - chardet==3.0.4 58 | - cloudpickle==1.2.2 59 | - cython==0.29.14 60 | - future==0.18.2 61 | - glfw==1.10.1 62 | - google-auth==1.11.0 63 | - google-auth-oauthlib==0.4.1 64 | - grpcio==1.26.0 65 | - idna==2.8 66 | - imageio==2.6.1 67 | - lockfile==0.12.2 68 | - markdown==3.1.1 69 | - oauthlib==3.1.0 70 | - opencv-python==4.2.0.32 71 | - protobuf==3.11.3 72 | - pyasn1==0.4.8 73 | - pyasn1-modules==0.2.8 74 | - pycparser==2.19 75 | - pyglet==1.3.2 76 | - requests==2.22.0 77 | - requests-oauthlib==1.3.0 78 | - rsa==4.0 79 | - scipy==1.4.1 80 | - tensorboard==2.1.0 81 | - urllib3==1.25.8 82 | - werkzeug==1.0.0 83 | 84 | -------------------------------------------------------------------------------- /gif/CartPole-v0.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rhklite/Parallel-PPO-PyTorch/e6a9949ba02f25b6b5e0bb78b98abef15e6ff165/gif/CartPole-v0.gif -------------------------------------------------------------------------------- /gif/CartPole-v1.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rhklite/Parallel-PPO-PyTorch/e6a9949ba02f25b6b5e0bb78b98abef15e6ff165/gif/CartPole-v1.gif -------------------------------------------------------------------------------- /gif/LunarLander-v2.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rhklite/Parallel-PPO-PyTorch/e6a9949ba02f25b6b5e0bb78b98abef15e6ff165/gif/LunarLander-v2.gif -------------------------------------------------------------------------------- /parallel_PPO.py: -------------------------------------------------------------------------------- 1 | # TODO: implement batching 2 | # TODO: implement GAE 3 | # TODO: implement value clipping (check openAI baseline) 4 | # TODO: see if i need to do value averaging 5 | # FIXME: subprocess hangs when terminate due to max steps 6 | 7 | import os 8 | import gym 9 | import time 10 | import print_custom as db 11 | from datetime import date 12 | from datetime import datetime 13 | from collections import namedtuple 14 | import csv 15 | import torch 16 | import torch.nn as nn 17 | import torch.multiprocessing as mp 18 | from torch.distributions import Categorical 19 | from torch.utils.tensorboard import SummaryWriter 20 | 21 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 22 | # device = "cpu" 23 | tb_writer = SummaryWriter() 24 | mp.set_start_method('spawn', True) 25 | 26 | # creating msgs for communication between subprocess and main process. 27 | # for when agent reached logging episode 28 | MsgRewardInfo = namedtuple('MsgRewardInfo', ['agent', 'episode', 'reward']) 29 | # for when agent reached update timestep 30 | MsgUpdateRequest = namedtuple('MsgUpdateRequest', ['agent', 'update']) 31 | # for when agent reached max episodes 32 | MsgMaxReached = namedtuple('MsgMaxReached', ['agent', 'reached']) 33 | 34 | 35 | class Memory: 36 | def __init__(self, num_agents, update_timestep, state_dim, agent_policy): 37 | """a preallocated, shared memory space for each agents to pool the 38 | collected experience 39 | 40 | Args: 41 | num_agents (int): the number of agents that are running in parallel 42 | used for calculating size of allocated memory 43 | update_timestep (int): number of timesteps until update, also used 44 | for calculating size of allocated memory 45 | state_dim (int) : the size of the state observation 46 | agent_policy (object): a network that contains the policy that the 47 | agents will be acting on 48 | """ 49 | 50 | self.states = torch.zeros( 51 | (update_timestep*num_agents, state_dim)).to(device).share_memory_() 52 | self.actions = torch.zeros( 53 | update_timestep*num_agents).to(device).share_memory_() 54 | self.logprobs = torch.zeros( 55 | update_timestep*num_agents).to(device).share_memory_() 56 | self.disReturn = torch.zeros( 57 | update_timestep*num_agents).to(device).share_memory_() 58 | 59 | self.agent_policy = agent_policy 60 | 61 | 62 | class ActorCritic(nn.Module): 63 | def __init__(self, state_dim, action_dim, n_latent_var): 64 | super(ActorCritic, self).__init__() 65 | 66 | self.action_layer = nn.Sequential( 67 | nn.Linear(state_dim, n_latent_var), 68 | nn.Tanh(), 69 | nn.Linear(n_latent_var, n_latent_var), 70 | nn.Tanh(), 71 | nn.Linear(n_latent_var, action_dim), 72 | nn.Softmax(dim=-1) 73 | ) 74 | 75 | self.value_layer = nn.Sequential( 76 | nn.Linear(state_dim, n_latent_var), 77 | nn.Tanh(), 78 | nn.Linear(n_latent_var, n_latent_var), 79 | nn.Tanh(), 80 | nn.Linear(n_latent_var, 1) 81 | ) 82 | 83 | def forward(self): 84 | raise NotImplementedError 85 | 86 | def act(self, state, evaluate): 87 | """pass the state observed into action_layer network to determine the action 88 | that the agent should take. 89 | 90 | Args: 91 | state (list): a list contatining the state observations 92 | 93 | Return: action (int): a number that indicates the action to be taken 94 | for gym environment 95 | log_prob (tensor): a tensor that contains the log probability 96 | of the action taken. require_grad is true 97 | """ 98 | 99 | state = torch.from_numpy(state).float().to(device) 100 | action_probs = self.action_layer(state) 101 | dist = Categorical(action_probs) 102 | 103 | if evaluate: 104 | _, action = action_probs.max(0) 105 | else: 106 | action = dist.sample() 107 | 108 | return action.item(), dist.log_prob(action) 109 | 110 | def evaluate(self, state, action): 111 | action_probs = self.action_layer(state) 112 | dist = Categorical(action_probs) 113 | action_logprobs = dist.log_prob(action) 114 | dist_entropy = dist.entropy() 115 | state_value = self.value_layer(state) 116 | 117 | return action_logprobs, torch.squeeze(state_value), dist_entropy 118 | 119 | 120 | class PPO: 121 | def __init__(self, state_dim, action_dim, n_latent_var, lr, betas, gamma, 122 | K_epochs, eps_clip): 123 | self.lr = lr 124 | self.betas = betas 125 | self.gamma = gamma 126 | self.eps_clip = eps_clip 127 | self.K_epochs = K_epochs 128 | 129 | self.policy = ActorCritic( 130 | state_dim, 131 | action_dim, 132 | n_latent_var 133 | ).to(device) 134 | 135 | self.optimizer = torch.optim.Adam( 136 | self.policy.parameters(), 137 | lr=lr, 138 | betas=betas 139 | ) 140 | 141 | self.policy_old = ActorCritic( 142 | state_dim, 143 | action_dim, 144 | n_latent_var 145 | ).to(device).share_memory() 146 | 147 | self.policy_old.load_state_dict(self.policy.state_dict()) 148 | 149 | self.MseLoss = nn.MSELoss() 150 | 151 | def update(self, memory): 152 | old_states = memory.states.detach() 153 | old_actions = memory.actions.detach() 154 | old_logprobs = memory.logprobs.detach() 155 | old_disReturn = memory.disReturn.detach() 156 | 157 | if old_disReturn.std() == 0: 158 | old_disReturn = (old_disReturn - old_disReturn.mean()) / 1e-5 159 | else: 160 | old_disReturn = (old_disReturn - old_disReturn.mean()) / \ 161 | (old_disReturn.std()) 162 | 163 | # old_disReturn = (old_disReturn - old_disReturn.mean()) / \ 164 | # (old_disReturn.std()+1e-5) 165 | 166 | for epoch in range(self.K_epochs): 167 | # Evaluating old actions and values: 168 | logprobs, state_values, dist_entropy = self.policy.evaluate( 169 | old_states, old_actions) 170 | 171 | # Finding the ratio (pi_theta/ pi_theta_old): 172 | # using exponential returns the log back to non-log version 173 | ratios = torch.exp(logprobs - old_logprobs.detach()) 174 | 175 | # Finding the surrogate loss: 176 | 177 | advantages = old_disReturn - state_values.detach() 178 | surr1 = ratios * advantages 179 | surr2 = torch.clamp(ratios, 1-self.eps_clip, 180 | 1+self.eps_clip)*advantages 181 | 182 | # see paper for this loss formulation; this loss function 183 | # need to be used if the policy and value network shares 184 | # parameters, however, i think the author of this code just used 185 | # this, even though the two network are not sharing parameters 186 | loss = -torch.min(surr1, surr2) + 0.5 * \ 187 | self.MseLoss(state_values, old_disReturn) - 0.005*dist_entropy 188 | 189 | tb_writer.add_scalar("Loss/train", loss.mean(), epoch, time.time()) 190 | # take gradient step 191 | self.optimizer.zero_grad() 192 | loss.mean().backward() 193 | self.optimizer.step() 194 | 195 | # copy new weights into old policy: 196 | self.policy_old.load_state_dict(self.policy.state_dict()) 197 | 198 | 199 | class Agent(mp.Process): 200 | """creating a single agent, which contains the agent's gym environment 201 | and relevant information, such as its ID 202 | """ 203 | 204 | def __init__(self, name, memory, pipe, env_name, max_episode, max_timestep, 205 | update_timestep, log_interval, gamma, seed=None, render=False): 206 | """initialization 207 | 208 | Args: 209 | memory (object): shared memory object 210 | pipe (object): connection used to talk to the main process 211 | name (str): a number that represent the ith agent. Also used 212 | to determine the memory index for this agent to pool 213 | max_timestep (int): limit steps to this for each episode. Used 214 | for environment that does not have step limit 215 | update_timestep (int): step to take in the env before update policy 216 | """ 217 | mp.Process.__init__(self, name=name) 218 | 219 | # variables usef for multiprocessing 220 | self.proc_id = name 221 | self.memory = memory 222 | self.pipe = pipe 223 | 224 | # variables for training 225 | self.max_episode = max_episode 226 | self.max_timestep = max_timestep 227 | self.update_timestep = update_timestep 228 | self.log_interval = log_interval 229 | 230 | self.gamma = gamma 231 | self.render = render 232 | self.env = gym.make(env_name) 233 | self.env.reset() 234 | self.env.seed(seed) 235 | 236 | def run(self): 237 | print("Agent {} started, Process ID {}".format(self.name, os.getpid())) 238 | actions = [] 239 | rewards = [] 240 | states = [] 241 | logprobs = [] 242 | is_terminal = [] 243 | timestep = 0 244 | # lists to collect agent experience 245 | # variables for logging 246 | running_reward = 0 247 | 248 | for i_episodes in range(1, self.max_episode+2): 249 | state = self.env.reset() 250 | 251 | if i_episodes == self.max_episode+1: 252 | db.printInfo("Max episodes reached") 253 | msg = MsgMaxReached(self.proc_id, True) 254 | self.pipe.send(msg) 255 | break 256 | 257 | for i in range(self.max_timestep): 258 | 259 | timestep += 1 260 | 261 | states.append(state) 262 | 263 | with torch.no_grad(): 264 | action, logprob = self.memory.agent_policy.act(state, False) 265 | state, reward, done, _ = self.env.step(action) 266 | 267 | actions.append(action) 268 | logprobs.append(logprob) 269 | rewards.append(reward) 270 | is_terminal.append(done) 271 | 272 | running_reward += reward 273 | 274 | if timestep % self.update_timestep == 0: 275 | stateT, actionT, logprobT, disReturn = \ 276 | self.experience_to_tensor( 277 | states, actions, rewards, logprobs, is_terminal) 278 | 279 | self.add_experience_to_pool(stateT, actionT, 280 | logprobT, disReturn) 281 | 282 | msg = MsgUpdateRequest(int(self.proc_id), True) 283 | self.pipe.send(msg) 284 | msg = self.pipe.recv() 285 | if msg == "RENDER": 286 | self.render = True 287 | timestep = 0 288 | actions = [] 289 | rewards = [] 290 | states = [] 291 | logprobs = [] 292 | is_terminal = [] 293 | 294 | if done: 295 | break 296 | 297 | if self.render: 298 | time.sleep(0.005) 299 | self.env.render() 300 | 301 | if i_episodes % self.log_interval == 0: 302 | running_reward = running_reward/self.log_interval 303 | # db.printInfo("sending reward msg") 304 | msg = MsgRewardInfo(self.proc_id, i_episodes, running_reward) 305 | self.pipe.send(msg) 306 | running_reward = 0 307 | 308 | def experience_to_tensor(self, states, actions, rewards, 309 | logprobs, is_terminal): 310 | """converts the experience collected by the agent into tensors 311 | 312 | Args: 313 | states (list): a list of states visited by the agent 314 | actions (list): a list of actions that the agent took 315 | rewards (list): a list of reward that the agent recieved 316 | logprobs (list): a list of log probabiliy of the action happening 317 | is_terminal (list): for each step, indicate if that the agent is in 318 | the terminal state 319 | 320 | Return: 321 | stateTensor (tensor): the states converted to a 1D tensor 322 | actionTensor (tensor): the actions converted to a 1D tensor 323 | disReturnTensor (tensor): discounted return as a 1D tensor 324 | logprobTensor (tensor): the logprobs converted to a 1D tensor 325 | """ 326 | 327 | # convert state, action and log prob into tensor 328 | stateTensor = torch.tensor(states).float() 329 | actionTensor = torch.tensor(actions).float() 330 | logprobTensor = torch.tensor(logprobs).float().detach() 331 | 332 | # convert reward into discounted return 333 | discounted_reward = 0 334 | disReturnTensor = [] 335 | for reward, done in zip(reversed(rewards), 336 | reversed(is_terminal)): 337 | if done: 338 | discounted_reward = 0 339 | discounted_reward = reward + (self.gamma*discounted_reward) 340 | disReturnTensor.insert(0, discounted_reward) 341 | 342 | disReturnTensor = torch.tensor(disReturnTensor).float() 343 | 344 | return stateTensor, actionTensor, logprobTensor, disReturnTensor 345 | 346 | def add_experience_to_pool(self, stateTensor, actionTensor, 347 | logprobTensor, disReturnTensor): 348 | 349 | start_idx = int(self.name)*self.update_timestep 350 | end_idx = start_idx + self.update_timestep 351 | self.memory.states[start_idx:end_idx] = stateTensor 352 | self.memory.actions[start_idx:end_idx] = actionTensor 353 | self.memory.logprobs[start_idx:end_idx] = logprobTensor 354 | self.memory.disReturn[start_idx:end_idx] = disReturnTensor 355 | 356 | 357 | def main(): 358 | 359 | ###################################### 360 | # Training Environment configuration 361 | # env_name = "Reacher-v2" 362 | env_name = "LunarLander-v2" 363 | # env_name = "CartPole-v0" 364 | num_agents = 2 365 | max_timestep = 300 # per episode the agent is allowed to take 366 | update_timestep = 2000 # total number of steps to take before update 367 | max_episode = 50000 368 | seed = None # seeding the environment 369 | render = False 370 | solved_reward = 230 371 | log_interval = 100 372 | save_log_to_csv = True 373 | 374 | # gets the parameter about the environment 375 | sample_env = gym.make(env_name) 376 | state_dim = sample_env.observation_space.shape[0] 377 | action_dim = 4 378 | # action_dim = sample_env.action_space.n 379 | print("#################################") 380 | print(env_name) 381 | print("Number of Agents: {}".format(num_agents)) 382 | print("#################################\n") 383 | del sample_env 384 | 385 | # PPO & Network Parameters 386 | n_latent_var = 64 387 | lr = 0.002 388 | betas = (0.9, 0.999) 389 | gamma = 0.99 390 | K_epochs = 4 391 | eps_clip = 0.2 392 | ###################################### 393 | 394 | ppo = PPO(state_dim, action_dim, n_latent_var, 395 | lr, betas, gamma, K_epochs, eps_clip) 396 | 397 | # TODO verify if i should pass in ppo.policy_old 398 | memory = Memory(num_agents, update_timestep, state_dim, ppo.policy_old) 399 | 400 | # starting agents and pipes 401 | agents = [] 402 | pipes = [] 403 | 404 | # tracking subprocess request status 405 | update_request = [False]*num_agents 406 | agent_completed = [False]*num_agents 407 | 408 | # tracking training status 409 | update_iteration = 0 410 | log_iteration = 0 411 | average_eps_reward = 0 412 | reward_record = [[None]*num_agents] 413 | 414 | # initialize subproceses experience 415 | for agent_id in range(num_agents): 416 | p_start, p_end = mp.Pipe() 417 | agent = Agent(str(agent_id), memory, p_end, env_name, max_episode, 418 | max_timestep, update_timestep, log_interval, gamma) 419 | agent.start() 420 | agents.append(agent) 421 | pipes.append(p_start) 422 | 423 | # starting training loop 424 | while True: 425 | for i, conn in enumerate(pipes): 426 | if conn.poll(): 427 | msg = conn.recv() 428 | 429 | # parsing information recieved from subprocess 430 | 431 | # if agent reached maximum training episode limit 432 | if type(msg).__name__ == "MsgMaxReached": 433 | agent_completed[i] = True 434 | # if agent is waiting for network update 435 | elif type(msg).__name__ == "MsgUpdateRequest": 436 | update_request[i] = True 437 | if False not in update_request: 438 | ppo.update(memory) 439 | update_iteration += 1 440 | update_request = [False]*num_agents 441 | msg = update_iteration 442 | # send to signal subprocesses to continue 443 | for pipe in pipes: 444 | pipe.send(msg) 445 | # if agent is sending over reward stats 446 | elif type(msg).__name__ == "MsgRewardInfo": 447 | idx = int(msg.episode/log_interval) 448 | if len(reward_record) < idx: 449 | reward_record.append([None]*num_agents) 450 | reward_record[idx-1][i] = msg.reward 451 | 452 | # if all agents has sent msg for this logging iteration 453 | if (None not in reward_record[log_iteration]): 454 | eps_reward = reward_record[log_iteration] 455 | average_eps_reward = 0 456 | for i in range(len(eps_reward)): 457 | print("Agent {} Episode {}, Avg Reward/Episode {:.2f}" 458 | .format(i, (log_iteration+1)*log_interval, 459 | eps_reward[i])) 460 | average_eps_reward += eps_reward[i] 461 | 462 | tb_writer.add_scalar("Agent_{}_Episodic_Reward".format(i), eps_reward[i], (log_iteration+1)*log_interval, time.time()) 463 | print("Main: Update Iteration: {}, Avg Reward Amongst Agents: {:.2f}\n" 464 | .format(update_iteration, 465 | average_eps_reward/num_agents)) 466 | tb_writer.add_scalar("Avg_Agent_reward", average_eps_reward/num_agents, update_iteration, time.time()) 467 | log_iteration += 1 468 | 469 | if False not in agent_completed: 470 | print("=Training ended with Max Episodes=") 471 | break 472 | if solved_reward <= average_eps_reward/num_agents: 473 | print("==============SOLVED==============") 474 | break 475 | 476 | for agent in agents: 477 | agent.terminate() 478 | 479 | # saving training results 480 | today = date.today() 481 | file_name = './Parallel_PPO_{}_{}_{:.2f}_{}_{}' \ 482 | .format(env_name, num_agents, average_eps_reward/num_agents, 483 | (log_iteration+1)*log_interval, today) 484 | 485 | # # saving trained model weights 486 | torch.save(ppo.policy.state_dict(), file_name+'.pth') 487 | 488 | # # saving reward log to csv 489 | if save_log_to_csv: 490 | heading = [] 491 | for i in range(num_agents): 492 | heading.append("Agent {}".format(i)) 493 | reward_record.insert(0, heading) 494 | 495 | with open(file_name+'.csv', 'w', newline='') as myfile: 496 | wr = csv.writer(myfile, quoting=csv.QUOTE_ALL) 497 | for entry in reward_record: 498 | wr.writerow(entry) 499 | 500 | 501 | if __name__ == "__main__": 502 | 503 | start = time.perf_counter() 504 | main() 505 | end = time.perf_counter() 506 | print("Training Completed, {:.2f} sec elapsed".format(end-start)) 507 | -------------------------------------------------------------------------------- /print_custom.py: -------------------------------------------------------------------------------- 1 | """ 2 | File Description: custom print utilities for debug 3 | Project: python 4 | Author: Daniel Dworakowski 5 | Date: Nov-18-2019 6 | """ 7 | 8 | import gc 9 | import re 10 | import torch 11 | import inspect 12 | # 13 | # Color terminal (https://stackoverflow.com/questions/287871/print-in-terminal-with-colors-using-python). 14 | 15 | 16 | class Colours: 17 | HEADER = '\033[95m' 18 | OKBLUE = '\033[94m' 19 | OKGREEN = '\033[92m' 20 | WARNING = '\033[93m' 21 | FAIL = '\033[91m' 22 | ENDC = '\033[0m' 23 | BOLD = '\033[1m' 24 | UNDERLINE = '\033[4m' 25 | # 26 | # Error information. 27 | 28 | 29 | def lineInfo(): 30 | callerframerecord = inspect.stack()[2] 31 | frame = callerframerecord[0] 32 | info = inspect.getframeinfo(frame) 33 | file = info.filename 34 | file = file[file.rfind('/') + 1:] 35 | return '%s::%s:%d' % (file, info.function, info.lineno) 36 | # 37 | # Line information. 38 | 39 | 40 | def getLineInfo(leveloffset=0): 41 | level = 2 + leveloffset 42 | callerframerecord = inspect.stack()[level] 43 | frame = callerframerecord[0] 44 | info = inspect.getframeinfo(frame) 45 | file = info.filename 46 | file = file[file.rfind('/') + 1:] 47 | return '%s: %d' % (file, info.lineno) 48 | # 49 | # Colours a string. 50 | 51 | 52 | def colourString(msg, ctype): 53 | return ctype + msg + Colours.ENDC 54 | # 55 | # Print something in color. 56 | 57 | 58 | def printColour(msg, ctype): 59 | print(colourString(msg, ctype)) 60 | # 61 | # Print information. 62 | 63 | 64 | def printInfo(*umsg): 65 | msg = '%s: ' % (lineInfo()) 66 | lst = '' 67 | for mstr in umsg: 68 | if isinstance(mstr, torch.Tensor): 69 | vname = varname(mstr, 'printInfo') 70 | lst += '[' + str(vname) + ']\n' 71 | elif not isinstance(mstr, str): 72 | vname = varname(mstr, 'printInfo') 73 | lst += '[' + str(vname) + '] ' 74 | lst += str(mstr) + ' ' 75 | msg = colourString(msg, Colours.OKGREEN) + lst 76 | print(msg) 77 | # 78 | # Print error information. 79 | 80 | 81 | def printFrame(): 82 | print(lineInfo(), Colours.WARNING) 83 | # 84 | # Print an error. 85 | 86 | 87 | def printError(*errstr): 88 | msg = '%s: ' % (lineInfo()) 89 | lst = '' 90 | for mstr in errstr: 91 | lst += str(mstr) + ' ' 92 | msg = colourString(msg, Colours.FAIL) + lst 93 | print(msg) 94 | # 95 | # Print a warning. 96 | 97 | 98 | def printWarn(*warnstr): 99 | msg = '%s: ' % (lineInfo()) 100 | lst = '' 101 | for mstr in warnstr: 102 | lst += str(mstr) + ' ' 103 | msg = colourString(msg, Colours.WARNING) + lst 104 | print(msg) 105 | 106 | # 107 | # Get name of variable passed to the function 108 | 109 | 110 | def varname(p, ss='printTensor'): 111 | level = 2 + 0 112 | frame = inspect.stack()[level][0] 113 | for line in inspect.getframeinfo(frame).code_context: 114 | m = re.search(r'\b%s\s*\(\s*(.*)\s*\)' % ss, line) 115 | if m: 116 | return m.group(1) 117 | # 118 | # 119 | 120 | 121 | def printList(is_different, dlist): 122 | ret = '' 123 | if is_different: 124 | ret = dlist 125 | else: 126 | ret = [str(dlist[0])] 127 | return ret 128 | # 129 | # 130 | 131 | 132 | def getDevice(t): 133 | ret = None 134 | if isinstance(t, torch.Tensor): 135 | ret = t.device 136 | else: 137 | ret = type(t) 138 | return ret 139 | # 140 | # Get the s 141 | 142 | 143 | def tensorListInfo(tensor_list, vname, usrmsg, leveloffset): 144 | assert isinstance(tensor_list, list) or isinstance(tensor_list, tuple) 145 | str_ret = '' 146 | dtypes = [tensor_list[0].dtype] 147 | devices = [tensor_list[0].device] 148 | shapes = [tensor_list[0].shape] 149 | dtype_different = False 150 | devices_different = False 151 | shapes_different = False 152 | for t_idx in range(1, len(tensor_list)): 153 | t = tensor_list[t_idx] 154 | dtypes.append(t.dtype) 155 | devices.append(getDevice(t)) 156 | shapes.append(t.shape) 157 | dtype_different |= (t.dtype != dtypes[0]) 158 | devices_different |= (t.device != devices[0]) 159 | shapes_different |= (t.shape != shapes[0]) 160 | dtypes = printList(dtype_different or devices_different, dtypes) 161 | devices = printList(dtype_different or devices_different, devices) 162 | shapes = str(printList(shapes_different, shapes)) 163 | devices_dtypes = ' '.join(map(str, *zip(dtypes, devices))) 164 | msg = colourString(colourString(getLineInfo(leveloffset + 1), Colours.UNDERLINE), Colours.OKBLUE) + ': [' + str(vname) + '] ' + ('' if isinstance(tensor_list, list) else '')+' len: %d' % len( 165 | tensor_list) + ' (' + colourString(devices_dtypes, Colours.WARNING) + ') -- ' + colourString('%s' % shapes, Colours.OKGREEN) + (' ' if isinstance(tensor_list, list) else ' ') + usrmsg 166 | return msg 167 | # 168 | # Print information about a tensor. 169 | 170 | 171 | def printTensor(tensor, usrmsg='', leveloffset=0): 172 | vname = varname(tensor) 173 | if isinstance(tensor, list) or isinstance(tensor, tuple): 174 | msg = tensorListInfo(tensor, vname, usrmsg, leveloffset) 175 | elif isinstance(tensor, torch.Tensor): 176 | msg = colourString(colourString(getLineInfo(leveloffset), Colours.UNDERLINE), Colours.OKBLUE) + ': [' + str(vname) + '] (' + colourString( 177 | str(tensor.dtype) + ' ' + str(tensor.device), Colours.WARNING) + ') -- ' + colourString('%s' % str(tensor.shape), Colours.OKGREEN) + ' ' + colourString('%s' % str(tensor.grad_fn), Colours.OKGREEN)+' ' + usrmsg 178 | else: 179 | msg = colourString(colourString(getLineInfo(leveloffset), Colours.UNDERLINE), Colours.OKBLUE) + ': [' + str(vname) + '] (' + colourString(str( 180 | tensor.dtype) + ' ' + str(getDevice(tensor)), Colours.WARNING) + ') -- ' + colourString('%s' % str(tensor.shape), Colours.OKGREEN) + ' ' + usrmsg 181 | print(msg) 182 | # 183 | # Print debugging information. 184 | 185 | 186 | def dprint(usrmsg, leveloffset=0): 187 | msg = colourString(colourString(getLineInfo(leveloffset), 188 | Colours.UNDERLINE), Colours.OKBLUE) + ': ' + str(usrmsg) 189 | print(msg) 190 | 191 | 192 | def hasNAN(t): 193 | msg = colourString(colourString(getLineInfo(), Colours.UNDERLINE), Colours.OKBLUE) + \ 194 | ': ' + \ 195 | colourString(str('Tensor has %s NaNs' % 196 | str((t != t).sum().item())), Colours.FAIL) 197 | print(msg) 198 | 199 | 200 | def torch_mem(): 201 | dprint('Torch report: Allocated: %.2f MBytes Cached: %.2f' % ( 202 | torch.cuda.memory_allocated() / (1024 ** 2), torch.cuda.memory_cached() / (1024 ** 2)), 1) 203 | # MEM utils 204 | 205 | 206 | def mem_report(): 207 | '''Report the memory usage of the tensor.storage in pytorch 208 | Both on CPUs and GPUs are reported 209 | https://gist.github.com/Stonesjtu/368ddf5d9eb56669269ecdf9b0d21cbe''' 210 | def _mem_report(tensors, mem_type): 211 | '''Print the selected tensors of type 212 | There are two major storage types in our major concern: 213 | - GPU: tensors transferred to CUDA devices 214 | - CPU: tensors remaining on the system memory (usually unimportant) 215 | Args: 216 | - tensors: the tensors of specified type 217 | - mem_type: 'CPU' or 'GPU' in current implementation ''' 218 | print('Storage on %s' % (mem_type)) 219 | print('-' * LEN) 220 | total_numel = 0 221 | total_mem = 0 222 | visited_data = [] 223 | for idx, tensor in enumerate(tensors): 224 | if tensor.is_sparse: 225 | continue 226 | # a data_ptr indicates a memory block allocated 227 | data_ptr = tensor.storage().data_ptr() 228 | if data_ptr in visited_data: 229 | continue 230 | visited_data.append(data_ptr) 231 | numel = tensor.storage().size() 232 | total_numel += numel 233 | element_size = tensor.storage().element_size() 234 | mem = numel * element_size / 1024 / 1024 # 32bit=4Byte, MByte 235 | total_mem += mem 236 | element_type = type(tensor).__name__ 237 | size = tuple(tensor.size()) 238 | print('{:3} {}\t\t{}\t\t{:.2f}\t\t{}'.format( 239 | idx, element_type, size, mem, tensor.grad_fn)) 240 | print('-' * LEN) 241 | print('Total Tensors: %d \tUsed Memory Space: %.5f MBytes' % 242 | (total_numel, total_mem)) 243 | print('Torch report: %.2f MBytes' % 244 | (torch.cuda.memory_allocated() / (1024 ** 2))) 245 | print('-' * LEN) 246 | LEN = 65 247 | print('=' * LEN) 248 | gc.collect() 249 | objects = gc.get_objects() 250 | print('%s\t%s\t\t\t%s' % ('Element type', 'Size', 'Used MEM(MBytes)')) 251 | tensors = [obj for obj in objects if torch.is_tensor(obj)] 252 | cuda_tensors = [t for t in tensors if t.is_cuda] 253 | host_tensors = [t for t in tensors if not t.is_cuda] 254 | _mem_report(cuda_tensors, 'GPU') 255 | _mem_report(host_tensors, 'CPU') 256 | print('=' * LEN) 257 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import gym 2 | from parallel_PPO import PPO, Memory 3 | from PIL import Image 4 | import torch 5 | import numpy as np 6 | import time 7 | import os 8 | 9 | 10 | def test(): 11 | ############## Hyperparameters ############## 12 | # env_name = "Reacher-v2" 13 | # env_name = "Acrobot-v1" 14 | env_name = "LunarLander-v2" 15 | # env_name = "CartPole-v0" 16 | # creating environment 17 | env = gym.make(env_name) 18 | state_dim = env.observation_space.shape[0] 19 | action_dim = 4 20 | render = False 21 | max_timesteps = 500 22 | n_latent_var = 64 # number of variables in hidden layer 23 | lr = 0.0007 24 | betas = (0.9, 0.999) 25 | gamma = 0.99 # discount factor 26 | K_epochs = 4 # update policy for K epochs 27 | eps_clip = 0.2 # clip parameter for PPO 28 | ############################################# 29 | 30 | n_episodes = 10 31 | max_timesteps = 500 32 | render = True 33 | save_gif = False 34 | 35 | # filename = "parallel_v3_PPO_CartPole-v0.pth" 36 | filename = "Parallel_PPO_LunarLander-v2_4_241.92_900_2019-12-01.pth" 37 | directory = "./" 38 | 39 | # filename = "v3_ReLU_PPO_LunarLander-v2_1_232.93_2019-11-20.pth" 40 | # directory = "./bug_test/test/ReLU/" 41 | 42 | ppo = PPO(state_dim, action_dim, n_latent_var, 43 | lr, betas, gamma, K_epochs, eps_clip) 44 | 45 | ppo.policy_old.load_state_dict(torch.load(directory+filename)) 46 | average_reward = [] 47 | for ep in range(1, n_episodes+1): 48 | ep_reward = 0 49 | state = env.reset() 50 | for t in range(max_timesteps): 51 | action, _ = ppo.policy_old.act(state, True) 52 | state, reward, done, _ = env.step(action) 53 | ep_reward += reward 54 | if render: 55 | env.render() 56 | if save_gif: 57 | img = env.render(mode='rgb_array') 58 | img = Image.fromarray(img) 59 | img.save('./gif/{}.jpg'.format(t)) 60 | if done: 61 | break 62 | if render: 63 | print('Episode: {}\tReward: {}'.format(ep, int(ep_reward))) 64 | average_reward.append(ep_reward) 65 | ep_reward = 0 66 | env.close() 67 | 68 | print("Tested {} Episode, Average Reward {:.2f}, Std {:.2f}".format( 69 | n_episodes, np.average(average_reward), np.std(average_reward))) 70 | 71 | if save_gif: 72 | os.system( 73 | "ffmpeg -f image2 -i ./gif/%d.jpg -r 300 ./gif/{}.gif -y".format(env_name)) 74 | os.system("rm ./gif/*.jpg") 75 | 76 | 77 | if __name__ == '__main__': 78 | test() 79 | --------------------------------------------------------------------------------