├── LICENSE ├── README.md ├── continuous_A3C.py ├── discrete_A3C.py ├── results ├── cartpole.png └── pendulum.png ├── shared_adam.py └── utils.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2017 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Simple implementation of Reinforcement Learning (A3C) using Pytorch 2 | 3 | This is a toy example of using multiprocessing in Python to asynchronously train a 4 | neural network to play discrete action [CartPole](https://gym.openai.com/envs/CartPole-v0/) and 5 | continuous action [Pendulum](https://gym.openai.com/envs/Pendulum-v0/) games. 6 | The asynchronous algorithm I used is called [Asynchronous Advantage Actor-Critic](https://arxiv.org/pdf/1602.01783.pdf) or A3C. 7 | 8 | I believe it would be the simplest toy implementation you can find at the moment (2018-01). 9 | 10 | ## What are the main focuses in this implementation? 11 | 12 | * Pytorch + multiprocessing (NOT threading) for parallel training 13 | * Both discrete and continuous action environments 14 | * To be simple and easy to dig into the code (less than 200 lines) 15 | 16 | ## Reason of using [Pytorch](http://pytorch.org/) instead of [Tensorflow](https://www.tensorflow.org/) 17 | 18 | Both of them are great for building your customized neural network. But to work 19 | with multiprocessing, Tensorflow is not that great due to its low compatibility with multiprocessing. 20 | I have an implementation of [Tensorflow A3C build on threading](https://github.com/MorvanZhou/Reinforcement-learning-with-tensorflow/tree/master/contents/10_A3C). 21 | I even tried to implement [distributed Tensorflow](https://github.com/MorvanZhou/Reinforcement-learning-with-tensorflow/blob/master/contents/10_A3C/A3C_distributed_tf.py). 22 | However, the distributed version is for cluster computing which I don't have. 23 | When using only one machine, it is slower than threading version I wrote. 24 | 25 | Fortunately, Pytorch gets the [multiprocessing compatibility](http://pytorch.org/docs/master/notes/multiprocessing.html). 26 | I went through many Pytorch A3C examples ([there](https://github.com/ikostrikov/pytorch-a3c), [there](https://github.com/jingweiz/pytorch-rl) 27 | and [there](https://github.com/ShangtongZhang/DeepRL)). They are great but too complicated to dig into the code. 28 | Therefore, this is my motivation to write my simple example codes. 29 | 30 | BTW, if you are interested to learn Pytorch, [there](https://github.com/MorvanZhou/PyTorch-Tutorial) 31 | is my simple tutorial code with many visualizations. I also made the tensorflow tutorial (same as pytorch) available in [here](https://github.com/MorvanZhou/Tensorflow-Tutorial). 32 | 33 | ## Codes & Results 34 | 35 | * [shared_adam.py](/shared_adam.py): optimizer that shares its parameters in parallel 36 | * [utils.py](/utils.py): useful function that can be used more than once 37 | * [discrete_A3C.py](/discrete_A3C.py): CartPole, neural net and training for discrete action space 38 | * [continuous_A3C.py](/continuous_A3C.py): Pendulum, neural net and training for continuous action space 39 | 40 | CartPole result 41 | ![cartpole](/results/cartpole.png) 42 | 43 | Pendulum result 44 | ![pendulum](/results/pendulum.png) 45 | 46 | ## Dependencies 47 | 48 | * pytorch >= 0.4.0 49 | * numpy 50 | * gym 51 | * matplotlib 52 | -------------------------------------------------------------------------------- /continuous_A3C.py: -------------------------------------------------------------------------------- 1 | """ 2 | Reinforcement Learning (A3C) using Pytroch + multiprocessing. 3 | The most simple implementation for continuous action. 4 | 5 | View more on my Chinese tutorial page [莫烦Python](https://morvanzhou.github.io/). 6 | """ 7 | 8 | import torch 9 | import torch.nn as nn 10 | from utils import v_wrap, set_init, push_and_pull, record 11 | import torch.nn.functional as F 12 | import torch.multiprocessing as mp 13 | from shared_adam import SharedAdam 14 | import gym 15 | import math, os 16 | os.environ["OMP_NUM_THREADS"] = "1" 17 | 18 | UPDATE_GLOBAL_ITER = 5 19 | GAMMA = 0.9 20 | MAX_EP = 3000 21 | MAX_EP_STEP = 200 22 | 23 | env = gym.make('Pendulum-v0') 24 | N_S = env.observation_space.shape[0] 25 | N_A = env.action_space.shape[0] 26 | 27 | 28 | class Net(nn.Module): 29 | def __init__(self, s_dim, a_dim): 30 | super(Net, self).__init__() 31 | self.s_dim = s_dim 32 | self.a_dim = a_dim 33 | self.a1 = nn.Linear(s_dim, 200) 34 | self.mu = nn.Linear(200, a_dim) 35 | self.sigma = nn.Linear(200, a_dim) 36 | self.c1 = nn.Linear(s_dim, 100) 37 | self.v = nn.Linear(100, 1) 38 | set_init([self.a1, self.mu, self.sigma, self.c1, self.v]) 39 | self.distribution = torch.distributions.Normal 40 | 41 | def forward(self, x): 42 | a1 = F.relu6(self.a1(x)) 43 | mu = 2 * F.tanh(self.mu(a1)) 44 | sigma = F.softplus(self.sigma(a1)) + 0.001 # avoid 0 45 | c1 = F.relu6(self.c1(x)) 46 | values = self.v(c1) 47 | return mu, sigma, values 48 | 49 | def choose_action(self, s): 50 | self.training = False 51 | mu, sigma, _ = self.forward(s) 52 | m = self.distribution(mu.view(1, ).data, sigma.view(1, ).data) 53 | return m.sample().numpy() 54 | 55 | def loss_func(self, s, a, v_t): 56 | self.train() 57 | mu, sigma, values = self.forward(s) 58 | td = v_t - values 59 | c_loss = td.pow(2) 60 | 61 | m = self.distribution(mu, sigma) 62 | log_prob = m.log_prob(a) 63 | entropy = 0.5 + 0.5 * math.log(2 * math.pi) + torch.log(m.scale) # exploration 64 | exp_v = log_prob * td.detach() + 0.005 * entropy 65 | a_loss = -exp_v 66 | total_loss = (a_loss + c_loss).mean() 67 | return total_loss 68 | 69 | 70 | class Worker(mp.Process): 71 | def __init__(self, gnet, opt, global_ep, global_ep_r, res_queue, name): 72 | super(Worker, self).__init__() 73 | self.name = 'w%i' % name 74 | self.g_ep, self.g_ep_r, self.res_queue = global_ep, global_ep_r, res_queue 75 | self.gnet, self.opt = gnet, opt 76 | self.lnet = Net(N_S, N_A) # local network 77 | self.env = gym.make('Pendulum-v0').unwrapped 78 | 79 | def run(self): 80 | total_step = 1 81 | while self.g_ep.value < MAX_EP: 82 | s = self.env.reset() 83 | buffer_s, buffer_a, buffer_r = [], [], [] 84 | ep_r = 0. 85 | for t in range(MAX_EP_STEP): 86 | if self.name == 'w0': 87 | self.env.render() 88 | a = self.lnet.choose_action(v_wrap(s[None, :])) 89 | s_, r, done, _ = self.env.step(a.clip(-2, 2)) 90 | if t == MAX_EP_STEP - 1: 91 | done = True 92 | ep_r += r 93 | buffer_a.append(a) 94 | buffer_s.append(s) 95 | buffer_r.append((r+8.1)/8.1) # normalize 96 | 97 | if total_step % UPDATE_GLOBAL_ITER == 0 or done: # update global and assign to local net 98 | # sync 99 | push_and_pull(self.opt, self.lnet, self.gnet, done, s_, buffer_s, buffer_a, buffer_r, GAMMA) 100 | buffer_s, buffer_a, buffer_r = [], [], [] 101 | 102 | if done: # done and print information 103 | record(self.g_ep, self.g_ep_r, ep_r, self.res_queue, self.name) 104 | break 105 | s = s_ 106 | total_step += 1 107 | 108 | self.res_queue.put(None) 109 | 110 | 111 | if __name__ == "__main__": 112 | gnet = Net(N_S, N_A) # global network 113 | gnet.share_memory() # share the global parameters in multiprocessing 114 | opt = SharedAdam(gnet.parameters(), lr=1e-4, betas=(0.95, 0.999)) # global optimizer 115 | global_ep, global_ep_r, res_queue = mp.Value('i', 0), mp.Value('d', 0.), mp.Queue() 116 | 117 | # parallel training 118 | workers = [Worker(gnet, opt, global_ep, global_ep_r, res_queue, i) for i in range(mp.cpu_count())] 119 | [w.start() for w in workers] 120 | res = [] # record episode reward to plot 121 | while True: 122 | r = res_queue.get() 123 | if r is not None: 124 | res.append(r) 125 | else: 126 | break 127 | [w.join() for w in workers] 128 | 129 | import matplotlib.pyplot as plt 130 | plt.plot(res) 131 | plt.ylabel('Moving average ep reward') 132 | plt.xlabel('Step') 133 | plt.show() 134 | -------------------------------------------------------------------------------- /discrete_A3C.py: -------------------------------------------------------------------------------- 1 | """ 2 | Reinforcement Learning (A3C) using Pytroch + multiprocessing. 3 | The most simple implementation for continuous action. 4 | 5 | View more on my Chinese tutorial page [莫烦Python](https://morvanzhou.github.io/). 6 | """ 7 | 8 | import torch 9 | import torch.nn as nn 10 | from utils import v_wrap, set_init, push_and_pull, record 11 | import torch.nn.functional as F 12 | import torch.multiprocessing as mp 13 | from shared_adam import SharedAdam 14 | import gym 15 | import os 16 | os.environ["OMP_NUM_THREADS"] = "1" 17 | 18 | UPDATE_GLOBAL_ITER = 5 19 | GAMMA = 0.9 20 | MAX_EP = 3000 21 | 22 | env = gym.make('CartPole-v0') 23 | N_S = env.observation_space.shape[0] 24 | N_A = env.action_space.n 25 | 26 | 27 | class Net(nn.Module): 28 | def __init__(self, s_dim, a_dim): 29 | super(Net, self).__init__() 30 | self.s_dim = s_dim 31 | self.a_dim = a_dim 32 | self.pi1 = nn.Linear(s_dim, 128) 33 | self.pi2 = nn.Linear(128, a_dim) 34 | self.v1 = nn.Linear(s_dim, 128) 35 | self.v2 = nn.Linear(128, 1) 36 | set_init([self.pi1, self.pi2, self.v1, self.v2]) 37 | self.distribution = torch.distributions.Categorical 38 | 39 | def forward(self, x): 40 | pi1 = torch.tanh(self.pi1(x)) 41 | logits = self.pi2(pi1) 42 | v1 = torch.tanh(self.v1(x)) 43 | values = self.v2(v1) 44 | return logits, values 45 | 46 | def choose_action(self, s): 47 | self.eval() 48 | logits, _ = self.forward(s) 49 | prob = F.softmax(logits, dim=1).data 50 | m = self.distribution(prob) 51 | return m.sample().numpy()[0] 52 | 53 | def loss_func(self, s, a, v_t): 54 | self.train() 55 | logits, values = self.forward(s) 56 | td = v_t - values 57 | c_loss = td.pow(2) 58 | 59 | probs = F.softmax(logits, dim=1) 60 | m = self.distribution(probs) 61 | exp_v = m.log_prob(a) * td.detach().squeeze() 62 | a_loss = -exp_v 63 | total_loss = (c_loss + a_loss).mean() 64 | return total_loss 65 | 66 | 67 | class Worker(mp.Process): 68 | def __init__(self, gnet, opt, global_ep, global_ep_r, res_queue, name): 69 | super(Worker, self).__init__() 70 | self.name = 'w%02i' % name 71 | self.g_ep, self.g_ep_r, self.res_queue = global_ep, global_ep_r, res_queue 72 | self.gnet, self.opt = gnet, opt 73 | self.lnet = Net(N_S, N_A) # local network 74 | self.env = gym.make('CartPole-v0').unwrapped 75 | 76 | def run(self): 77 | total_step = 1 78 | while self.g_ep.value < MAX_EP: 79 | s = self.env.reset() 80 | buffer_s, buffer_a, buffer_r = [], [], [] 81 | ep_r = 0. 82 | while True: 83 | if self.name == 'w00': 84 | self.env.render() 85 | a = self.lnet.choose_action(v_wrap(s[None, :])) 86 | s_, r, done, _ = self.env.step(a) 87 | if done: r = -1 88 | ep_r += r 89 | buffer_a.append(a) 90 | buffer_s.append(s) 91 | buffer_r.append(r) 92 | 93 | if total_step % UPDATE_GLOBAL_ITER == 0 or done: # update global and assign to local net 94 | # sync 95 | push_and_pull(self.opt, self.lnet, self.gnet, done, s_, buffer_s, buffer_a, buffer_r, GAMMA) 96 | buffer_s, buffer_a, buffer_r = [], [], [] 97 | 98 | if done: # done and print information 99 | record(self.g_ep, self.g_ep_r, ep_r, self.res_queue, self.name) 100 | break 101 | s = s_ 102 | total_step += 1 103 | self.res_queue.put(None) 104 | 105 | 106 | if __name__ == "__main__": 107 | gnet = Net(N_S, N_A) # global network 108 | gnet.share_memory() # share the global parameters in multiprocessing 109 | opt = SharedAdam(gnet.parameters(), lr=1e-4, betas=(0.92, 0.999)) # global optimizer 110 | global_ep, global_ep_r, res_queue = mp.Value('i', 0), mp.Value('d', 0.), mp.Queue() 111 | 112 | # parallel training 113 | workers = [Worker(gnet, opt, global_ep, global_ep_r, res_queue, i) for i in range(mp.cpu_count())] 114 | [w.start() for w in workers] 115 | res = [] # record episode reward to plot 116 | while True: 117 | r = res_queue.get() 118 | if r is not None: 119 | res.append(r) 120 | else: 121 | break 122 | [w.join() for w in workers] 123 | 124 | import matplotlib.pyplot as plt 125 | plt.plot(res) 126 | plt.ylabel('Moving average ep reward') 127 | plt.xlabel('Step') 128 | plt.show() 129 | -------------------------------------------------------------------------------- /results/cartpole.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MorvanZhou/pytorch-A3C/5ab27abee2c3ac3ca921ac393bfcbda4e0a91745/results/cartpole.png -------------------------------------------------------------------------------- /results/pendulum.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MorvanZhou/pytorch-A3C/5ab27abee2c3ac3ca921ac393bfcbda4e0a91745/results/pendulum.png -------------------------------------------------------------------------------- /shared_adam.py: -------------------------------------------------------------------------------- 1 | """ 2 | Shared optimizer, the parameters in the optimizer will shared in the multiprocessors. 3 | """ 4 | 5 | import torch 6 | 7 | 8 | class SharedAdam(torch.optim.Adam): 9 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.99), eps=1e-8, 10 | weight_decay=0): 11 | super(SharedAdam, self).__init__(params, lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) 12 | # State initialization 13 | for group in self.param_groups: 14 | for p in group['params']: 15 | state = self.state[p] 16 | state['step'] = 0 17 | state['exp_avg'] = torch.zeros_like(p.data) 18 | state['exp_avg_sq'] = torch.zeros_like(p.data) 19 | 20 | # share in memory 21 | state['exp_avg'].share_memory_() 22 | state['exp_avg_sq'].share_memory_() 23 | 24 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Functions that use multiple times 3 | """ 4 | 5 | from torch import nn 6 | import torch 7 | import numpy as np 8 | 9 | 10 | def v_wrap(np_array, dtype=np.float32): 11 | if np_array.dtype != dtype: 12 | np_array = np_array.astype(dtype) 13 | return torch.from_numpy(np_array) 14 | 15 | 16 | def set_init(layers): 17 | for layer in layers: 18 | nn.init.normal_(layer.weight, mean=0., std=0.1) 19 | nn.init.constant_(layer.bias, 0.) 20 | 21 | 22 | def push_and_pull(opt, lnet, gnet, done, s_, bs, ba, br, gamma): 23 | if done: 24 | v_s_ = 0. # terminal 25 | else: 26 | v_s_ = lnet.forward(v_wrap(s_[None, :]))[-1].data.numpy()[0, 0] 27 | 28 | buffer_v_target = [] 29 | for r in br[::-1]: # reverse buffer r 30 | v_s_ = r + gamma * v_s_ 31 | buffer_v_target.append(v_s_) 32 | buffer_v_target.reverse() 33 | 34 | loss = lnet.loss_func( 35 | v_wrap(np.vstack(bs)), 36 | v_wrap(np.array(ba), dtype=np.int64) if ba[0].dtype == np.int64 else v_wrap(np.vstack(ba)), 37 | v_wrap(np.array(buffer_v_target)[:, None])) 38 | 39 | # calculate local gradients and push local parameters to global 40 | opt.zero_grad() 41 | loss.backward() 42 | for lp, gp in zip(lnet.parameters(), gnet.parameters()): 43 | gp._grad = lp.grad 44 | opt.step() 45 | 46 | # pull global parameters 47 | lnet.load_state_dict(gnet.state_dict()) 48 | 49 | 50 | def record(global_ep, global_ep_r, ep_r, res_queue, name): 51 | with global_ep.get_lock(): 52 | global_ep.value += 1 53 | with global_ep_r.get_lock(): 54 | if global_ep_r.value == 0.: 55 | global_ep_r.value = ep_r 56 | else: 57 | global_ep_r.value = global_ep_r.value * 0.99 + ep_r * 0.01 58 | res_queue.put(global_ep_r.value) 59 | print( 60 | name, 61 | "Ep:", global_ep.value, 62 | "| Ep_r: %.0f" % global_ep_r.value, 63 | ) --------------------------------------------------------------------------------