├── .gitignore ├── LICENSE ├── README.rst ├── ddpg.py ├── evaluator.py ├── main.py ├── memory.py ├── model.py ├── normalized_env.py ├── output ├── MountainCarContinuous-v0-run0 │ ├── actor.pkl │ ├── critic.pkl │ ├── validate_reward.mat │ └── validate_reward.png └── Pendulum-v0-run0 │ ├── actor.pkl │ ├── critic.pkl │ ├── validate_reward.mat │ └── validate_reward.png ├── random_process.py └── util.py /.gitignore: -------------------------------------------------------------------------------- 1 | output/* 2 | 3 | # Byte-compiled / optimized / DLL files 4 | __pycache__/ 5 | *.py[cod] 6 | *$py.class 7 | 8 | # C extensions 9 | *.so 10 | 11 | # Distribution / packaging 12 | .Python 13 | env/ 14 | build/ 15 | develop-eggs/ 16 | dist/ 17 | downloads/ 18 | eggs/ 19 | .eggs/ 20 | lib/ 21 | lib64/ 22 | parts/ 23 | sdist/ 24 | var/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .coverage 43 | .coverage.* 44 | .cache 45 | nosetests.xml 46 | coverage.xml 47 | *,cover 48 | .hypothesis/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | 58 | # Flask stuff: 59 | instance/ 60 | .webassets-cache 61 | 62 | # Scrapy stuff: 63 | .scrapy 64 | 65 | # Sphinx documentation 66 | docs/_build/ 67 | 68 | # PyBuilder 69 | target/ 70 | 71 | # IPython Notebook 72 | .ipynb_checkpoints 73 | 74 | # pyenv 75 | .python-version 76 | 77 | # celery beat schedule file 78 | celerybeat-schedule 79 | 80 | # dotenv 81 | .env 82 | 83 | # virtualenv 84 | venv/ 85 | ENV/ 86 | 87 | # Spyder project settings 88 | .spyderproject 89 | 90 | # Rope project settings 91 | .ropeproject 92 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "{}" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright {yyyy} {name of copyright owner} 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.rst: -------------------------------------------------------------------------------- 1 | ====== 2 | Deep Deterministic Policy Gradient on PyTorch 3 | ====== 4 | 5 | Overview 6 | ====== 7 | The is the implementation of `Deep Deterministic Policy Gradient `_ (DDPG) using `PyTorch `_. Part of the utilities functions such as replay buffer and random process are from `keras-rl `_ repo. Contributes are very welcome. 8 | 9 | Dependencies 10 | ====== 11 | * Python 3.4 12 | * PyTorch 0.1.9 13 | * `OpenAI Gym `_ 14 | 15 | Run 16 | ====== 17 | * Training : results of two environment and their training curves: 18 | 19 | * Pendulum-v0 20 | 21 | .. code-block:: console 22 | 23 | $ ./main.py --debug 24 | 25 | .. image:: output/Pendulum-v0-run0/validate_reward.png 26 | :width: 800px 27 | :align: left 28 | :height: 600px 29 | :alt: alternate text 30 | 31 | * MountainCarContinuous-v0 32 | 33 | .. code-block:: console 34 | 35 | $ ./main.py --env MountainCarContinuous-v0 --validate_episodes 100 --max_episode_length 2500 --ou_sigma 0.5 --debug 36 | 37 | .. image:: output/MountainCarContinuous-v0-run0/validate_reward.png 38 | :width: 800px 39 | :align: left 40 | :height: 600px 41 | :alt: alternate text 42 | 43 | * Testing : 44 | 45 | .. code-block:: console 46 | 47 | $ ./main.py --mode test --debug 48 | 49 | TODO 50 | ====== 51 | 52 | -------------------------------------------------------------------------------- /ddpg.py: -------------------------------------------------------------------------------- 1 | 2 | import numpy as np 3 | 4 | import torch 5 | import torch.nn as nn 6 | from torch.optim import Adam 7 | 8 | from model import (Actor, Critic) 9 | from memory import SequentialMemory 10 | from random_process import OrnsteinUhlenbeckProcess 11 | from util import * 12 | 13 | # from ipdb import set_trace as debug 14 | 15 | criterion = nn.MSELoss() 16 | 17 | class DDPG(object): 18 | def __init__(self, nb_states, nb_actions, args): 19 | 20 | if args.seed > 0: 21 | self.seed(args.seed) 22 | 23 | self.nb_states = nb_states 24 | self.nb_actions= nb_actions 25 | 26 | # Create Actor and Critic Network 27 | net_cfg = { 28 | 'hidden1':args.hidden1, 29 | 'hidden2':args.hidden2, 30 | 'init_w':args.init_w 31 | } 32 | self.actor = Actor(self.nb_states, self.nb_actions, **net_cfg) 33 | self.actor_target = Actor(self.nb_states, self.nb_actions, **net_cfg) 34 | self.actor_optim = Adam(self.actor.parameters(), lr=args.prate) 35 | 36 | self.critic = Critic(self.nb_states, self.nb_actions, **net_cfg) 37 | self.critic_target = Critic(self.nb_states, self.nb_actions, **net_cfg) 38 | self.critic_optim = Adam(self.critic.parameters(), lr=args.rate) 39 | 40 | hard_update(self.actor_target, self.actor) # Make sure target is with the same weight 41 | hard_update(self.critic_target, self.critic) 42 | 43 | #Create replay buffer 44 | self.memory = SequentialMemory(limit=args.rmsize, window_length=args.window_length) 45 | self.random_process = OrnsteinUhlenbeckProcess(size=nb_actions, theta=args.ou_theta, mu=args.ou_mu, sigma=args.ou_sigma) 46 | 47 | # Hyper-parameters 48 | self.batch_size = args.bsize 49 | self.tau = args.tau 50 | self.discount = args.discount 51 | self.depsilon = 1.0 / args.epsilon 52 | 53 | # 54 | self.epsilon = 1.0 55 | self.s_t = None # Most recent state 56 | self.a_t = None # Most recent action 57 | self.is_training = True 58 | 59 | # 60 | if USE_CUDA: self.cuda() 61 | 62 | def update_policy(self): 63 | # Sample batch 64 | state_batch, action_batch, reward_batch, \ 65 | next_state_batch, terminal_batch = self.memory.sample_and_split(self.batch_size) 66 | 67 | # Prepare for the target q batch 68 | next_q_values = self.critic_target([ 69 | to_tensor(next_state_batch, volatile=True), 70 | self.actor_target(to_tensor(next_state_batch, volatile=True)), 71 | ]) 72 | next_q_values.volatile=False 73 | 74 | target_q_batch = to_tensor(reward_batch) + \ 75 | self.discount*to_tensor(terminal_batch.astype(np.float))*next_q_values 76 | 77 | # Critic update 78 | self.critic.zero_grad() 79 | 80 | q_batch = self.critic([ to_tensor(state_batch), to_tensor(action_batch) ]) 81 | 82 | value_loss = criterion(q_batch, target_q_batch) 83 | value_loss.backward() 84 | self.critic_optim.step() 85 | 86 | # Actor update 87 | self.actor.zero_grad() 88 | 89 | policy_loss = -self.critic([ 90 | to_tensor(state_batch), 91 | self.actor(to_tensor(state_batch)) 92 | ]) 93 | 94 | policy_loss = policy_loss.mean() 95 | policy_loss.backward() 96 | self.actor_optim.step() 97 | 98 | # Target update 99 | soft_update(self.actor_target, self.actor, self.tau) 100 | soft_update(self.critic_target, self.critic, self.tau) 101 | 102 | def eval(self): 103 | self.actor.eval() 104 | self.actor_target.eval() 105 | self.critic.eval() 106 | self.critic_target.eval() 107 | 108 | def cuda(self): 109 | self.actor.cuda() 110 | self.actor_target.cuda() 111 | self.critic.cuda() 112 | self.critic_target.cuda() 113 | 114 | def observe(self, r_t, s_t1, done): 115 | if self.is_training: 116 | self.memory.append(self.s_t, self.a_t, r_t, done) 117 | self.s_t = s_t1 118 | 119 | def random_action(self): 120 | action = np.random.uniform(-1.,1.,self.nb_actions) 121 | self.a_t = action 122 | return action 123 | 124 | def select_action(self, s_t, decay_epsilon=True): 125 | action = to_numpy( 126 | self.actor(to_tensor(np.array([s_t]))) 127 | ).squeeze(0) 128 | action += self.is_training*max(self.epsilon, 0)*self.random_process.sample() 129 | action = np.clip(action, -1., 1.) 130 | 131 | if decay_epsilon: 132 | self.epsilon -= self.depsilon 133 | 134 | self.a_t = action 135 | return action 136 | 137 | def reset(self, obs): 138 | self.s_t = obs 139 | self.random_process.reset_states() 140 | 141 | def load_weights(self, output): 142 | if output is None: return 143 | 144 | self.actor.load_state_dict( 145 | torch.load('{}/actor.pkl'.format(output)) 146 | ) 147 | 148 | self.critic.load_state_dict( 149 | torch.load('{}/critic.pkl'.format(output)) 150 | ) 151 | 152 | 153 | def save_model(self,output): 154 | torch.save( 155 | self.actor.state_dict(), 156 | '{}/actor.pkl'.format(output) 157 | ) 158 | torch.save( 159 | self.critic.state_dict(), 160 | '{}/critic.pkl'.format(output) 161 | ) 162 | 163 | def seed(self,s): 164 | torch.manual_seed(s) 165 | if USE_CUDA: 166 | torch.cuda.manual_seed(s) 167 | -------------------------------------------------------------------------------- /evaluator.py: -------------------------------------------------------------------------------- 1 | 2 | import numpy as np 3 | import matplotlib.pyplot as plt 4 | from scipy.io import savemat 5 | 6 | from util import * 7 | 8 | class Evaluator(object): 9 | 10 | def __init__(self, num_episodes, interval, save_path='', max_episode_length=None): 11 | self.num_episodes = num_episodes 12 | self.max_episode_length = max_episode_length 13 | self.interval = interval 14 | self.save_path = save_path 15 | self.results = np.array([]).reshape(num_episodes,0) 16 | 17 | def __call__(self, env, policy, debug=False, visualize=False, save=True): 18 | 19 | self.is_training = False 20 | observation = None 21 | result = [] 22 | 23 | for episode in range(self.num_episodes): 24 | 25 | # reset at the start of episode 26 | observation = env.reset() 27 | episode_steps = 0 28 | episode_reward = 0. 29 | 30 | assert observation is not None 31 | 32 | # start episode 33 | done = False 34 | while not done: 35 | # basic operation, action ,reward, blablabla ... 36 | action = policy(observation) 37 | 38 | observation, reward, done, info = env.step(action) 39 | if self.max_episode_length and episode_steps >= self.max_episode_length -1: 40 | done = True 41 | 42 | if visualize: 43 | env.render(mode='human') 44 | 45 | # update 46 | episode_reward += reward 47 | episode_steps += 1 48 | 49 | if debug: prYellow('[Evaluate] #Episode{}: episode_reward:{}'.format(episode,episode_reward)) 50 | result.append(episode_reward) 51 | 52 | result = np.array(result).reshape(-1,1) 53 | self.results = np.hstack([self.results, result]) 54 | 55 | if save: 56 | self.save_results('{}/validate_reward'.format(self.save_path)) 57 | return np.mean(result) 58 | 59 | def save_results(self, fn): 60 | 61 | y = np.mean(self.results, axis=0) 62 | error=np.std(self.results, axis=0) 63 | 64 | x = range(0,self.results.shape[1]*self.interval,self.interval) 65 | fig, ax = plt.subplots(1, 1, figsize=(6, 5)) 66 | plt.xlabel('Timestep') 67 | plt.ylabel('Average Reward') 68 | ax.errorbar(x, y, yerr=error, fmt='-o') 69 | plt.savefig(fn+'.png') 70 | savemat(fn+'.mat', {'reward':self.results}) -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import numpy as np 4 | import argparse 5 | from copy import deepcopy 6 | import torch 7 | import gym 8 | 9 | from normalized_env import NormalizedEnv 10 | from evaluator import Evaluator 11 | from ddpg import DDPG 12 | from util import * 13 | 14 | gym.undo_logger_setup() 15 | 16 | def train(num_iterations, agent, env, evaluate, validate_steps, output, max_episode_length=None, debug=False): 17 | 18 | agent.is_training = True 19 | step = episode = episode_steps = 0 20 | episode_reward = 0. 21 | observation = None 22 | while step < num_iterations: 23 | # reset if it is the start of episode 24 | if observation is None: 25 | observation = deepcopy(env.reset()) 26 | agent.reset(observation) 27 | 28 | # agent pick action ... 29 | if step <= args.warmup: 30 | action = agent.random_action() 31 | else: 32 | action = agent.select_action(observation) 33 | 34 | # env response with next_observation, reward, terminate_info 35 | observation2, reward, done, info = env.step(action) 36 | observation2 = deepcopy(observation2) 37 | if max_episode_length and episode_steps >= max_episode_length -1: 38 | done = True 39 | 40 | # agent observe and update policy 41 | agent.observe(reward, observation2, done) 42 | if step > args.warmup : 43 | agent.update_policy() 44 | 45 | # [optional] evaluate 46 | if evaluate is not None and validate_steps > 0 and step % validate_steps == 0: 47 | policy = lambda x: agent.select_action(x, decay_epsilon=False) 48 | validate_reward = evaluate(env, policy, debug=False, visualize=False) 49 | if debug: prYellow('[Evaluate] Step_{:07d}: mean_reward:{}'.format(step, validate_reward)) 50 | 51 | # [optional] save intermideate model 52 | if step % int(num_iterations/3) == 0: 53 | agent.save_model(output) 54 | 55 | # update 56 | step += 1 57 | episode_steps += 1 58 | episode_reward += reward 59 | observation = deepcopy(observation2) 60 | 61 | if done: # end of episode 62 | if debug: prGreen('#{}: episode_reward:{} steps:{}'.format(episode,episode_reward,step)) 63 | 64 | agent.memory.append( 65 | observation, 66 | agent.select_action(observation), 67 | 0., False 68 | ) 69 | 70 | # reset 71 | observation = None 72 | episode_steps = 0 73 | episode_reward = 0. 74 | episode += 1 75 | 76 | def test(num_episodes, agent, env, evaluate, model_path, visualize=True, debug=False): 77 | 78 | agent.load_weights(model_path) 79 | agent.is_training = False 80 | agent.eval() 81 | policy = lambda x: agent.select_action(x, decay_epsilon=False) 82 | 83 | for i in range(num_episodes): 84 | validate_reward = evaluate(env, policy, debug=debug, visualize=visualize, save=False) 85 | if debug: prYellow('[Evaluate] #{}: mean_reward:{}'.format(i, validate_reward)) 86 | 87 | 88 | if __name__ == "__main__": 89 | 90 | parser = argparse.ArgumentParser(description='PyTorch on TORCS with Multi-modal') 91 | 92 | parser.add_argument('--mode', default='train', type=str, help='support option: train/test') 93 | parser.add_argument('--env', default='Pendulum-v0', type=str, help='open-ai gym environment') 94 | parser.add_argument('--hidden1', default=400, type=int, help='hidden num of first fully connect layer') 95 | parser.add_argument('--hidden2', default=300, type=int, help='hidden num of second fully connect layer') 96 | parser.add_argument('--rate', default=0.001, type=float, help='learning rate') 97 | parser.add_argument('--prate', default=0.0001, type=float, help='policy net learning rate (only for DDPG)') 98 | parser.add_argument('--warmup', default=100, type=int, help='time without training but only filling the replay memory') 99 | parser.add_argument('--discount', default=0.99, type=float, help='') 100 | parser.add_argument('--bsize', default=64, type=int, help='minibatch size') 101 | parser.add_argument('--rmsize', default=6000000, type=int, help='memory size') 102 | parser.add_argument('--window_length', default=1, type=int, help='') 103 | parser.add_argument('--tau', default=0.001, type=float, help='moving average for target network') 104 | parser.add_argument('--ou_theta', default=0.15, type=float, help='noise theta') 105 | parser.add_argument('--ou_sigma', default=0.2, type=float, help='noise sigma') 106 | parser.add_argument('--ou_mu', default=0.0, type=float, help='noise mu') 107 | parser.add_argument('--validate_episodes', default=20, type=int, help='how many episode to perform during validate experiment') 108 | parser.add_argument('--max_episode_length', default=500, type=int, help='') 109 | parser.add_argument('--validate_steps', default=2000, type=int, help='how many steps to perform a validate experiment') 110 | parser.add_argument('--output', default='output', type=str, help='') 111 | parser.add_argument('--debug', dest='debug', action='store_true') 112 | parser.add_argument('--init_w', default=0.003, type=float, help='') 113 | parser.add_argument('--train_iter', default=200000, type=int, help='train iters each timestep') 114 | parser.add_argument('--epsilon', default=50000, type=int, help='linear decay of exploration policy') 115 | parser.add_argument('--seed', default=-1, type=int, help='') 116 | parser.add_argument('--resume', default='default', type=str, help='Resuming model path for testing') 117 | # parser.add_argument('--l2norm', default=0.01, type=float, help='l2 weight decay') # TODO 118 | # parser.add_argument('--cuda', dest='cuda', action='store_true') # TODO 119 | 120 | args = parser.parse_args() 121 | args.output = get_output_folder(args.output, args.env) 122 | if args.resume == 'default': 123 | args.resume = 'output/{}-run0'.format(args.env) 124 | 125 | env = NormalizedEnv(gym.make(args.env)) 126 | 127 | if args.seed > 0: 128 | np.random.seed(args.seed) 129 | env.seed(args.seed) 130 | 131 | nb_states = env.observation_space.shape[0] 132 | nb_actions = env.action_space.shape[0] 133 | 134 | 135 | agent = DDPG(nb_states, nb_actions, args) 136 | evaluate = Evaluator(args.validate_episodes, 137 | args.validate_steps, args.output, max_episode_length=args.max_episode_length) 138 | 139 | if args.mode == 'train': 140 | train(args.train_iter, agent, env, evaluate, 141 | args.validate_steps, args.output, max_episode_length=args.max_episode_length, debug=args.debug) 142 | 143 | elif args.mode == 'test': 144 | test(args.validate_episodes, agent, env, evaluate, args.resume, 145 | visualize=True, debug=args.debug) 146 | 147 | else: 148 | raise RuntimeError('undefined mode {}'.format(args.mode)) 149 | -------------------------------------------------------------------------------- /memory.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from collections import deque, namedtuple 3 | import warnings 4 | import random 5 | 6 | import numpy as np 7 | 8 | # [reference] https://github.com/matthiasplappert/keras-rl/blob/master/rl/memory.py 9 | 10 | # This is to be understood as a transition: Given `state0`, performing `action` 11 | # yields `reward` and results in `state1`, which might be `terminal`. 12 | Experience = namedtuple('Experience', 'state0, action, reward, state1, terminal1') 13 | 14 | 15 | def sample_batch_indexes(low, high, size): 16 | if high - low >= size: 17 | # We have enough data. Draw without replacement, that is each index is unique in the 18 | # batch. We cannot use `np.random.choice` here because it is horribly inefficient as 19 | # the memory grows. See https://github.com/numpy/numpy/issues/2764 for a discussion. 20 | # `random.sample` does the same thing (drawing without replacement) and is way faster. 21 | try: 22 | r = xrange(low, high) 23 | except NameError: 24 | r = range(low, high) 25 | batch_idxs = random.sample(r, size) 26 | else: 27 | # Not enough data. Help ourselves with sampling from the range, but the same index 28 | # can occur multiple times. This is not good and should be avoided by picking a 29 | # large enough warm-up phase. 30 | warnings.warn('Not enough entries to sample without replacement. Consider increasing your warm-up phase to avoid oversampling!') 31 | batch_idxs = np.random.random_integers(low, high - 1, size=size) 32 | assert len(batch_idxs) == size 33 | return batch_idxs 34 | 35 | 36 | class RingBuffer(object): 37 | def __init__(self, maxlen): 38 | self.maxlen = maxlen 39 | self.start = 0 40 | self.length = 0 41 | self.data = [None for _ in range(maxlen)] 42 | 43 | def __len__(self): 44 | return self.length 45 | 46 | def __getitem__(self, idx): 47 | if idx < 0 or idx >= self.length: 48 | raise KeyError() 49 | return self.data[(self.start + idx) % self.maxlen] 50 | 51 | def append(self, v): 52 | if self.length < self.maxlen: 53 | # We have space, simply increase the length. 54 | self.length += 1 55 | elif self.length == self.maxlen: 56 | # No space, "remove" the first item. 57 | self.start = (self.start + 1) % self.maxlen 58 | else: 59 | # This should never happen. 60 | raise RuntimeError() 61 | self.data[(self.start + self.length - 1) % self.maxlen] = v 62 | 63 | 64 | def zeroed_observation(observation): 65 | if hasattr(observation, 'shape'): 66 | return np.zeros(observation.shape) 67 | elif hasattr(observation, '__iter__'): 68 | out = [] 69 | for x in observation: 70 | out.append(zeroed_observation(x)) 71 | return out 72 | else: 73 | return 0. 74 | 75 | 76 | class Memory(object): 77 | def __init__(self, window_length, ignore_episode_boundaries=False): 78 | self.window_length = window_length 79 | self.ignore_episode_boundaries = ignore_episode_boundaries 80 | 81 | self.recent_observations = deque(maxlen=window_length) 82 | self.recent_terminals = deque(maxlen=window_length) 83 | 84 | def sample(self, batch_size, batch_idxs=None): 85 | raise NotImplementedError() 86 | 87 | def append(self, observation, action, reward, terminal, training=True): 88 | self.recent_observations.append(observation) 89 | self.recent_terminals.append(terminal) 90 | 91 | def get_recent_state(self, current_observation): 92 | # This code is slightly complicated by the fact that subsequent observations might be 93 | # from different episodes. We ensure that an experience never spans multiple episodes. 94 | # This is probably not that important in practice but it seems cleaner. 95 | state = [current_observation] 96 | idx = len(self.recent_observations) - 1 97 | for offset in range(0, self.window_length - 1): 98 | current_idx = idx - offset 99 | current_terminal = self.recent_terminals[current_idx - 1] if current_idx - 1 >= 0 else False 100 | if current_idx < 0 or (not self.ignore_episode_boundaries and current_terminal): 101 | # The previously handled observation was terminal, don't add the current one. 102 | # Otherwise we would leak into a different episode. 103 | break 104 | state.insert(0, self.recent_observations[current_idx]) 105 | while len(state) < self.window_length: 106 | state.insert(0, zeroed_observation(state[0])) 107 | return state 108 | 109 | def get_config(self): 110 | config = { 111 | 'window_length': self.window_length, 112 | 'ignore_episode_boundaries': self.ignore_episode_boundaries, 113 | } 114 | return config 115 | 116 | class SequentialMemory(Memory): 117 | def __init__(self, limit, **kwargs): 118 | super(SequentialMemory, self).__init__(**kwargs) 119 | 120 | self.limit = limit 121 | 122 | # Do not use deque to implement the memory. This data structure may seem convenient but 123 | # it is way too slow on random access. Instead, we use our own ring buffer implementation. 124 | self.actions = RingBuffer(limit) 125 | self.rewards = RingBuffer(limit) 126 | self.terminals = RingBuffer(limit) 127 | self.observations = RingBuffer(limit) 128 | 129 | def sample(self, batch_size, batch_idxs=None): 130 | if batch_idxs is None: 131 | # Draw random indexes such that we have at least a single entry before each 132 | # index. 133 | batch_idxs = sample_batch_indexes(0, self.nb_entries - 1, size=batch_size) 134 | batch_idxs = np.array(batch_idxs) + 1 135 | assert np.min(batch_idxs) >= 1 136 | assert np.max(batch_idxs) < self.nb_entries 137 | assert len(batch_idxs) == batch_size 138 | 139 | # Create experiences 140 | experiences = [] 141 | for idx in batch_idxs: 142 | terminal0 = self.terminals[idx - 2] if idx >= 2 else False 143 | while terminal0: 144 | # Skip this transition because the environment was reset here. Select a new, random 145 | # transition and use this instead. This may cause the batch to contain the same 146 | # transition twice. 147 | idx = sample_batch_indexes(1, self.nb_entries, size=1)[0] 148 | terminal0 = self.terminals[idx - 2] if idx >= 2 else False 149 | assert 1 <= idx < self.nb_entries 150 | 151 | # This code is slightly complicated by the fact that subsequent observations might be 152 | # from different episodes. We ensure that an experience never spans multiple episodes. 153 | # This is probably not that important in practice but it seems cleaner. 154 | state0 = [self.observations[idx - 1]] 155 | for offset in range(0, self.window_length - 1): 156 | current_idx = idx - 2 - offset 157 | current_terminal = self.terminals[current_idx - 1] if current_idx - 1 > 0 else False 158 | if current_idx < 0 or (not self.ignore_episode_boundaries and current_terminal): 159 | # The previously handled observation was terminal, don't add the current one. 160 | # Otherwise we would leak into a different episode. 161 | break 162 | state0.insert(0, self.observations[current_idx]) 163 | while len(state0) < self.window_length: 164 | state0.insert(0, zeroed_observation(state0[0])) 165 | action = self.actions[idx - 1] 166 | reward = self.rewards[idx - 1] 167 | terminal1 = self.terminals[idx - 1] 168 | 169 | # Okay, now we need to create the follow-up state. This is state0 shifted on timestep 170 | # to the right. Again, we need to be careful to not include an observation from the next 171 | # episode if the last state is terminal. 172 | state1 = [np.copy(x) for x in state0[1:]] 173 | state1.append(self.observations[idx]) 174 | 175 | assert len(state0) == self.window_length 176 | assert len(state1) == len(state0) 177 | experiences.append(Experience(state0=state0, action=action, reward=reward, 178 | state1=state1, terminal1=terminal1)) 179 | assert len(experiences) == batch_size 180 | return experiences 181 | 182 | def sample_and_split(self, batch_size, batch_idxs=None): 183 | experiences = self.sample(batch_size, batch_idxs) 184 | 185 | state0_batch = [] 186 | reward_batch = [] 187 | action_batch = [] 188 | terminal1_batch = [] 189 | state1_batch = [] 190 | for e in experiences: 191 | state0_batch.append(e.state0) 192 | state1_batch.append(e.state1) 193 | reward_batch.append(e.reward) 194 | action_batch.append(e.action) 195 | terminal1_batch.append(0. if e.terminal1 else 1.) 196 | 197 | # Prepare and validate parameters. 198 | state0_batch = np.array(state0_batch).reshape(batch_size,-1) 199 | state1_batch = np.array(state1_batch).reshape(batch_size,-1) 200 | terminal1_batch = np.array(terminal1_batch).reshape(batch_size,-1) 201 | reward_batch = np.array(reward_batch).reshape(batch_size,-1) 202 | action_batch = np.array(action_batch).reshape(batch_size,-1) 203 | 204 | return state0_batch, action_batch, reward_batch, state1_batch, terminal1_batch 205 | 206 | 207 | def append(self, observation, action, reward, terminal, training=True): 208 | super(SequentialMemory, self).append(observation, action, reward, terminal, training=training) 209 | 210 | # This needs to be understood as follows: in `observation`, take `action`, obtain `reward` 211 | # and weather the next state is `terminal` or not. 212 | if training: 213 | self.observations.append(observation) 214 | self.actions.append(action) 215 | self.rewards.append(reward) 216 | self.terminals.append(terminal) 217 | 218 | @property 219 | def nb_entries(self): 220 | return len(self.observations) 221 | 222 | def get_config(self): 223 | config = super(SequentialMemory, self).get_config() 224 | config['limit'] = self.limit 225 | return config 226 | 227 | 228 | class EpisodeParameterMemory(Memory): 229 | def __init__(self, limit, **kwargs): 230 | super(EpisodeParameterMemory, self).__init__(**kwargs) 231 | self.limit = limit 232 | 233 | self.params = RingBuffer(limit) 234 | self.intermediate_rewards = [] 235 | self.total_rewards = RingBuffer(limit) 236 | 237 | def sample(self, batch_size, batch_idxs=None): 238 | if batch_idxs is None: 239 | batch_idxs = sample_batch_indexes(0, self.nb_entries, size=batch_size) 240 | assert len(batch_idxs) == batch_size 241 | 242 | batch_params = [] 243 | batch_total_rewards = [] 244 | for idx in batch_idxs: 245 | batch_params.append(self.params[idx]) 246 | batch_total_rewards.append(self.total_rewards[idx]) 247 | return batch_params, batch_total_rewards 248 | 249 | def append(self, observation, action, reward, terminal, training=True): 250 | super(EpisodeParameterMemory, self).append(observation, action, reward, terminal, training=training) 251 | if training: 252 | self.intermediate_rewards.append(reward) 253 | 254 | def finalize_episode(self, params): 255 | total_reward = sum(self.intermediate_rewards) 256 | self.total_rewards.append(total_reward) 257 | self.params.append(params) 258 | self.intermediate_rewards = [] 259 | 260 | @property 261 | def nb_entries(self): 262 | return len(self.total_rewards) 263 | 264 | def get_config(self): 265 | config = super(SequentialMemory, self).get_config() 266 | config['limit'] = self.limit 267 | return config 268 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | 2 | import numpy as np 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | from ipdb import set_trace as debug 9 | 10 | def fanin_init(size, fanin=None): 11 | fanin = fanin or size[0] 12 | v = 1. / np.sqrt(fanin) 13 | return torch.Tensor(size).uniform_(-v, v) 14 | 15 | class Actor(nn.Module): 16 | def __init__(self, nb_states, nb_actions, hidden1=400, hidden2=300, init_w=3e-3): 17 | super(Actor, self).__init__() 18 | self.fc1 = nn.Linear(nb_states, hidden1) 19 | self.fc2 = nn.Linear(hidden1, hidden2) 20 | self.fc3 = nn.Linear(hidden2, nb_actions) 21 | self.relu = nn.ReLU() 22 | self.tanh = nn.Tanh() 23 | self.init_weights(init_w) 24 | 25 | def init_weights(self, init_w): 26 | self.fc1.weight.data = fanin_init(self.fc1.weight.data.size()) 27 | self.fc2.weight.data = fanin_init(self.fc2.weight.data.size()) 28 | self.fc3.weight.data.uniform_(-init_w, init_w) 29 | 30 | def forward(self, x): 31 | out = self.fc1(x) 32 | out = self.relu(out) 33 | out = self.fc2(out) 34 | out = self.relu(out) 35 | out = self.fc3(out) 36 | out = self.tanh(out) 37 | return out 38 | 39 | class Critic(nn.Module): 40 | def __init__(self, nb_states, nb_actions, hidden1=400, hidden2=300, init_w=3e-3): 41 | super(Critic, self).__init__() 42 | self.fc1 = nn.Linear(nb_states, hidden1) 43 | self.fc2 = nn.Linear(hidden1+nb_actions, hidden2) 44 | self.fc3 = nn.Linear(hidden2, 1) 45 | self.relu = nn.ReLU() 46 | self.init_weights(init_w) 47 | 48 | def init_weights(self, init_w): 49 | self.fc1.weight.data = fanin_init(self.fc1.weight.data.size()) 50 | self.fc2.weight.data = fanin_init(self.fc2.weight.data.size()) 51 | self.fc3.weight.data.uniform_(-init_w, init_w) 52 | 53 | def forward(self, xs): 54 | x, a = xs 55 | out = self.fc1(x) 56 | out = self.relu(out) 57 | # debug() 58 | out = self.fc2(torch.cat([out,a],1)) 59 | out = self.relu(out) 60 | out = self.fc3(out) 61 | return out -------------------------------------------------------------------------------- /normalized_env.py: -------------------------------------------------------------------------------- 1 | 2 | import gym 3 | 4 | # https://github.com/openai/gym/blob/master/gym/core.py 5 | class NormalizedEnv(gym.ActionWrapper): 6 | """ Wrap action """ 7 | 8 | def _action(self, action): 9 | act_k = (self.action_space.high - self.action_space.low)/ 2. 10 | act_b = (self.action_space.high + self.action_space.low)/ 2. 11 | return act_k * action + act_b 12 | 13 | def _reverse_action(self, action): 14 | act_k_inv = 2./(self.action_space.high - self.action_space.low) 15 | act_b = (self.action_space.high + self.action_space.low)/ 2. 16 | return act_k_inv * (action - act_b) 17 | -------------------------------------------------------------------------------- /output/MountainCarContinuous-v0-run0/actor.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ghliu/pytorch-ddpg/e9db328ca70ef9daf7ab3d4b44975076ceddf088/output/MountainCarContinuous-v0-run0/actor.pkl -------------------------------------------------------------------------------- /output/MountainCarContinuous-v0-run0/critic.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ghliu/pytorch-ddpg/e9db328ca70ef9daf7ab3d4b44975076ceddf088/output/MountainCarContinuous-v0-run0/critic.pkl -------------------------------------------------------------------------------- /output/MountainCarContinuous-v0-run0/validate_reward.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ghliu/pytorch-ddpg/e9db328ca70ef9daf7ab3d4b44975076ceddf088/output/MountainCarContinuous-v0-run0/validate_reward.mat -------------------------------------------------------------------------------- /output/MountainCarContinuous-v0-run0/validate_reward.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ghliu/pytorch-ddpg/e9db328ca70ef9daf7ab3d4b44975076ceddf088/output/MountainCarContinuous-v0-run0/validate_reward.png -------------------------------------------------------------------------------- /output/Pendulum-v0-run0/actor.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ghliu/pytorch-ddpg/e9db328ca70ef9daf7ab3d4b44975076ceddf088/output/Pendulum-v0-run0/actor.pkl -------------------------------------------------------------------------------- /output/Pendulum-v0-run0/critic.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ghliu/pytorch-ddpg/e9db328ca70ef9daf7ab3d4b44975076ceddf088/output/Pendulum-v0-run0/critic.pkl -------------------------------------------------------------------------------- /output/Pendulum-v0-run0/validate_reward.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ghliu/pytorch-ddpg/e9db328ca70ef9daf7ab3d4b44975076ceddf088/output/Pendulum-v0-run0/validate_reward.mat -------------------------------------------------------------------------------- /output/Pendulum-v0-run0/validate_reward.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ghliu/pytorch-ddpg/e9db328ca70ef9daf7ab3d4b44975076ceddf088/output/Pendulum-v0-run0/validate_reward.png -------------------------------------------------------------------------------- /random_process.py: -------------------------------------------------------------------------------- 1 | 2 | import numpy as np 3 | 4 | # [reference] https://github.com/matthiasplappert/keras-rl/blob/master/rl/random.py 5 | 6 | class RandomProcess(object): 7 | def reset_states(self): 8 | pass 9 | 10 | class AnnealedGaussianProcess(RandomProcess): 11 | def __init__(self, mu, sigma, sigma_min, n_steps_annealing): 12 | self.mu = mu 13 | self.sigma = sigma 14 | self.n_steps = 0 15 | 16 | if sigma_min is not None: 17 | self.m = -float(sigma - sigma_min) / float(n_steps_annealing) 18 | self.c = sigma 19 | self.sigma_min = sigma_min 20 | else: 21 | self.m = 0. 22 | self.c = sigma 23 | self.sigma_min = sigma 24 | 25 | @property 26 | def current_sigma(self): 27 | sigma = max(self.sigma_min, self.m * float(self.n_steps) + self.c) 28 | return sigma 29 | 30 | 31 | # Based on http://math.stackexchange.com/questions/1287634/implementing-ornstein-uhlenbeck-in-matlab 32 | class OrnsteinUhlenbeckProcess(AnnealedGaussianProcess): 33 | def __init__(self, theta, mu=0., sigma=1., dt=1e-2, x0=None, size=1, sigma_min=None, n_steps_annealing=1000): 34 | super(OrnsteinUhlenbeckProcess, self).__init__(mu=mu, sigma=sigma, sigma_min=sigma_min, n_steps_annealing=n_steps_annealing) 35 | self.theta = theta 36 | self.mu = mu 37 | self.dt = dt 38 | self.x0 = x0 39 | self.size = size 40 | self.reset_states() 41 | 42 | def sample(self): 43 | x = self.x_prev + self.theta * (self.mu - self.x_prev) * self.dt + self.current_sigma * np.sqrt(self.dt) * np.random.normal(size=self.size) 44 | self.x_prev = x 45 | self.n_steps += 1 46 | return x 47 | 48 | def reset_states(self): 49 | self.x_prev = self.x0 if self.x0 is not None else np.zeros(self.size) -------------------------------------------------------------------------------- /util.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import torch 4 | from torch.autograd import Variable 5 | 6 | USE_CUDA = torch.cuda.is_available() 7 | FLOAT = torch.cuda.FloatTensor if USE_CUDA else torch.FloatTensor 8 | 9 | def prRed(prt): print("\033[91m {}\033[00m" .format(prt)) 10 | def prGreen(prt): print("\033[92m {}\033[00m" .format(prt)) 11 | def prYellow(prt): print("\033[93m {}\033[00m" .format(prt)) 12 | def prLightPurple(prt): print("\033[94m {}\033[00m" .format(prt)) 13 | def prPurple(prt): print("\033[95m {}\033[00m" .format(prt)) 14 | def prCyan(prt): print("\033[96m {}\033[00m" .format(prt)) 15 | def prLightGray(prt): print("\033[97m {}\033[00m" .format(prt)) 16 | def prBlack(prt): print("\033[98m {}\033[00m" .format(prt)) 17 | 18 | def to_numpy(var): 19 | return var.cpu().data.numpy() if USE_CUDA else var.data.numpy() 20 | 21 | def to_tensor(ndarray, volatile=False, requires_grad=False, dtype=FLOAT): 22 | return Variable( 23 | torch.from_numpy(ndarray), volatile=volatile, requires_grad=requires_grad 24 | ).type(dtype) 25 | 26 | def soft_update(target, source, tau): 27 | for target_param, param in zip(target.parameters(), source.parameters()): 28 | target_param.data.copy_( 29 | target_param.data * (1.0 - tau) + param.data * tau 30 | ) 31 | 32 | def hard_update(target, source): 33 | for target_param, param in zip(target.parameters(), source.parameters()): 34 | target_param.data.copy_(param.data) 35 | 36 | def get_output_folder(parent_dir, env_name): 37 | """Return save folder. 38 | 39 | Assumes folders in the parent_dir have suffix -run{run 40 | number}. Finds the highest run number and sets the output folder 41 | to that number + 1. This is just convenient so that if you run the 42 | same script multiple times tensorboard can plot all of the results 43 | on the same plots with different names. 44 | 45 | Parameters 46 | ---------- 47 | parent_dir: str 48 | Path of the directory containing all experiment runs. 49 | 50 | Returns 51 | ------- 52 | parent_dir/run_dir 53 | Path to this run's save directory. 54 | """ 55 | os.makedirs(parent_dir, exist_ok=True) 56 | experiment_id = 0 57 | for folder_name in os.listdir(parent_dir): 58 | if not os.path.isdir(os.path.join(parent_dir, folder_name)): 59 | continue 60 | try: 61 | folder_name = int(folder_name.split('-run')[-1]) 62 | if folder_name > experiment_id: 63 | experiment_id = folder_name 64 | except: 65 | pass 66 | experiment_id += 1 67 | 68 | parent_dir = os.path.join(parent_dir, env_name) 69 | parent_dir = parent_dir + '-run{}'.format(experiment_id) 70 | os.makedirs(parent_dir, exist_ok=True) 71 | return parent_dir 72 | 73 | 74 | --------------------------------------------------------------------------------