├── .gitignore ├── LICENSE.md ├── README.md ├── envs.py ├── images └── PongReward.png ├── main.py ├── model.py ├── my_optim.py ├── test.py └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | *.pyc 3 | -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2017 Ilya Kostrikov 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # pytorch-a3c 2 | 3 | This is a PyTorch implementation of Asynchronous Advantage Actor Critic (A3C) from ["Asynchronous Methods for Deep Reinforcement Learning"](https://arxiv.org/pdf/1602.01783v1.pdf). 4 | 5 | This implementation is inspired by [Universe Starter Agent](https://github.com/openai/universe-starter-agent). 6 | In contrast to the starter agent, it uses an optimizer with shared statistics as in the original paper. 7 | 8 | Please use this bibtex if you want to cite this repository in your publications: 9 | 10 | @misc{pytorchaaac, 11 | author = {Kostrikov, Ilya}, 12 | title = {PyTorch Implementations of Asynchronous Advantage Actor Critic}, 13 | year = {2018}, 14 | publisher = {GitHub}, 15 | journal = {GitHub repository}, 16 | howpublished = {\url{https://github.com/ikostrikov/pytorch-a3c}}, 17 | } 18 | 19 | ## A2C 20 | 21 | I **highly recommend** to check a sychronous version and other algorithms: [pytorch-a2c-ppo-acktr](https://github.com/ikostrikov/pytorch-a2c-ppo-acktr). 22 | 23 | In my experience, A2C works better than A3C and ACKTR is better than both of them. Moreover, PPO is a great algorithm for continuous control. Thus, I recommend to try A2C/PPO/ACKTR first and use A3C only if you need it specifically for some reasons. 24 | 25 | Also read [OpenAI blog](https://blog.openai.com/baselines-acktr-a2c/) for more information. 26 | 27 | ## Contributions 28 | 29 | Contributions are very welcome. If you know how to make this code better, don't hesitate to send a pull request. 30 | 31 | ## Usage 32 | ```bash 33 | # Works only wih Python 3. 34 | python3 main.py --env-name "PongDeterministic-v4" --num-processes 16 35 | ``` 36 | 37 | This code runs evaluation in a separate thread in addition to 16 processes. 38 | 39 | ## Results 40 | 41 | With 16 processes it converges for PongDeterministic-v4 in 15 minutes. 42 | ![PongDeterministic-v4](images/PongReward.png) 43 | 44 | For BreakoutDeterministic-v4 it takes more than several hours. 45 | -------------------------------------------------------------------------------- /envs.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import gym 3 | import numpy as np 4 | from gym.spaces.box import Box 5 | 6 | 7 | # Taken from https://github.com/openai/universe-starter-agent 8 | def create_atari_env(env_id): 9 | env = gym.make(env_id) 10 | env = AtariRescale42x42(env) 11 | env = NormalizedEnv(env) 12 | return env 13 | 14 | 15 | def _process_frame42(frame): 16 | frame = frame[34:34 + 160, :160] 17 | # Resize by half, then down to 42x42 (essentially mipmapping). If 18 | # we resize directly we lose pixels that, when mapped to 42x42, 19 | # aren't close enough to the pixel boundary. 20 | frame = cv2.resize(frame, (80, 80)) 21 | frame = cv2.resize(frame, (42, 42)) 22 | frame = frame.mean(2, keepdims=True) 23 | frame = frame.astype(np.float32) 24 | frame *= (1.0 / 255.0) 25 | frame = np.moveaxis(frame, -1, 0) 26 | return frame 27 | 28 | 29 | class AtariRescale42x42(gym.ObservationWrapper): 30 | def __init__(self, env=None): 31 | super(AtariRescale42x42, self).__init__(env) 32 | self.observation_space = Box(0.0, 1.0, [1, 42, 42]) 33 | 34 | def _observation(self, observation): 35 | return _process_frame42(observation) 36 | 37 | 38 | class NormalizedEnv(gym.ObservationWrapper): 39 | def __init__(self, env=None): 40 | super(NormalizedEnv, self).__init__(env) 41 | self.state_mean = 0 42 | self.state_std = 0 43 | self.alpha = 0.9999 44 | self.num_steps = 0 45 | 46 | def _observation(self, observation): 47 | self.num_steps += 1 48 | self.state_mean = self.state_mean * self.alpha + \ 49 | observation.mean() * (1 - self.alpha) 50 | self.state_std = self.state_std * self.alpha + \ 51 | observation.std() * (1 - self.alpha) 52 | 53 | unbiased_mean = self.state_mean / (1 - pow(self.alpha, self.num_steps)) 54 | unbiased_std = self.state_std / (1 - pow(self.alpha, self.num_steps)) 55 | 56 | return (observation - unbiased_mean) / (unbiased_std + 1e-8) 57 | -------------------------------------------------------------------------------- /images/PongReward.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ikostrikov/pytorch-a3c/48d95844755e2c3e2c7e48bbd1a7141f7212b63f/images/PongReward.png -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import argparse 4 | import os 5 | 6 | import torch 7 | import torch.multiprocessing as mp 8 | 9 | import my_optim 10 | from envs import create_atari_env 11 | from model import ActorCritic 12 | from test import test 13 | from train import train 14 | 15 | # Based on 16 | # https://github.com/pytorch/examples/tree/master/mnist_hogwild 17 | # Training settings 18 | parser = argparse.ArgumentParser(description='A3C') 19 | parser.add_argument('--lr', type=float, default=0.0001, 20 | help='learning rate (default: 0.0001)') 21 | parser.add_argument('--gamma', type=float, default=0.99, 22 | help='discount factor for rewards (default: 0.99)') 23 | parser.add_argument('--gae-lambda', type=float, default=1.00, 24 | help='lambda parameter for GAE (default: 1.00)') 25 | parser.add_argument('--entropy-coef', type=float, default=0.01, 26 | help='entropy term coefficient (default: 0.01)') 27 | parser.add_argument('--value-loss-coef', type=float, default=0.5, 28 | help='value loss coefficient (default: 0.5)') 29 | parser.add_argument('--max-grad-norm', type=float, default=50, 30 | help='value loss coefficient (default: 50)') 31 | parser.add_argument('--seed', type=int, default=1, 32 | help='random seed (default: 1)') 33 | parser.add_argument('--num-processes', type=int, default=4, 34 | help='how many training processes to use (default: 4)') 35 | parser.add_argument('--num-steps', type=int, default=20, 36 | help='number of forward steps in A3C (default: 20)') 37 | parser.add_argument('--max-episode-length', type=int, default=1000000, 38 | help='maximum length of an episode (default: 1000000)') 39 | parser.add_argument('--env-name', default='PongDeterministic-v4', 40 | help='environment to train on (default: PongDeterministic-v4)') 41 | parser.add_argument('--no-shared', default=False, 42 | help='use an optimizer without shared momentum.') 43 | 44 | 45 | if __name__ == '__main__': 46 | os.environ['OMP_NUM_THREADS'] = '1' 47 | os.environ['CUDA_VISIBLE_DEVICES'] = "" 48 | 49 | args = parser.parse_args() 50 | 51 | torch.manual_seed(args.seed) 52 | env = create_atari_env(args.env_name) 53 | shared_model = ActorCritic( 54 | env.observation_space.shape[0], env.action_space) 55 | shared_model.share_memory() 56 | 57 | if args.no_shared: 58 | optimizer = None 59 | else: 60 | optimizer = my_optim.SharedAdam(shared_model.parameters(), lr=args.lr) 61 | optimizer.share_memory() 62 | 63 | processes = [] 64 | 65 | counter = mp.Value('i', 0) 66 | lock = mp.Lock() 67 | 68 | p = mp.Process(target=test, args=(args.num_processes, args, shared_model, counter)) 69 | p.start() 70 | processes.append(p) 71 | 72 | for rank in range(0, args.num_processes): 73 | p = mp.Process(target=train, args=(rank, args, shared_model, counter, lock, optimizer)) 74 | p.start() 75 | processes.append(p) 76 | for p in processes: 77 | p.join() 78 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | 7 | def normalized_columns_initializer(weights, std=1.0): 8 | out = torch.randn(weights.size()) 9 | out *= std / torch.sqrt(out.pow(2).sum(1, keepdim=True)) 10 | return out 11 | 12 | 13 | def weights_init(m): 14 | classname = m.__class__.__name__ 15 | if classname.find('Conv') != -1: 16 | weight_shape = list(m.weight.data.size()) 17 | fan_in = np.prod(weight_shape[1:4]) 18 | fan_out = np.prod(weight_shape[2:4]) * weight_shape[0] 19 | w_bound = np.sqrt(6. / (fan_in + fan_out)) 20 | m.weight.data.uniform_(-w_bound, w_bound) 21 | m.bias.data.fill_(0) 22 | elif classname.find('Linear') != -1: 23 | weight_shape = list(m.weight.data.size()) 24 | fan_in = weight_shape[1] 25 | fan_out = weight_shape[0] 26 | w_bound = np.sqrt(6. / (fan_in + fan_out)) 27 | m.weight.data.uniform_(-w_bound, w_bound) 28 | m.bias.data.fill_(0) 29 | 30 | 31 | class ActorCritic(torch.nn.Module): 32 | def __init__(self, num_inputs, action_space): 33 | super(ActorCritic, self).__init__() 34 | self.conv1 = nn.Conv2d(num_inputs, 32, 3, stride=2, padding=1) 35 | self.conv2 = nn.Conv2d(32, 32, 3, stride=2, padding=1) 36 | self.conv3 = nn.Conv2d(32, 32, 3, stride=2, padding=1) 37 | self.conv4 = nn.Conv2d(32, 32, 3, stride=2, padding=1) 38 | 39 | self.lstm = nn.LSTMCell(32 * 3 * 3, 256) 40 | 41 | num_outputs = action_space.n 42 | self.critic_linear = nn.Linear(256, 1) 43 | self.actor_linear = nn.Linear(256, num_outputs) 44 | 45 | self.apply(weights_init) 46 | self.actor_linear.weight.data = normalized_columns_initializer( 47 | self.actor_linear.weight.data, 0.01) 48 | self.actor_linear.bias.data.fill_(0) 49 | self.critic_linear.weight.data = normalized_columns_initializer( 50 | self.critic_linear.weight.data, 1.0) 51 | self.critic_linear.bias.data.fill_(0) 52 | 53 | self.lstm.bias_ih.data.fill_(0) 54 | self.lstm.bias_hh.data.fill_(0) 55 | 56 | self.train() 57 | 58 | def forward(self, inputs): 59 | inputs, (hx, cx) = inputs 60 | x = F.elu(self.conv1(inputs)) 61 | x = F.elu(self.conv2(x)) 62 | x = F.elu(self.conv3(x)) 63 | x = F.elu(self.conv4(x)) 64 | 65 | x = x.view(-1, 32 * 3 * 3) 66 | hx, cx = self.lstm(x, (hx, cx)) 67 | x = hx 68 | 69 | return self.critic_linear(x), self.actor_linear(x), (hx, cx) 70 | -------------------------------------------------------------------------------- /my_optim.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.optim as optim 5 | 6 | 7 | class SharedAdam(optim.Adam): 8 | """Implements Adam algorithm with shared states. 9 | """ 10 | 11 | def __init__(self, 12 | params, 13 | lr=1e-3, 14 | betas=(0.9, 0.999), 15 | eps=1e-8, 16 | weight_decay=0): 17 | super(SharedAdam, self).__init__(params, lr, betas, eps, weight_decay) 18 | 19 | for group in self.param_groups: 20 | for p in group['params']: 21 | state = self.state[p] 22 | state['step'] = torch.zeros(1) 23 | state['exp_avg'] = p.data.new().resize_as_(p.data).zero_() 24 | state['exp_avg_sq'] = p.data.new().resize_as_(p.data).zero_() 25 | 26 | def share_memory(self): 27 | for group in self.param_groups: 28 | for p in group['params']: 29 | state = self.state[p] 30 | state['step'].share_memory_() 31 | state['exp_avg'].share_memory_() 32 | state['exp_avg_sq'].share_memory_() 33 | 34 | def step(self, closure=None): 35 | """Performs a single optimization step. 36 | Arguments: 37 | closure (callable, optional): A closure that reevaluates the model 38 | and returns the loss. 39 | """ 40 | loss = None 41 | if closure is not None: 42 | loss = closure() 43 | 44 | for group in self.param_groups: 45 | for p in group['params']: 46 | if p.grad is None: 47 | continue 48 | grad = p.grad.data 49 | state = self.state[p] 50 | 51 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 52 | beta1, beta2 = group['betas'] 53 | 54 | state['step'] += 1 55 | 56 | if group['weight_decay'] != 0: 57 | grad = grad.add(group['weight_decay'], p.data) 58 | 59 | # Decay the first and second moment running average coefficient 60 | exp_avg.mul_(beta1).add_(1 - beta1, grad) 61 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) 62 | 63 | denom = exp_avg_sq.sqrt().add_(group['eps']) 64 | 65 | bias_correction1 = 1 - beta1 ** state['step'].item() 66 | bias_correction2 = 1 - beta2 ** state['step'].item() 67 | step_size = group['lr'] * math.sqrt( 68 | bias_correction2) / bias_correction1 69 | 70 | p.data.addcdiv_(-step_size, exp_avg, denom) 71 | 72 | return loss 73 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import time 2 | from collections import deque 3 | 4 | import torch 5 | import torch.nn.functional as F 6 | 7 | from envs import create_atari_env 8 | from model import ActorCritic 9 | 10 | 11 | def test(rank, args, shared_model, counter): 12 | torch.manual_seed(args.seed + rank) 13 | 14 | env = create_atari_env(args.env_name) 15 | env.seed(args.seed + rank) 16 | 17 | model = ActorCritic(env.observation_space.shape[0], env.action_space) 18 | 19 | model.eval() 20 | 21 | state = env.reset() 22 | state = torch.from_numpy(state) 23 | reward_sum = 0 24 | done = True 25 | 26 | start_time = time.time() 27 | 28 | # a quick hack to prevent the agent from stucking 29 | actions = deque(maxlen=100) 30 | episode_length = 0 31 | while True: 32 | episode_length += 1 33 | # Sync with the shared model 34 | if done: 35 | model.load_state_dict(shared_model.state_dict()) 36 | cx = torch.zeros(1, 256) 37 | hx = torch.zeros(1, 256) 38 | else: 39 | cx = cx.detach() 40 | hx = hx.detach() 41 | 42 | with torch.no_grad(): 43 | value, logit, (hx, cx) = model((state.unsqueeze(0), (hx, cx))) 44 | prob = F.softmax(logit, dim=-1) 45 | action = prob.max(1, keepdim=True)[1].numpy() 46 | 47 | state, reward, done, _ = env.step(action[0, 0]) 48 | done = done or episode_length >= args.max_episode_length 49 | reward_sum += reward 50 | 51 | # a quick hack to prevent the agent from stucking 52 | actions.append(action[0, 0]) 53 | if actions.count(actions[0]) == actions.maxlen: 54 | done = True 55 | 56 | if done: 57 | print("Time {}, num steps {}, FPS {:.0f}, episode reward {}, episode length {}".format( 58 | time.strftime("%Hh %Mm %Ss", 59 | time.gmtime(time.time() - start_time)), 60 | counter.value, counter.value / (time.time() - start_time), 61 | reward_sum, episode_length)) 62 | reward_sum = 0 63 | episode_length = 0 64 | actions.clear() 65 | state = env.reset() 66 | time.sleep(60) 67 | 68 | state = torch.from_numpy(state) 69 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import torch.optim as optim 4 | 5 | from envs import create_atari_env 6 | from model import ActorCritic 7 | 8 | 9 | def ensure_shared_grads(model, shared_model): 10 | for param, shared_param in zip(model.parameters(), 11 | shared_model.parameters()): 12 | if shared_param.grad is not None: 13 | return 14 | shared_param._grad = param.grad 15 | 16 | 17 | def train(rank, args, shared_model, counter, lock, optimizer=None): 18 | torch.manual_seed(args.seed + rank) 19 | 20 | env = create_atari_env(args.env_name) 21 | env.seed(args.seed + rank) 22 | 23 | model = ActorCritic(env.observation_space.shape[0], env.action_space) 24 | 25 | if optimizer is None: 26 | optimizer = optim.Adam(shared_model.parameters(), lr=args.lr) 27 | 28 | model.train() 29 | 30 | state = env.reset() 31 | state = torch.from_numpy(state) 32 | done = True 33 | 34 | episode_length = 0 35 | while True: 36 | # Sync with the shared model 37 | model.load_state_dict(shared_model.state_dict()) 38 | if done: 39 | cx = torch.zeros(1, 256) 40 | hx = torch.zeros(1, 256) 41 | else: 42 | cx = cx.detach() 43 | hx = hx.detach() 44 | 45 | values = [] 46 | log_probs = [] 47 | rewards = [] 48 | entropies = [] 49 | 50 | for step in range(args.num_steps): 51 | episode_length += 1 52 | value, logit, (hx, cx) = model((state.unsqueeze(0), 53 | (hx, cx))) 54 | prob = F.softmax(logit, dim=-1) 55 | log_prob = F.log_softmax(logit, dim=-1) 56 | entropy = -(log_prob * prob).sum(1, keepdim=True) 57 | entropies.append(entropy) 58 | 59 | action = prob.multinomial(num_samples=1).detach() 60 | log_prob = log_prob.gather(1, action) 61 | 62 | state, reward, done, _ = env.step(action.numpy()) 63 | done = done or episode_length >= args.max_episode_length 64 | reward = max(min(reward, 1), -1) 65 | 66 | with lock: 67 | counter.value += 1 68 | 69 | if done: 70 | episode_length = 0 71 | state = env.reset() 72 | 73 | state = torch.from_numpy(state) 74 | values.append(value) 75 | log_probs.append(log_prob) 76 | rewards.append(reward) 77 | 78 | if done: 79 | break 80 | 81 | R = torch.zeros(1, 1) 82 | if not done: 83 | value, _, _ = model((state.unsqueeze(0), (hx, cx))) 84 | R = value.detach() 85 | 86 | values.append(R) 87 | policy_loss = 0 88 | value_loss = 0 89 | gae = torch.zeros(1, 1) 90 | for i in reversed(range(len(rewards))): 91 | R = args.gamma * R + rewards[i] 92 | advantage = R - values[i] 93 | value_loss = value_loss + 0.5 * advantage.pow(2) 94 | 95 | # Generalized Advantage Estimation 96 | delta_t = rewards[i] + args.gamma * \ 97 | values[i + 1] - values[i] 98 | gae = gae * args.gamma * args.gae_lambda + delta_t 99 | 100 | policy_loss = policy_loss - \ 101 | log_probs[i] * gae.detach() - args.entropy_coef * entropies[i] 102 | 103 | optimizer.zero_grad() 104 | 105 | (policy_loss + args.value_loss_coef * value_loss).backward() 106 | torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm) 107 | 108 | ensure_shared_grads(model, shared_model) 109 | optimizer.step() 110 | --------------------------------------------------------------------------------