├── DDPG.py ├── LICENSE ├── OurDDPG.py ├── README.md ├── TD3.py ├── learning_curves ├── Ant │ ├── TD3_Ant-v1_0.npy │ ├── TD3_Ant-v1_1.npy │ ├── TD3_Ant-v1_2.npy │ ├── TD3_Ant-v1_3.npy │ ├── TD3_Ant-v1_4.npy │ ├── TD3_Ant-v1_5.npy │ ├── TD3_Ant-v1_6.npy │ ├── TD3_Ant-v1_7.npy │ ├── TD3_Ant-v1_8.npy │ └── TD3_Ant-v1_9.npy ├── HalfCheetah │ ├── TD3_HalfCheetah-v1_0.npy │ ├── TD3_HalfCheetah-v1_1.npy │ ├── TD3_HalfCheetah-v1_2.npy │ ├── TD3_HalfCheetah-v1_3.npy │ ├── TD3_HalfCheetah-v1_4.npy │ ├── TD3_HalfCheetah-v1_5.npy │ ├── TD3_HalfCheetah-v1_6.npy │ ├── TD3_HalfCheetah-v1_7.npy │ ├── TD3_HalfCheetah-v1_8.npy │ └── TD3_HalfCheetah-v1_9.npy ├── Hopper │ ├── TD3_Hopper-v1_0.npy │ ├── TD3_Hopper-v1_1.npy │ ├── TD3_Hopper-v1_2.npy │ ├── TD3_Hopper-v1_3.npy │ ├── TD3_Hopper-v1_4.npy │ ├── TD3_Hopper-v1_5.npy │ ├── TD3_Hopper-v1_6.npy │ ├── TD3_Hopper-v1_7.npy │ ├── TD3_Hopper-v1_8.npy │ └── TD3_Hopper-v1_9.npy ├── InvertedDoublePendulum │ ├── TD3_InvertedDoublePendulum-v1_0.npy │ ├── TD3_InvertedDoublePendulum-v1_1.npy │ ├── TD3_InvertedDoublePendulum-v1_2.npy │ ├── TD3_InvertedDoublePendulum-v1_3.npy │ ├── TD3_InvertedDoublePendulum-v1_4.npy │ ├── TD3_InvertedDoublePendulum-v1_5.npy │ ├── TD3_InvertedDoublePendulum-v1_6.npy │ ├── TD3_InvertedDoublePendulum-v1_7.npy │ ├── TD3_InvertedDoublePendulum-v1_8.npy │ └── TD3_InvertedDoublePendulum-v1_9.npy ├── InvertedPendulum │ ├── TD3_InvertedPendulum-v1_0.npy │ ├── TD3_InvertedPendulum-v1_1.npy │ ├── TD3_InvertedPendulum-v1_2.npy │ ├── TD3_InvertedPendulum-v1_3.npy │ ├── TD3_InvertedPendulum-v1_4.npy │ ├── TD3_InvertedPendulum-v1_5.npy │ ├── TD3_InvertedPendulum-v1_6.npy │ ├── TD3_InvertedPendulum-v1_7.npy │ ├── TD3_InvertedPendulum-v1_8.npy │ └── TD3_InvertedPendulum-v1_9.npy ├── Reacher │ ├── TD3_Reacher-v1_0.npy │ ├── TD3_Reacher-v1_1.npy │ ├── TD3_Reacher-v1_2.npy │ ├── TD3_Reacher-v1_3.npy │ ├── TD3_Reacher-v1_4.npy │ ├── TD3_Reacher-v1_5.npy │ ├── TD3_Reacher-v1_6.npy │ ├── TD3_Reacher-v1_7.npy │ ├── TD3_Reacher-v1_8.npy │ └── TD3_Reacher-v1_9.npy └── Walker │ ├── TD3_Walker2d-v1_0.npy │ ├── TD3_Walker2d-v1_1.npy │ ├── TD3_Walker2d-v1_2.npy │ ├── TD3_Walker2d-v1_3.npy │ ├── TD3_Walker2d-v1_4.npy │ ├── TD3_Walker2d-v1_5.npy │ ├── TD3_Walker2d-v1_6.npy │ ├── TD3_Walker2d-v1_7.npy │ ├── TD3_Walker2d-v1_8.npy │ └── TD3_Walker2d-v1_9.npy ├── main.py ├── run_experiments.sh └── utils.py /DDPG.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import numpy as np 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | 8 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 9 | 10 | # Implementation of Deep Deterministic Policy Gradients (DDPG) 11 | # Paper: https://arxiv.org/abs/1509.02971 12 | # [Not the implementation used in the TD3 paper] 13 | 14 | 15 | class Actor(nn.Module): 16 | def __init__(self, state_dim, action_dim, max_action): 17 | super(Actor, self).__init__() 18 | 19 | self.l1 = nn.Linear(state_dim, 400) 20 | self.l2 = nn.Linear(400, 300) 21 | self.l3 = nn.Linear(300, action_dim) 22 | 23 | self.max_action = max_action 24 | 25 | 26 | def forward(self, state): 27 | a = F.relu(self.l1(state)) 28 | a = F.relu(self.l2(a)) 29 | return self.max_action * torch.tanh(self.l3(a)) 30 | 31 | 32 | class Critic(nn.Module): 33 | def __init__(self, state_dim, action_dim): 34 | super(Critic, self).__init__() 35 | 36 | self.l1 = nn.Linear(state_dim, 400) 37 | self.l2 = nn.Linear(400 + action_dim, 300) 38 | self.l3 = nn.Linear(300, 1) 39 | 40 | 41 | def forward(self, state, action): 42 | q = F.relu(self.l1(state)) 43 | q = F.relu(self.l2(torch.cat([q, action], 1))) 44 | return self.l3(q) 45 | 46 | 47 | class DDPG(object): 48 | def __init__(self, state_dim, action_dim, max_action, discount=0.99, tau=0.001): 49 | self.actor = Actor(state_dim, action_dim, max_action).to(device) 50 | self.actor_target = copy.deepcopy(self.actor) 51 | self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), lr=1e-4) 52 | 53 | self.critic = Critic(state_dim, action_dim).to(device) 54 | self.critic_target = copy.deepcopy(self.critic) 55 | self.critic_optimizer = torch.optim.Adam(self.critic.parameters(), weight_decay=1e-2) 56 | 57 | self.discount = discount 58 | self.tau = tau 59 | 60 | 61 | def select_action(self, state): 62 | state = torch.FloatTensor(state.reshape(1, -1)).to(device) 63 | return self.actor(state).cpu().data.numpy().flatten() 64 | 65 | 66 | def train(self, replay_buffer, batch_size=64): 67 | # Sample replay buffer 68 | state, action, next_state, reward, not_done = replay_buffer.sample(batch_size) 69 | 70 | # Compute the target Q value 71 | target_Q = self.critic_target(next_state, self.actor_target(next_state)) 72 | target_Q = reward + (not_done * self.discount * target_Q).detach() 73 | 74 | # Get current Q estimate 75 | current_Q = self.critic(state, action) 76 | 77 | # Compute critic loss 78 | critic_loss = F.mse_loss(current_Q, target_Q) 79 | 80 | # Optimize the critic 81 | self.critic_optimizer.zero_grad() 82 | critic_loss.backward() 83 | self.critic_optimizer.step() 84 | 85 | # Compute actor loss 86 | actor_loss = -self.critic(state, self.actor(state)).mean() 87 | 88 | # Optimize the actor 89 | self.actor_optimizer.zero_grad() 90 | actor_loss.backward() 91 | self.actor_optimizer.step() 92 | 93 | # Update the frozen target models 94 | for param, target_param in zip(self.critic.parameters(), self.critic_target.parameters()): 95 | target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data) 96 | 97 | for param, target_param in zip(self.actor.parameters(), self.actor_target.parameters()): 98 | target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data) 99 | 100 | 101 | def save(self, filename): 102 | torch.save(self.critic.state_dict(), filename + "_critic") 103 | torch.save(self.critic_optimizer.state_dict(), filename + "_critic_optimizer") 104 | 105 | torch.save(self.actor.state_dict(), filename + "_actor") 106 | torch.save(self.actor_optimizer.state_dict(), filename + "_actor_optimizer") 107 | 108 | 109 | def load(self, filename): 110 | self.critic.load_state_dict(torch.load(filename + "_critic")) 111 | self.critic_optimizer.load_state_dict(torch.load(filename + "_critic_optimizer")) 112 | self.critic_target = copy.deepcopy(self.critic) 113 | 114 | self.actor.load_state_dict(torch.load(filename + "_actor")) 115 | self.actor_optimizer.load_state_dict(torch.load(filename + "_actor_optimizer")) 116 | self.actor_target = copy.deepcopy(self.actor) 117 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Scott Fujimoto 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /OurDDPG.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import numpy as np 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | 8 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 9 | 10 | # Re-tuned version of Deep Deterministic Policy Gradients (DDPG) 11 | # Paper: https://arxiv.org/abs/1509.02971 12 | 13 | 14 | class Actor(nn.Module): 15 | def __init__(self, state_dim, action_dim, max_action): 16 | super(Actor, self).__init__() 17 | 18 | self.l1 = nn.Linear(state_dim, 256) 19 | self.l2 = nn.Linear(256, 256) 20 | self.l3 = nn.Linear(256, action_dim) 21 | 22 | self.max_action = max_action 23 | 24 | 25 | def forward(self, state): 26 | a = F.relu(self.l1(state)) 27 | a = F.relu(self.l2(a)) 28 | return self.max_action * torch.tanh(self.l3(a)) 29 | 30 | 31 | class Critic(nn.Module): 32 | def __init__(self, state_dim, action_dim): 33 | super(Critic, self).__init__() 34 | 35 | self.l1 = nn.Linear(state_dim + action_dim, 256) 36 | self.l2 = nn.Linear(256, 256) 37 | self.l3 = nn.Linear(256, 1) 38 | 39 | 40 | def forward(self, state, action): 41 | q = F.relu(self.l1(torch.cat([state, action], 1))) 42 | q = F.relu(self.l2(q)) 43 | return self.l3(q) 44 | 45 | 46 | class DDPG(object): 47 | def __init__(self, state_dim, action_dim, max_action, discount=0.99, tau=0.005): 48 | self.actor = Actor(state_dim, action_dim, max_action).to(device) 49 | self.actor_target = copy.deepcopy(self.actor) 50 | self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), lr=3e-4) 51 | 52 | self.critic = Critic(state_dim, action_dim).to(device) 53 | self.critic_target = copy.deepcopy(self.critic) 54 | self.critic_optimizer = torch.optim.Adam(self.critic.parameters(), lr=3e-4) 55 | 56 | self.discount = discount 57 | self.tau = tau 58 | 59 | 60 | def select_action(self, state): 61 | state = torch.FloatTensor(state.reshape(1, -1)).to(device) 62 | return self.actor(state).cpu().data.numpy().flatten() 63 | 64 | 65 | def train(self, replay_buffer, batch_size=256): 66 | # Sample replay buffer 67 | state, action, next_state, reward, not_done = replay_buffer.sample(batch_size) 68 | 69 | # Compute the target Q value 70 | target_Q = self.critic_target(next_state, self.actor_target(next_state)) 71 | target_Q = reward + (not_done * self.discount * target_Q).detach() 72 | 73 | # Get current Q estimate 74 | current_Q = self.critic(state, action) 75 | 76 | # Compute critic loss 77 | critic_loss = F.mse_loss(current_Q, target_Q) 78 | 79 | # Optimize the critic 80 | self.critic_optimizer.zero_grad() 81 | critic_loss.backward() 82 | self.critic_optimizer.step() 83 | 84 | # Compute actor loss 85 | actor_loss = -self.critic(state, self.actor(state)).mean() 86 | 87 | # Optimize the actor 88 | self.actor_optimizer.zero_grad() 89 | actor_loss.backward() 90 | self.actor_optimizer.step() 91 | 92 | # Update the frozen target models 93 | for param, target_param in zip(self.critic.parameters(), self.critic_target.parameters()): 94 | target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data) 95 | 96 | for param, target_param in zip(self.actor.parameters(), self.actor_target.parameters()): 97 | target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data) 98 | 99 | 100 | def save(self, filename): 101 | torch.save(self.critic.state_dict(), filename + "_critic") 102 | torch.save(self.critic_optimizer.state_dict(), filename + "_critic_optimizer") 103 | 104 | torch.save(self.actor.state_dict(), filename + "_actor") 105 | torch.save(self.actor_optimizer.state_dict(), filename + "_actor_optimizer") 106 | 107 | 108 | def load(self, filename): 109 | self.critic.load_state_dict(torch.load(filename + "_critic")) 110 | self.critic_optimizer.load_state_dict(torch.load(filename + "_critic_optimizer")) 111 | self.critic_target = copy.deepcopy(self.critic) 112 | 113 | self.actor.load_state_dict(torch.load(filename + "_actor")) 114 | self.actor_optimizer.load_state_dict(torch.load(filename + "_actor_optimizer")) 115 | self.actor_target = copy.deepcopy(self.actor) 116 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Addressing Function Approximation Error in Actor-Critic Methods 2 | 3 | PyTorch implementation of Twin Delayed Deep Deterministic Policy Gradients (TD3). If you use our code or data please cite the [paper](https://arxiv.org/abs/1802.09477). 4 | 5 | Method is tested on [MuJoCo](http://www.mujoco.org/) continuous control tasks in [OpenAI gym](https://github.com/openai/gym). 6 | Networks are trained using [PyTorch 1.2](https://github.com/pytorch/pytorch) and Python 3.7. 7 | 8 | ### Usage 9 | The paper results can be reproduced by running: 10 | ``` 11 | ./run_experiments.sh 12 | ``` 13 | Experiments on single environments can be run by calling: 14 | ``` 15 | python main.py --env HalfCheetah-v2 16 | ``` 17 | 18 | Hyper-parameters can be modified with different arguments to main.py. We include an implementation of DDPG (DDPG.py), which is not used in the paper, for easy comparison of hyper-parameters with TD3. This is not the implementation of "Our DDPG" as used in the paper (see OurDDPG.py). 19 | 20 | Algorithms which TD3 compares against (PPO, TRPO, ACKTR, DDPG) can be found at [OpenAI baselines repository](https://github.com/openai/baselines). 21 | 22 | ### Results 23 | Code is no longer exactly representative of the code used in the paper. Minor adjustments to hyperparamters, etc, to improve performance. Learning curves are still the original results found in the paper. 24 | 25 | Learning curves found in the paper are found under /learning_curves. Each learning curve are formatted as NumPy arrays of 201 evaluations (201,), where each evaluation corresponds to the average total reward from running the policy for 10 episodes with no exploration. The first evaluation is the randomly initialized policy network (unused in the paper). Evaluations are peformed every 5000 time steps, over a total of 1 million time steps. 26 | 27 | Numerical results can be found in the paper, or from the learning curves. Video of the learned agent can be found [here](https://youtu.be/x33Vw-6vzso). 28 | 29 | ### Bibtex 30 | 31 | ``` 32 | @inproceedings{fujimoto2018addressing, 33 | title={Addressing Function Approximation Error in Actor-Critic Methods}, 34 | author={Fujimoto, Scott and Hoof, Herke and Meger, David}, 35 | booktitle={International Conference on Machine Learning}, 36 | pages={1582--1591}, 37 | year={2018} 38 | } 39 | ``` 40 | -------------------------------------------------------------------------------- /TD3.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import numpy as np 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | 8 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 9 | 10 | # Implementation of Twin Delayed Deep Deterministic Policy Gradients (TD3) 11 | # Paper: https://arxiv.org/abs/1802.09477 12 | 13 | 14 | class Actor(nn.Module): 15 | def __init__(self, state_dim, action_dim, max_action): 16 | super(Actor, self).__init__() 17 | 18 | self.l1 = nn.Linear(state_dim, 256) 19 | self.l2 = nn.Linear(256, 256) 20 | self.l3 = nn.Linear(256, action_dim) 21 | 22 | self.max_action = max_action 23 | 24 | 25 | def forward(self, state): 26 | a = F.relu(self.l1(state)) 27 | a = F.relu(self.l2(a)) 28 | return self.max_action * torch.tanh(self.l3(a)) 29 | 30 | 31 | class Critic(nn.Module): 32 | def __init__(self, state_dim, action_dim): 33 | super(Critic, self).__init__() 34 | 35 | # Q1 architecture 36 | self.l1 = nn.Linear(state_dim + action_dim, 256) 37 | self.l2 = nn.Linear(256, 256) 38 | self.l3 = nn.Linear(256, 1) 39 | 40 | # Q2 architecture 41 | self.l4 = nn.Linear(state_dim + action_dim, 256) 42 | self.l5 = nn.Linear(256, 256) 43 | self.l6 = nn.Linear(256, 1) 44 | 45 | 46 | def forward(self, state, action): 47 | sa = torch.cat([state, action], 1) 48 | 49 | q1 = F.relu(self.l1(sa)) 50 | q1 = F.relu(self.l2(q1)) 51 | q1 = self.l3(q1) 52 | 53 | q2 = F.relu(self.l4(sa)) 54 | q2 = F.relu(self.l5(q2)) 55 | q2 = self.l6(q2) 56 | return q1, q2 57 | 58 | 59 | def Q1(self, state, action): 60 | sa = torch.cat([state, action], 1) 61 | 62 | q1 = F.relu(self.l1(sa)) 63 | q1 = F.relu(self.l2(q1)) 64 | q1 = self.l3(q1) 65 | return q1 66 | 67 | 68 | class TD3(object): 69 | def __init__( 70 | self, 71 | state_dim, 72 | action_dim, 73 | max_action, 74 | discount=0.99, 75 | tau=0.005, 76 | policy_noise=0.2, 77 | noise_clip=0.5, 78 | policy_freq=2 79 | ): 80 | 81 | self.actor = Actor(state_dim, action_dim, max_action).to(device) 82 | self.actor_target = copy.deepcopy(self.actor) 83 | self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), lr=3e-4) 84 | 85 | self.critic = Critic(state_dim, action_dim).to(device) 86 | self.critic_target = copy.deepcopy(self.critic) 87 | self.critic_optimizer = torch.optim.Adam(self.critic.parameters(), lr=3e-4) 88 | 89 | self.max_action = max_action 90 | self.discount = discount 91 | self.tau = tau 92 | self.policy_noise = policy_noise 93 | self.noise_clip = noise_clip 94 | self.policy_freq = policy_freq 95 | 96 | self.total_it = 0 97 | 98 | 99 | def select_action(self, state): 100 | state = torch.FloatTensor(state.reshape(1, -1)).to(device) 101 | return self.actor(state).cpu().data.numpy().flatten() 102 | 103 | 104 | def train(self, replay_buffer, batch_size=256): 105 | self.total_it += 1 106 | 107 | # Sample replay buffer 108 | state, action, next_state, reward, not_done = replay_buffer.sample(batch_size) 109 | 110 | with torch.no_grad(): 111 | # Select action according to policy and add clipped noise 112 | noise = ( 113 | torch.randn_like(action) * self.policy_noise 114 | ).clamp(-self.noise_clip, self.noise_clip) 115 | 116 | next_action = ( 117 | self.actor_target(next_state) + noise 118 | ).clamp(-self.max_action, self.max_action) 119 | 120 | # Compute the target Q value 121 | target_Q1, target_Q2 = self.critic_target(next_state, next_action) 122 | target_Q = torch.min(target_Q1, target_Q2) 123 | target_Q = reward + not_done * self.discount * target_Q 124 | 125 | # Get current Q estimates 126 | current_Q1, current_Q2 = self.critic(state, action) 127 | 128 | # Compute critic loss 129 | critic_loss = F.mse_loss(current_Q1, target_Q) + F.mse_loss(current_Q2, target_Q) 130 | 131 | # Optimize the critic 132 | self.critic_optimizer.zero_grad() 133 | critic_loss.backward() 134 | self.critic_optimizer.step() 135 | 136 | # Delayed policy updates 137 | if self.total_it % self.policy_freq == 0: 138 | 139 | # Compute actor losse 140 | actor_loss = -self.critic.Q1(state, self.actor(state)).mean() 141 | 142 | # Optimize the actor 143 | self.actor_optimizer.zero_grad() 144 | actor_loss.backward() 145 | self.actor_optimizer.step() 146 | 147 | # Update the frozen target models 148 | for param, target_param in zip(self.critic.parameters(), self.critic_target.parameters()): 149 | target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data) 150 | 151 | for param, target_param in zip(self.actor.parameters(), self.actor_target.parameters()): 152 | target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data) 153 | 154 | 155 | def save(self, filename): 156 | torch.save(self.critic.state_dict(), filename + "_critic") 157 | torch.save(self.critic_optimizer.state_dict(), filename + "_critic_optimizer") 158 | 159 | torch.save(self.actor.state_dict(), filename + "_actor") 160 | torch.save(self.actor_optimizer.state_dict(), filename + "_actor_optimizer") 161 | 162 | 163 | def load(self, filename): 164 | self.critic.load_state_dict(torch.load(filename + "_critic")) 165 | self.critic_optimizer.load_state_dict(torch.load(filename + "_critic_optimizer")) 166 | self.critic_target = copy.deepcopy(self.critic) 167 | 168 | self.actor.load_state_dict(torch.load(filename + "_actor")) 169 | self.actor_optimizer.load_state_dict(torch.load(filename + "_actor_optimizer")) 170 | self.actor_target = copy.deepcopy(self.actor) 171 | -------------------------------------------------------------------------------- /learning_curves/Ant/TD3_Ant-v1_0.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sfujim/TD3/34770ccdcb51df6f6c7d85c1ba7e40d71b940d16/learning_curves/Ant/TD3_Ant-v1_0.npy -------------------------------------------------------------------------------- /learning_curves/Ant/TD3_Ant-v1_1.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sfujim/TD3/34770ccdcb51df6f6c7d85c1ba7e40d71b940d16/learning_curves/Ant/TD3_Ant-v1_1.npy -------------------------------------------------------------------------------- /learning_curves/Ant/TD3_Ant-v1_2.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sfujim/TD3/34770ccdcb51df6f6c7d85c1ba7e40d71b940d16/learning_curves/Ant/TD3_Ant-v1_2.npy -------------------------------------------------------------------------------- /learning_curves/Ant/TD3_Ant-v1_3.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sfujim/TD3/34770ccdcb51df6f6c7d85c1ba7e40d71b940d16/learning_curves/Ant/TD3_Ant-v1_3.npy -------------------------------------------------------------------------------- /learning_curves/Ant/TD3_Ant-v1_4.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sfujim/TD3/34770ccdcb51df6f6c7d85c1ba7e40d71b940d16/learning_curves/Ant/TD3_Ant-v1_4.npy -------------------------------------------------------------------------------- /learning_curves/Ant/TD3_Ant-v1_5.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sfujim/TD3/34770ccdcb51df6f6c7d85c1ba7e40d71b940d16/learning_curves/Ant/TD3_Ant-v1_5.npy -------------------------------------------------------------------------------- /learning_curves/Ant/TD3_Ant-v1_6.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sfujim/TD3/34770ccdcb51df6f6c7d85c1ba7e40d71b940d16/learning_curves/Ant/TD3_Ant-v1_6.npy -------------------------------------------------------------------------------- /learning_curves/Ant/TD3_Ant-v1_7.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sfujim/TD3/34770ccdcb51df6f6c7d85c1ba7e40d71b940d16/learning_curves/Ant/TD3_Ant-v1_7.npy -------------------------------------------------------------------------------- /learning_curves/Ant/TD3_Ant-v1_8.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sfujim/TD3/34770ccdcb51df6f6c7d85c1ba7e40d71b940d16/learning_curves/Ant/TD3_Ant-v1_8.npy -------------------------------------------------------------------------------- /learning_curves/Ant/TD3_Ant-v1_9.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sfujim/TD3/34770ccdcb51df6f6c7d85c1ba7e40d71b940d16/learning_curves/Ant/TD3_Ant-v1_9.npy -------------------------------------------------------------------------------- /learning_curves/HalfCheetah/TD3_HalfCheetah-v1_0.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sfujim/TD3/34770ccdcb51df6f6c7d85c1ba7e40d71b940d16/learning_curves/HalfCheetah/TD3_HalfCheetah-v1_0.npy -------------------------------------------------------------------------------- /learning_curves/HalfCheetah/TD3_HalfCheetah-v1_1.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sfujim/TD3/34770ccdcb51df6f6c7d85c1ba7e40d71b940d16/learning_curves/HalfCheetah/TD3_HalfCheetah-v1_1.npy -------------------------------------------------------------------------------- /learning_curves/HalfCheetah/TD3_HalfCheetah-v1_2.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sfujim/TD3/34770ccdcb51df6f6c7d85c1ba7e40d71b940d16/learning_curves/HalfCheetah/TD3_HalfCheetah-v1_2.npy -------------------------------------------------------------------------------- /learning_curves/HalfCheetah/TD3_HalfCheetah-v1_3.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sfujim/TD3/34770ccdcb51df6f6c7d85c1ba7e40d71b940d16/learning_curves/HalfCheetah/TD3_HalfCheetah-v1_3.npy -------------------------------------------------------------------------------- /learning_curves/HalfCheetah/TD3_HalfCheetah-v1_4.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sfujim/TD3/34770ccdcb51df6f6c7d85c1ba7e40d71b940d16/learning_curves/HalfCheetah/TD3_HalfCheetah-v1_4.npy -------------------------------------------------------------------------------- /learning_curves/HalfCheetah/TD3_HalfCheetah-v1_5.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sfujim/TD3/34770ccdcb51df6f6c7d85c1ba7e40d71b940d16/learning_curves/HalfCheetah/TD3_HalfCheetah-v1_5.npy -------------------------------------------------------------------------------- /learning_curves/HalfCheetah/TD3_HalfCheetah-v1_6.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sfujim/TD3/34770ccdcb51df6f6c7d85c1ba7e40d71b940d16/learning_curves/HalfCheetah/TD3_HalfCheetah-v1_6.npy -------------------------------------------------------------------------------- /learning_curves/HalfCheetah/TD3_HalfCheetah-v1_7.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sfujim/TD3/34770ccdcb51df6f6c7d85c1ba7e40d71b940d16/learning_curves/HalfCheetah/TD3_HalfCheetah-v1_7.npy -------------------------------------------------------------------------------- /learning_curves/HalfCheetah/TD3_HalfCheetah-v1_8.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sfujim/TD3/34770ccdcb51df6f6c7d85c1ba7e40d71b940d16/learning_curves/HalfCheetah/TD3_HalfCheetah-v1_8.npy -------------------------------------------------------------------------------- /learning_curves/HalfCheetah/TD3_HalfCheetah-v1_9.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sfujim/TD3/34770ccdcb51df6f6c7d85c1ba7e40d71b940d16/learning_curves/HalfCheetah/TD3_HalfCheetah-v1_9.npy -------------------------------------------------------------------------------- /learning_curves/Hopper/TD3_Hopper-v1_0.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sfujim/TD3/34770ccdcb51df6f6c7d85c1ba7e40d71b940d16/learning_curves/Hopper/TD3_Hopper-v1_0.npy -------------------------------------------------------------------------------- /learning_curves/Hopper/TD3_Hopper-v1_1.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sfujim/TD3/34770ccdcb51df6f6c7d85c1ba7e40d71b940d16/learning_curves/Hopper/TD3_Hopper-v1_1.npy -------------------------------------------------------------------------------- /learning_curves/Hopper/TD3_Hopper-v1_2.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sfujim/TD3/34770ccdcb51df6f6c7d85c1ba7e40d71b940d16/learning_curves/Hopper/TD3_Hopper-v1_2.npy -------------------------------------------------------------------------------- /learning_curves/Hopper/TD3_Hopper-v1_3.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sfujim/TD3/34770ccdcb51df6f6c7d85c1ba7e40d71b940d16/learning_curves/Hopper/TD3_Hopper-v1_3.npy -------------------------------------------------------------------------------- /learning_curves/Hopper/TD3_Hopper-v1_4.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sfujim/TD3/34770ccdcb51df6f6c7d85c1ba7e40d71b940d16/learning_curves/Hopper/TD3_Hopper-v1_4.npy -------------------------------------------------------------------------------- /learning_curves/Hopper/TD3_Hopper-v1_5.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sfujim/TD3/34770ccdcb51df6f6c7d85c1ba7e40d71b940d16/learning_curves/Hopper/TD3_Hopper-v1_5.npy -------------------------------------------------------------------------------- /learning_curves/Hopper/TD3_Hopper-v1_6.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sfujim/TD3/34770ccdcb51df6f6c7d85c1ba7e40d71b940d16/learning_curves/Hopper/TD3_Hopper-v1_6.npy -------------------------------------------------------------------------------- /learning_curves/Hopper/TD3_Hopper-v1_7.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sfujim/TD3/34770ccdcb51df6f6c7d85c1ba7e40d71b940d16/learning_curves/Hopper/TD3_Hopper-v1_7.npy -------------------------------------------------------------------------------- /learning_curves/Hopper/TD3_Hopper-v1_8.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sfujim/TD3/34770ccdcb51df6f6c7d85c1ba7e40d71b940d16/learning_curves/Hopper/TD3_Hopper-v1_8.npy -------------------------------------------------------------------------------- /learning_curves/Hopper/TD3_Hopper-v1_9.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sfujim/TD3/34770ccdcb51df6f6c7d85c1ba7e40d71b940d16/learning_curves/Hopper/TD3_Hopper-v1_9.npy -------------------------------------------------------------------------------- /learning_curves/InvertedDoublePendulum/TD3_InvertedDoublePendulum-v1_0.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sfujim/TD3/34770ccdcb51df6f6c7d85c1ba7e40d71b940d16/learning_curves/InvertedDoublePendulum/TD3_InvertedDoublePendulum-v1_0.npy -------------------------------------------------------------------------------- /learning_curves/InvertedDoublePendulum/TD3_InvertedDoublePendulum-v1_1.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sfujim/TD3/34770ccdcb51df6f6c7d85c1ba7e40d71b940d16/learning_curves/InvertedDoublePendulum/TD3_InvertedDoublePendulum-v1_1.npy -------------------------------------------------------------------------------- /learning_curves/InvertedDoublePendulum/TD3_InvertedDoublePendulum-v1_2.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sfujim/TD3/34770ccdcb51df6f6c7d85c1ba7e40d71b940d16/learning_curves/InvertedDoublePendulum/TD3_InvertedDoublePendulum-v1_2.npy -------------------------------------------------------------------------------- /learning_curves/InvertedDoublePendulum/TD3_InvertedDoublePendulum-v1_3.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sfujim/TD3/34770ccdcb51df6f6c7d85c1ba7e40d71b940d16/learning_curves/InvertedDoublePendulum/TD3_InvertedDoublePendulum-v1_3.npy -------------------------------------------------------------------------------- /learning_curves/InvertedDoublePendulum/TD3_InvertedDoublePendulum-v1_4.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sfujim/TD3/34770ccdcb51df6f6c7d85c1ba7e40d71b940d16/learning_curves/InvertedDoublePendulum/TD3_InvertedDoublePendulum-v1_4.npy -------------------------------------------------------------------------------- /learning_curves/InvertedDoublePendulum/TD3_InvertedDoublePendulum-v1_5.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sfujim/TD3/34770ccdcb51df6f6c7d85c1ba7e40d71b940d16/learning_curves/InvertedDoublePendulum/TD3_InvertedDoublePendulum-v1_5.npy -------------------------------------------------------------------------------- /learning_curves/InvertedDoublePendulum/TD3_InvertedDoublePendulum-v1_6.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sfujim/TD3/34770ccdcb51df6f6c7d85c1ba7e40d71b940d16/learning_curves/InvertedDoublePendulum/TD3_InvertedDoublePendulum-v1_6.npy -------------------------------------------------------------------------------- /learning_curves/InvertedDoublePendulum/TD3_InvertedDoublePendulum-v1_7.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sfujim/TD3/34770ccdcb51df6f6c7d85c1ba7e40d71b940d16/learning_curves/InvertedDoublePendulum/TD3_InvertedDoublePendulum-v1_7.npy -------------------------------------------------------------------------------- /learning_curves/InvertedDoublePendulum/TD3_InvertedDoublePendulum-v1_8.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sfujim/TD3/34770ccdcb51df6f6c7d85c1ba7e40d71b940d16/learning_curves/InvertedDoublePendulum/TD3_InvertedDoublePendulum-v1_8.npy -------------------------------------------------------------------------------- /learning_curves/InvertedDoublePendulum/TD3_InvertedDoublePendulum-v1_9.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sfujim/TD3/34770ccdcb51df6f6c7d85c1ba7e40d71b940d16/learning_curves/InvertedDoublePendulum/TD3_InvertedDoublePendulum-v1_9.npy -------------------------------------------------------------------------------- /learning_curves/InvertedPendulum/TD3_InvertedPendulum-v1_0.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sfujim/TD3/34770ccdcb51df6f6c7d85c1ba7e40d71b940d16/learning_curves/InvertedPendulum/TD3_InvertedPendulum-v1_0.npy -------------------------------------------------------------------------------- /learning_curves/InvertedPendulum/TD3_InvertedPendulum-v1_1.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sfujim/TD3/34770ccdcb51df6f6c7d85c1ba7e40d71b940d16/learning_curves/InvertedPendulum/TD3_InvertedPendulum-v1_1.npy -------------------------------------------------------------------------------- /learning_curves/InvertedPendulum/TD3_InvertedPendulum-v1_2.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sfujim/TD3/34770ccdcb51df6f6c7d85c1ba7e40d71b940d16/learning_curves/InvertedPendulum/TD3_InvertedPendulum-v1_2.npy -------------------------------------------------------------------------------- /learning_curves/InvertedPendulum/TD3_InvertedPendulum-v1_3.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sfujim/TD3/34770ccdcb51df6f6c7d85c1ba7e40d71b940d16/learning_curves/InvertedPendulum/TD3_InvertedPendulum-v1_3.npy -------------------------------------------------------------------------------- /learning_curves/InvertedPendulum/TD3_InvertedPendulum-v1_4.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sfujim/TD3/34770ccdcb51df6f6c7d85c1ba7e40d71b940d16/learning_curves/InvertedPendulum/TD3_InvertedPendulum-v1_4.npy -------------------------------------------------------------------------------- /learning_curves/InvertedPendulum/TD3_InvertedPendulum-v1_5.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sfujim/TD3/34770ccdcb51df6f6c7d85c1ba7e40d71b940d16/learning_curves/InvertedPendulum/TD3_InvertedPendulum-v1_5.npy -------------------------------------------------------------------------------- /learning_curves/InvertedPendulum/TD3_InvertedPendulum-v1_6.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sfujim/TD3/34770ccdcb51df6f6c7d85c1ba7e40d71b940d16/learning_curves/InvertedPendulum/TD3_InvertedPendulum-v1_6.npy -------------------------------------------------------------------------------- /learning_curves/InvertedPendulum/TD3_InvertedPendulum-v1_7.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sfujim/TD3/34770ccdcb51df6f6c7d85c1ba7e40d71b940d16/learning_curves/InvertedPendulum/TD3_InvertedPendulum-v1_7.npy -------------------------------------------------------------------------------- /learning_curves/InvertedPendulum/TD3_InvertedPendulum-v1_8.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sfujim/TD3/34770ccdcb51df6f6c7d85c1ba7e40d71b940d16/learning_curves/InvertedPendulum/TD3_InvertedPendulum-v1_8.npy -------------------------------------------------------------------------------- /learning_curves/InvertedPendulum/TD3_InvertedPendulum-v1_9.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sfujim/TD3/34770ccdcb51df6f6c7d85c1ba7e40d71b940d16/learning_curves/InvertedPendulum/TD3_InvertedPendulum-v1_9.npy -------------------------------------------------------------------------------- /learning_curves/Reacher/TD3_Reacher-v1_0.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sfujim/TD3/34770ccdcb51df6f6c7d85c1ba7e40d71b940d16/learning_curves/Reacher/TD3_Reacher-v1_0.npy -------------------------------------------------------------------------------- /learning_curves/Reacher/TD3_Reacher-v1_1.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sfujim/TD3/34770ccdcb51df6f6c7d85c1ba7e40d71b940d16/learning_curves/Reacher/TD3_Reacher-v1_1.npy -------------------------------------------------------------------------------- /learning_curves/Reacher/TD3_Reacher-v1_2.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sfujim/TD3/34770ccdcb51df6f6c7d85c1ba7e40d71b940d16/learning_curves/Reacher/TD3_Reacher-v1_2.npy -------------------------------------------------------------------------------- /learning_curves/Reacher/TD3_Reacher-v1_3.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sfujim/TD3/34770ccdcb51df6f6c7d85c1ba7e40d71b940d16/learning_curves/Reacher/TD3_Reacher-v1_3.npy -------------------------------------------------------------------------------- /learning_curves/Reacher/TD3_Reacher-v1_4.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sfujim/TD3/34770ccdcb51df6f6c7d85c1ba7e40d71b940d16/learning_curves/Reacher/TD3_Reacher-v1_4.npy -------------------------------------------------------------------------------- /learning_curves/Reacher/TD3_Reacher-v1_5.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sfujim/TD3/34770ccdcb51df6f6c7d85c1ba7e40d71b940d16/learning_curves/Reacher/TD3_Reacher-v1_5.npy -------------------------------------------------------------------------------- /learning_curves/Reacher/TD3_Reacher-v1_6.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sfujim/TD3/34770ccdcb51df6f6c7d85c1ba7e40d71b940d16/learning_curves/Reacher/TD3_Reacher-v1_6.npy -------------------------------------------------------------------------------- /learning_curves/Reacher/TD3_Reacher-v1_7.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sfujim/TD3/34770ccdcb51df6f6c7d85c1ba7e40d71b940d16/learning_curves/Reacher/TD3_Reacher-v1_7.npy -------------------------------------------------------------------------------- /learning_curves/Reacher/TD3_Reacher-v1_8.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sfujim/TD3/34770ccdcb51df6f6c7d85c1ba7e40d71b940d16/learning_curves/Reacher/TD3_Reacher-v1_8.npy -------------------------------------------------------------------------------- /learning_curves/Reacher/TD3_Reacher-v1_9.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sfujim/TD3/34770ccdcb51df6f6c7d85c1ba7e40d71b940d16/learning_curves/Reacher/TD3_Reacher-v1_9.npy -------------------------------------------------------------------------------- /learning_curves/Walker/TD3_Walker2d-v1_0.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sfujim/TD3/34770ccdcb51df6f6c7d85c1ba7e40d71b940d16/learning_curves/Walker/TD3_Walker2d-v1_0.npy -------------------------------------------------------------------------------- /learning_curves/Walker/TD3_Walker2d-v1_1.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sfujim/TD3/34770ccdcb51df6f6c7d85c1ba7e40d71b940d16/learning_curves/Walker/TD3_Walker2d-v1_1.npy -------------------------------------------------------------------------------- /learning_curves/Walker/TD3_Walker2d-v1_2.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sfujim/TD3/34770ccdcb51df6f6c7d85c1ba7e40d71b940d16/learning_curves/Walker/TD3_Walker2d-v1_2.npy -------------------------------------------------------------------------------- /learning_curves/Walker/TD3_Walker2d-v1_3.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sfujim/TD3/34770ccdcb51df6f6c7d85c1ba7e40d71b940d16/learning_curves/Walker/TD3_Walker2d-v1_3.npy -------------------------------------------------------------------------------- /learning_curves/Walker/TD3_Walker2d-v1_4.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sfujim/TD3/34770ccdcb51df6f6c7d85c1ba7e40d71b940d16/learning_curves/Walker/TD3_Walker2d-v1_4.npy -------------------------------------------------------------------------------- /learning_curves/Walker/TD3_Walker2d-v1_5.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sfujim/TD3/34770ccdcb51df6f6c7d85c1ba7e40d71b940d16/learning_curves/Walker/TD3_Walker2d-v1_5.npy -------------------------------------------------------------------------------- /learning_curves/Walker/TD3_Walker2d-v1_6.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sfujim/TD3/34770ccdcb51df6f6c7d85c1ba7e40d71b940d16/learning_curves/Walker/TD3_Walker2d-v1_6.npy -------------------------------------------------------------------------------- /learning_curves/Walker/TD3_Walker2d-v1_7.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sfujim/TD3/34770ccdcb51df6f6c7d85c1ba7e40d71b940d16/learning_curves/Walker/TD3_Walker2d-v1_7.npy -------------------------------------------------------------------------------- /learning_curves/Walker/TD3_Walker2d-v1_8.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sfujim/TD3/34770ccdcb51df6f6c7d85c1ba7e40d71b940d16/learning_curves/Walker/TD3_Walker2d-v1_8.npy -------------------------------------------------------------------------------- /learning_curves/Walker/TD3_Walker2d-v1_9.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sfujim/TD3/34770ccdcb51df6f6c7d85c1ba7e40d71b940d16/learning_curves/Walker/TD3_Walker2d-v1_9.npy -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import gym 4 | import argparse 5 | import os 6 | 7 | import utils 8 | import TD3 9 | import OurDDPG 10 | import DDPG 11 | 12 | 13 | # Runs policy for X episodes and returns average reward 14 | # A fixed seed is used for the eval environment 15 | def eval_policy(policy, env_name, seed, eval_episodes=10): 16 | eval_env = gym.make(env_name) 17 | eval_env.seed(seed + 100) 18 | 19 | avg_reward = 0. 20 | for _ in range(eval_episodes): 21 | state, done = eval_env.reset(), False 22 | while not done: 23 | action = policy.select_action(np.array(state)) 24 | state, reward, done, _ = eval_env.step(action) 25 | avg_reward += reward 26 | 27 | avg_reward /= eval_episodes 28 | 29 | print("---------------------------------------") 30 | print(f"Evaluation over {eval_episodes} episodes: {avg_reward:.3f}") 31 | print("---------------------------------------") 32 | return avg_reward 33 | 34 | 35 | if __name__ == "__main__": 36 | 37 | parser = argparse.ArgumentParser() 38 | parser.add_argument("--policy", default="TD3") # Policy name (TD3, DDPG or OurDDPG) 39 | parser.add_argument("--env", default="HalfCheetah-v2") # OpenAI gym environment name 40 | parser.add_argument("--seed", default=0, type=int) # Sets Gym, PyTorch and Numpy seeds 41 | parser.add_argument("--start_timesteps", default=25e3, type=int)# Time steps initial random policy is used 42 | parser.add_argument("--eval_freq", default=5e3, type=int) # How often (time steps) we evaluate 43 | parser.add_argument("--max_timesteps", default=1e6, type=int) # Max time steps to run environment 44 | parser.add_argument("--expl_noise", default=0.1, type=float) # Std of Gaussian exploration noise 45 | parser.add_argument("--batch_size", default=256, type=int) # Batch size for both actor and critic 46 | parser.add_argument("--discount", default=0.99, type=float) # Discount factor 47 | parser.add_argument("--tau", default=0.005, type=float) # Target network update rate 48 | parser.add_argument("--policy_noise", default=0.2) # Noise added to target policy during critic update 49 | parser.add_argument("--noise_clip", default=0.5) # Range to clip target policy noise 50 | parser.add_argument("--policy_freq", default=2, type=int) # Frequency of delayed policy updates 51 | parser.add_argument("--save_model", action="store_true") # Save model and optimizer parameters 52 | parser.add_argument("--load_model", default="") # Model load file name, "" doesn't load, "default" uses file_name 53 | args = parser.parse_args() 54 | 55 | file_name = f"{args.policy}_{args.env}_{args.seed}" 56 | print("---------------------------------------") 57 | print(f"Policy: {args.policy}, Env: {args.env}, Seed: {args.seed}") 58 | print("---------------------------------------") 59 | 60 | if not os.path.exists("./results"): 61 | os.makedirs("./results") 62 | 63 | if args.save_model and not os.path.exists("./models"): 64 | os.makedirs("./models") 65 | 66 | env = gym.make(args.env) 67 | 68 | # Set seeds 69 | env.seed(args.seed) 70 | env.action_space.seed(args.seed) 71 | torch.manual_seed(args.seed) 72 | np.random.seed(args.seed) 73 | 74 | state_dim = env.observation_space.shape[0] 75 | action_dim = env.action_space.shape[0] 76 | max_action = float(env.action_space.high[0]) 77 | 78 | kwargs = { 79 | "state_dim": state_dim, 80 | "action_dim": action_dim, 81 | "max_action": max_action, 82 | "discount": args.discount, 83 | "tau": args.tau, 84 | } 85 | 86 | # Initialize policy 87 | if args.policy == "TD3": 88 | # Target policy smoothing is scaled wrt the action scale 89 | kwargs["policy_noise"] = args.policy_noise * max_action 90 | kwargs["noise_clip"] = args.noise_clip * max_action 91 | kwargs["policy_freq"] = args.policy_freq 92 | policy = TD3.TD3(**kwargs) 93 | elif args.policy == "OurDDPG": 94 | policy = OurDDPG.DDPG(**kwargs) 95 | elif args.policy == "DDPG": 96 | policy = DDPG.DDPG(**kwargs) 97 | 98 | if args.load_model != "": 99 | policy_file = file_name if args.load_model == "default" else args.load_model 100 | policy.load(f"./models/{policy_file}") 101 | 102 | replay_buffer = utils.ReplayBuffer(state_dim, action_dim) 103 | 104 | # Evaluate untrained policy 105 | evaluations = [eval_policy(policy, args.env, args.seed)] 106 | 107 | state, done = env.reset(), False 108 | episode_reward = 0 109 | episode_timesteps = 0 110 | episode_num = 0 111 | 112 | for t in range(int(args.max_timesteps)): 113 | 114 | episode_timesteps += 1 115 | 116 | # Select action randomly or according to policy 117 | if t < args.start_timesteps: 118 | action = env.action_space.sample() 119 | else: 120 | action = ( 121 | policy.select_action(np.array(state)) 122 | + np.random.normal(0, max_action * args.expl_noise, size=action_dim) 123 | ).clip(-max_action, max_action) 124 | 125 | # Perform action 126 | next_state, reward, done, _ = env.step(action) 127 | done_bool = float(done) if episode_timesteps < env._max_episode_steps else 0 128 | 129 | # Store data in replay buffer 130 | replay_buffer.add(state, action, next_state, reward, done_bool) 131 | 132 | state = next_state 133 | episode_reward += reward 134 | 135 | # Train agent after collecting sufficient data 136 | if t >= args.start_timesteps: 137 | policy.train(replay_buffer, args.batch_size) 138 | 139 | if done: 140 | # +1 to account for 0 indexing. +0 on ep_timesteps since it will increment +1 even if done=True 141 | print(f"Total T: {t+1} Episode Num: {episode_num+1} Episode T: {episode_timesteps} Reward: {episode_reward:.3f}") 142 | # Reset environment 143 | state, done = env.reset(), False 144 | episode_reward = 0 145 | episode_timesteps = 0 146 | episode_num += 1 147 | 148 | # Evaluate episode 149 | if (t + 1) % args.eval_freq == 0: 150 | evaluations.append(eval_policy(policy, args.env, args.seed)) 151 | np.save(f"./results/{file_name}", evaluations) 152 | if args.save_model: policy.save(f"./models/{file_name}") 153 | -------------------------------------------------------------------------------- /run_experiments.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Script to reproduce results 4 | 5 | for ((i=0;i<10;i+=1)) 6 | do 7 | python main.py \ 8 | --policy "TD3" \ 9 | --env "HalfCheetah-v3" \ 10 | --seed $i 11 | 12 | python main.py \ 13 | --policy "TD3" \ 14 | --env "Hopper-v3" \ 15 | --seed $i 16 | 17 | python main.py \ 18 | --policy "TD3" \ 19 | --env "Walker2d-v3" \ 20 | --seed $i 21 | 22 | python main.py \ 23 | --policy "TD3" \ 24 | --env "Ant-v3" \ 25 | --seed $i 26 | 27 | python main.py \ 28 | --policy "TD3" \ 29 | --env "Humanoid-v3" \ 30 | --seed $i 31 | 32 | python main.py \ 33 | --policy "TD3" \ 34 | --env "InvertedPendulum-v2" \ 35 | --seed $i \ 36 | --start_timesteps 1000 37 | 38 | python main.py \ 39 | --policy "TD3" \ 40 | --env "InvertedDoublePendulum-v2" \ 41 | --seed $i \ 42 | --start_timesteps 1000 43 | 44 | python main.py \ 45 | --policy "TD3" \ 46 | --env "Reacher-v2" \ 47 | --seed $i \ 48 | --start_timesteps 1000 49 | done 50 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | 5 | class ReplayBuffer(object): 6 | def __init__(self, state_dim, action_dim, max_size=int(1e6)): 7 | self.max_size = max_size 8 | self.ptr = 0 9 | self.size = 0 10 | 11 | self.state = np.zeros((max_size, state_dim)) 12 | self.action = np.zeros((max_size, action_dim)) 13 | self.next_state = np.zeros((max_size, state_dim)) 14 | self.reward = np.zeros((max_size, 1)) 15 | self.not_done = np.zeros((max_size, 1)) 16 | 17 | self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 18 | 19 | 20 | def add(self, state, action, next_state, reward, done): 21 | self.state[self.ptr] = state 22 | self.action[self.ptr] = action 23 | self.next_state[self.ptr] = next_state 24 | self.reward[self.ptr] = reward 25 | self.not_done[self.ptr] = 1. - done 26 | 27 | self.ptr = (self.ptr + 1) % self.max_size 28 | self.size = min(self.size + 1, self.max_size) 29 | 30 | 31 | def sample(self, batch_size): 32 | ind = np.random.randint(0, self.size, size=batch_size) 33 | 34 | return ( 35 | torch.FloatTensor(self.state[ind]).to(self.device), 36 | torch.FloatTensor(self.action[ind]).to(self.device), 37 | torch.FloatTensor(self.next_state[ind]).to(self.device), 38 | torch.FloatTensor(self.reward[ind]).to(self.device), 39 | torch.FloatTensor(self.not_done[ind]).to(self.device) 40 | ) --------------------------------------------------------------------------------