├── LICENSE ├── README.md ├── SumTree.py ├── cartpole_per.py └── prioritized_memory.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 RLCode 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # per 2 | PER(Prioritized Experience Replay) implementation in PyTorch 3 | -------------------------------------------------------------------------------- /SumTree.py: -------------------------------------------------------------------------------- 1 | import numpy 2 | 3 | 4 | # SumTree 5 | # a binary tree data structure where the parent’s value is the sum of its children 6 | class SumTree: 7 | write = 0 8 | 9 | def __init__(self, capacity): 10 | self.capacity = capacity 11 | self.tree = numpy.zeros(2 * capacity - 1) 12 | self.data = numpy.zeros(capacity, dtype=object) 13 | self.n_entries = 0 14 | 15 | # update to the root node 16 | def _propagate(self, idx, change): 17 | parent = (idx - 1) // 2 18 | 19 | self.tree[parent] += change 20 | 21 | if parent != 0: 22 | self._propagate(parent, change) 23 | 24 | # find sample on leaf node 25 | def _retrieve(self, idx, s): 26 | left = 2 * idx + 1 27 | right = left + 1 28 | 29 | if left >= len(self.tree): 30 | return idx 31 | 32 | if s <= self.tree[left]: 33 | return self._retrieve(left, s) 34 | else: 35 | return self._retrieve(right, s - self.tree[left]) 36 | 37 | def total(self): 38 | return self.tree[0] 39 | 40 | # store priority and sample 41 | def add(self, p, data): 42 | idx = self.write + self.capacity - 1 43 | 44 | self.data[self.write] = data 45 | self.update(idx, p) 46 | 47 | self.write += 1 48 | if self.write >= self.capacity: 49 | self.write = 0 50 | 51 | if self.n_entries < self.capacity: 52 | self.n_entries += 1 53 | 54 | # update priority 55 | def update(self, idx, p): 56 | change = p - self.tree[idx] 57 | 58 | self.tree[idx] = p 59 | self._propagate(idx, change) 60 | 61 | # get priority and sample 62 | def get(self, s): 63 | idx = self._retrieve(0, s) 64 | dataIdx = idx - self.capacity + 1 65 | 66 | return (idx, self.tree[idx], self.data[dataIdx]) 67 | -------------------------------------------------------------------------------- /cartpole_per.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import gym 3 | import torch 4 | import pylab 5 | import random 6 | import numpy as np 7 | from collections import deque 8 | import torch.nn as nn 9 | import torch.optim as optim 10 | import torch.nn.functional as F 11 | from torch.autograd import Variable 12 | from torchvision import transforms 13 | from prioritized_memory import Memory 14 | 15 | EPISODES = 500 16 | 17 | # approximate Q function using Neural Network 18 | # state is input and Q Value of each action is output of network 19 | class DQN(nn.Module): 20 | def __init__(self, state_size, action_size): 21 | super(DQN, self).__init__() 22 | self.fc = nn.Sequential( 23 | nn.Linear(state_size, 24), 24 | nn.ReLU(), 25 | nn.Linear(24, 24), 26 | nn.ReLU(), 27 | nn.Linear(24, action_size) 28 | ) 29 | 30 | def forward(self, x): 31 | return self.fc(x) 32 | 33 | 34 | # DQN Agent for the Cartpole 35 | # it uses Neural Network to approximate q function 36 | # and prioritized experience replay memory & target q network 37 | class DQNAgent(): 38 | def __init__(self, state_size, action_size): 39 | # if you want to see Cartpole learning, then change to True 40 | self.render = False 41 | self.load_model = False 42 | 43 | # get size of state and action 44 | self.state_size = state_size 45 | self.action_size = action_size 46 | 47 | # These are hyper parameters for the DQN 48 | self.discount_factor = 0.99 49 | self.learning_rate = 0.001 50 | self.memory_size = 20000 51 | self.epsilon = 1.0 52 | self.epsilon_min = 0.01 53 | self.explore_step = 5000 54 | self.epsilon_decay = (self.epsilon - self.epsilon_min) / self.explore_step 55 | self.batch_size = 64 56 | self.train_start = 1000 57 | 58 | # create prioritized replay memory using SumTree 59 | self.memory = Memory(self.memory_size) 60 | 61 | # create main model and target model 62 | self.model = DQN(state_size, action_size) 63 | self.model.apply(self.weights_init) 64 | self.target_model = DQN(state_size, action_size) 65 | self.optimizer = optim.Adam(self.model.parameters(), 66 | lr=self.learning_rate) 67 | 68 | # initialize target model 69 | self.update_target_model() 70 | 71 | if self.load_model: 72 | self.model = torch.load('save_model/cartpole_dqn') 73 | 74 | # weight xavier initialize 75 | def weights_init(self, m): 76 | classname = m.__class__.__name__ 77 | if classname.find('Linear') != -1: 78 | torch.nn.init.xavier_uniform(m.weight) 79 | 80 | # after some time interval update the target model to be same with model 81 | def update_target_model(self): 82 | self.target_model.load_state_dict(self.model.state_dict()) 83 | 84 | # get action from model using epsilon-greedy policy 85 | def get_action(self, state): 86 | if np.random.rand() <= self.epsilon: 87 | return random.randrange(self.action_size) 88 | else: 89 | state = torch.from_numpy(state) 90 | state = Variable(state).float().cpu() 91 | q_value = self.model(state) 92 | _, action = torch.max(q_value, 1) 93 | return int(action) 94 | 95 | # save sample (error,) to the replay memory 96 | def append_sample(self, state, action, reward, next_state, done): 97 | target = self.model(Variable(torch.FloatTensor(state))).data 98 | old_val = target[0][action] 99 | target_val = self.target_model(Variable(torch.FloatTensor(next_state))).data 100 | if done: 101 | target[0][action] = reward 102 | else: 103 | target[0][action] = reward + self.discount_factor * torch.max(target_val) 104 | 105 | error = abs(old_val - target[0][action]) 106 | 107 | self.memory.add(error, (state, action, reward, next_state, done)) 108 | 109 | # pick samples from prioritized replay memory (with batch_size) 110 | def train_model(self): 111 | if self.epsilon > self.epsilon_min: 112 | self.epsilon -= self.epsilon_decay 113 | 114 | mini_batch, idxs, is_weights = self.memory.sample(self.batch_size) 115 | mini_batch = np.array(mini_batch).transpose() 116 | 117 | states = np.vstack(mini_batch[0]) 118 | actions = list(mini_batch[1]) 119 | rewards = list(mini_batch[2]) 120 | next_states = np.vstack(mini_batch[3]) 121 | dones = mini_batch[4] 122 | 123 | # bool to binary 124 | dones = dones.astype(int) 125 | 126 | # Q function of current state 127 | states = torch.Tensor(states) 128 | states = Variable(states).float() 129 | pred = self.model(states) 130 | 131 | # one-hot encoding 132 | a = torch.LongTensor(actions).view(-1, 1) 133 | 134 | one_hot_action = torch.FloatTensor(self.batch_size, self.action_size).zero_() 135 | one_hot_action.scatter_(1, a, 1) 136 | 137 | pred = torch.sum(pred.mul(Variable(one_hot_action)), dim=1) 138 | 139 | # Q function of next state 140 | next_states = torch.Tensor(next_states) 141 | next_states = Variable(next_states).float() 142 | next_pred = self.target_model(next_states).data 143 | 144 | rewards = torch.FloatTensor(rewards) 145 | dones = torch.FloatTensor(dones) 146 | 147 | # Q Learning: get maximum Q value at s' from target model 148 | target = rewards + (1 - dones) * self.discount_factor * next_pred.max(1)[0] 149 | target = Variable(target) 150 | 151 | errors = torch.abs(pred - target).data.numpy() 152 | 153 | # update priority 154 | for i in range(self.batch_size): 155 | idx = idxs[i] 156 | self.memory.update(idx, errors[i]) 157 | 158 | self.optimizer.zero_grad() 159 | 160 | # MSE Loss function 161 | loss = (torch.FloatTensor(is_weights) * F.mse_loss(pred, target)).mean() 162 | loss.backward() 163 | 164 | # and train 165 | self.optimizer.step() 166 | 167 | 168 | if __name__ == "__main__": 169 | # In case of CartPole-v1, maximum length of episode is 500 170 | env = gym.make('CartPole-v1') 171 | state_size = env.observation_space.shape[0] 172 | action_size = env.action_space.n 173 | model = DQN(state_size, action_size) 174 | 175 | agent = DQNAgent(state_size, action_size) 176 | scores, episodes = [], [] 177 | 178 | for e in range(EPISODES): 179 | done = False 180 | score = 0 181 | 182 | state = env.reset() 183 | state = np.reshape(state, [1, state_size]) 184 | 185 | while not done: 186 | if agent.render: 187 | env.render() 188 | 189 | # get action for the current state and go one step in environment 190 | action = agent.get_action(state) 191 | 192 | next_state, reward, done, info = env.step(action) 193 | next_state = np.reshape(next_state, [1, state_size]) 194 | # if an action make the episode end, then gives penalty of -100 195 | reward = reward if not done or score == 499 else -10 196 | 197 | # save the sample to the replay memory 198 | agent.append_sample(state, action, reward, next_state, done) 199 | # every time step do the training 200 | if agent.memory.tree.n_entries >= agent.train_start: 201 | agent.train_model() 202 | 203 | score += reward 204 | state = next_state 205 | 206 | if done: 207 | # every episode update the target model to be same with model 208 | agent.update_target_model() 209 | 210 | # every episode, plot the play time 211 | score = score if score == 500 else score + 10 212 | scores.append(score) 213 | episodes.append(e) 214 | pylab.plot(episodes, scores, 'b') 215 | pylab.savefig("./save_graph/cartpole_dqn.png") 216 | print("episode:", e, " score:", score, " memory length:", 217 | agent.memory.tree.n_entries, " epsilon:", agent.epsilon) 218 | 219 | # if the mean of scores of last 10 episode is bigger than 490 220 | # stop training 221 | if np.mean(scores[-min(10, len(scores)):]) > 490: 222 | torch.save(agent.model, "./save_model/cartpole_dqn") 223 | sys.exit() 224 | -------------------------------------------------------------------------------- /prioritized_memory.py: -------------------------------------------------------------------------------- 1 | import random 2 | import numpy as np 3 | from SumTree import SumTree 4 | 5 | class Memory: # stored as ( s, a, r, s_ ) in SumTree 6 | e = 0.01 7 | a = 0.6 8 | beta = 0.4 9 | beta_increment_per_sampling = 0.001 10 | 11 | def __init__(self, capacity): 12 | self.tree = SumTree(capacity) 13 | self.capacity = capacity 14 | 15 | def _get_priority(self, error): 16 | return (np.abs(error) + self.e) ** self.a 17 | 18 | def add(self, error, sample): 19 | p = self._get_priority(error) 20 | self.tree.add(p, sample) 21 | 22 | def sample(self, n): 23 | batch = [] 24 | idxs = [] 25 | segment = self.tree.total() / n 26 | priorities = [] 27 | 28 | self.beta = np.min([1., self.beta + self.beta_increment_per_sampling]) 29 | 30 | for i in range(n): 31 | a = segment * i 32 | b = segment * (i + 1) 33 | 34 | s = random.uniform(a, b) 35 | (idx, p, data) = self.tree.get(s) 36 | priorities.append(p) 37 | batch.append(data) 38 | idxs.append(idx) 39 | 40 | sampling_probabilities = priorities / self.tree.total() 41 | is_weight = np.power(self.tree.n_entries * sampling_probabilities, -self.beta) 42 | is_weight /= is_weight.max() 43 | 44 | return batch, idxs, is_weight 45 | 46 | def update(self, idx, error): 47 | p = self._get_priority(error) 48 | self.tree.update(idx, p) 49 | --------------------------------------------------------------------------------