├── LICENSE ├── README.md ├── graph.png └── pg.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Daochang Liu 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # pytorch-policy-gradient-example 2 | 3 | Train an agent for CartPole-v0 using naive Policy Gradient. 4 | 5 | Inspired by [Andrej Karpathy's blog](https://karpathy.github.io/2016/05/31/rl/). 6 | 7 | Code partly from [Pytorch DQN Tutorial](http://pytorch.org/tutorials/intermediate/reinforcement_q_learning.html) 8 | 9 | Solved in 500 episodes (Avg Reward): 10 | 11 | ![alt text](./graph.png "Avg Reward") 12 | -------------------------------------------------------------------------------- /graph.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Finspire13/pytorch-policy-gradient-example/eded52d0d4f56b2d6f72394c083dbd70029294fb/graph.png -------------------------------------------------------------------------------- /pg.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.distributions import Bernoulli 5 | from torch.autograd import Variable 6 | from itertools import count 7 | import matplotlib.pyplot as plt 8 | import numpy as np 9 | import gym 10 | import pdb 11 | 12 | 13 | class PolicyNet(nn.Module): 14 | def __init__(self): 15 | super(PolicyNet, self).__init__() 16 | 17 | self.fc1 = nn.Linear(4, 24) 18 | self.fc2 = nn.Linear(24, 36) 19 | self.fc3 = nn.Linear(36, 1) # Prob of Left 20 | 21 | def forward(self, x): 22 | x = F.relu(self.fc1(x)) 23 | x = F.relu(self.fc2(x)) 24 | x = F.sigmoid(self.fc3(x)) 25 | return x 26 | 27 | 28 | def main(): 29 | 30 | # Plot duration curve: 31 | # From http://pytorch.org/tutorials/intermediate/reinforcement_q_learning.html 32 | episode_durations = [] 33 | def plot_durations(): 34 | plt.figure(2) 35 | plt.clf() 36 | durations_t = torch.FloatTensor(episode_durations) 37 | plt.title('Training...') 38 | plt.xlabel('Episode') 39 | plt.ylabel('Duration') 40 | plt.plot(durations_t.numpy()) 41 | # Take 100 episode averages and plot them too 42 | if len(durations_t) >= 100: 43 | means = durations_t.unfold(0, 100, 1).mean(1).view(-1) 44 | means = torch.cat((torch.zeros(99), means)) 45 | plt.plot(means.numpy()) 46 | 47 | plt.pause(0.001) # pause a bit so that plots are updated 48 | 49 | # Parameters 50 | num_episode = 5000 51 | batch_size = 5 52 | learning_rate = 0.01 53 | gamma = 0.99 54 | 55 | env = gym.make('CartPole-v0') 56 | policy_net = PolicyNet() 57 | optimizer = torch.optim.RMSprop(policy_net.parameters(), lr=learning_rate) 58 | 59 | # Batch History 60 | state_pool = [] 61 | action_pool = [] 62 | reward_pool = [] 63 | steps = 0 64 | 65 | 66 | for e in range(num_episode): 67 | 68 | state = env.reset() 69 | state = torch.from_numpy(state).float() 70 | state = Variable(state) 71 | env.render(mode='rgb_array') 72 | 73 | for t in count(): 74 | 75 | probs = policy_net(state) 76 | m = Bernoulli(probs) 77 | action = m.sample() 78 | 79 | action = action.data.numpy().astype(int)[0] 80 | next_state, reward, done, _ = env.step(action) 81 | env.render(mode='rgb_array') 82 | 83 | # To mark boundarys between episodes 84 | if done: 85 | reward = 0 86 | 87 | state_pool.append(state) 88 | action_pool.append(float(action)) 89 | reward_pool.append(reward) 90 | 91 | state = next_state 92 | state = torch.from_numpy(state).float() 93 | state = Variable(state) 94 | 95 | steps += 1 96 | 97 | if done: 98 | episode_durations.append(t + 1) 99 | plot_durations() 100 | break 101 | 102 | # Update policy 103 | if e > 0 and e % batch_size == 0: 104 | 105 | # Discount reward 106 | running_add = 0 107 | for i in reversed(range(steps)): 108 | if reward_pool[i] == 0: 109 | running_add = 0 110 | else: 111 | running_add = running_add * gamma + reward_pool[i] 112 | reward_pool[i] = running_add 113 | 114 | # Normalize reward 115 | reward_mean = np.mean(reward_pool) 116 | reward_std = np.std(reward_pool) 117 | for i in range(steps): 118 | reward_pool[i] = (reward_pool[i] - reward_mean) / reward_std 119 | 120 | # Gradient Desent 121 | optimizer.zero_grad() 122 | 123 | for i in range(steps): 124 | state = state_pool[i] 125 | action = Variable(torch.FloatTensor([action_pool[i]])) 126 | reward = reward_pool[i] 127 | 128 | probs = policy_net(state) 129 | m = Bernoulli(probs) 130 | loss = -m.log_prob(action) * reward # Negtive score function x reward 131 | loss.backward() 132 | 133 | optimizer.step() 134 | 135 | state_pool = [] 136 | action_pool = [] 137 | reward_pool = [] 138 | steps = 0 139 | 140 | 141 | if __name__ == '__main__': 142 | main() 143 | --------------------------------------------------------------------------------