├── README.md ├── ddpg.py ├── dqn.py ├── img ├── ddpg.png ├── ddpg_heatmap.png ├── dqn.png ├── dqn_heatmap.png ├── ppo.png ├── ppo_d.png ├── ppo_d_heatmap.png └── ppo_heatmap.png ├── log ├── ddpg_training_records.pkl ├── dqn_training_records.pkl ├── ppo_d_training_records.pkl └── ppo_training_records.pkl ├── param ├── ddpg_anet_params.pkl ├── ddpg_cnet_params.pkl ├── dqn_net_params.pkl ├── ppo_anet_params.pkl ├── ppo_cnet_params.pkl ├── ppo_d_anet_params.pkl └── ppo_d_cnet_params.pkl ├── plot_heatmap.py ├── ppo.py └── ppo_d.py /README.md: -------------------------------------------------------------------------------- 1 | # Reinforcement Learning Methods with PyTorch 2 | Try different reinforcement learning methods with PyTorch on the OpenAI Gym! All the algorithms are validated on Pendulum-v0. 3 | 4 | ## Requirement 5 | To run the code, you need: 6 | - torch 0.4 7 | - gym 0.10 8 | 9 | ## Method 10 | There are four versions of algorithms realized: 11 | - DDQN with discretized action space 12 | - DDPG with continuous action space 13 | - PPO with discretized action space 14 | - PPO with continuous action space 15 | Note that in PPO using value function to estimate advantages, which is different from the original one. 16 | 17 | ## Result 18 | The moving averaged episode rewards are shown as below: 19 | 20 | ![dqn](img/dqn.png) 21 | ![ddpg](img/ddpg.png) 22 | ![ppo_d](img/ppo_d.png) 23 | ![ppo](img/ppo.png) 24 | 25 | The heatmaps of value and action are shown as below: 26 | 27 | ![dqn_heatmap](img/dqn_heatmap.png) 28 | ![ddpg](img/ddpg_heatmap.png) 29 | ![ppo_d_heatmap](img/ppo_d_heatmap.png) 30 | ![ppo_heatmap](img/ppo_heatmap.png) 31 | 32 | From the results, we find that value-based algorithums are data-efficient for they are off-policy. Discretized action space is easier to train but the result looks 33 | ugly (trembling). 34 | 35 | ## Reference 36 | - [Deep Reinforcement Learning with Double Q-learning](https://arxiv.org/abs/1509.06461) 37 | - [Continuous control with deep reinforcement learning](https://arxiv.org/abs/1509.02971) 38 | - [Proximal Policy Optimization Algorithms](https://arxiv.org/abs/1707.06347) -------------------------------------------------------------------------------- /ddpg.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import pickle 3 | from collections import namedtuple 4 | 5 | import matplotlib.pyplot as plt 6 | import numpy as np 7 | 8 | import gym 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | import torch.optim as optim 13 | from torch.distributions import Normal 14 | 15 | parser = argparse.ArgumentParser(description='Solve the Pendulum-v0 with DDPG') 16 | parser.add_argument( 17 | '--gamma', type=float, default=0.9, metavar='G', help='discount factor (default: 0.9)') 18 | 19 | parser.add_argument('--seed', type=int, default=0, metavar='N', help='random seed (default: 0)') 20 | parser.add_argument('--render', action='store_true', help='render the environment') 21 | parser.add_argument( 22 | '--log-interval', 23 | type=int, 24 | default=10, 25 | metavar='N', 26 | help='interval between training status logs (default: 10)') 27 | args = parser.parse_args() 28 | 29 | torch.manual_seed(args.seed) 30 | np.random.seed(args.seed) 31 | 32 | TrainingRecord = namedtuple('TrainingRecord', ['ep', 'reward']) 33 | Transition = namedtuple('Transition', ['s', 'a', 'r', 's_']) 34 | 35 | 36 | class ActorNet(nn.Module): 37 | 38 | def __init__(self): 39 | super(ActorNet, self).__init__() 40 | self.fc = nn.Linear(3, 100) 41 | self.mu_head = nn.Linear(100, 1) 42 | 43 | def forward(self, s): 44 | x = F.relu(self.fc(s)) 45 | u = 2.0 * F.tanh(self.mu_head(x)) 46 | return u 47 | 48 | 49 | class CriticNet(nn.Module): 50 | 51 | def __init__(self): 52 | super(CriticNet, self).__init__() 53 | self.fc = nn.Linear(4, 100) 54 | self.v_head = nn.Linear(100, 1) 55 | 56 | def forward(self, s, a): 57 | x = F.relu(self.fc(torch.cat([s, a], dim=1))) 58 | state_value = self.v_head(x) 59 | return state_value 60 | 61 | 62 | class Memory(): 63 | 64 | data_pointer = 0 65 | isfull = False 66 | 67 | def __init__(self, capacity): 68 | self.memory = np.empty(capacity, dtype=object) 69 | self.capacity = capacity 70 | 71 | def update(self, transition): 72 | self.memory[self.data_pointer] = transition 73 | self.data_pointer += 1 74 | if self.data_pointer == self.capacity: 75 | self.data_pointer = 0 76 | self.isfull = True 77 | 78 | def sample(self, batch_size): 79 | return np.random.choice(self.memory, batch_size) 80 | 81 | 82 | class Agent(): 83 | 84 | max_grad_norm = 0.5 85 | 86 | def __init__(self): 87 | self.training_step = 0 88 | self.var = 1. 89 | self.eval_cnet, self.target_cnet = CriticNet().float(), CriticNet().float() 90 | self.eval_anet, self.target_anet = ActorNet().float(), ActorNet().float() 91 | self.memory = Memory(2000) 92 | self.optimizer_c = optim.Adam(self.eval_cnet.parameters(), lr=1e-3) 93 | self.optimizer_a = optim.Adam(self.eval_anet.parameters(), lr=3e-4) 94 | 95 | def select_action(self, state): 96 | state = torch.from_numpy(state).float().unsqueeze(0) 97 | mu = self.eval_anet(state) 98 | dist = Normal(mu, torch.tensor(self.var, dtype=torch.float)) 99 | action = dist.sample() 100 | action.clamp(-2.0, 2.0) 101 | return (action.item(),) 102 | 103 | def save_param(self): 104 | torch.save(self.eval_anet.state_dict(), 'param/ddpg_anet_params.pkl') 105 | torch.save(self.eval_cnet.state_dict(), 'param/ddpg_cnet_params.pkl') 106 | 107 | def store_transition(self, transition): 108 | self.memory.update(transition) 109 | 110 | def update(self): 111 | self.training_step += 1 112 | 113 | transitions = self.memory.sample(32) 114 | s = torch.tensor([t.s for t in transitions], dtype=torch.float) 115 | a = torch.tensor([t.a for t in transitions], dtype=torch.float).view(-1, 1) 116 | r = torch.tensor([t.r for t in transitions], dtype=torch.float).view(-1, 1) 117 | s_ = torch.tensor([t.s_ for t in transitions], dtype=torch.float) 118 | 119 | with torch.no_grad(): 120 | q_target = r + args.gamma * self.target_cnet(s_, self.target_anet(s_)) 121 | q_eval = self.eval_cnet(s, a) 122 | 123 | # update critic net 124 | self.optimizer_c.zero_grad() 125 | c_loss = F.smooth_l1_loss(q_eval, q_target) 126 | c_loss.backward() 127 | nn.utils.clip_grad_norm_(self.eval_cnet.parameters(), self.max_grad_norm) 128 | self.optimizer_c.step() 129 | 130 | # update actor net 131 | self.optimizer_a.zero_grad() 132 | a_loss = -self.eval_cnet(s, self.eval_anet(s)).mean() 133 | a_loss.backward() 134 | nn.utils.clip_grad_norm_(self.eval_anet.parameters(), self.max_grad_norm) 135 | self.optimizer_a.step() 136 | 137 | if self.training_step % 200 == 0: 138 | self.target_cnet.load_state_dict(self.eval_cnet.state_dict()) 139 | if self.training_step % 201 == 0: 140 | self.target_anet.load_state_dict(self.eval_anet.state_dict()) 141 | 142 | self.var = max(self.var * 0.999, 0.01) 143 | 144 | return q_eval.mean().item() 145 | 146 | 147 | def main(): 148 | env = gym.make('Pendulum-v0') 149 | env.seed(args.seed) 150 | 151 | agent = Agent() 152 | 153 | training_records = [] 154 | running_reward, running_q = -1000, 0 155 | for i_ep in range(1000): 156 | score = 0 157 | state = env.reset() 158 | 159 | for t in range(200): 160 | action = agent.select_action(state) 161 | state_, reward, done, _ = env.step(action) 162 | score += reward 163 | if args.render: 164 | env.render() 165 | agent.store_transition(Transition(state, action, (reward + 8) / 8, state_)) 166 | state = state_ 167 | if agent.memory.isfull: 168 | q = agent.update() 169 | running_q = 0.99 * running_q + 0.01 * q 170 | 171 | running_reward = running_reward * 0.9 + score * 0.1 172 | training_records.append(TrainingRecord(i_ep, running_reward)) 173 | 174 | if i_ep % args.log_interval == 0: 175 | print('Step {}\tAverage score: {:.2f}\tAverage Q: {:.2f}'.format( 176 | i_ep, running_reward, running_q)) 177 | if running_reward > -200: 178 | print("Solved! Running reward is now {}!".format(running_reward)) 179 | env.close() 180 | agent.save_param() 181 | with open('log/ddpg_training_records.pkl', 'wb') as f: 182 | pickle.dump(training_records, f) 183 | break 184 | 185 | env.close() 186 | 187 | plt.plot([r.ep for r in training_records], [r.reward for r in training_records]) 188 | plt.title('DDPG') 189 | plt.xlabel('Episode') 190 | plt.ylabel('Moving averaged episode reward') 191 | plt.savefig("img/ddpg.png") 192 | plt.show() 193 | 194 | 195 | if __name__ == '__main__': 196 | main() 197 | -------------------------------------------------------------------------------- /dqn.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import pickle 3 | from collections import namedtuple 4 | 5 | import matplotlib.pyplot as plt 6 | import numpy as np 7 | 8 | import gym 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | import torch.optim as optim 13 | 14 | parser = argparse.ArgumentParser(description='Solve the Pendulum-v0 with DQN') 15 | parser.add_argument( 16 | '--gamma', type=float, default=0.9, metavar='G', help='discount factor (default: 0.9)') 17 | parser.add_argument( 18 | '--num_actions', type=int, default=5, metavar='N', help='discretize action space (default: 5)') 19 | parser.add_argument('--seed', type=int, default=0, metavar='N', help='random seed (default: 0)') 20 | parser.add_argument('--render', action='store_true', help='render the environment') 21 | parser.add_argument( 22 | '--log-interval', 23 | type=int, 24 | default=10, 25 | metavar='N', 26 | help='interval between training status logs (default: 10)') 27 | args = parser.parse_args() 28 | 29 | torch.manual_seed(args.seed) 30 | np.random.seed(args.seed) 31 | 32 | TrainingRecord = namedtuple('TrainingRecord', ['ep', 'reward']) 33 | Transition = namedtuple('Transition', ['s', 'a', 'r', 's_']) 34 | 35 | 36 | class Net(nn.Module): 37 | 38 | def __init__(self): 39 | super(Net, self).__init__() 40 | self.fc = nn.Linear(3, 100) 41 | self.a_head = nn.Linear(100, args.num_actions) 42 | self.v_head = nn.Linear(100, 1) 43 | 44 | def forward(self, x): 45 | x = F.tanh(self.fc(x)) 46 | a = self.a_head(x) - self.a_head(x).mean(1, keepdim=True) 47 | v = self.v_head(x) 48 | action_scores = a + v 49 | return action_scores 50 | 51 | 52 | class Memory(): 53 | 54 | data_pointer = 0 55 | isfull = False 56 | 57 | def __init__(self, capacity): 58 | self.memory = np.empty(capacity, dtype=object) 59 | self.capacity = capacity 60 | 61 | def update(self, transition): 62 | self.memory[self.data_pointer] = transition 63 | self.data_pointer += 1 64 | if self.data_pointer == self.capacity: 65 | self.data_pointer = 0 66 | self.isfull = True 67 | 68 | def sample(self, batch_size): 69 | return np.random.choice(self.memory, batch_size) 70 | 71 | 72 | class Agent(): 73 | 74 | action_list = [(i * 4 - 2,) for i in range(args.num_actions)] 75 | max_grad_norm = 0.5 76 | 77 | def __init__(self): 78 | self.training_step = 0 79 | self.epsilon = 1 80 | self.eval_net, self.target_net = Net().float(), Net().float() 81 | self.memory = Memory(2000) 82 | self.optimizer = optim.Adam(self.eval_net.parameters(), lr=1e-3) 83 | 84 | def select_action(self, state): 85 | state = torch.from_numpy(state).float().unsqueeze(0) 86 | if np.random.random() < self.epsilon: 87 | action_index = np.random.randint(args.num_actions) 88 | else: 89 | probs = agent.eval_net(state) 90 | action_index = probs.max(1)[1].item() 91 | return self.action_list[action_index], action_index 92 | 93 | def save_param(self): 94 | torch.save(self.eval_net.state_dict(), 'param/dqn_net_params.pkl') 95 | 96 | def store_transition(self, transition): 97 | self.memory.update(transition) 98 | 99 | def update(self): 100 | self.training_step += 1 101 | 102 | transitions = self.memory.sample(32) 103 | s = torch.tensor([t.s for t in transitions], dtype=torch.float) 104 | a = torch.tensor([t.a for t in transitions], dtype=torch.long).view(-1, 1) 105 | r = torch.tensor([t.r for t in transitions], dtype=torch.float).view(-1, 1) 106 | s_ = torch.tensor([t.s_ for t in transitions], dtype=torch.float) 107 | 108 | # natural dqn 109 | # q_eval = self.eval_net(s).gather(1, a) 110 | # with torch.no_grad(): 111 | # q_target = r + args.gamma * self.target_net(s_).max(1, keepdim=True)[0] 112 | 113 | # double dqn 114 | with torch.no_grad(): 115 | a_ = self.eval_net(s_).max(1, keepdim=True)[1] 116 | q_target = r + args.gamma * self.target_net(s_).gather(1, a_) 117 | q_eval = self.eval_net(s).gather(1, a) 118 | 119 | self.optimizer.zero_grad() 120 | loss = F.smooth_l1_loss(q_eval, q_target) 121 | loss.backward() 122 | nn.utils.clip_grad_norm_(self.eval_net.parameters(), self.max_grad_norm) 123 | self.optimizer.step() 124 | 125 | if self.training_step % 200 == 0: 126 | self.target_net.load_state_dict(self.eval_net.state_dict()) 127 | 128 | self.epsilon = max(self.epsilon * 0.999, 0.01) 129 | 130 | return q_eval.mean().item() 131 | 132 | 133 | def main(): 134 | env = gym.make('Pendulum-v0') 135 | env.seed(args.seed) 136 | 137 | agent = Agent() 138 | 139 | training_records = [] 140 | running_reward, running_q = -1000, 0 141 | for i_ep in range(100): 142 | score = 0 143 | state = env.reset() 144 | 145 | for t in range(200): 146 | action, action_index = agent.select_action(state) 147 | state_, reward, done, _ = env.step(action) 148 | score += reward 149 | if args.render: 150 | env.render() 151 | agent.store_transition(Transition(state, action_index, (reward + 8) / 8, state_)) 152 | state = state_ 153 | if agent.memory.isfull: 154 | q = agent.update() 155 | running_q = 0.99 * running_q + 0.01 * q 156 | 157 | running_reward = running_reward * 0.9 + score * 0.1 158 | training_records.append(TrainingRecord(i_ep, running_reward)) 159 | 160 | if i_ep % args.log_interval == 0: 161 | print('Ep {}\tAverage score: {:.2f}\tAverage Q: {:.2f}'.format( 162 | i_ep, running_reward, running_q)) 163 | if running_reward > -200: 164 | print("Solved! Running reward is now {}!".format(running_reward)) 165 | env.close() 166 | agent.save_param() 167 | with open('log/dqn_training_records.pkl', 'wb') as f: 168 | pickle.dump(training_records, f) 169 | break 170 | 171 | env.close() 172 | 173 | plt.plot([r.ep for r in training_records], [r.reward for r in training_records]) 174 | plt.title('DQN') 175 | plt.xlabel('Episode') 176 | plt.ylabel('Moving averaged episode reward') 177 | plt.savefig("img/dqn.png") 178 | plt.show() 179 | 180 | 181 | if __name__ == '__main__': 182 | main() 183 | -------------------------------------------------------------------------------- /img/ddpg.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xtma/simple-pytorch-rl/49432b8b771435028b3e2c6d17dc7c7ae957cf41/img/ddpg.png -------------------------------------------------------------------------------- /img/ddpg_heatmap.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xtma/simple-pytorch-rl/49432b8b771435028b3e2c6d17dc7c7ae957cf41/img/ddpg_heatmap.png -------------------------------------------------------------------------------- /img/dqn.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xtma/simple-pytorch-rl/49432b8b771435028b3e2c6d17dc7c7ae957cf41/img/dqn.png -------------------------------------------------------------------------------- /img/dqn_heatmap.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xtma/simple-pytorch-rl/49432b8b771435028b3e2c6d17dc7c7ae957cf41/img/dqn_heatmap.png -------------------------------------------------------------------------------- /img/ppo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xtma/simple-pytorch-rl/49432b8b771435028b3e2c6d17dc7c7ae957cf41/img/ppo.png -------------------------------------------------------------------------------- /img/ppo_d.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xtma/simple-pytorch-rl/49432b8b771435028b3e2c6d17dc7c7ae957cf41/img/ppo_d.png -------------------------------------------------------------------------------- /img/ppo_d_heatmap.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xtma/simple-pytorch-rl/49432b8b771435028b3e2c6d17dc7c7ae957cf41/img/ppo_d_heatmap.png -------------------------------------------------------------------------------- /img/ppo_heatmap.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xtma/simple-pytorch-rl/49432b8b771435028b3e2c6d17dc7c7ae957cf41/img/ppo_heatmap.png -------------------------------------------------------------------------------- /log/ddpg_training_records.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xtma/simple-pytorch-rl/49432b8b771435028b3e2c6d17dc7c7ae957cf41/log/ddpg_training_records.pkl -------------------------------------------------------------------------------- /log/dqn_training_records.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xtma/simple-pytorch-rl/49432b8b771435028b3e2c6d17dc7c7ae957cf41/log/dqn_training_records.pkl -------------------------------------------------------------------------------- /log/ppo_d_training_records.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xtma/simple-pytorch-rl/49432b8b771435028b3e2c6d17dc7c7ae957cf41/log/ppo_d_training_records.pkl -------------------------------------------------------------------------------- /log/ppo_training_records.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xtma/simple-pytorch-rl/49432b8b771435028b3e2c6d17dc7c7ae957cf41/log/ppo_training_records.pkl -------------------------------------------------------------------------------- /param/ddpg_anet_params.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xtma/simple-pytorch-rl/49432b8b771435028b3e2c6d17dc7c7ae957cf41/param/ddpg_anet_params.pkl -------------------------------------------------------------------------------- /param/ddpg_cnet_params.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xtma/simple-pytorch-rl/49432b8b771435028b3e2c6d17dc7c7ae957cf41/param/ddpg_cnet_params.pkl -------------------------------------------------------------------------------- /param/dqn_net_params.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xtma/simple-pytorch-rl/49432b8b771435028b3e2c6d17dc7c7ae957cf41/param/dqn_net_params.pkl -------------------------------------------------------------------------------- /param/ppo_anet_params.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xtma/simple-pytorch-rl/49432b8b771435028b3e2c6d17dc7c7ae957cf41/param/ppo_anet_params.pkl -------------------------------------------------------------------------------- /param/ppo_cnet_params.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xtma/simple-pytorch-rl/49432b8b771435028b3e2c6d17dc7c7ae957cf41/param/ppo_cnet_params.pkl -------------------------------------------------------------------------------- /param/ppo_d_anet_params.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xtma/simple-pytorch-rl/49432b8b771435028b3e2c6d17dc7c7ae957cf41/param/ppo_d_anet_params.pkl -------------------------------------------------------------------------------- /param/ppo_d_cnet_params.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xtma/simple-pytorch-rl/49432b8b771435028b3e2c6d17dc7c7ae957cf41/param/ppo_d_cnet_params.pkl -------------------------------------------------------------------------------- /plot_heatmap.py: -------------------------------------------------------------------------------- 1 | from matplotlib import pyplot as plt 2 | import numpy as np 3 | import torch 4 | 5 | # import parser 6 | # parser.add_argument('--algo', type=str, default='ppo') 7 | 8 | 9 | def dqn_heatmap(): 10 | from dqn import Net 11 | 12 | x_pxl, y_pxl = 300, 400 13 | 14 | state = torch.Tensor([[np.cos(theta), np.sin(theta), thetadot] 15 | for thetadot in np.linspace(-8, 8, y_pxl) 16 | for theta in np.linspace(-np.pi, np.pi, x_pxl)]) 17 | 18 | net = Net() 19 | net.load_state_dict(torch.load('param/dqn_net_params.pkl')) 20 | q = net(state) 21 | value_map = q.max(1)[0].view(y_pxl, x_pxl).detach().numpy() 22 | action_map = q.max(1)[1].view(y_pxl, x_pxl).detach().numpy() / 10 * 4 - 2 23 | 24 | fig = plt.figure() 25 | fig.suptitle('DQN') 26 | ax = fig.add_subplot(121) 27 | im = ax.imshow(value_map, cmap=plt.cm.spring, interpolation='bicubic') 28 | plt.colorbar(im, shrink=0.5) 29 | ax.set_title('Value Map') 30 | ax.set_xlabel('$\\theta$') 31 | ax.set_xticks(np.linspace(0, x_pxl, 5)) 32 | ax.set_xticklabels(['$-\\pi$', '$-\\pi/2$', '$0$', '$\\pi/2$', '$\\pi$']) 33 | ax.set_ylabel('$\\dot{\\theta}$') 34 | ax.set_yticks(np.linspace(0, y_pxl, 5)) 35 | ax.set_yticklabels(['-8', '-4', '0', '4', '8']) 36 | 37 | ax = fig.add_subplot(122) 38 | im = ax.imshow(action_map, cmap=plt.cm.winter, interpolation='bicubic') 39 | plt.colorbar(im, shrink=0.5) 40 | ax.set_title('Action Map') 41 | ax.set_xlabel('$\\theta$') 42 | ax.set_xticks(np.linspace(0, x_pxl, 5)) 43 | ax.set_xticklabels(['$-\\pi$', '$-\\pi/2$', '$0$', '$\\pi/2$', '$\\pi$']) 44 | ax.set_ylabel('$\\dot{\\theta}$') 45 | ax.set_yticks(np.linspace(0, y_pxl, 5)) 46 | ax.set_yticklabels(['-8', '-4', '0', '4', '8']) 47 | plt.tight_layout() 48 | plt.savefig('img/dqn_heatmap.png') 49 | plt.show() 50 | 51 | 52 | def ddpg_heatmap(): 53 | from ddpg import ActorNet, CriticNet 54 | 55 | x_pxl, y_pxl = 300, 400 56 | 57 | state = torch.Tensor([[np.cos(theta), np.sin(theta), thetadot] 58 | for thetadot in np.linspace(-8, 8, y_pxl) 59 | for theta in np.linspace(-np.pi, np.pi, x_pxl)]) 60 | 61 | anet = ActorNet() 62 | anet.load_state_dict(torch.load('param/ddpg_anet_params.pkl')) 63 | action_map = anet(state).view(y_pxl, x_pxl).detach().numpy() 64 | 65 | cnet = CriticNet() 66 | cnet.load_state_dict(torch.load('param/ddpg_cnet_params.pkl')) 67 | value_map = cnet(state, anet(state)).view(y_pxl, x_pxl).detach().numpy() 68 | 69 | fig = plt.figure() 70 | fig.suptitle('DDPG') 71 | ax = fig.add_subplot(121) 72 | im = ax.imshow(value_map, cmap=plt.cm.spring, interpolation='bicubic') 73 | plt.colorbar(im, shrink=0.5) 74 | ax.set_title('Value Map') 75 | ax.set_xlabel('$\\theta$') 76 | ax.set_xticks(np.linspace(0, x_pxl, 5)) 77 | ax.set_xticklabels(['$-\\pi$', '$-\\pi/2$', '$0$', '$\\pi/2$', '$\\pi$']) 78 | ax.set_ylabel('$\\dot{\\theta}$') 79 | ax.set_yticks(np.linspace(0, y_pxl, 5)) 80 | ax.set_yticklabels(['-8', '-4', '0', '4', '8']) 81 | 82 | ax = fig.add_subplot(122) 83 | im = ax.imshow(action_map, cmap=plt.cm.winter, interpolation='bicubic') 84 | plt.colorbar(im, shrink=0.5) 85 | ax.set_title('Action Map') 86 | ax.set_xlabel('$\\theta$') 87 | ax.set_xticks(np.linspace(0, x_pxl, 5)) 88 | ax.set_xticklabels(['$-\\pi$', '$-\\pi/2$', '$0$', '$\\pi/2$', '$\\pi$']) 89 | ax.set_ylabel('$\\dot{\\theta}$') 90 | ax.set_yticks(np.linspace(0, y_pxl, 5)) 91 | ax.set_yticklabels(['-8', '-4', '0', '4', '8']) 92 | plt.tight_layout() 93 | plt.savefig('img/ddpg_heatmap.png') 94 | plt.show() 95 | 96 | 97 | def ppo_heatmap(): 98 | from ppo import ActorNet, CriticNet 99 | 100 | x_pxl, y_pxl = 300, 400 101 | 102 | state = torch.Tensor([[np.cos(theta), np.sin(theta), thetadot] 103 | for thetadot in np.linspace(-8, 8, y_pxl) 104 | for theta in np.linspace(-np.pi, np.pi, x_pxl)]) 105 | cnet = CriticNet() 106 | cnet.load_state_dict(torch.load('param/ppo_cnet_params.pkl')) 107 | value_map = cnet(state).view(y_pxl, x_pxl).detach().numpy() 108 | 109 | anet = ActorNet() 110 | anet.load_state_dict(torch.load('param/ppo_anet_params.pkl')) 111 | action_map = anet(state)[0].view(y_pxl, x_pxl).detach().numpy() 112 | 113 | fig = plt.figure() 114 | fig.suptitle('PPO') 115 | ax = fig.add_subplot(121) 116 | im = ax.imshow(value_map, cmap=plt.cm.spring, interpolation='bicubic') 117 | plt.colorbar(im, shrink=0.5) 118 | ax.set_title('Value Map') 119 | ax.set_xlabel('$\\theta$') 120 | ax.set_xticks(np.linspace(0, x_pxl, 5)) 121 | ax.set_xticklabels(['$-\\pi$', '$-\\pi/2$', '$0$', '$\\pi/2$', '$\\pi$']) 122 | ax.set_ylabel('$\\dot{\\theta}$') 123 | ax.set_yticks(np.linspace(0, y_pxl, 5)) 124 | ax.set_yticklabels(['-8', '-4', '0', '4', '8']) 125 | 126 | ax = fig.add_subplot(122) 127 | im = ax.imshow(action_map, cmap=plt.cm.winter, interpolation='bicubic') 128 | plt.colorbar(im, shrink=0.5) 129 | ax.set_title('Action Map') 130 | ax.set_xlabel('$\\theta$') 131 | ax.set_xticks(np.linspace(0, x_pxl, 5)) 132 | ax.set_xticklabels(['$-\\pi$', '$-\\pi/2$', '$0$', '$\\pi/2$', '$\\pi$']) 133 | ax.set_ylabel('$\\dot{\\theta}$') 134 | ax.set_yticks(np.linspace(0, y_pxl, 5)) 135 | ax.set_yticklabels(['-8', '-4', '0', '4', '8']) 136 | plt.tight_layout() 137 | plt.savefig('img/ppo_heatmap.png') 138 | plt.show() 139 | 140 | 141 | def ppo_d_heatmap(): 142 | from ppo_d import ActorNet, CriticNet 143 | 144 | x_pxl, y_pxl = 300, 400 145 | 146 | state = torch.Tensor([[np.cos(theta), np.sin(theta), thetadot] 147 | for thetadot in np.linspace(-8, 8, y_pxl) 148 | for theta in np.linspace(-np.pi, np.pi, x_pxl)]) 149 | 150 | anet = ActorNet() 151 | anet.load_state_dict(torch.load('param/ppo_d_anet_params.pkl')) 152 | action_map = anet(state).max(1)[1].view(y_pxl, x_pxl).detach().numpy() / 10 * 4 - 2 153 | 154 | cnet = CriticNet() 155 | cnet.load_state_dict(torch.load('param/ppo_d_cnet_params.pkl')) 156 | value_map = cnet(state).view(y_pxl, x_pxl).detach().numpy() 157 | 158 | fig = plt.figure() 159 | fig.suptitle('PPO_d') 160 | ax = fig.add_subplot(121) 161 | im = ax.imshow(value_map, cmap=plt.cm.spring, interpolation='bicubic') 162 | plt.colorbar(im, shrink=0.5) 163 | ax.set_title('Value Map') 164 | ax.set_xlabel('$\\theta$') 165 | ax.set_xticks(np.linspace(0, x_pxl, 5)) 166 | ax.set_xticklabels(['$-\\pi$', '$-\\pi/2$', '$0$', '$\\pi/2$', '$\\pi$']) 167 | ax.set_ylabel('$\\dot{\\theta}$') 168 | ax.set_yticks(np.linspace(0, y_pxl, 5)) 169 | ax.set_yticklabels(['-8', '-4', '0', '4', '8']) 170 | 171 | ax = fig.add_subplot(122) 172 | im = ax.imshow(action_map, cmap=plt.cm.winter, interpolation='bicubic') 173 | plt.colorbar(im, shrink=0.5) 174 | ax.set_title('Action Map') 175 | ax.set_xlabel('$\\theta$') 176 | ax.set_xticks(np.linspace(0, x_pxl, 5)) 177 | ax.set_xticklabels(['$-\\pi$', '$-\\pi/2$', '$0$', '$\\pi/2$', '$\\pi$']) 178 | ax.set_ylabel('$\\dot{\\theta}$') 179 | ax.set_yticks(np.linspace(0, y_pxl, 5)) 180 | ax.set_yticklabels(['-8', '-4', '0', '4', '8']) 181 | plt.tight_layout() 182 | plt.savefig('img/ppo_d_heatmap.png') 183 | plt.show() 184 | 185 | 186 | dqn_heatmap() 187 | ddpg_heatmap() 188 | ppo_heatmap() 189 | ppo_d_heatmap() -------------------------------------------------------------------------------- /ppo.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import pickle 3 | from collections import namedtuple 4 | 5 | import matplotlib.pyplot as plt 6 | 7 | import gym 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | import torch.optim as optim 12 | from torch.distributions import Normal 13 | from torch.utils.data.sampler import BatchSampler, SubsetRandomSampler 14 | 15 | parser = argparse.ArgumentParser(description='Solve the Pendulum-v0 with PPO') 16 | parser.add_argument( 17 | '--gamma', type=float, default=0.9, metavar='G', help='discount factor (default: 0.9)') 18 | parser.add_argument('--seed', type=int, default=0, metavar='N', help='random seed (default: 0)') 19 | parser.add_argument('--render', action='store_true', help='render the environment') 20 | parser.add_argument( 21 | '--log-interval', 22 | type=int, 23 | default=10, 24 | metavar='N', 25 | help='interval between training status logs (default: 10)') 26 | args = parser.parse_args() 27 | 28 | torch.manual_seed(args.seed) 29 | 30 | TrainingRecord = namedtuple('TrainingRecord', ['ep', 'reward']) 31 | Transition = namedtuple('Transition', ['s', 'a', 'a_log_p', 'r', 's_']) 32 | 33 | 34 | class ActorNet(nn.Module): 35 | 36 | def __init__(self): 37 | super(ActorNet, self).__init__() 38 | self.fc = nn.Linear(3, 100) 39 | self.mu_head = nn.Linear(100, 1) 40 | self.sigma_head = nn.Linear(100, 1) 41 | 42 | def forward(self, x): 43 | x = F.relu(self.fc(x)) 44 | mu = 2.0 * F.tanh(self.mu_head(x)) 45 | sigma = F.softplus(self.sigma_head(x)) 46 | return (mu, sigma) 47 | 48 | 49 | class CriticNet(nn.Module): 50 | 51 | def __init__(self): 52 | super(CriticNet, self).__init__() 53 | self.fc = nn.Linear(3, 100) 54 | self.v_head = nn.Linear(100, 1) 55 | 56 | def forward(self, x): 57 | x = F.relu(self.fc(x)) 58 | state_value = self.v_head(x) 59 | return state_value 60 | 61 | 62 | class Agent(): 63 | 64 | clip_param = 0.2 65 | max_grad_norm = 0.5 66 | ppo_epoch = 10 67 | buffer_capacity, batch_size = 1000, 32 68 | 69 | def __init__(self): 70 | self.training_step = 0 71 | self.anet = ActorNet().float() 72 | self.cnet = CriticNet().float() 73 | self.buffer = [] 74 | self.counter = 0 75 | 76 | self.optimizer_a = optim.Adam(self.anet.parameters(), lr=1e-4) 77 | self.optimizer_c = optim.Adam(self.cnet.parameters(), lr=3e-4) 78 | 79 | def select_action(self, state): 80 | state = torch.from_numpy(state).float().unsqueeze(0) 81 | with torch.no_grad(): 82 | (mu, sigma) = self.anet(state) 83 | dist = Normal(mu, sigma) 84 | action = dist.sample() 85 | action_log_prob = dist.log_prob(action) 86 | action.clamp(-2.0, 2.0) 87 | return action.item(), action_log_prob.item() 88 | 89 | def get_value(self, state): 90 | 91 | state = torch.from_numpy(state).float().unsqueeze(0) 92 | with torch.no_grad(): 93 | state_value = self.cnet(state) 94 | return state_value.item() 95 | 96 | def save_param(self): 97 | torch.save(self.anet.state_dict(), 'param/ppo_anet_params.pkl') 98 | torch.save(self.cnet.state_dict(), 'param/ppo_cnet_params.pkl') 99 | 100 | def store(self, transition): 101 | self.buffer.append(transition) 102 | self.counter += 1 103 | return self.counter % self.buffer_capacity == 0 104 | 105 | def update(self): 106 | self.training_step += 1 107 | 108 | s = torch.tensor([t.s for t in self.buffer], dtype=torch.float) 109 | a = torch.tensor([t.a for t in self.buffer], dtype=torch.float).view(-1, 1) 110 | r = torch.tensor([t.r for t in self.buffer], dtype=torch.float).view(-1, 1) 111 | s_ = torch.tensor([t.s_ for t in self.buffer], dtype=torch.float) 112 | 113 | old_action_log_probs = torch.tensor( 114 | [t.a_log_p for t in self.buffer], dtype=torch.float).view(-1, 1) 115 | 116 | r = (r - r.mean()) / (r.std() + 1e-5) 117 | with torch.no_grad(): 118 | target_v = r + args.gamma * self.cnet(s_) 119 | 120 | adv = (target_v - self.cnet(s)).detach() 121 | 122 | for _ in range(self.ppo_epoch): 123 | for index in BatchSampler( 124 | SubsetRandomSampler(range(self.buffer_capacity)), self.batch_size, False): 125 | 126 | (mu, sigma) = self.anet(s[index]) 127 | dist = Normal(mu, sigma) 128 | action_log_probs = dist.log_prob(a[index]) 129 | ratio = torch.exp(action_log_probs - old_action_log_probs[index]) 130 | 131 | surr1 = ratio * adv[index] 132 | surr2 = torch.clamp(ratio, 1.0 - self.clip_param, 133 | 1.0 + self.clip_param) * adv[index] 134 | action_loss = -torch.min(surr1, surr2).mean() 135 | 136 | self.optimizer_a.zero_grad() 137 | action_loss.backward() 138 | nn.utils.clip_grad_norm_(self.anet.parameters(), self.max_grad_norm) 139 | self.optimizer_a.step() 140 | 141 | value_loss = F.smooth_l1_loss(self.cnet(s[index]), target_v[index]) 142 | self.optimizer_c.zero_grad() 143 | value_loss.backward() 144 | nn.utils.clip_grad_norm_(self.cnet.parameters(), self.max_grad_norm) 145 | self.optimizer_c.step() 146 | 147 | del self.buffer[:] 148 | 149 | 150 | def main(): 151 | env = gym.make('Pendulum-v0') 152 | env.seed(args.seed) 153 | 154 | agent = Agent() 155 | 156 | training_records = [] 157 | running_reward = -1000 158 | state = env.reset() 159 | for i_ep in range(1000): 160 | score = 0 161 | state = env.reset() 162 | 163 | for t in range(200): 164 | action, action_log_prob = agent.select_action(state) 165 | state_, reward, done, _ = env.step([action]) 166 | if args.render: 167 | env.render() 168 | if agent.store(Transition(state, action, action_log_prob, (reward + 8) / 8, state_)): 169 | agent.update() 170 | score += reward 171 | state = state_ 172 | 173 | running_reward = running_reward * 0.9 + score * 0.1 174 | training_records.append(TrainingRecord(i_ep, running_reward)) 175 | 176 | if i_ep % args.log_interval == 0: 177 | print('Ep {}\tMoving average score: {:.2f}\t'.format(i_ep, running_reward)) 178 | if running_reward > -200: 179 | print("Solved! Moving average score is now {}!".format(running_reward)) 180 | env.close() 181 | agent.save_param() 182 | with open('log/ppo_training_records.pkl', 'wb') as f: 183 | pickle.dump(training_records, f) 184 | break 185 | 186 | plt.plot([r.ep for r in training_records], [r.reward for r in training_records]) 187 | plt.title('PPO') 188 | plt.xlabel('Episode') 189 | plt.ylabel('Moving averaged episode reward') 190 | plt.savefig("img/ppo.png") 191 | plt.show() 192 | 193 | 194 | if __name__ == '__main__': 195 | main() 196 | -------------------------------------------------------------------------------- /ppo_d.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import pickle 3 | from collections import namedtuple 4 | 5 | import matplotlib.pyplot as plt 6 | 7 | import gym 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | import torch.optim as optim 12 | from torch.distributions import Categorical 13 | from torch.utils.data.sampler import BatchSampler, SubsetRandomSampler 14 | 15 | parser = argparse.ArgumentParser(description='Solve the Pendulum-v0 with PPO (discrete)') 16 | parser.add_argument( 17 | '--gamma', type=float, default=0.9, metavar='G', help='discount factor (default: 0.9)') 18 | parser.add_argument( 19 | '--num-actions', type=int, default=11, metavar='N', help='discretize action space (default:11)') 20 | parser.add_argument('--seed', type=int, default=0, metavar='N', help='random seed (default: 0)') 21 | parser.add_argument('--render', action='store_true', help='render the environment') 22 | parser.add_argument( 23 | '--log-interval', 24 | type=int, 25 | default=10, 26 | metavar='N', 27 | help='interval between training status logs (default: 10)') 28 | args = parser.parse_args() 29 | 30 | torch.manual_seed(args.seed) 31 | 32 | TrainingRecord = namedtuple('TrainingRecord', ['ep', 'reward']) 33 | Transition = namedtuple('Transition', ['s', 'a', 'a_p', 'r', 's_']) 34 | 35 | 36 | class ActorNet(nn.Module): 37 | 38 | def __init__(self): 39 | super(ActorNet, self).__init__() 40 | self.fc = nn.Linear(3, 100) 41 | self.a_head = nn.Linear(100, args.num_actions) 42 | 43 | def forward(self, x): 44 | x = F.relu(self.fc(x)) 45 | action_score = self.a_head(x) 46 | return F.softmax(action_score, dim=-1) 47 | 48 | 49 | class CriticNet(nn.Module): 50 | 51 | def __init__(self): 52 | super(CriticNet, self).__init__() 53 | self.fc = nn.Linear(3, 100) 54 | self.a_head = nn.Linear(100, args.num_actions) 55 | self.v_head = nn.Linear(100, 1) 56 | 57 | def forward(self, x): 58 | x = F.relu(self.fc(x)) 59 | state_value = self.v_head(x) 60 | return state_value 61 | 62 | 63 | class Agent(): 64 | 65 | action_list = [(i * 4 - 2,) for i in range(args.num_actions)] 66 | clip_param = 0.2 67 | max_grad_norm = 0.5 68 | ppo_epoch = 10 69 | buffer_capacity, batch_size = 1000, 32 70 | 71 | def __init__(self): 72 | self.training_step = 0 73 | self.anet = ActorNet().float() 74 | self.cnet = CriticNet().float() 75 | 76 | self.buffer = [] 77 | self.counter = 0 78 | 79 | self.optimizer_a = optim.Adam(self.anet.parameters(), lr=1e-3) 80 | self.optimizer_c = optim.Adam(self.cnet.parameters(), lr=3e-3) 81 | 82 | def select_action(self, state): 83 | state = torch.from_numpy(state).float().unsqueeze(0) 84 | probs = self.anet(state) 85 | m = Categorical(probs) 86 | action = m.sample() 87 | return self.action_list[action.item()], action.item(), probs[:, action].item() 88 | 89 | def store(self, transition): 90 | self.buffer.append(transition) 91 | self.counter += 1 92 | return self.counter % self.buffer_capacity == 0 93 | 94 | def save_param(self): 95 | torch.save(self.anet.state_dict(), 'param/ppo_d_anet_params.pkl') 96 | torch.save(self.cnet.state_dict(), 'param/ppo_d_cnet_params.pkl') 97 | 98 | def update(self): 99 | self.training_step += 1 100 | 101 | s = torch.tensor([t.s for t in self.buffer], dtype=torch.float) 102 | a = torch.tensor([t.a for t in self.buffer], dtype=torch.long).view(-1, 1) 103 | r = torch.tensor([t.r for t in self.buffer], dtype=torch.float).view(-1, 1) 104 | s_ = torch.tensor([t.s_ for t in self.buffer], dtype=torch.float) 105 | old_action_probs = torch.tensor([t.a_p for t in self.buffer], dtype=torch.float).view(-1, 1) 106 | 107 | r = (r - r.mean()) / (r.std() + 1e-5) 108 | 109 | with torch.no_grad(): 110 | target_v = r + args.gamma * self.cnet(s_) 111 | adv = (target_v - self.cnet(s)).detach() 112 | 113 | for _ in range(self.ppo_epoch): 114 | for index in BatchSampler( 115 | SubsetRandomSampler(range(self.buffer_capacity)), self.batch_size, False): 116 | action_probs = self.anet(s[index]).gather(1, a[index]) 117 | ratio = action_probs / old_action_probs[index] 118 | 119 | surr1 = ratio * adv[index] 120 | surr2 = torch.clamp(ratio, 1.0 - self.clip_param, 121 | 1.0 + self.clip_param) * adv[index] 122 | action_loss = -torch.min(surr1, surr2).mean() 123 | self.optimizer_a.zero_grad() 124 | action_loss.backward() 125 | nn.utils.clip_grad_norm_(self.anet.parameters(), self.max_grad_norm) 126 | self.optimizer_a.step() 127 | 128 | value_loss = F.smooth_l1_loss(self.cnet(s[index]), target_v[index]) 129 | self.optimizer_c.zero_grad() 130 | value_loss.backward() 131 | nn.utils.clip_grad_norm_(self.cnet.parameters(), self.max_grad_norm) 132 | self.optimizer_c.step() 133 | 134 | del self.buffer[:] 135 | 136 | 137 | def main(): 138 | env = gym.make('Pendulum-v0') 139 | env.seed(args.seed) 140 | 141 | agent = Agent() 142 | 143 | training_records = [] 144 | running_reward = -1000 145 | state = env.reset() 146 | for i_ep in range(1000): 147 | score = 0 148 | state = env.reset() 149 | 150 | for t in range(200): 151 | action, action_index, action_prob = agent.select_action(state) 152 | state_, reward, done, _ = env.step(action) 153 | if args.render: 154 | env.render() 155 | if agent.store(Transition(state, action_index, action_prob, (reward + 8) / 8, state_)): 156 | agent.update() 157 | score += reward 158 | state = state_ 159 | 160 | running_reward = running_reward * 0.9 + score * 0.1 161 | training_records.append(TrainingRecord(i_ep, running_reward)) 162 | 163 | if i_ep % args.log_interval == 0: 164 | print('Ep {}\tMoving average score: {:.2f}\t'.format(i_ep, running_reward)) 165 | if running_reward > -200: 166 | print("Solved! Moving average score is now {}!".format(running_reward)) 167 | env.close() 168 | agent.save_param() 169 | with open('log/ppo_d_training_records.pkl', 'wb') as f: 170 | pickle.dump(training_records, f) 171 | break 172 | 173 | plt.plot([r.ep for r in training_records], [r.reward for r in training_records]) 174 | plt.title('PPO (discrete)') 175 | plt.xlabel('Episode') 176 | plt.ylabel('Moving averaged episode reward') 177 | plt.savefig("img/ppo_d.png") 178 | plt.show() 179 | 180 | 181 | if __name__ == '__main__': 182 | main() 183 | --------------------------------------------------------------------------------