├── LICENSE ├── README.md ├── continuous ├── LAP_TD3.py ├── PAL_TD3.py ├── PER_TD3.py ├── README.md ├── TD3.py ├── main.py └── utils.py └── discrete ├── DDQN.py ├── LAP_DDQN.py ├── PAL_DDQN.py ├── PER_DDQN.py ├── README.md ├── main.py └── utils.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Scott Fujimoto 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # An Equivalence between Loss Functions and Non-Uniform Sampling in Experience Replay 2 | 3 | PyTorch implementation of Loss-Adjusted Prioritized (LAP) experience replay and Prioritized Approximation Loss (PAL). LAP is an improvement to prioritized experience replay which eliminates the importance sampling weights in a principled manner, by considering the relationship to the loss function. PAL is a uniformly sampled loss function with the same expected gradient as LAP. 4 | 5 | The [paper](https://arxiv.org/abs/2007.06049) will be presented at NeurIPS 2020. Code is provided for both continuous (with TD3) and discrete (with DDQN) domains. 6 | 7 | ### Bibtex 8 | 9 | ``` 10 | @article{fujimoto2020equivalence, 11 | title={An Equivalence between Loss Functions and Non-Uniform Sampling in Experience Replay}, 12 | author={Fujimoto, Scott and Meger, David and Precup, Doina}, 13 | journal={Advances in Neural Information Processing Systems}, 14 | volume={33}, 15 | year={2020} 16 | } 17 | ``` 18 | -------------------------------------------------------------------------------- /continuous/LAP_TD3.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import numpy as np 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | 8 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 9 | 10 | 11 | class Actor(nn.Module): 12 | def __init__(self, state_dim, action_dim, max_action): 13 | super(Actor, self).__init__() 14 | 15 | self.l1 = nn.Linear(state_dim, 256) 16 | self.l2 = nn.Linear(256, 256) 17 | self.l3 = nn.Linear(256, action_dim) 18 | 19 | self.max_action = max_action 20 | 21 | 22 | def forward(self, state): 23 | a = F.relu(self.l1(state)) 24 | a = F.relu(self.l2(a)) 25 | return self.max_action * torch.tanh(self.l3(a)) 26 | 27 | 28 | def act(self, state): 29 | a = F.relu(self.l1(state)) 30 | a = F.relu(self.l2(a)) 31 | a = self.l3(a) 32 | return self.max_action * torch.tanh(a), a 33 | 34 | 35 | class Critic(nn.Module): 36 | def __init__(self, state_dim, action_dim): 37 | super(Critic, self).__init__() 38 | 39 | # Q1 architecture 40 | self.l1 = nn.Linear(state_dim + action_dim, 256) 41 | self.l2 = nn.Linear(256, 256) 42 | self.l3 = nn.Linear(256, 1) 43 | 44 | # Q2 architecture 45 | self.l4 = nn.Linear(state_dim + action_dim, 256) 46 | self.l5 = nn.Linear(256, 256) 47 | self.l6 = nn.Linear(256, 1) 48 | 49 | 50 | def forward(self, state, action): 51 | sa = torch.cat([state, action], 1) 52 | 53 | q1 = F.relu(self.l1(sa)) 54 | q1 = F.relu(self.l2(q1)) 55 | q1 = self.l3(q1) 56 | 57 | q2 = F.relu(self.l4(sa)) 58 | q2 = F.relu(self.l5(q2)) 59 | q2 = self.l6(q2) 60 | return q1, q2 61 | 62 | 63 | def Q1(self, state, action): 64 | sa = torch.cat([state, action], 1) 65 | 66 | q1 = F.relu(self.l1(sa)) 67 | q1 = F.relu(self.l2(q1)) 68 | q1 = self.l3(q1) 69 | return q1 70 | 71 | 72 | class LAP_TD3(object): 73 | def __init__( 74 | self, 75 | state_dim, 76 | action_dim, 77 | max_action, 78 | discount=0.99, 79 | tau=0.005, 80 | policy_noise=0.2, 81 | noise_clip=0.5, 82 | policy_freq=2, 83 | alpha=0.4, 84 | min_priority=1 85 | ): 86 | 87 | self.actor = Actor(state_dim, action_dim, max_action).to(device) 88 | self.actor_target = copy.deepcopy(self.actor) 89 | self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), lr=3e-4) 90 | 91 | self.critic = Critic(state_dim, action_dim).to(device) 92 | self.critic_target = copy.deepcopy(self.critic) 93 | self.critic_optimizer = torch.optim.Adam(self.critic.parameters(), lr=3e-4) 94 | 95 | self.max_action = max_action 96 | self.discount = discount 97 | self.tau = tau 98 | self.policy_noise = policy_noise 99 | self.noise_clip = noise_clip 100 | self.policy_freq = policy_freq 101 | self.alpha = alpha 102 | self.min_priority = min_priority 103 | 104 | self.total_it = 0 105 | 106 | 107 | def select_action(self, state, test=False): 108 | state = torch.FloatTensor(state.reshape(1, -1)).to(device) 109 | return self.actor(state).cpu().data.numpy().flatten() 110 | 111 | 112 | def train(self, replay_buffer, batch_size=256): 113 | self.total_it += 1 114 | 115 | # Sample replay buffer 116 | state, action, next_state, reward, not_done, ind, weights = replay_buffer.sample(batch_size) 117 | 118 | with torch.no_grad(): 119 | # Select action according to policy and add clipped noise 120 | noise = ( 121 | torch.randn_like(action) * self.policy_noise 122 | ).clamp(-self.noise_clip, self.noise_clip) 123 | 124 | next_action = ( 125 | self.actor_target(next_state) + noise 126 | ).clamp(-self.max_action, self.max_action) 127 | 128 | # Compute the target Q value 129 | target_Q1, target_Q2 = self.critic_target(next_state, next_action) 130 | target_Q = torch.min(target_Q1, target_Q2) 131 | target_Q = reward + not_done * self.discount * target_Q 132 | 133 | # Get current Q estimates 134 | current_Q1, current_Q2 = self.critic(state, action) 135 | 136 | td_loss1 = (current_Q1 - target_Q).abs() 137 | td_loss2 = (current_Q2 - target_Q).abs() 138 | 139 | # Compute critic loss 140 | critic_loss = self.huber(td_loss1) + self.huber(td_loss2) 141 | 142 | # Optimize the critic 143 | self.critic_optimizer.zero_grad() 144 | critic_loss.backward() 145 | self.critic_optimizer.step() 146 | 147 | priority = torch.max(td_loss1, td_loss2).clamp(min=self.min_priority).pow(self.alpha).cpu().data.numpy().flatten() 148 | 149 | replay_buffer.update_priority(ind, priority) 150 | 151 | # Delayed policy updates 152 | if self.total_it % self.policy_freq == 0: 153 | 154 | actor_loss = -self.critic.Q1(state, self.actor(state)).mean() 155 | 156 | # Optimize the actor 157 | self.actor_optimizer.zero_grad() 158 | actor_loss.backward() 159 | self.actor_optimizer.step() 160 | 161 | # Update the frozen target models 162 | for param, target_param in zip(self.critic.parameters(), self.critic_target.parameters()): 163 | target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data) 164 | 165 | for param, target_param in zip(self.actor.parameters(), self.actor_target.parameters()): 166 | target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data) 167 | 168 | 169 | def huber(self, x): 170 | return torch.where(x < self.min_priority, 0.5 * x.pow(2), self.min_priority * x).mean() 171 | 172 | 173 | def save(self, filename): 174 | torch.save(self.critic.state_dict(), filename + "_critic") 175 | torch.save(self.critic_optimizer.state_dict(), filename + "_critic_optimizer") 176 | torch.save(self.actor.state_dict(), filename + "_actor") 177 | torch.save(self.actor_optimizer.state_dict(), filename + "_actor_optimizer") 178 | 179 | 180 | def load(self, filename): 181 | self.critic.load_state_dict(torch.load(filename + "_critic")) 182 | self.critic_optimizer.load_state_dict(torch.load(filename + "_critic_optimizer")) 183 | self.actor.load_state_dict(torch.load(filename + "_actor")) 184 | self.actor_optimizer.load_state_dict(torch.load(filename + "_actor_optimizer")) 185 | -------------------------------------------------------------------------------- /continuous/PAL_TD3.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import numpy as np 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | 8 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 9 | 10 | 11 | class Actor(nn.Module): 12 | def __init__(self, state_dim, action_dim, max_action): 13 | super(Actor, self).__init__() 14 | 15 | self.l1 = nn.Linear(state_dim, 256) 16 | self.l2 = nn.Linear(256, 256) 17 | self.l3 = nn.Linear(256, action_dim) 18 | 19 | self.max_action = max_action 20 | 21 | 22 | def forward(self, state): 23 | a = F.relu(self.l1(state)) 24 | a = F.relu(self.l2(a)) 25 | return self.max_action * torch.tanh(self.l3(a)) 26 | 27 | 28 | def act(self, state): 29 | a = F.relu(self.l1(state)) 30 | a = F.relu(self.l2(a)) 31 | a = self.l3(a) 32 | return self.max_action * torch.tanh(a), a 33 | 34 | 35 | class Critic(nn.Module): 36 | def __init__(self, state_dim, action_dim): 37 | super(Critic, self).__init__() 38 | 39 | # Q1 architecture 40 | self.l1 = nn.Linear(state_dim + action_dim, 256) 41 | self.l2 = nn.Linear(256, 256) 42 | self.l3 = nn.Linear(256, 1) 43 | 44 | # Q2 architecture 45 | self.l4 = nn.Linear(state_dim + action_dim, 256) 46 | self.l5 = nn.Linear(256, 256) 47 | self.l6 = nn.Linear(256, 1) 48 | 49 | 50 | def forward(self, state, action): 51 | sa = torch.cat([state, action], 1) 52 | 53 | q1 = F.relu(self.l1(sa)) 54 | q1 = F.relu(self.l2(q1)) 55 | q1 = self.l3(q1) 56 | 57 | q2 = F.relu(self.l4(sa)) 58 | q2 = F.relu(self.l5(q2)) 59 | q2 = self.l6(q2) 60 | return q1, q2 61 | 62 | 63 | def Q1(self, state, action): 64 | sa = torch.cat([state, action], 1) 65 | 66 | q1 = F.relu(self.l1(sa)) 67 | q1 = F.relu(self.l2(q1)) 68 | q1 = self.l3(q1) 69 | return q1 70 | 71 | 72 | class PAL_TD3(object): 73 | def __init__( 74 | self, 75 | state_dim, 76 | action_dim, 77 | max_action, 78 | discount=0.99, 79 | tau=0.005, 80 | policy_noise=0.2, 81 | noise_clip=0.5, 82 | policy_freq=2, 83 | alpha=0.4, 84 | min_priority=1 85 | ): 86 | 87 | self.actor = Actor(state_dim, action_dim, max_action).to(device) 88 | self.actor_target = copy.deepcopy(self.actor) 89 | self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), lr=3e-4) 90 | 91 | self.critic = Critic(state_dim, action_dim).to(device) 92 | self.critic_target = copy.deepcopy(self.critic) 93 | self.critic_optimizer = torch.optim.Adam(self.critic.parameters(), lr=3e-4) 94 | 95 | self.max_action = max_action 96 | self.discount = discount 97 | self.tau = tau 98 | self.policy_noise = policy_noise 99 | self.noise_clip = noise_clip 100 | self.policy_freq = policy_freq 101 | self.alpha = alpha 102 | self.min_priority = min_priority 103 | 104 | self.total_it = 0 105 | 106 | 107 | def select_action(self, state, test=False): 108 | state = torch.FloatTensor(state.reshape(1, -1)).to(device) 109 | return self.actor(state).cpu().data.numpy().flatten() 110 | 111 | 112 | def train(self, replay_buffer, batch_size=256): 113 | self.total_it += 1 114 | 115 | # Sample replay buffer 116 | state, action, next_state, reward, not_done = replay_buffer.sample(batch_size) 117 | 118 | with torch.no_grad(): 119 | # Select action according to policy and add clipped noise 120 | noise = ( 121 | torch.randn_like(action) * self.policy_noise 122 | ).clamp(-self.noise_clip, self.noise_clip) 123 | 124 | next_action = ( 125 | self.actor_target(next_state) + noise 126 | ).clamp(-self.max_action, self.max_action) 127 | 128 | # Compute the target Q value 129 | target_Q1, target_Q2 = self.critic_target(next_state, next_action) 130 | target_Q = torch.min(target_Q1, target_Q2) 131 | target_Q = reward + not_done * self.discount * target_Q 132 | 133 | # Get current Q estimates 134 | current_Q1, current_Q2 = self.critic(state, action) 135 | 136 | td_loss1 = (current_Q1 - target_Q) 137 | td_loss2 = (current_Q2 - target_Q) 138 | 139 | critic_loss = self.PAL(td_loss1) + self.PAL(td_loss2) 140 | critic_loss /= torch.max(td_loss1.abs(), td_loss2.abs()).clamp(min=self.min_priority).pow(self.alpha).mean().detach() 141 | 142 | # Optimize the critic 143 | self.critic_optimizer.zero_grad() 144 | critic_loss.backward() 145 | self.critic_optimizer.step() 146 | 147 | # Delayed policy updates 148 | if self.total_it % self.policy_freq == 0: 149 | 150 | actor_loss = -self.critic.Q1(state, self.actor(state)).mean() 151 | 152 | # Optimize the actor 153 | self.actor_optimizer.zero_grad() 154 | actor_loss.backward() 155 | self.actor_optimizer.step() 156 | 157 | # Update the frozen target models 158 | for param, target_param in zip(self.critic.parameters(), self.critic_target.parameters()): 159 | target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data) 160 | 161 | for param, target_param in zip(self.actor.parameters(), self.actor_target.parameters()): 162 | target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data) 163 | 164 | 165 | # If min_priority=1, this can be simplified. 166 | def PAL(self, x): 167 | return torch.where( 168 | x.abs() < self.min_priority, 169 | (self.min_priority ** self.alpha) * 0.5 * x.pow(2), 170 | self.min_priority * x.abs().pow(1. + self.alpha)/(1. + self.alpha) 171 | ).mean() 172 | 173 | 174 | def save(self, filename): 175 | torch.save(self.critic.state_dict(), filename + "_critic") 176 | torch.save(self.critic_optimizer.state_dict(), filename + "_critic_optimizer") 177 | torch.save(self.actor.state_dict(), filename + "_actor") 178 | torch.save(self.actor_optimizer.state_dict(), filename + "_actor_optimizer") 179 | 180 | 181 | def load(self, filename): 182 | self.critic.load_state_dict(torch.load(filename + "_critic")) 183 | self.critic_optimizer.load_state_dict(torch.load(filename + "_critic_optimizer")) 184 | self.actor.load_state_dict(torch.load(filename + "_actor")) 185 | self.actor_optimizer.load_state_dict(torch.load(filename + "_actor_optimizer")) 186 | -------------------------------------------------------------------------------- /continuous/PER_TD3.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import numpy as np 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | 8 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 9 | 10 | 11 | class Actor(nn.Module): 12 | def __init__(self, state_dim, action_dim, max_action): 13 | super(Actor, self).__init__() 14 | 15 | self.l1 = nn.Linear(state_dim, 256) 16 | self.l2 = nn.Linear(256, 256) 17 | self.l3 = nn.Linear(256, action_dim) 18 | 19 | self.max_action = max_action 20 | 21 | 22 | def forward(self, state): 23 | a = F.relu(self.l1(state)) 24 | a = F.relu(self.l2(a)) 25 | return self.max_action * torch.tanh(self.l3(a)) 26 | 27 | 28 | def act(self, state): 29 | a = F.relu(self.l1(state)) 30 | a = F.relu(self.l2(a)) 31 | a = self.l3(a) 32 | return self.max_action * torch.tanh(a), a 33 | 34 | 35 | class Critic(nn.Module): 36 | def __init__(self, state_dim, action_dim): 37 | super(Critic, self).__init__() 38 | 39 | # Q1 architecture 40 | self.l1 = nn.Linear(state_dim + action_dim, 256) 41 | self.l2 = nn.Linear(256, 256) 42 | self.l3 = nn.Linear(256, 1) 43 | 44 | # Q2 architecture 45 | self.l4 = nn.Linear(state_dim + action_dim, 256) 46 | self.l5 = nn.Linear(256, 256) 47 | self.l6 = nn.Linear(256, 1) 48 | 49 | 50 | def forward(self, state, action): 51 | sa = torch.cat([state, action], 1) 52 | 53 | q1 = F.relu(self.l1(sa)) 54 | q1 = F.relu(self.l2(q1)) 55 | q1 = self.l3(q1) 56 | 57 | q2 = F.relu(self.l4(sa)) 58 | q2 = F.relu(self.l5(q2)) 59 | q2 = self.l6(q2) 60 | return q1, q2 61 | 62 | 63 | def Q1(self, state, action): 64 | sa = torch.cat([state, action], 1) 65 | 66 | q1 = F.relu(self.l1(sa)) 67 | q1 = F.relu(self.l2(q1)) 68 | q1 = self.l3(q1) 69 | return q1 70 | 71 | 72 | class PER_TD3(object): 73 | def __init__( 74 | self, 75 | state_dim, 76 | action_dim, 77 | max_action, 78 | discount=0.99, 79 | tau=0.005, 80 | policy_noise=0.2, 81 | noise_clip=0.5, 82 | policy_freq=2, 83 | alpha=0.6, 84 | ): 85 | 86 | self.actor = Actor(state_dim, action_dim, max_action).to(device) 87 | self.actor_target = copy.deepcopy(self.actor) 88 | self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), lr=3e-4) 89 | 90 | self.critic = Critic(state_dim, action_dim).to(device) 91 | self.critic_target = copy.deepcopy(self.critic) 92 | self.critic_optimizer = torch.optim.Adam(self.critic.parameters(), lr=3e-4) 93 | 94 | self.max_action = max_action 95 | self.discount = discount 96 | self.tau = tau 97 | self.policy_noise = policy_noise 98 | self.noise_clip = noise_clip 99 | self.policy_freq = policy_freq 100 | self.alpha = alpha 101 | 102 | self.total_it = 0 103 | 104 | 105 | def select_action(self, state, test=False): 106 | state = torch.FloatTensor(state.reshape(1, -1)).to(device) 107 | return self.actor(state).cpu().data.numpy().flatten() 108 | 109 | 110 | def train(self, replay_buffer, batch_size=256): 111 | self.total_it += 1 112 | 113 | # Sample replay buffer 114 | state, action, next_state, reward, not_done, ind, weights = replay_buffer.sample(batch_size) 115 | 116 | with torch.no_grad(): 117 | # Select action according to policy and add clipped noise 118 | noise = ( 119 | torch.randn_like(action) * self.policy_noise 120 | ).clamp(-self.noise_clip, self.noise_clip) 121 | 122 | next_action = ( 123 | self.actor_target(next_state) + noise 124 | ).clamp(-self.max_action, self.max_action) 125 | 126 | # Compute the target Q value 127 | target_Q1, target_Q2 = self.critic_target(next_state, next_action) 128 | target_Q = torch.min(target_Q1, target_Q2) 129 | target_Q = reward + not_done * self.discount * target_Q 130 | 131 | # Get current Q estimates 132 | current_Q1, current_Q2 = self.critic(state, action) 133 | 134 | td_loss1 = (current_Q1 - target_Q).abs() 135 | td_loss2 = (current_Q2 - target_Q).abs() 136 | 137 | # Compute critic loss 138 | critic_loss = ( 139 | (weights * F.mse_loss(current_Q1, target_Q, reduction='none')).mean() 140 | + (weights * F.mse_loss(current_Q2, target_Q, reduction='none')).mean() 141 | ) 142 | 143 | # Optimize the critic 144 | self.critic_optimizer.zero_grad() 145 | critic_loss.backward() 146 | self.critic_optimizer.step() 147 | 148 | priority = torch.max(td_loss1, td_loss2).pow(self.alpha).cpu().data.numpy().flatten() 149 | 150 | replay_buffer.update_priority(ind, priority) 151 | 152 | # Delayed policy updates 153 | if self.total_it % self.policy_freq == 0: 154 | 155 | actor_loss = -self.critic.Q1(state, self.actor(state)).mean() 156 | 157 | # Optimize the actor 158 | self.actor_optimizer.zero_grad() 159 | actor_loss.backward() 160 | self.actor_optimizer.step() 161 | 162 | # Update the frozen target models 163 | for param, target_param in zip(self.critic.parameters(), self.critic_target.parameters()): 164 | target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data) 165 | 166 | for param, target_param in zip(self.actor.parameters(), self.actor_target.parameters()): 167 | target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data) 168 | 169 | 170 | def save(self, filename): 171 | torch.save(self.critic.state_dict(), filename + "_critic") 172 | torch.save(self.critic_optimizer.state_dict(), filename + "_critic_optimizer") 173 | torch.save(self.actor.state_dict(), filename + "_actor") 174 | torch.save(self.actor_optimizer.state_dict(), filename + "_actor_optimizer") 175 | 176 | 177 | def load(self, filename): 178 | self.critic.load_state_dict(torch.load(filename + "_critic")) 179 | self.critic_optimizer.load_state_dict(torch.load(filename + "_critic_optimizer")) 180 | self.actor.load_state_dict(torch.load(filename + "_actor")) 181 | self.actor_optimizer.load_state_dict(torch.load(filename + "_actor_optimizer")) 182 | -------------------------------------------------------------------------------- /continuous/README.md: -------------------------------------------------------------------------------- 1 | # LAP/PAL with TD3 for Continuous Control 2 | 3 | Code for Loss-Adjusted Prioritized (LAP) experience replay and Prioritized Approximation Loss (PAL) with TD3. 4 | 5 | Paper results were collected with [MuJoCo 2.0.2.9](http://www.mujoco.org/) on [OpenAI gym](https://github.com/openai/gym). Networks are trained using [PyTorch 1.2.0](https://github.com/pytorch/pytorch) and Python 3.7. 6 | 7 | Example command: 8 | ``` 9 | python main.py --policy "LAP_TD3" --env "HalfCheetah-v3" 10 | ``` 11 | 12 | Hyper-parameters can be modified with different arguments to main.py. OpenAI gym now defaults to MuJoCo 1.50. We found the performance of Humanoid-v3 is lower on this version, although relative order between algorithms is unchanged. 13 | -------------------------------------------------------------------------------- /continuous/TD3.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import numpy as np 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | 8 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 9 | 10 | 11 | class Actor(nn.Module): 12 | def __init__(self, state_dim, action_dim, max_action): 13 | super(Actor, self).__init__() 14 | 15 | self.l1 = nn.Linear(state_dim, 256) 16 | self.l2 = nn.Linear(256, 256) 17 | self.l3 = nn.Linear(256, action_dim) 18 | 19 | self.max_action = max_action 20 | 21 | 22 | def forward(self, state): 23 | a = F.relu(self.l1(state)) 24 | a = F.relu(self.l2(a)) 25 | return self.max_action * torch.tanh(self.l3(a)) 26 | 27 | 28 | class Critic(nn.Module): 29 | def __init__(self, state_dim, action_dim): 30 | super(Critic, self).__init__() 31 | 32 | # Q1 architecture 33 | self.l1 = nn.Linear(state_dim + action_dim, 256) 34 | self.l2 = nn.Linear(256, 256) 35 | self.l3 = nn.Linear(256, 1) 36 | 37 | # Q2 architecture 38 | self.l4 = nn.Linear(state_dim + action_dim, 256) 39 | self.l5 = nn.Linear(256, 256) 40 | self.l6 = nn.Linear(256, 1) 41 | 42 | 43 | def forward(self, state, action): 44 | sa = torch.cat([state, action], 1) 45 | 46 | q1 = F.relu(self.l1(sa)) 47 | q1 = F.relu(self.l2(q1)) 48 | q1 = self.l3(q1) 49 | 50 | q2 = F.relu(self.l4(sa)) 51 | q2 = F.relu(self.l5(q2)) 52 | q2 = self.l6(q2) 53 | return q1, q2 54 | 55 | 56 | def Q1(self, state, action): 57 | sa = torch.cat([state, action], 1) 58 | 59 | q1 = F.relu(self.l1(sa)) 60 | q1 = F.relu(self.l2(q1)) 61 | q1 = self.l3(q1) 62 | return q1 63 | 64 | 65 | class TD3(object): 66 | def __init__( 67 | self, 68 | state_dim, 69 | action_dim, 70 | max_action, 71 | discount=0.99, 72 | tau=0.005, 73 | policy_noise=0.2, 74 | noise_clip=0.5, 75 | policy_freq=2 76 | ): 77 | 78 | self.actor = Actor(state_dim, action_dim, max_action).to(device) 79 | self.actor_target = copy.deepcopy(self.actor) 80 | self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), lr=3e-4) 81 | 82 | self.critic = Critic(state_dim, action_dim).to(device) 83 | self.critic_target = copy.deepcopy(self.critic) 84 | self.critic_optimizer = torch.optim.Adam(self.critic.parameters(), lr=3e-4) 85 | 86 | self.max_action = max_action 87 | self.discount = discount 88 | self.tau = tau 89 | self.policy_noise = policy_noise 90 | self.noise_clip = noise_clip 91 | self.policy_freq = policy_freq 92 | 93 | self.total_it = 0 94 | 95 | 96 | def select_action(self, state, test=False): 97 | state = torch.FloatTensor(state.reshape(1, -1)).to(device) 98 | return self.actor(state).cpu().data.numpy().flatten() 99 | 100 | 101 | def train(self, replay_buffer, batch_size=256): 102 | self.total_it += 1 103 | 104 | # Sample replay buffer 105 | state, action, next_state, reward, not_done = replay_buffer.sample(batch_size) 106 | 107 | with torch.no_grad(): 108 | # Select action according to policy and add clipped noise 109 | noise = ( 110 | torch.randn_like(action) * self.policy_noise 111 | ).clamp(-self.noise_clip, self.noise_clip) 112 | 113 | next_action = ( 114 | self.actor_target(next_state) + noise 115 | ).clamp(-self.max_action, self.max_action) 116 | 117 | # Compute the target Q value 118 | target_Q1, target_Q2 = self.critic_target(next_state, next_action) 119 | target_Q = torch.min(target_Q1, target_Q2) 120 | target_Q = reward + not_done * self.discount * target_Q 121 | 122 | # Get current Q estimates 123 | current_Q1, current_Q2 = self.critic(state, action) 124 | 125 | # Compute critic loss 126 | critic_loss = F.mse_loss(current_Q1, target_Q) + F.mse_loss(current_Q2, target_Q) 127 | 128 | # Optimize the critic 129 | self.critic_optimizer.zero_grad() 130 | critic_loss.backward() 131 | self.critic_optimizer.step() 132 | 133 | # Delayed policy updates 134 | if self.total_it % self.policy_freq == 0: 135 | 136 | # Compute actor loss 137 | actor_loss = -self.critic.Q1(state, self.actor(state)).mean() 138 | 139 | # Optimize the actor 140 | self.actor_optimizer.zero_grad() 141 | actor_loss.backward() 142 | self.actor_optimizer.step() 143 | 144 | # Update the frozen target models 145 | for param, target_param in zip(self.critic.parameters(), self.critic_target.parameters()): 146 | target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data) 147 | 148 | for param, target_param in zip(self.actor.parameters(), self.actor_target.parameters()): 149 | target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data) 150 | 151 | 152 | def save(self, filename): 153 | torch.save(self.critic.state_dict(), filename + "_critic") 154 | torch.save(self.critic_optimizer.state_dict(), filename + "_critic_optimizer") 155 | torch.save(self.actor.state_dict(), filename + "_actor") 156 | torch.save(self.actor_optimizer.state_dict(), filename + "_actor_optimizer") 157 | 158 | 159 | def load(self, filename): 160 | self.critic.load_state_dict(torch.load(filename + "_critic")) 161 | self.critic_optimizer.load_state_dict(torch.load(filename + "_critic_optimizer")) 162 | self.actor.load_state_dict(torch.load(filename + "_actor")) 163 | self.actor_optimizer.load_state_dict(torch.load(filename + "_actor_optimizer")) 164 | -------------------------------------------------------------------------------- /continuous/main.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import gym 4 | import argparse 5 | import os 6 | import time 7 | 8 | import utils 9 | import TD3 10 | import LAP_TD3 11 | import PAL_TD3 12 | import PER_TD3 13 | 14 | 15 | # Runs policy for X episodes and returns average reward 16 | def eval_policy(policy, env, seed, eval_episodes=10): 17 | eval_env = gym.make(env) 18 | eval_env.seed(seed + 100) 19 | 20 | avg_reward = 0. 21 | for _ in range(eval_episodes): 22 | state, done = eval_env.reset(), False 23 | while not done: 24 | action = policy.select_action(np.array(state), test=True) 25 | state, reward, done, _ = eval_env.step(action) 26 | avg_reward += reward 27 | 28 | avg_reward /= eval_episodes 29 | 30 | print("---------------------------------------") 31 | print(f"Evaluation over {eval_episodes} episodes: {avg_reward:.3f}") 32 | print("---------------------------------------") 33 | return avg_reward 34 | 35 | 36 | if __name__ == "__main__": 37 | 38 | parser = argparse.ArgumentParser() 39 | parser.add_argument("--algorithm", default="LAP_TD3") # Algorithm nameu 40 | parser.add_argument("--env", default="HalfCheetah-v3") # OpenAI gym environment name 41 | parser.add_argument("--seed", default=0, type=int) # Sets Gym, PyTorch and Numpy seeds 42 | parser.add_argument("--start_timesteps", default=25e3, type=int)# Time steps initial random policy is used 43 | parser.add_argument("--eval_freq", default=5e3, type=int) # How often (time steps) we evaluate 44 | parser.add_argument("--max_timesteps", default=3e6, type=int) # Max time steps to run environment 45 | parser.add_argument("--expl_noise", default=0.1) # Std of Gaussian exploration noise 46 | parser.add_argument("--batch_size", default=256, type=int) # Batch size for both actor and critic 47 | parser.add_argument("--discount", default=0.99) # Discount factor 48 | parser.add_argument("--tau", default=0.005) # Target network update rate 49 | parser.add_argument("--policy_noise", default=0.2) # Noise added to target policy during critic update 50 | parser.add_argument("--noise_clip", default=0.5) # Range to clip target policy noise 51 | parser.add_argument("--policy_freq", default=2, type=int) # Frequency of delayed policy updates 52 | parser.add_argument("--alpha", default=0.4) # Priority = TD^alpha (only used by LAP/PAL) 53 | parser.add_argument("--min_priority", default=1, type=int) # Minimum priority (set to 1 in paper, only used by LAP/PAL) 54 | args = parser.parse_args() 55 | 56 | file_name = "%s_%s_%s" % (args.algorithm, args.env, str(args.seed)) 57 | print("---------------------------------------") 58 | print(f"Settings: {file_name}") 59 | print("---------------------------------------") 60 | 61 | if not os.path.exists("./results"): 62 | os.makedirs("./results") 63 | 64 | env = gym.make(args.env) 65 | 66 | # Set seeds 67 | env.seed(args.seed) 68 | torch.manual_seed(args.seed) 69 | np.random.seed(args.seed) 70 | 71 | state_dim = env.observation_space.shape[0] 72 | action_dim = env.action_space.shape[0] 73 | max_action = float(env.action_space.high[0]) 74 | 75 | kwargs = { 76 | "state_dim": state_dim, 77 | "action_dim": action_dim, 78 | "max_action": max_action, 79 | "discount": args.discount, 80 | "tau": args.tau, 81 | "policy_noise": args.policy_noise * max_action, 82 | "noise_clip": args.noise_clip * max_action, 83 | "policy_freq": args.policy_freq 84 | } 85 | 86 | # Initialize policy and replay buffer 87 | if args.algorithm == "TD3": 88 | policy = TD3.TD3(**kwargs) 89 | replay_buffer = utils.ReplayBuffer(state_dim, action_dim) 90 | 91 | elif args.algorithm == "PER_TD3": 92 | policy = PER_TD3.PER_TD3(**kwargs) 93 | replay_buffer = utils.PrioritizedReplayBuffer(state_dim, action_dim) 94 | 95 | kwargs["alpha"] = args.alpha 96 | kwargs["min_priority"] = args.min_priority 97 | 98 | if args.algorithm == "LAP_TD3": 99 | policy = LAP_TD3.LAP_TD3(**kwargs) 100 | replay_buffer = utils.PrioritizedReplayBuffer(state_dim, action_dim) 101 | 102 | elif args.algorithm == "PAL_TD3": 103 | policy = PAL_TD3.PAL_TD3(**kwargs) 104 | replay_buffer = utils.ReplayBuffer(state_dim, action_dim) 105 | 106 | # Evaluate untrained policy 107 | evaluations = [eval_policy(policy, args.env, args.seed)] 108 | 109 | state, done = env.reset(), False 110 | episode_reward = 0 111 | episode_timesteps = 0 112 | episode_num = 0 113 | 114 | for t in range(int(args.max_timesteps)): 115 | 116 | episode_timesteps += 1 117 | 118 | # Select action randomly or according to policy 119 | if t < args.start_timesteps: 120 | action = env.action_space.sample() 121 | else: 122 | action = ( 123 | policy.select_action(np.array(state)) 124 | + np.random.normal(0, max_action * args.expl_noise, size=action_dim) 125 | ).clip(-max_action, max_action) 126 | 127 | # Perform action 128 | next_state, reward, done, _ = env.step(action) 129 | done_bool = float(done) if episode_timesteps < env._max_episode_steps else 0 130 | 131 | # Store data in replay buffer 132 | replay_buffer.add(state, action, next_state, reward, done_bool) 133 | 134 | state = next_state 135 | episode_reward += reward 136 | 137 | # Train agent after collecting sufficient data 138 | if t >= args.start_timesteps: #>= 139 | policy.train(replay_buffer, args.batch_size) 140 | 141 | if done: 142 | # +1 to account for 0 indexing. +0 on ep_timesteps since it will increment +1 even if done=True 143 | print(f"Total T: {t+1} Episode Num: {episode_num+1} Episode T: {episode_timesteps} Reward: {episode_reward:.3f}") 144 | state, done = env.reset(), False 145 | episode_reward = 0 146 | episode_timesteps = 0 147 | episode_num += 1 148 | 149 | # Evaluate episode 150 | if (t + 1) % args.eval_freq == 0: 151 | evaluations.append(eval_policy(policy, args.env, args.seed)) 152 | np.save("./results/%s" % (file_name), evaluations) -------------------------------------------------------------------------------- /continuous/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | 5 | class ReplayBuffer(object): 6 | def __init__(self, state_dim, action_dim, max_size=int(1e6)): 7 | self.max_size = max_size 8 | self.ptr = 0 9 | self.size = 0 10 | 11 | self.state = np.zeros((max_size, state_dim)) 12 | self.action = np.zeros((max_size, action_dim)) 13 | self.next_state = np.zeros((max_size, state_dim)) 14 | self.reward = np.zeros((max_size, 1)) 15 | self.not_done = np.zeros((max_size, 1)) 16 | 17 | self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 18 | 19 | 20 | def add(self, state, action, next_state, reward, done): 21 | self.state[self.ptr] = state 22 | self.action[self.ptr] = action 23 | self.next_state[self.ptr] = next_state 24 | self.reward[self.ptr] = reward 25 | self.not_done[self.ptr] = 1. - done 26 | 27 | self.ptr = (self.ptr + 1) % self.max_size 28 | self.size = min(self.size + 1, self.max_size) 29 | 30 | 31 | def sample(self, batch_size): 32 | ind = np.random.randint(self.size, size=batch_size) 33 | 34 | return ( 35 | torch.FloatTensor(self.state[ind]).to(self.device), 36 | torch.FloatTensor(self.action[ind]).to(self.device), 37 | torch.FloatTensor(self.next_state[ind]).to(self.device), 38 | torch.FloatTensor(self.reward[ind]).to(self.device), 39 | torch.FloatTensor(self.not_done[ind]).to(self.device) 40 | ) 41 | 42 | 43 | class PrioritizedReplayBuffer(): 44 | def __init__(self, state_dim, action_dim, max_size=int(1e6)): 45 | self.max_size = max_size 46 | self.ptr = 0 47 | self.size = 0 48 | 49 | self.state = np.zeros((max_size, state_dim)) 50 | self.action = np.zeros((max_size, action_dim)) 51 | self.next_state = np.zeros((max_size, state_dim)) 52 | self.reward = np.zeros((max_size, 1)) 53 | self.not_done = np.zeros((max_size, 1)) 54 | 55 | self.tree = SumTree(max_size) 56 | self.max_priority = 1.0 57 | self.beta = 0.4 58 | 59 | self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 60 | 61 | 62 | def add(self, state, action, next_state, reward, done): 63 | self.state[self.ptr] = state 64 | self.action[self.ptr] = action 65 | self.next_state[self.ptr] = next_state 66 | self.reward[self.ptr] = reward 67 | self.not_done[self.ptr] = 1. - done 68 | 69 | self.tree.set(self.ptr, self.max_priority) 70 | 71 | self.ptr = (self.ptr + 1) % self.max_size 72 | self.size = min(self.size + 1, self.max_size) 73 | 74 | 75 | def sample(self, batch_size): 76 | ind = self.tree.sample(batch_size) 77 | 78 | weights = self.tree.levels[-1][ind] ** -self.beta 79 | weights /= weights.max() 80 | 81 | self.beta = min(self.beta + 2e-7, 1) # Hardcoded: 0.4 + 2e-7 * 3e6 = 1.0. Only used by PER. 82 | 83 | return ( 84 | torch.FloatTensor(self.state[ind]).to(self.device), 85 | torch.FloatTensor(self.action[ind]).to(self.device), 86 | torch.FloatTensor(self.next_state[ind]).to(self.device), 87 | torch.FloatTensor(self.reward[ind]).to(self.device), 88 | torch.FloatTensor(self.not_done[ind]).to(self.device), 89 | ind, 90 | torch.FloatTensor(weights).to(self.device).reshape(-1, 1) 91 | ) 92 | 93 | 94 | def update_priority(self, ind, priority): 95 | self.max_priority = max(priority.max(), self.max_priority) 96 | self.tree.batch_set(ind, priority) 97 | 98 | 99 | class SumTree(object): 100 | def __init__(self, max_size): 101 | self.levels = [np.zeros(1)] 102 | # Tree construction 103 | # Double the number of nodes at each level 104 | level_size = 1 105 | while level_size < max_size: 106 | level_size *= 2 107 | self.levels.append(np.zeros(level_size)) 108 | 109 | 110 | # Batch binary search through sum tree 111 | # Sample a priority between 0 and the max priority 112 | # and then search the tree for the corresponding index 113 | def sample(self, batch_size): 114 | value = np.random.uniform(0, self.levels[0][0], size=batch_size) 115 | ind = np.zeros(batch_size, dtype=int) 116 | 117 | for nodes in self.levels[1:]: 118 | ind *= 2 119 | left_sum = nodes[ind] 120 | 121 | is_greater = np.greater(value, left_sum) 122 | # If value > left_sum -> go right (+1), else go left (+0) 123 | ind += is_greater 124 | # If we go right, we only need to consider the values in the right tree 125 | # so we subtract the sum of values in the left tree 126 | value -= left_sum * is_greater 127 | 128 | return ind 129 | 130 | 131 | def set(self, ind, new_priority): 132 | priority_diff = new_priority - self.levels[-1][ind] 133 | 134 | for nodes in self.levels[::-1]: 135 | np.add.at(nodes, ind, priority_diff) 136 | ind //= 2 137 | 138 | 139 | def batch_set(self, ind, new_priority): 140 | # Confirm we don't increment a node twice 141 | ind, unique_ind = np.unique(ind, return_index=True) 142 | priority_diff = new_priority[unique_ind] - self.levels[-1][ind] 143 | 144 | for nodes in self.levels[::-1]: 145 | np.add.at(nodes, ind, priority_diff) 146 | ind //= 2 -------------------------------------------------------------------------------- /discrete/DDQN.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import numpy as np 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | 8 | # Used for Atari 9 | class Conv_Q(nn.Module): 10 | def __init__(self, frames, num_actions): 11 | super(Conv_Q, self).__init__() 12 | self.c1 = nn.Conv2d(frames, 32, kernel_size=8, stride=4) 13 | self.c2 = nn.Conv2d(32, 64, kernel_size=4, stride=2) 14 | self.c3 = nn.Conv2d(64, 64, kernel_size=3, stride=1) 15 | self.l1 = nn.Linear(3136, 512) 16 | self.l2 = nn.Linear(512, num_actions) 17 | 18 | 19 | def forward(self, state): 20 | q = F.relu(self.c1(state)) 21 | q = F.relu(self.c2(q)) 22 | q = F.relu(self.c3(q)) 23 | q = F.relu(self.l1(q.reshape(-1, 3136))) 24 | return self.l2(q) 25 | 26 | 27 | # Used for Box2D / Toy problems 28 | class FC_Q(nn.Module): 29 | def __init__(self, state_dim, num_actions): 30 | super(FC_Q, self).__init__() 31 | self.l1 = nn.Linear(state_dim, 256) 32 | self.l2 = nn.Linear(256, 256) 33 | self.l3 = nn.Linear(256, num_actions) 34 | 35 | 36 | def forward(self, state): 37 | q = F.relu(self.l1(state)) 38 | q = F.relu(self.l2(q)) 39 | return self.l3(q) 40 | 41 | 42 | class DDQN(object): 43 | def __init__( 44 | self, 45 | is_atari, 46 | num_actions, 47 | state_dim, 48 | device, 49 | discount=0.99, 50 | optimizer="Adam", 51 | optimizer_parameters={}, 52 | polyak_target_update=False, 53 | target_update_frequency=8e3, 54 | tau=0.005, 55 | initial_eps = 1, 56 | end_eps = 0.001, 57 | eps_decay_period = 25e4, 58 | eval_eps=0.001, 59 | ): 60 | 61 | self.device = device 62 | 63 | # Determine network type 64 | self.Q = Conv_Q(4, num_actions).to(self.device) if is_atari else FC_Q(state_dim, num_actions).to(self.device) 65 | self.Q_target = copy.deepcopy(self.Q) 66 | self.Q_optimizer = getattr(torch.optim, optimizer)(self.Q.parameters(), **optimizer_parameters) 67 | 68 | self.discount = discount 69 | 70 | # Target update rule 71 | self.maybe_update_target = self.polyak_target_update if polyak_target_update else self.copy_target_update 72 | self.target_update_frequency = target_update_frequency 73 | self.tau = tau 74 | 75 | # Decay for eps 76 | self.initial_eps = initial_eps 77 | self.end_eps = end_eps 78 | self.slope = (self.end_eps - self.initial_eps) / eps_decay_period 79 | 80 | # Evaluation hyper-parameters 81 | self.state_shape = (-1,) + state_dim if is_atari else (-1, state_dim) 82 | self.eval_eps = eval_eps 83 | self.num_actions = num_actions 84 | 85 | # Number of training iterations 86 | self.iterations = 0 87 | 88 | 89 | def select_action(self, state, eval=False): 90 | eps = self.eval_eps if eval \ 91 | else max(self.slope * self.iterations + self.initial_eps, self.end_eps) 92 | 93 | # Select action according to policy with probability (1-eps) 94 | # otherwise, select random action 95 | if np.random.uniform(0,1) > eps: 96 | with torch.no_grad(): 97 | state = torch.FloatTensor(state).reshape(self.state_shape).to(self.device) 98 | return int(self.Q(state).argmax(1)) 99 | else: 100 | return np.random.randint(self.num_actions) 101 | 102 | 103 | def train(self, replay_buffer): 104 | # Sample replay buffer 105 | state, action, next_state, reward, done = replay_buffer.sample() 106 | 107 | # Compute the target Q value 108 | with torch.no_grad(): 109 | next_action = self.Q(next_state).argmax(1, keepdim=True) 110 | target_Q = ( 111 | reward + done * self.discount * 112 | self.Q_target(next_state).gather(1, next_action).reshape(-1, 1) 113 | ) 114 | 115 | # Get current Q estimate 116 | current_Q = self.Q(state).gather(1, action) 117 | 118 | # Compute Q loss 119 | Q_loss = F.smooth_l1_loss(current_Q, target_Q) 120 | 121 | # Optimize the Q network 122 | self.Q_optimizer.zero_grad() 123 | Q_loss.backward() 124 | self.Q_optimizer.step() 125 | 126 | # Update target network by polyak or full copy every X iterations. 127 | self.iterations += 1 128 | self.maybe_update_target() 129 | 130 | 131 | def polyak_target_update(self): 132 | for param, target_param in zip(self.Q.parameters(), self.Q_target.parameters()): 133 | target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data) 134 | 135 | 136 | def copy_target_update(self): 137 | if self.iterations % self.target_update_frequency == 0: 138 | self.Q_target.load_state_dict(self.Q.state_dict()) 139 | 140 | 141 | def save(self, filename): 142 | torch.save(self.iterations, filename + "iterations") 143 | torch.save(self.Q.state_dict(), f"{filename}Q_{self.iterations}") 144 | torch.save(self.Q_optimizer.state_dict(), filename + "optimizer") 145 | 146 | 147 | def load(self, filename): 148 | self.iterations = torch.load(filename + "iterations") 149 | self.Q.load_state_dict(torch.load(f"{filename}Q_{self.iterations}")) 150 | self.Q_target = copy.deepcopy(self.Q) 151 | self.Q_optimizer.load_state_dict(torch.load(filename + "optimizer")) -------------------------------------------------------------------------------- /discrete/LAP_DDQN.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import numpy as np 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | 8 | # Used for Atari 9 | class Conv_Q(nn.Module): 10 | def __init__(self, frames, num_actions): 11 | super(Conv_Q, self).__init__() 12 | self.c1 = nn.Conv2d(frames, 32, kernel_size=8, stride=4) 13 | self.c2 = nn.Conv2d(32, 64, kernel_size=4, stride=2) 14 | self.c3 = nn.Conv2d(64, 64, kernel_size=3, stride=1) 15 | self.l1 = nn.Linear(3136, 512) 16 | self.l2 = nn.Linear(512, num_actions) 17 | 18 | 19 | def forward(self, state): 20 | q = F.relu(self.c1(state)) 21 | q = F.relu(self.c2(q)) 22 | q = F.relu(self.c3(q)) 23 | q = F.relu(self.l1(q.reshape(-1, 3136))) 24 | return self.l2(q) 25 | 26 | 27 | # Used for Box2D / Toy problems 28 | class FC_Q(nn.Module): 29 | def __init__(self, state_dim, num_actions): 30 | super(FC_Q, self).__init__() 31 | self.l1 = nn.Linear(state_dim, 256) 32 | self.l2 = nn.Linear(256, 256) 33 | self.l3 = nn.Linear(256, num_actions) 34 | 35 | 36 | def forward(self, state): 37 | q = F.relu(self.l1(state)) 38 | q = F.relu(self.l2(q)) 39 | return self.l3(q) 40 | 41 | 42 | class LAP_DDQN(object): 43 | def __init__( 44 | self, 45 | is_atari, 46 | num_actions, 47 | state_dim, 48 | device, 49 | discount=0.99, 50 | optimizer="Adam", 51 | optimizer_parameters={}, 52 | polyak_target_update=False, 53 | target_update_frequency=8e3, 54 | tau=0.005, 55 | initial_eps = 1, 56 | end_eps = 0.001, 57 | eps_decay_period = 25e4, 58 | eval_eps=0.001, 59 | alpha=0.6, 60 | min_priority=1e-2 61 | ): 62 | 63 | self.device = device 64 | 65 | # Determine network type 66 | self.Q = Conv_Q(4, num_actions).to(self.device) if is_atari else FC_Q(state_dim, num_actions).to(self.device) 67 | self.Q_target = copy.deepcopy(self.Q) 68 | self.Q_optimizer = getattr(torch.optim, optimizer)(self.Q.parameters(), **optimizer_parameters) 69 | 70 | self.discount = discount 71 | 72 | # LAP hyper-parameters 73 | self.alpha = alpha 74 | self.min_priority = min_priority 75 | 76 | # Target update rule 77 | self.maybe_update_target = self.polyak_target_update if polyak_target_update else self.copy_target_update 78 | self.target_update_frequency = target_update_frequency 79 | self.tau = tau 80 | 81 | # Decay for eps 82 | self.initial_eps = initial_eps 83 | self.end_eps = end_eps 84 | self.slope = (self.end_eps - self.initial_eps) / eps_decay_period 85 | 86 | # Evaluation hyper-parameters 87 | self.state_shape = (-1,) + state_dim if is_atari else (-1, state_dim) 88 | self.eval_eps = eval_eps 89 | self.num_actions = num_actions 90 | 91 | # Number of training iterations 92 | self.iterations = 0 93 | 94 | 95 | def select_action(self, state, eval=False): 96 | eps = self.eval_eps if eval \ 97 | else max(self.slope * self.iterations + self.initial_eps, self.end_eps) 98 | 99 | # Select action according to policy with probability (1-eps) 100 | # otherwise, select random action 101 | if np.random.uniform(0,1) > eps: 102 | with torch.no_grad(): 103 | state = torch.FloatTensor(state).reshape(self.state_shape).to(self.device) 104 | return int(self.Q(state).argmax(1)) 105 | else: 106 | return np.random.randint(self.num_actions) 107 | 108 | 109 | def train(self, replay_buffer): 110 | # Sample replay buffer 111 | state, action, next_state, reward, done, ind, weights = replay_buffer.sample() 112 | 113 | # Compute the target Q value 114 | with torch.no_grad(): 115 | next_action = self.Q(next_state).argmax(1, keepdim=True) 116 | target_Q = ( 117 | reward + done * self.discount * 118 | self.Q_target(next_state).gather(1, next_action).reshape(-1, 1) 119 | ) 120 | 121 | # Get current Q estimate 122 | current_Q = self.Q(state).gather(1, action) 123 | 124 | td_loss = (current_Q - target_Q).abs() 125 | Q_loss = self.huber(td_loss) 126 | 127 | # Optimize the Q network 128 | self.Q_optimizer.zero_grad() 129 | Q_loss.backward() 130 | self.Q_optimizer.step() 131 | 132 | # Update target network by polyak or full copy every X iterations. 133 | self.iterations += 1 134 | self.maybe_update_target() 135 | 136 | priority = td_loss.clamp(min=self.min_priority).pow(self.alpha).cpu().data.numpy().flatten() 137 | replay_buffer.update_priority(ind, priority) 138 | 139 | 140 | def huber(self, x): 141 | return torch.where(x < self.min_priority, 0.5 * x.pow(2), self.min_priority * x).mean() 142 | 143 | 144 | def polyak_target_update(self): 145 | for param, target_param in zip(self.Q.parameters(), self.Q_target.parameters()): 146 | target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data) 147 | 148 | 149 | def copy_target_update(self): 150 | if self.iterations % self.target_update_frequency == 0: 151 | self.Q_target.load_state_dict(self.Q.state_dict()) 152 | 153 | 154 | def save(self, filename): 155 | torch.save(self.iterations, filename + "iterations") 156 | torch.save(self.Q.state_dict(), f"{filename}Q_{self.iterations}") 157 | torch.save(self.Q_optimizer.state_dict(), filename + "optimizer") 158 | 159 | 160 | def load(self, filename): 161 | self.iterations = torch.load(filename + "iterations") 162 | self.Q.load_state_dict(torch.load(f"{filename}Q_{self.iterations}")) 163 | self.Q_target = copy.deepcopy(self.Q) 164 | self.Q_optimizer.load_state_dict(torch.load(filename + "optimizer")) -------------------------------------------------------------------------------- /discrete/PAL_DDQN.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import numpy as np 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | 8 | # Used for Atari 9 | class Conv_Q(nn.Module): 10 | def __init__(self, frames, num_actions): 11 | super(Conv_Q, self).__init__() 12 | self.c1 = nn.Conv2d(frames, 32, kernel_size=8, stride=4) 13 | self.c2 = nn.Conv2d(32, 64, kernel_size=4, stride=2) 14 | self.c3 = nn.Conv2d(64, 64, kernel_size=3, stride=1) 15 | self.l1 = nn.Linear(3136, 512) 16 | self.l2 = nn.Linear(512, num_actions) 17 | 18 | 19 | def forward(self, state): 20 | q = F.relu(self.c1(state)) 21 | q = F.relu(self.c2(q)) 22 | q = F.relu(self.c3(q)) 23 | q = F.relu(self.l1(q.reshape(-1, 3136))) 24 | return self.l2(q) 25 | 26 | 27 | # Used for Box2D / Toy problems 28 | class FC_Q(nn.Module): 29 | def __init__(self, state_dim, num_actions): 30 | super(FC_Q, self).__init__() 31 | self.l1 = nn.Linear(state_dim, 256) 32 | self.l2 = nn.Linear(256, 256) 33 | self.l3 = nn.Linear(256, num_actions) 34 | 35 | 36 | def forward(self, state): 37 | q = F.relu(self.l1(state)) 38 | q = F.relu(self.l2(q)) 39 | return self.l3(q) 40 | 41 | 42 | class PAL_DDQN(object): 43 | def __init__( 44 | self, 45 | is_atari, 46 | num_actions, 47 | state_dim, 48 | device, 49 | discount=0.99, 50 | optimizer="Adam", 51 | optimizer_parameters={}, 52 | polyak_target_update=False, 53 | target_update_frequency=8e3, 54 | tau=0.005, 55 | initial_eps = 1, 56 | end_eps = 0.001, 57 | eps_decay_period = 25e4, 58 | eval_eps=0.001, 59 | alpha=0.6, 60 | min_priority=1e-2 61 | ): 62 | 63 | self.device = device 64 | 65 | # Determine network type 66 | self.Q = Conv_Q(4, num_actions).to(self.device) if is_atari else FC_Q(state_dim, num_actions).to(self.device) 67 | self.Q_target = copy.deepcopy(self.Q) 68 | self.Q_optimizer = getattr(torch.optim, optimizer)(self.Q.parameters(), **optimizer_parameters) 69 | 70 | self.discount = discount 71 | 72 | # PAL hyper-parameters 73 | self.alpha = alpha 74 | self.min_priority = min_priority 75 | 76 | # Target update rule 77 | self.maybe_update_target = self.polyak_target_update if polyak_target_update else self.copy_target_update 78 | self.target_update_frequency = target_update_frequency 79 | self.tau = tau 80 | 81 | # Decay for eps 82 | self.initial_eps = initial_eps 83 | self.end_eps = end_eps 84 | self.slope = (self.end_eps - self.initial_eps) / eps_decay_period 85 | 86 | # Evaluation hyper-parameters 87 | self.state_shape = (-1,) + state_dim if is_atari else (-1, state_dim) 88 | self.eval_eps = eval_eps 89 | self.num_actions = num_actions 90 | 91 | # Number of training iterations 92 | self.iterations = 0 93 | 94 | 95 | def select_action(self, state, eval=False): 96 | eps = self.eval_eps if eval \ 97 | else max(self.slope * self.iterations + self.initial_eps, self.end_eps) 98 | 99 | # Select action according to policy with probability (1-eps) 100 | # otherwise, select random action 101 | if np.random.uniform(0,1) > eps: 102 | with torch.no_grad(): 103 | state = torch.FloatTensor(state).reshape(self.state_shape).to(self.device) 104 | return int(self.Q(state).argmax(1)) 105 | else: 106 | return np.random.randint(self.num_actions) 107 | 108 | 109 | def train(self, replay_buffer): 110 | # Sample replay buffer 111 | state, action, next_state, reward, done = replay_buffer.sample() 112 | 113 | # Compute the target Q value 114 | with torch.no_grad(): 115 | next_action = self.Q(next_state).argmax(1, keepdim=True) 116 | target_Q = ( 117 | reward + done * self.discount * 118 | self.Q_target(next_state).gather(1, next_action).reshape(-1, 1) 119 | ) 120 | 121 | # Get current Q estimate 122 | current_Q = self.Q(state).gather(1, action) 123 | 124 | td_loss = (current_Q - target_Q).abs() 125 | weight = td_loss.clamp(min=self.min_priority).pow(self.alpha).mean().detach() 126 | 127 | # Compute critic loss 128 | Q_loss = self.PAL(td_loss)/weight.detach() 129 | 130 | # Optimize the Q 131 | self.Q_optimizer.zero_grad() 132 | Q_loss.backward() 133 | self.Q_optimizer.step() 134 | 135 | # Update target network by polyak or full copy every X iterations. 136 | self.iterations += 1 137 | self.maybe_update_target() 138 | 139 | 140 | def PAL(self, x): 141 | return torch.where( 142 | x.abs() < self.min_priority, 143 | (self.min_priority ** self.alpha) * 0.5 * x.pow(2), 144 | self.min_priority * x.abs().pow(1. + self.alpha)/(1. + self.alpha) 145 | ).mean() 146 | 147 | 148 | def polyak_target_update(self): 149 | for param, target_param in zip(self.Q.parameters(), self.Q_target.parameters()): 150 | target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data) 151 | 152 | 153 | def copy_target_update(self): 154 | if self.iterations % self.target_update_frequency == 0: 155 | self.Q_target.load_state_dict(self.Q.state_dict()) 156 | 157 | 158 | def save(self, filename): 159 | torch.save(self.iterations, filename + "iterations") 160 | torch.save(self.Q.state_dict(), f"{filename}Q_{self.iterations}") 161 | torch.save(self.Q_optimizer.state_dict(), filename + "optimizer") 162 | 163 | 164 | def load(self, filename): 165 | self.iterations = torch.load(filename + "iterations") 166 | self.Q.load_state_dict(torch.load(f"{filename}Q_{self.iterations}")) 167 | self.Q_target = copy.deepcopy(self.Q) 168 | self.Q_optimizer.load_state_dict(torch.load(filename + "optimizer")) -------------------------------------------------------------------------------- /discrete/PER_DDQN.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import numpy as np 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | 8 | # Used for Atari 9 | class Conv_Q(nn.Module): 10 | def __init__(self, frames, num_actions): 11 | super(Conv_Q, self).__init__() 12 | self.c1 = nn.Conv2d(frames, 32, kernel_size=8, stride=4) 13 | self.c2 = nn.Conv2d(32, 64, kernel_size=4, stride=2) 14 | self.c3 = nn.Conv2d(64, 64, kernel_size=3, stride=1) 15 | self.l1 = nn.Linear(3136, 512) 16 | self.l2 = nn.Linear(512, num_actions) 17 | 18 | 19 | def forward(self, state): 20 | q = F.relu(self.c1(state)) 21 | q = F.relu(self.c2(q)) 22 | q = F.relu(self.c3(q)) 23 | q = F.relu(self.l1(q.reshape(-1, 3136))) 24 | return self.l2(q) 25 | 26 | 27 | # Used for Box2D / Toy problems 28 | class FC_Q(nn.Module): 29 | def __init__(self, state_dim, num_actions): 30 | super(FC_Q, self).__init__() 31 | self.l1 = nn.Linear(state_dim, 256) 32 | self.l2 = nn.Linear(256, 256) 33 | self.l3 = nn.Linear(256, num_actions) 34 | 35 | 36 | def forward(self, state): 37 | q = F.relu(self.l1(state)) 38 | q = F.relu(self.l2(q)) 39 | return self.l3(q) 40 | 41 | 42 | class PER_DDQN(object): 43 | def __init__( 44 | self, 45 | is_atari, 46 | num_actions, 47 | state_dim, 48 | device, 49 | discount=0.99, 50 | optimizer="Adam", 51 | optimizer_parameters={}, 52 | polyak_target_update=False, 53 | target_update_frequency=8e3, 54 | tau=0.005, 55 | initial_eps = 1, 56 | end_eps = 0.001, 57 | eps_decay_period = 25e4, 58 | eval_eps=0.001, 59 | ): 60 | 61 | self.device = device 62 | 63 | # Determine network type 64 | self.Q = Conv_Q(4, num_actions).to(self.device) if is_atari else FC_Q(state_dim, num_actions).to(self.device) 65 | self.Q_target = copy.deepcopy(self.Q) 66 | self.Q_optimizer = getattr(torch.optim, optimizer)(self.Q.parameters(), **optimizer_parameters) 67 | 68 | self.discount = discount 69 | 70 | # Target update rule 71 | self.maybe_update_target = self.polyak_target_update if polyak_target_update else self.copy_target_update 72 | self.target_update_frequency = target_update_frequency 73 | self.tau = tau 74 | 75 | # Decay for eps 76 | self.initial_eps = initial_eps 77 | self.end_eps = end_eps 78 | self.slope = (self.end_eps - self.initial_eps) / eps_decay_period 79 | 80 | # Evaluation hyper-parameters 81 | self.state_shape = (-1, 4, 84, 84) if is_atari else (-1, state_dim) ### need to pass framesize 82 | self.eval_eps = eval_eps 83 | self.num_actions = num_actions 84 | 85 | # Number of training iterations 86 | self.iterations = 0 87 | 88 | 89 | def select_action(self, state, eval=False): 90 | eps = self.eval_eps if eval \ 91 | else max(self.slope * self.iterations + self.initial_eps, self.end_eps) 92 | 93 | # Select action according to policy with probability (1-eps) 94 | # otherwise, select random action 95 | if np.random.uniform(0,1) > eps: 96 | with torch.no_grad(): 97 | state = torch.FloatTensor(state).reshape(self.state_shape).to(self.device) 98 | return int(self.Q(state).argmax(1)) 99 | else: 100 | return np.random.randint(self.num_actions) 101 | 102 | 103 | def train(self, replay_buffer): 104 | # Sample replay buffer 105 | state, action, next_state, reward, done, ind, weights = replay_buffer.sample() 106 | 107 | # Compute the target Q value 108 | with torch.no_grad(): 109 | next_action = self.Q(next_state).argmax(1, keepdim=True) 110 | target_Q = ( 111 | reward + done * self.discount * 112 | self.Q_target(next_state).gather(1, next_action).reshape(-1, 1) 113 | ) 114 | 115 | # Get current Q estimate 116 | current_Q = self.Q(state).gather(1, action) 117 | 118 | td_loss = (current_Q - target_Q).abs() 119 | Q_loss = self.huber(td_loss) 120 | 121 | # Optimize the Q network 122 | self.Q_optimizer.zero_grad() 123 | Q_loss.backward() 124 | self.Q_optimizer.step() 125 | 126 | # Update target network by polyak or full copy every X iterations. 127 | self.iterations += 1 128 | self.maybe_update_target() 129 | 130 | priority = td_loss.pow(0.6).clamp(min=0.06309573444).cpu().data.numpy().flatten() 131 | replay_buffer.reinsert(ind, priority) 132 | 133 | 134 | def train(self, replay_buffer): 135 | # Sample replay buffer 136 | state, action, next_state, reward, done, ind, weights = replay_buffer.sample() 137 | 138 | # Compute the target Q value 139 | with torch.no_grad(): 140 | next_action = self.Q(next_state).argmax(1, keepdim=True) 141 | target_Q = ( 142 | reward + done * self.discount * 143 | self.Q_target(next_state).gather(1, next_action).reshape(-1, 1) 144 | ) 145 | 146 | # Get current Q estimate 147 | current_Q = self.Q(state).gather(1, action) 148 | 149 | # Compute Q loss 150 | Q_loss = (weights * F.smooth_l1_loss(current_Q, target_Q, reduction='none')).mean() 151 | 152 | # Optimize the Q network 153 | self.Q_optimizer.zero_grad() 154 | Q_loss.backward() 155 | self.Q_optimizer.step() 156 | 157 | # Update target network by polyak or full copy every X iterations. 158 | self.iterations += 1 159 | self.maybe_update_target() 160 | 161 | priority = ((current_Q - target_Q).abs() + 1e-10).pow(0.6).cpu().data.numpy().flatten() 162 | replay_buffer.update_priority(ind, priority) 163 | 164 | 165 | def polyak_target_update(self): 166 | for param, target_param in zip(self.Q.parameters(), self.Q_target.parameters()): 167 | target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data) 168 | 169 | 170 | def copy_target_update(self): 171 | if self.iterations % self.target_update_frequency == 0: 172 | self.Q_target.load_state_dict(self.Q.state_dict()) 173 | 174 | 175 | def save(self, filename): 176 | torch.save(self.iterations, filename + "iterations") 177 | torch.save(self.Q.state_dict(), f"{filename}Q_{self.iterations}") 178 | torch.save(self.Q_optimizer.state_dict(), filename + "optimizer") 179 | 180 | 181 | def load(self, filename): 182 | self.iterations = torch.load(filename + "iterations") 183 | self.Q.load_state_dict(torch.load(f"{filename}Q_{self.iterations}")) 184 | self.Q_target = copy.deepcopy(self.Q) 185 | self.Q_optimizer.load_state_dict(torch.load(filename + "optimizer")) -------------------------------------------------------------------------------- /discrete/README.md: -------------------------------------------------------------------------------- 1 | # LAP/PAL with DDQN for Discrete Action Domains 2 | 3 | Code for Loss-Adjusted Prioritized (LAP) experience replay and Prioritized Approximation Loss (PAL) with Double DQN. 4 | 5 | Paper results were collected with [OpenAI gym](https://github.com/openai/gym). Networks are trained using [PyTorch 1.2.0](https://github.com/pytorch/pytorch) and Python 3.7. 6 | 7 | Example command: 8 | ``` 9 | python main.py --policy "LAP_DDQN" --env "PongNoFrameskip-v0" 10 | ``` 11 | 12 | Hyper-parameters can be modified with different arguments to main.py and the parameter dicts in main.py. Code is set up to potentially run non-Atari environments, but the performance is mostly untested. 13 | -------------------------------------------------------------------------------- /discrete/main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import copy 3 | import importlib 4 | import json 5 | import os 6 | 7 | import numpy as np 8 | import torch 9 | 10 | import DDQN 11 | import PER_DDQN 12 | import LAP_DDQN 13 | import PAL_DDQN 14 | import utils 15 | 16 | 17 | def main(env, replay_buffer, is_atari, state_dim, num_actions, args, parameters, device): 18 | # Initialize and load policy 19 | kwargs = { 20 | "is_atari": is_atari, 21 | "num_actions": num_actions, 22 | "state_dim": state_dim, 23 | "device": device, 24 | "discount": parameters["discount"], 25 | "optimizer": parameters["optimizer"], 26 | "optimizer_parameters": parameters["optimizer_parameters"], 27 | "polyak_target_update": parameters["polyak_target_update"], 28 | "target_update_frequency": parameters["target_update_freq"], 29 | "tau": parameters["tau"], 30 | "initial_eps": parameters["initial_eps"], 31 | "end_eps": parameters["end_eps"], 32 | "eps_decay_period": parameters["eps_decay_period"], 33 | "eval_eps": parameters["eval_eps"] 34 | } 35 | 36 | if args.algorithm == "DDQN": 37 | policy = DDQN.DDQN(**kwargs) 38 | elif args.algorithm == "PER_DDQN": 39 | policy = PER_DDQN.PER_DDQN(**kwargs) 40 | 41 | kwargs["alpha"] = parameters["alpha"] 42 | kwargs["min_priority"] = parameters["min_priority"] 43 | 44 | if args.algorithm == "LAP_DDQN": 45 | policy = LAP_DDQN.LAP_DDQN(**kwargs) 46 | elif args.algorithm == "PAL_DDQN": 47 | policy = PAL_DDQN.PAL_DDQN(**kwargs) 48 | 49 | evaluations = [] 50 | 51 | state, done = env.reset(), False 52 | episode_start = True 53 | episode_reward = 0 54 | episode_timesteps = 0 55 | episode_num = 0 56 | 57 | # Interact with the environment for max_timesteps 58 | for t in range(int(args.max_timesteps)): 59 | 60 | episode_timesteps += 1 61 | 62 | #if args.train_behavioral: 63 | if t < parameters["start_timesteps"]: 64 | action = env.action_space.sample() 65 | else: 66 | action = policy.select_action(np.array(state)) 67 | 68 | # Perform action and log results 69 | next_state, reward, done, info = env.step(action) 70 | episode_reward += reward 71 | 72 | # Only consider "done" if episode terminates due to failure condition 73 | done_float = float(done) if episode_timesteps < env._max_episode_steps else 0 74 | 75 | # For atari, info[0] = clipped reward, info[1] = done_float 76 | if is_atari: 77 | reward = info[0] 78 | done_float = info[1] 79 | 80 | # Store data in replay buffer 81 | replay_buffer.add(state, action, next_state, reward, done_float, done, episode_start) 82 | state = copy.copy(next_state) 83 | episode_start = False 84 | 85 | # Train agent after collecting sufficient data 86 | if t >= parameters["start_timesteps"] and (t + 1) % parameters["train_freq"] == 0: 87 | policy.train(replay_buffer) 88 | 89 | if done: 90 | # +1 to account for 0 indexing. +0 on ep_timesteps since it will increment +1 even if done=True 91 | print(f"Total T: {t+1} Episode Num: {episode_num+1} Episode T: {episode_timesteps} Reward: {episode_reward:.3f}") 92 | # Reset environment 93 | state, done = env.reset(), False 94 | episode_start = True 95 | episode_reward = 0 96 | episode_timesteps = 0 97 | episode_num += 1 98 | 99 | # Evaluate episode 100 | if (t + 1) % parameters["eval_freq"] == 0: 101 | evaluations.append(eval_policy(policy, args.env, args.seed)) 102 | np.save(f"./results/{setting}.npy", evaluations) 103 | 104 | 105 | # Runs policy for X episodes and returns average reward 106 | # A fixed seed is used for the eval environment 107 | def eval_policy(policy, env_name, seed, eval_episodes=10): 108 | eval_env, _, _, _ = utils.make_env(env_name, atari_preprocessing) 109 | eval_env.seed(seed + 100) 110 | 111 | avg_reward = 0. 112 | for _ in range(eval_episodes): 113 | state, done = eval_env.reset(), False 114 | while not done: 115 | action = policy.select_action(np.array(state), eval=True) 116 | state, reward, done, _ = eval_env.step(action) 117 | avg_reward += reward 118 | 119 | avg_reward /= eval_episodes 120 | 121 | print("---------------------------------------") 122 | print(f"Evaluation over {eval_episodes} episodes: {avg_reward:.3f}") 123 | print("---------------------------------------") 124 | return avg_reward 125 | 126 | 127 | if __name__ == "__main__": 128 | 129 | # Atari Specific 130 | atari_preprocessing = { 131 | "frame_skip": 4, 132 | "frame_size": 84, 133 | "state_history": 4, 134 | "done_on_life_loss": False, 135 | "reward_clipping": True, 136 | "max_episode_timesteps": 27e3 137 | } 138 | 139 | atari_parameters = { 140 | # LAP/PAL 141 | "alpha": 0.6, 142 | "min_priority": 1e-2, 143 | # Exploration 144 | "start_timesteps": 2e4, 145 | "initial_eps": 1, 146 | "end_eps": 1e-2, 147 | "eps_decay_period": 25e4, 148 | # Evaluation 149 | "eval_freq": 5e4, 150 | "eval_eps": 1e-3, 151 | # Learning 152 | "discount": 0.99, 153 | "buffer_size": 1e6, 154 | "batch_size": 32, 155 | "optimizer": "RMSprop", 156 | "optimizer_parameters": { 157 | "lr": 0.0000625, 158 | "alpha": 0.95, 159 | "centered": True, 160 | "eps": 0.00001 161 | }, 162 | "train_freq": 4, 163 | "polyak_target_update": False, 164 | "target_update_freq": 8e3, 165 | "tau": 1 166 | } 167 | 168 | regular_parameters = { 169 | # LAP/PAL 170 | "alpha": 0.4, 171 | "min_priority": 1, 172 | # Exploration 173 | "start_timesteps": 1e3, 174 | "initial_eps": 0.1, 175 | "end_eps": 0.1, 176 | "eps_decay_period": 1, 177 | # Evaluation 178 | "eval_freq": 5e3, 179 | "eval_eps": 0, 180 | # Learning 181 | "discount": 0.99, 182 | "buffer_size": 1e6, 183 | "batch_size": 64, 184 | "optimizer": "Adam", 185 | "optimizer_parameters": { 186 | "lr": 3e-4 187 | }, 188 | "train_freq": 1, 189 | "polyak_target_update": True, 190 | "target_update_freq": 1, 191 | "tau": 0.005 192 | } 193 | 194 | # Load parameters 195 | parser = argparse.ArgumentParser() 196 | parser.add_argument("--algorithm", default="LAP_DDQN") # OpenAI gym environment name 197 | parser.add_argument("--env", default="PongNoFrameskip-v0") # OpenAI gym environment name #PongNoFrameskip-v0 198 | parser.add_argument("--seed", default=0, type=int) # Sets Gym, PyTorch and Numpy seeds 199 | parser.add_argument("--buffer_name", default="Default") # Prepends name to filename 200 | parser.add_argument("--max_timesteps", default=50e6, type=int) # Max time steps to run environment or train for 201 | args = parser.parse_args() 202 | 203 | print("---------------------------------------") 204 | print(f"Setting: Algorithm: {args.algorithm}, Env: {args.env}, Seed: {args.seed}") 205 | print("---------------------------------------") 206 | 207 | setting = f"{args.algorithm}_{args.env}_{args.seed}" 208 | 209 | if not os.path.exists("./results"): 210 | os.makedirs("./results") 211 | 212 | # Make env and determine properties 213 | env, is_atari, state_dim, num_actions = utils.make_env(args.env, atari_preprocessing) 214 | parameters = atari_parameters if is_atari else regular_parameters 215 | 216 | env.seed(args.seed) 217 | torch.manual_seed(args.seed) 218 | np.random.seed(args.seed) 219 | 220 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 221 | 222 | # Initialize buffer 223 | prioritized = True if args.algorithm == "PER_DDQN" or args.algorithm == "LAP_DDQN" else False 224 | replay_buffer = utils.ReplayBuffer( 225 | state_dim, 226 | prioritized, 227 | is_atari, 228 | atari_preprocessing, 229 | parameters["batch_size"], 230 | parameters["buffer_size"], 231 | device 232 | ) 233 | 234 | main(env, replay_buffer, is_atari, state_dim, num_actions, args, parameters, device) 235 | -------------------------------------------------------------------------------- /discrete/utils.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import gym 3 | import numpy as np 4 | import torch 5 | 6 | 7 | def ReplayBuffer(state_dim, prioritized, is_atari, atari_preprocessing, batch_size, buffer_size, device): 8 | if is_atari: 9 | return PrioritizedAtariBuffer(state_dim, atari_preprocessing, batch_size, buffer_size, device, prioritized) 10 | else: 11 | return PrioritizedStandardBuffer(state_dim, batch_size, buffer_size, device, prioritized) 12 | 13 | 14 | class PrioritizedAtariBuffer(object): 15 | def __init__(self, state_dim, atari_preprocessing, batch_size, buffer_size, device, prioritized): 16 | self.batch_size = batch_size 17 | self.max_size = int(buffer_size) 18 | self.device = device 19 | 20 | self.state_history = atari_preprocessing["state_history"] 21 | 22 | self.ptr = 0 23 | self.size = 0 24 | 25 | self.state = np.zeros(( 26 | self.max_size + 1, 27 | atari_preprocessing["frame_size"], 28 | atari_preprocessing["frame_size"] 29 | ), dtype=np.uint8) 30 | 31 | self.action = np.zeros((self.max_size, 1), dtype=np.int64) 32 | self.reward = np.zeros((self.max_size, 1)) 33 | 34 | # not_done only consider "done" if episode terminates due to failure condition 35 | # if episode terminates due to timelimit, the transition is not added to the buffer 36 | self.not_done = np.zeros((self.max_size, 1)) 37 | self.first_timestep = np.zeros(self.max_size, dtype=np.uint8) 38 | 39 | self.prioritized = prioritized 40 | 41 | if self.prioritized: 42 | self.tree = SumTree(self.max_size) 43 | self.max_priority = 1.0 44 | self.beta = 0.4 45 | 46 | 47 | def add(self, state, action, next_state, reward, done, env_done, first_timestep): 48 | # If dones don't match, env has reset due to timelimit 49 | # and we don't add the transition to the buffer 50 | if done != env_done: 51 | return 52 | 53 | self.state[self.ptr] = state[0] 54 | self.action[self.ptr] = action 55 | self.reward[self.ptr] = reward 56 | self.not_done[self.ptr] = 1. - done 57 | self.first_timestep[self.ptr] = first_timestep 58 | 59 | self.ptr = (self.ptr + 1) % self.max_size 60 | self.size = min(self.size + 1, self.max_size) 61 | 62 | if self.prioritized: 63 | self.tree.set(self.ptr, self.max_priority) 64 | 65 | 66 | def sample(self): 67 | ind = self.tree.sample(self.batch_size) if self.prioritized \ 68 | else np.random.randint(0, self.size, size=self.batch_size) 69 | 70 | # Note + is concatenate here 71 | state = np.zeros(((self.batch_size, self.state_history) + self.state.shape[1:]), dtype=np.uint8) 72 | next_state = np.array(state) 73 | 74 | state_not_done = 1. 75 | next_not_done = 1. 76 | for i in range(self.state_history): 77 | 78 | # Wrap around if the buffer is filled 79 | if self.size == self.max_size: 80 | j = (ind - i) % self.max_size 81 | k = (ind - i + 1) % self.max_size 82 | else: 83 | j = ind - i 84 | k = (ind - i + 1).clip(min=0) 85 | # If j == -1, then we set state_not_done to 0. 86 | state_not_done *= (j + 1).clip(min=0, max=1).reshape(-1, 1, 1) 87 | j = j.clip(min=0) 88 | 89 | # State should be all 0s if the episode terminated previously 90 | state[:, i] = self.state[j] * state_not_done 91 | next_state[:, i] = self.state[k] * next_not_done 92 | 93 | # If this was the first timestep, make everything previous = 0 94 | next_not_done *= state_not_done 95 | state_not_done *= (1. - self.first_timestep[j]).reshape(-1, 1, 1) 96 | 97 | batch = ( 98 | torch.ByteTensor(state).to(self.device).float(), 99 | torch.LongTensor(self.action[ind]).to(self.device), 100 | torch.ByteTensor(next_state).to(self.device).float(), 101 | torch.FloatTensor(self.reward[ind]).to(self.device), 102 | torch.FloatTensor(self.not_done[ind]).to(self.device) 103 | ) 104 | 105 | if self.prioritized: 106 | weights = np.array(self.tree.nodes[-1][ind]) ** -self.beta 107 | weights /= weights.max() 108 | self.beta = min(self.beta + 4.8e-8, 1) # Hardcoded: 0.4 + 4.8e-8 * 12.5e6 = 1.0. Only used by PER. 109 | batch += (ind, torch.FloatTensor(weights).to(self.device).reshape(-1, 1)) 110 | 111 | return batch 112 | 113 | 114 | def update_priority(self, ind, priority): 115 | self.max_priority = max(priority.max(), self.max_priority) 116 | self.tree.batch_set(ind, priority) 117 | 118 | 119 | # Replay buffer for standard gym tasks 120 | class PrioritizedStandardBuffer(): 121 | def __init__(self, state_dim, batch_size, buffer_size, device, prioritized): 122 | self.batch_size = batch_size 123 | self.max_size = int(buffer_size) 124 | self.device = device 125 | 126 | self.ptr = 0 127 | self.size = 0 128 | 129 | self.state = np.zeros((self.max_size, state_dim)) 130 | self.action = np.zeros((self.max_size, 1)) 131 | self.next_state = np.array(self.state) 132 | self.reward = np.zeros((self.max_size, 1)) 133 | self.not_done = np.zeros((self.max_size, 1)) 134 | 135 | self.prioritized = prioritized 136 | 137 | if self.prioritized: 138 | self.tree = SumTree(self.max_size) 139 | self.max_priority = 1.0 140 | self.beta = 0.4 141 | 142 | 143 | def add(self, state, action, next_state, reward, done, env_done, first_timestep): 144 | self.state[self.ptr] = state 145 | self.action[self.ptr] = action 146 | self.next_state[self.ptr] = next_state 147 | self.reward[self.ptr] = reward 148 | self.not_done[self.ptr] = 1. - done 149 | 150 | if self.prioritized: 151 | self.tree.set(self.ptr, self.max_priority) 152 | 153 | self.ptr = (self.ptr + 1) % self.max_size 154 | self.size = min(self.size + 1, self.max_size) 155 | 156 | 157 | def sample(self): 158 | ind = self.tree.sample(self.batch_size) if self.prioritized \ 159 | else np.random.randint(0, self.size, size=self.batch_size) 160 | 161 | batch = ( 162 | torch.FloatTensor(self.state[ind]).to(self.device), 163 | torch.LongTensor(self.action[ind]).to(self.device), 164 | torch.FloatTensor(self.next_state[ind]).to(self.device), 165 | torch.FloatTensor(self.reward[ind]).to(self.device), 166 | torch.FloatTensor(self.not_done[ind]).to(self.device) 167 | ) 168 | 169 | if self.prioritized: 170 | weights = np.array(self.tree.nodes[-1][ind]) ** -self.beta 171 | weights /= weights.max() 172 | self.beta = min(self.beta + 2e-7, 1) # Hardcoded: 0.4 + 2e-7 * 3e6 = 1.0. Only used by PER. 173 | batch += (ind, torch.FloatTensor(weights).to(self.device).reshape(-1, 1)) 174 | 175 | return batch 176 | 177 | 178 | def update_priority(self, ind, priority): 179 | self.max_priority = max(priority.max(), self.max_priority) 180 | self.tree.batch_set(ind, priority) 181 | 182 | 183 | class SumTree(object): 184 | def __init__(self, max_size): 185 | self.nodes = [] 186 | # Tree construction 187 | # Double the number of nodes at each level 188 | level_size = 1 189 | for _ in range(int(np.ceil(np.log2(max_size))) + 1): 190 | nodes = np.zeros(level_size) 191 | self.nodes.append(nodes) 192 | level_size *= 2 193 | 194 | 195 | # Batch binary search through sum tree 196 | # Sample a priority between 0 and the max priority 197 | # and then search the tree for the corresponding index 198 | def sample(self, batch_size): 199 | query_value = np.random.uniform(0, self.nodes[0][0], size=batch_size) 200 | node_index = np.zeros(batch_size, dtype=int) 201 | 202 | for nodes in self.nodes[1:]: 203 | node_index *= 2 204 | left_sum = nodes[node_index] 205 | 206 | is_greater = np.greater(query_value, left_sum) 207 | # If query_value > left_sum -> go right (+1), else go left (+0) 208 | node_index += is_greater 209 | # If we go right, we only need to consider the values in the right tree 210 | # so we subtract the sum of values in the left tree 211 | query_value -= left_sum * is_greater 212 | 213 | return node_index 214 | 215 | 216 | def set(self, node_index, new_priority): 217 | priority_diff = new_priority - self.nodes[-1][node_index] 218 | 219 | for nodes in self.nodes[::-1]: 220 | np.add.at(nodes, node_index, priority_diff) 221 | node_index //= 2 222 | 223 | 224 | def batch_set(self, node_index, new_priority): 225 | # Confirm we don't increment a node twice 226 | node_index, unique_index = np.unique(node_index, return_index=True) 227 | priority_diff = new_priority[unique_index] - self.nodes[-1][node_index] 228 | 229 | for nodes in self.nodes[::-1]: 230 | np.add.at(nodes, node_index, priority_diff) 231 | node_index //= 2 232 | 233 | 234 | # Atari Preprocessing 235 | # Code is based on https://github.com/openai/gym/blob/master/gym/wrappers/atari_preprocessing.py 236 | class AtariPreprocessing(object): 237 | def __init__( 238 | self, 239 | env, 240 | frame_skip=4, 241 | frame_size=84, 242 | state_history=4, 243 | done_on_life_loss=False, 244 | reward_clipping=True, # Clips to a range of -1,1 245 | max_episode_timesteps=27000 246 | ): 247 | self.env = env.env 248 | self.done_on_life_loss = done_on_life_loss 249 | self.frame_skip = frame_skip 250 | self.frame_size = frame_size 251 | self.reward_clipping = reward_clipping 252 | self._max_episode_steps = max_episode_timesteps 253 | self.observation_space = np.zeros((frame_size, frame_size)) 254 | self.action_space = self.env.action_space 255 | 256 | self.lives = 0 257 | self.episode_length = 0 258 | 259 | # Tracks previous 2 frames 260 | self.frame_buffer = np.zeros( 261 | (2, 262 | self.env.observation_space.shape[0], 263 | self.env.observation_space.shape[1]), 264 | dtype=np.uint8 265 | ) 266 | # Tracks previous 4 states 267 | self.state_buffer = np.zeros((state_history, frame_size, frame_size), dtype=np.uint8) 268 | 269 | 270 | def reset(self): 271 | self.env.reset() 272 | self.lives = self.env.ale.lives() 273 | self.episode_length = 0 274 | self.env.ale.getScreenGrayscale(self.frame_buffer[0]) 275 | self.frame_buffer[1] = 0 276 | 277 | self.state_buffer[0] = self.adjust_frame() 278 | self.state_buffer[1:] = 0 279 | return self.state_buffer 280 | 281 | 282 | # Takes single action is repeated for frame_skip frames (usually 4) 283 | # Reward is accumulated over those frames 284 | def step(self, action): 285 | total_reward = 0. 286 | self.episode_length += 1 287 | 288 | for frame in range(self.frame_skip): 289 | _, reward, done, _ = self.env.step(action) 290 | total_reward += reward 291 | 292 | if self.done_on_life_loss: 293 | crt_lives = self.env.ale.lives() 294 | done = True if crt_lives < self.lives else done 295 | self.lives = crt_lives 296 | 297 | if done: 298 | break 299 | 300 | # Second last and last frame 301 | f = frame + 2 - self.frame_skip 302 | if f >= 0: 303 | self.env.ale.getScreenGrayscale(self.frame_buffer[f]) 304 | 305 | self.state_buffer[1:] = self.state_buffer[:-1] 306 | self.state_buffer[0] = self.adjust_frame() 307 | 308 | done_float = float(done) 309 | if self.episode_length >= self._max_episode_steps: 310 | done = True 311 | 312 | return self.state_buffer, total_reward, done, [np.clip(total_reward, -1, 1), done_float] 313 | 314 | 315 | def adjust_frame(self): 316 | # Take maximum over last two frames 317 | np.maximum( 318 | self.frame_buffer[0], 319 | self.frame_buffer[1], 320 | out=self.frame_buffer[0] 321 | ) 322 | 323 | # Resize 324 | image = cv2.resize( 325 | self.frame_buffer[0], 326 | (self.frame_size, self.frame_size), 327 | interpolation=cv2.INTER_AREA 328 | ) 329 | return np.array(image, dtype=np.uint8) 330 | 331 | 332 | def seed(self, seed): 333 | self.env.seed(seed) 334 | 335 | 336 | # Create environment, add wrapper if necessary and create env_properties 337 | def make_env(env_name, atari_preprocessing): 338 | env = gym.make(env_name) 339 | 340 | is_atari = gym.envs.registry.spec(env_name).entry_point == 'gym.envs.atari:AtariEnv' 341 | env = AtariPreprocessing(env, **atari_preprocessing) if is_atari else env 342 | 343 | state_dim = ( 344 | atari_preprocessing["state_history"], 345 | atari_preprocessing["frame_size"], 346 | atari_preprocessing["frame_size"] 347 | ) if is_atari else env.observation_space.shape[0] 348 | 349 | return ( 350 | env, 351 | is_atari, 352 | state_dim, 353 | env.action_space.n 354 | ) --------------------------------------------------------------------------------