├── LICENSE ├── README.md ├── a3c.py ├── a3c_ale.py ├── ale.py ├── async.py ├── copy_param.py ├── demo_a3c_ale.py ├── demo_a3c_doom.py ├── doom_env.py ├── dqn_head.py ├── dqn_phi.py ├── environment.py ├── init_like_torch.py ├── nonbias_weight_decay.py ├── plot_scores.py ├── policy.py ├── policy_output.py ├── prepare_output_dir.py ├── random_seed.py ├── rmsprop_async.py ├── run_a3c.py ├── train_a3c_doom.py ├── trained_model ├── breakout_ff │ ├── 80000000_finish.h5 │ ├── animation.gif │ ├── scores.txt │ └── scores.txt.png └── space_invaders_lstm │ ├── 80000000_finish.h5 │ ├── animation.gif │ ├── scores.txt │ └── scores.txt.png └── v_function.py /LICENSE: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) 2 | 3 | Copyright (c) 2016 Yasuhiro Fujita 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Async-RL 2 | 3 | *(2017/02/25) Now the A3C implementation in this repository has been ported into [ChainerRL](https://github.com/pfnet/chainerrl), a Chainer-based deep reinforcement learning library, with some enhancement such as support for continuous actions by Gaussian policies and n-step Q-learning, so I recommend using it instead of this repository.* 4 | 5 | ![A3C FF playing Breakout](https://raw.githubusercontent.com/muupan/async-rl/master/trained_model/breakout_ff/animation.gif) 6 | ![A3C LSTM playing Space Invaders](https://raw.githubusercontent.com/muupan/async-rl/master/trained_model/space_invaders_lstm/animation.gif) 7 | 8 | This is a repository where I attempt to reproduce the results of [Asynchronous Methods for Deep Reinforcement Learning](http://arxiv.org/abs/1602.01783). Currently I have only replicated A3C FF/LSTM for Atari. 9 | 10 | Any feedback is welcome :) 11 | 12 | ## Supported Features 13 | 14 | - A3C FF/LSTM (only for discrete action space) 15 | - Atari environment 16 | - ViZDoom environment (experimental) 17 | 18 | ## Current Status 19 | 20 | ### A3C FF 21 | 22 | I trained A3C FF for ALE's Breakout with 36 processes (AWS EC2 c4.8xlarge) for 80 million training steps, which took about 17 hours. The mean and median of scores of test runs along training are plotted below. Ten test runs for every 1 million training steps (counted by the global shared counter). The results seems slightly worse than theirs. 23 | 24 | 25 | 26 | The trained model is uploaded at `trained_model/breakout_ff/80000000_finish.h5`, so you can make it to play Breakout by the following command: 27 | 28 | ``` 29 | python demo_a3c_ale.py trained_model/breakout_ff/80000000_finish.h5 30 | ``` 31 | 32 | The animation gif above is the episode I cherry-picked from 10 demo runs using that model. 33 | 34 | ### A3C LSTM 35 | 36 | I also trained A3C LSTM for ALE's Space Invaders in the same manner with A3C FF. Training A3C LSTM took about 24 hours for 80 million training steps. 37 | 38 | 39 | 40 | The trained model is uploaded at `trained_model/space_invaders_lstm/80000000_finish.h5`, so you can make it to play Space Invaders by the following command: 41 | 42 | ``` 43 | python demo_a3c_ale.py trained_model/space_invaders_lstm/80000000_finish.h5 --use-lstm 44 | ``` 45 | 46 | The animation gif above is the episode I cherry-picked from 10 demo runs using that model. 47 | 48 | ### Implementation details 49 | 50 | I received a confirmation about their implementation details and some hyperparameters by e-mail from Dr. Mnih. I summarized them in the wiki: https://github.com/muupan/async-rl/wiki 51 | 52 | ## Requirements 53 | 54 | - Python 3.5.1 55 | - chainer 1.8.1 56 | - cached-property 1.3.0 57 | - h5py 2.5.0 58 | - Arcade-Learning-Environment 59 | 60 | ## Training 61 | 62 | ``` 63 | python a3c_ale.py [--use-lstm] 64 | ``` 65 | 66 | `a3c_ale.py` will save best-so-far models and test scores into the output directory. 67 | 68 | Unfortunately it seems this script has some bug now. Please see the issues [#5](https://github.com/muupan/async-rl/issues/5) and [#6](https://github.com/muupan/async-rl/issues/6). I'm trying to fix it. 69 | 70 | ## Evaluation 71 | 72 | ``` 73 | python demo_a3c_ale.py [--use-lstm] 74 | ``` 75 | 76 | ## Similar Projects 77 | 78 | - https://github.com/miyosuda/async_deep_reinforce 79 | -------------------------------------------------------------------------------- /a3c.py: -------------------------------------------------------------------------------- 1 | import copy 2 | from logging import getLogger 3 | import os 4 | 5 | import numpy as np 6 | import chainer 7 | from chainer import serializers 8 | from chainer import functions as F 9 | 10 | import copy_param 11 | 12 | logger = getLogger(__name__) 13 | 14 | 15 | class A3CModel(chainer.Link): 16 | 17 | def pi_and_v(self, state, keep_same_state=False): 18 | raise NotImplementedError() 19 | 20 | def reset_state(self): 21 | pass 22 | 23 | def unchain_backward(self): 24 | pass 25 | 26 | 27 | class A3C(object): 28 | """A3C: Asynchronous Advantage Actor-Critic. 29 | 30 | See http://arxiv.org/abs/1602.01783 31 | """ 32 | 33 | def __init__(self, model, optimizer, t_max, gamma, beta=1e-2, 34 | process_idx=0, clip_reward=True, phi=lambda x: x, 35 | pi_loss_coef=1.0, v_loss_coef=0.5, 36 | keep_loss_scale_same=False): 37 | 38 | # Globally shared model 39 | self.shared_model = model 40 | 41 | # Thread specific model 42 | self.model = copy.deepcopy(self.shared_model) 43 | 44 | self.optimizer = optimizer 45 | self.t_max = t_max 46 | self.gamma = gamma 47 | self.beta = beta 48 | self.process_idx = process_idx 49 | self.clip_reward = clip_reward 50 | self.phi = phi 51 | self.pi_loss_coef = pi_loss_coef 52 | self.v_loss_coef = v_loss_coef 53 | self.keep_loss_scale_same = keep_loss_scale_same 54 | 55 | self.t = 0 56 | self.t_start = 0 57 | self.past_action_log_prob = {} 58 | self.past_action_entropy = {} 59 | self.past_states = {} 60 | self.past_rewards = {} 61 | self.past_values = {} 62 | 63 | def sync_parameters(self): 64 | copy_param.copy_param(target_link=self.model, 65 | source_link=self.shared_model) 66 | 67 | def act(self, state, reward, is_state_terminal): 68 | 69 | if self.clip_reward: 70 | reward = np.clip(reward, -1, 1) 71 | 72 | if not is_state_terminal: 73 | statevar = chainer.Variable(np.expand_dims(self.phi(state), 0)) 74 | 75 | self.past_rewards[self.t - 1] = reward 76 | 77 | if (is_state_terminal and self.t_start < self.t) \ 78 | or self.t - self.t_start == self.t_max: 79 | 80 | assert self.t_start < self.t 81 | 82 | if is_state_terminal: 83 | R = 0 84 | else: 85 | _, vout = self.model.pi_and_v(statevar, keep_same_state=True) 86 | R = float(vout.data) 87 | 88 | pi_loss = 0 89 | v_loss = 0 90 | for i in reversed(range(self.t_start, self.t)): 91 | R *= self.gamma 92 | R += self.past_rewards[i] 93 | v = self.past_values[i] 94 | if self.process_idx == 0: 95 | logger.debug('s:%s v:%s R:%s', 96 | self.past_states[i].data.sum(), v.data, R) 97 | advantage = R - v 98 | # Accumulate gradients of policy 99 | log_prob = self.past_action_log_prob[i] 100 | entropy = self.past_action_entropy[i] 101 | 102 | # Log probability is increased proportionally to advantage 103 | pi_loss -= log_prob * float(advantage.data) 104 | # Entropy is maximized 105 | pi_loss -= self.beta * entropy 106 | # Accumulate gradients of value function 107 | 108 | v_loss += (v - R) ** 2 / 2 109 | 110 | if self.pi_loss_coef != 1.0: 111 | pi_loss *= self.pi_loss_coef 112 | 113 | if self.v_loss_coef != 1.0: 114 | v_loss *= self.v_loss_coef 115 | 116 | # Normalize the loss of sequences truncated by terminal states 117 | if self.keep_loss_scale_same and \ 118 | self.t - self.t_start < self.t_max: 119 | factor = self.t_max / (self.t - self.t_start) 120 | pi_loss *= factor 121 | v_loss *= factor 122 | 123 | if self.process_idx == 0: 124 | logger.debug('pi_loss:%s v_loss:%s', pi_loss.data, v_loss.data) 125 | 126 | total_loss = pi_loss + F.reshape(v_loss, pi_loss.data.shape) 127 | 128 | # Compute gradients using thread-specific model 129 | self.model.zerograds() 130 | total_loss.backward() 131 | # Copy the gradients to the globally shared model 132 | self.shared_model.zerograds() 133 | copy_param.copy_grad( 134 | target_link=self.shared_model, source_link=self.model) 135 | # Update the globally shared model 136 | if self.process_idx == 0: 137 | norm = self.optimizer.compute_grads_norm() 138 | logger.debug('grad norm:%s', norm) 139 | self.optimizer.update() 140 | if self.process_idx == 0: 141 | logger.debug('update') 142 | 143 | self.sync_parameters() 144 | self.model.unchain_backward() 145 | 146 | self.past_action_log_prob = {} 147 | self.past_action_entropy = {} 148 | self.past_states = {} 149 | self.past_rewards = {} 150 | self.past_values = {} 151 | 152 | self.t_start = self.t 153 | 154 | if not is_state_terminal: 155 | self.past_states[self.t] = statevar 156 | pout, vout = self.model.pi_and_v(statevar) 157 | self.past_action_log_prob[self.t] = pout.sampled_actions_log_probs 158 | self.past_action_entropy[self.t] = pout.entropy 159 | self.past_values[self.t] = vout 160 | self.t += 1 161 | if self.process_idx == 0: 162 | logger.debug('t:%s entropy:%s, probs:%s', 163 | self.t, pout.entropy.data, pout.probs.data) 164 | return pout.action_indices[0] 165 | else: 166 | self.model.reset_state() 167 | return None 168 | 169 | def load_model(self, model_filename): 170 | """Load a network model form a file 171 | """ 172 | serializers.load_hdf5(model_filename, self.model) 173 | copy_param.copy_param(target_link=self.model, 174 | source_link=self.shared_model) 175 | opt_filename = model_filename + '.opt' 176 | if os.path.exists(opt_filename): 177 | print('WARNING: {0} was not found, so loaded only a model'.format( 178 | opt_filename)) 179 | serializers.load_hdf5(model_filename + '.opt', self.optimizer) 180 | 181 | def save_model(self, model_filename): 182 | """Save a network model to a file 183 | """ 184 | serializers.save_hdf5(model_filename, self.model) 185 | serializers.save_hdf5(model_filename + '.opt', self.optimizer) 186 | -------------------------------------------------------------------------------- /a3c_ale.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import copy 3 | import multiprocessing as mp 4 | import os 5 | import sys 6 | import statistics 7 | import time 8 | 9 | import chainer 10 | from chainer import links as L 11 | from chainer import functions as F 12 | import numpy as np 13 | 14 | import policy 15 | import v_function 16 | import dqn_head 17 | import a3c 18 | import ale 19 | import random_seed 20 | import async 21 | import rmsprop_async 22 | from prepare_output_dir import prepare_output_dir 23 | from nonbias_weight_decay import NonbiasWeightDecay 24 | from init_like_torch import init_like_torch 25 | from dqn_phi import dqn_phi 26 | 27 | 28 | class A3CFF(chainer.ChainList, a3c.A3CModel): 29 | 30 | def __init__(self, n_actions): 31 | self.head = dqn_head.NIPSDQNHead() 32 | self.pi = policy.FCSoftmaxPolicy( 33 | self.head.n_output_channels, n_actions) 34 | self.v = v_function.FCVFunction(self.head.n_output_channels) 35 | super().__init__(self.head, self.pi, self.v) 36 | init_like_torch(self) 37 | 38 | def pi_and_v(self, state, keep_same_state=False): 39 | out = self.head(state) 40 | return self.pi(out), self.v(out) 41 | 42 | 43 | class A3CLSTM(chainer.ChainList, a3c.A3CModel): 44 | 45 | def __init__(self, n_actions): 46 | self.head = dqn_head.NIPSDQNHead() 47 | self.pi = policy.FCSoftmaxPolicy( 48 | self.head.n_output_channels, n_actions) 49 | self.v = v_function.FCVFunction(self.head.n_output_channels) 50 | self.lstm = L.LSTM(self.head.n_output_channels, 51 | self.head.n_output_channels) 52 | super().__init__(self.head, self.lstm, self.pi, self.v) 53 | init_like_torch(self) 54 | 55 | def pi_and_v(self, state, keep_same_state=False): 56 | out = self.head(state) 57 | if keep_same_state: 58 | prev_h, prev_c = self.lstm.h, self.lstm.c 59 | out = self.lstm(out) 60 | self.lstm.h, self.lstm.c = prev_h, prev_c 61 | else: 62 | out = self.lstm(out) 63 | return self.pi(out), self.v(out) 64 | 65 | def reset_state(self): 66 | self.lstm.reset_state() 67 | 68 | def unchain_backward(self): 69 | self.lstm.h.unchain_backward() 70 | self.lstm.c.unchain_backward() 71 | 72 | 73 | def eval_performance(rom, p_func, n_runs): 74 | assert n_runs > 1, 'Computing stdev requires at least two runs' 75 | scores = [] 76 | for i in range(n_runs): 77 | env = ale.ALE(rom, treat_life_lost_as_terminal=False) 78 | test_r = 0 79 | while not env.is_terminal: 80 | s = chainer.Variable(np.expand_dims(dqn_phi(env.state), 0)) 81 | pout = p_func(s) 82 | a = pout.action_indices[0] 83 | test_r += env.receive_action(a) 84 | scores.append(test_r) 85 | print('test_{}:'.format(i), test_r) 86 | mean = statistics.mean(scores) 87 | median = statistics.median(scores) 88 | stdev = statistics.stdev(scores) 89 | return mean, median, stdev 90 | 91 | 92 | def train_loop(process_idx, counter, max_score, args, agent, env, start_time): 93 | try: 94 | 95 | total_r = 0 96 | episode_r = 0 97 | global_t = 0 98 | local_t = 0 99 | 100 | while True: 101 | 102 | # Get and increment the global counter 103 | with counter.get_lock(): 104 | counter.value += 1 105 | global_t = counter.value 106 | local_t += 1 107 | 108 | if global_t > args.steps: 109 | break 110 | 111 | agent.optimizer.lr = ( 112 | args.steps - global_t - 1) / args.steps * args.lr 113 | 114 | total_r += env.reward 115 | episode_r += env.reward 116 | 117 | action = agent.act(env.state, env.reward, env.is_terminal) 118 | 119 | if env.is_terminal: 120 | if process_idx == 0: 121 | print('{} global_t:{} local_t:{} lr:{} episode_r:{}'.format( 122 | args.outdir, global_t, local_t, agent.optimizer.lr, episode_r)) 123 | episode_r = 0 124 | env.initialize() 125 | else: 126 | env.receive_action(action) 127 | 128 | if global_t % args.eval_frequency == 0: 129 | # Evaluation 130 | 131 | # We must use a copy of the model because test runs can change 132 | # the hidden states of the model 133 | test_model = copy.deepcopy(agent.model) 134 | test_model.reset_state() 135 | 136 | def p_func(s): 137 | pout, _ = test_model.pi_and_v(s) 138 | test_model.unchain_backward() 139 | return pout 140 | mean, median, stdev = eval_performance( 141 | args.rom, p_func, args.eval_n_runs) 142 | with open(os.path.join(args.outdir, 'scores.txt'), 'a+') as f: 143 | elapsed = time.time() - start_time 144 | record = (global_t, elapsed, mean, median, stdev) 145 | print('\t'.join(str(x) for x in record), file=f) 146 | with max_score.get_lock(): 147 | if mean > max_score.value: 148 | # Save the best model so far 149 | print('The best score is updated {} -> {}'.format( 150 | max_score.value, mean)) 151 | filename = os.path.join( 152 | args.outdir, '{}.h5'.format(global_t)) 153 | agent.save_model(filename) 154 | print('Saved the current best model to {}'.format( 155 | filename)) 156 | max_score.value = mean 157 | 158 | except KeyboardInterrupt: 159 | if process_idx == 0: 160 | # Save the current model before being killed 161 | agent.save_model(os.path.join( 162 | args.outdir, '{}_keyboardinterrupt.h5'.format(global_t))) 163 | print('Saved the current model to {}'.format( 164 | args.outdir), file=sys.stderr) 165 | raise 166 | 167 | if global_t == args.steps + 1: 168 | # Save the final model 169 | agent.save_model( 170 | os.path.join(args.outdir, '{}_finish.h5'.format(args.steps))) 171 | print('Saved the final model to {}'.format(args.outdir)) 172 | 173 | 174 | def train_loop_with_profile(process_idx, counter, max_score, args, agent, env, 175 | start_time): 176 | import cProfile 177 | cmd = 'train_loop(process_idx, counter, max_score, args, agent, env, ' \ 178 | 'start_time)' 179 | cProfile.runctx(cmd, globals(), locals(), 180 | 'profile-{}.out'.format(os.getpid())) 181 | 182 | 183 | def main(): 184 | 185 | # Prevent numpy from using multiple threads 186 | os.environ['OMP_NUM_THREADS'] = '1' 187 | 188 | import logging 189 | logging.basicConfig(level=logging.DEBUG) 190 | 191 | parser = argparse.ArgumentParser() 192 | parser.add_argument('processes', type=int) 193 | parser.add_argument('rom', type=str) 194 | parser.add_argument('--seed', type=int, default=None) 195 | parser.add_argument('--outdir', type=str, default=None) 196 | parser.add_argument('--use-sdl', action='store_true') 197 | parser.add_argument('--t-max', type=int, default=5) 198 | parser.add_argument('--beta', type=float, default=1e-2) 199 | parser.add_argument('--profile', action='store_true') 200 | parser.add_argument('--steps', type=int, default=8 * 10 ** 7) 201 | parser.add_argument('--lr', type=float, default=7e-4) 202 | parser.add_argument('--eval-frequency', type=int, default=10 ** 6) 203 | parser.add_argument('--eval-n-runs', type=int, default=10) 204 | parser.add_argument('--weight-decay', type=float, default=0.0) 205 | parser.add_argument('--use-lstm', action='store_true') 206 | parser.set_defaults(use_sdl=False) 207 | parser.set_defaults(use_lstm=False) 208 | args = parser.parse_args() 209 | 210 | if args.seed is not None: 211 | random_seed.set_random_seed(args.seed) 212 | 213 | args.outdir = prepare_output_dir(args, args.outdir) 214 | 215 | print('Output files are saved in {}'.format(args.outdir)) 216 | 217 | n_actions = ale.ALE(args.rom).number_of_actions 218 | 219 | def model_opt(): 220 | if args.use_lstm: 221 | model = A3CLSTM(n_actions) 222 | else: 223 | model = A3CFF(n_actions) 224 | opt = rmsprop_async.RMSpropAsync(lr=7e-4, eps=1e-1, alpha=0.99) 225 | opt.setup(model) 226 | opt.add_hook(chainer.optimizer.GradientClipping(40)) 227 | if args.weight_decay > 0: 228 | opt.add_hook(NonbiasWeightDecay(args.weight_decay)) 229 | return model, opt 230 | 231 | model, opt = model_opt() 232 | 233 | shared_params = async.share_params_as_shared_arrays(model) 234 | shared_states = async.share_states_as_shared_arrays(opt) 235 | 236 | max_score = mp.Value('f', np.finfo(np.float32).min) 237 | counter = mp.Value('l', 0) 238 | start_time = time.time() 239 | 240 | # Write a header line first 241 | with open(os.path.join(args.outdir, 'scores.txt'), 'a+') as f: 242 | column_names = ('steps', 'elapsed', 'mean', 'median', 'stdev') 243 | print('\t'.join(column_names), file=f) 244 | 245 | def run_func(process_idx): 246 | env = ale.ALE(args.rom, use_sdl=args.use_sdl) 247 | model, opt = model_opt() 248 | async.set_shared_params(model, shared_params) 249 | async.set_shared_states(opt, shared_states) 250 | 251 | agent = a3c.A3C(model, opt, args.t_max, 0.99, beta=args.beta, 252 | process_idx=process_idx, phi=dqn_phi) 253 | 254 | if args.profile: 255 | train_loop_with_profile(process_idx, counter, max_score, 256 | args, agent, env, start_time) 257 | else: 258 | train_loop(process_idx, counter, max_score, 259 | args, agent, env, start_time) 260 | 261 | async.run_async(args.processes, run_func) 262 | 263 | 264 | if __name__ == '__main__': 265 | main() 266 | -------------------------------------------------------------------------------- /ale.py: -------------------------------------------------------------------------------- 1 | import collections 2 | import os 3 | import sys 4 | 5 | import numpy as np 6 | from ale_python_interface import ALEInterface 7 | import cv2 8 | 9 | import environment 10 | 11 | 12 | class ALE(environment.EpisodicEnvironment): 13 | """Arcade Learning Environment. 14 | """ 15 | 16 | def __init__(self, rom_filename, seed=None, use_sdl=False, n_last_screens=4, 17 | frame_skip=4, treat_life_lost_as_terminal=True, 18 | crop_or_scale='scale', max_start_nullops=30, 19 | record_screen_dir=None): 20 | self.n_last_screens = n_last_screens 21 | self.treat_life_lost_as_terminal = treat_life_lost_as_terminal 22 | self.crop_or_scale = crop_or_scale 23 | self.max_start_nullops = max_start_nullops 24 | 25 | ale = ALEInterface() 26 | if seed is not None: 27 | assert seed >= 0 and seed < 2 ** 16, \ 28 | "ALE's random seed must be represented by unsigned int" 29 | else: 30 | # Use numpy's random state 31 | seed = np.random.randint(0, 2 ** 16) 32 | ale.setInt(b'random_seed', seed) 33 | ale.setFloat(b'repeat_action_probability', 0.0) 34 | ale.setBool(b'color_averaging', False) 35 | if record_screen_dir is not None: 36 | ale.setString(b'record_screen_dir', str.encode(record_screen_dir)) 37 | self.frame_skip = frame_skip 38 | if use_sdl: 39 | if 'DISPLAY' not in os.environ: 40 | raise RuntimeError( 41 | 'Please set DISPLAY environment variable for use_sdl=True') 42 | # SDL settings below are from the ALE python example 43 | if sys.platform == 'darwin': 44 | import pygame 45 | pygame.init() 46 | ale.setBool(b'sound', False) # Sound doesn't work on OSX 47 | elif sys.platform.startswith('linux'): 48 | ale.setBool(b'sound', True) 49 | ale.setBool(b'display_screen', True) 50 | ale.loadROM(str.encode(rom_filename)) 51 | 52 | assert ale.getFrameNumber() == 0 53 | 54 | 55 | self.ale = ale 56 | self.legal_actions = ale.getMinimalActionSet() 57 | self.initialize() 58 | 59 | def current_screen(self): 60 | # Max of two consecutive frames 61 | assert self.last_raw_screen is not None 62 | rgb_img = np.maximum(self.ale.getScreenRGB(), self.last_raw_screen) 63 | # Make sure the last raw screen is used only once 64 | self.last_raw_screen = None 65 | assert rgb_img.shape == (210, 160, 3) 66 | # RGB -> Luminance 67 | img = rgb_img[:, :, 0] * 0.2126 + rgb_img[:, :, 1] * \ 68 | 0.0722 + rgb_img[:, :, 2] * 0.7152 69 | img = img.astype(np.uint8) 70 | if img.shape == (250, 160): 71 | raise RuntimeError("This ROM is for PAL. Please use ROMs for NTSC") 72 | assert img.shape == (210, 160) 73 | if self.crop_or_scale == 'crop': 74 | # Shrink (210, 160) -> (110, 84) 75 | img = cv2.resize(img, (84, 110), 76 | interpolation=cv2.INTER_LINEAR) 77 | assert img.shape == (110, 84) 78 | # Crop (110, 84) -> (84, 84) 79 | unused_height = 110 - 84 80 | bottom_crop = 8 81 | top_crop = unused_height - bottom_crop 82 | img = img[top_crop: 110 - bottom_crop, :] 83 | elif self.crop_or_scale == 'scale': 84 | img = cv2.resize(img, (84, 84), 85 | interpolation=cv2.INTER_LINEAR) 86 | else: 87 | raise RuntimeError('crop_or_scale must be either crop or scale') 88 | assert img.shape == (84, 84) 89 | return img 90 | 91 | @property 92 | def state(self): 93 | assert len(self.last_screens) == 4 94 | return list(self.last_screens) 95 | 96 | @property 97 | def is_terminal(self): 98 | if self.treat_life_lost_as_terminal: 99 | return self.lives_lost or self.ale.game_over() 100 | else: 101 | return self.ale.game_over() 102 | 103 | @property 104 | def reward(self): 105 | return self._reward 106 | 107 | @property 108 | def number_of_actions(self): 109 | return len(self.legal_actions) 110 | 111 | def receive_action(self, action): 112 | assert not self.is_terminal 113 | 114 | rewards = [] 115 | for i in range(4): 116 | 117 | # Last screeen must be stored before executing the 4th action 118 | if i == 3: 119 | self.last_raw_screen = self.ale.getScreenRGB() 120 | 121 | rewards.append(self.ale.act(self.legal_actions[action])) 122 | 123 | # Check if lives are lost 124 | if self.lives > self.ale.lives(): 125 | self.lives_lost = True 126 | else: 127 | self.lives_lost = False 128 | self.lives = self.ale.lives() 129 | 130 | if self.is_terminal: 131 | break 132 | 133 | # We must have last screen here unless it's terminal 134 | if not self.is_terminal: 135 | self.last_screens.append(self.current_screen()) 136 | 137 | self._reward = sum(rewards) 138 | 139 | return self._reward 140 | 141 | def initialize(self): 142 | 143 | if self.ale.game_over(): 144 | self.ale.reset_game() 145 | 146 | if self.max_start_nullops > 0: 147 | n_nullops = np.random.randint(0, self.max_start_nullops + 1) 148 | for _ in range(n_nullops): 149 | self.ale.act(0) 150 | 151 | self._reward = 0 152 | 153 | self.last_raw_screen = self.ale.getScreenRGB() 154 | 155 | self.last_screens = collections.deque( 156 | [np.zeros((84, 84), dtype=np.uint8)] * 3 + 157 | [self.current_screen()], 158 | maxlen=self.n_last_screens) 159 | 160 | self.lives_lost = False 161 | self.lives = self.ale.lives() 162 | -------------------------------------------------------------------------------- /async.py: -------------------------------------------------------------------------------- 1 | import multiprocessing as mp 2 | import os 3 | import random 4 | 5 | import chainer 6 | import numpy as np 7 | 8 | import random_seed 9 | 10 | 11 | def set_shared_params(a, b): 12 | """ 13 | Args: 14 | a (chainer.Link): link whose params are to be replaced 15 | b (dict): dict that consists of (param_name, multiprocessing.Array) 16 | """ 17 | assert isinstance(a, chainer.Link) 18 | for param_name, param in a.namedparams(): 19 | if param_name in b: 20 | shared_param = b[param_name] 21 | param.data = np.frombuffer( 22 | shared_param, dtype=param.data.dtype).reshape(param.data.shape) 23 | 24 | 25 | def set_shared_states(a, b): 26 | assert isinstance(a, chainer.Optimizer) 27 | assert hasattr(a, 'target'), 'Optimizer.setup must be called first' 28 | for state_name, shared_state in b.items(): 29 | for param_name, param in shared_state.items(): 30 | old_param = a._states[state_name][param_name] 31 | a._states[state_name][param_name] = np.frombuffer( 32 | param, 33 | dtype=old_param.dtype).reshape(old_param.shape) 34 | 35 | 36 | def extract_params_as_shared_arrays(link): 37 | assert isinstance(link, chainer.Link) 38 | shared_arrays = {} 39 | for param_name, param in link.namedparams(): 40 | shared_arrays[param_name] = mp.RawArray('f', param.data.ravel()) 41 | return shared_arrays 42 | 43 | 44 | def share_params_as_shared_arrays(link): 45 | shared_arrays = extract_params_as_shared_arrays(link) 46 | set_shared_params(link, shared_arrays) 47 | return shared_arrays 48 | 49 | 50 | def share_states_as_shared_arrays(link): 51 | shared_arrays = extract_states_as_shared_arrays(link) 52 | set_shared_states(link, shared_arrays) 53 | return shared_arrays 54 | 55 | 56 | def extract_states_as_shared_arrays(optimizer): 57 | assert isinstance(optimizer, chainer.Optimizer) 58 | assert hasattr(optimizer, 'target'), 'Optimizer.setup must be called first' 59 | shared_arrays = {} 60 | for state_name, state in optimizer._states.items(): 61 | shared_arrays[state_name] = {} 62 | for param_name, param in state.items(): 63 | shared_arrays[state_name][ 64 | param_name] = mp.RawArray('f', param.ravel()) 65 | return shared_arrays 66 | 67 | 68 | def run_async(n_process, run_func): 69 | """Run experiments asynchronously. 70 | 71 | Args: 72 | n_process (int): number of processes 73 | run_func: function that will be run in parallel 74 | """ 75 | 76 | processes = [] 77 | 78 | def set_seed_and_run(process_idx, run_func): 79 | random_seed.set_random_seed(np.random.randint(0, 2 ** 32)) 80 | run_func(process_idx) 81 | 82 | for process_idx in range(n_process): 83 | processes.append(mp.Process(target=set_seed_and_run, args=( 84 | process_idx, run_func))) 85 | 86 | for p in processes: 87 | p.start() 88 | 89 | for p in processes: 90 | p.join() 91 | -------------------------------------------------------------------------------- /copy_param.py: -------------------------------------------------------------------------------- 1 | def copy_param(target_link, source_link): 2 | """Copy parameters of a link to another link. 3 | """ 4 | target_params = dict(target_link.namedparams()) 5 | for param_name, param in source_link.namedparams(): 6 | target_params[param_name].data[:] = param.data 7 | 8 | 9 | def copy_grad(target_link, source_link): 10 | """Copy gradients of a link to another link. 11 | """ 12 | target_params = dict(target_link.namedparams()) 13 | for param_name, param in source_link.namedparams(): 14 | target_params[param_name].grad[:] = param.grad 15 | -------------------------------------------------------------------------------- /demo_a3c_ale.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | 4 | import numpy as np 5 | import chainer 6 | from chainer import serializers 7 | 8 | import ale 9 | import random_seed 10 | from dqn_phi import dqn_phi 11 | from a3c_ale import A3CFF 12 | from a3c_ale import A3CLSTM 13 | 14 | 15 | def eval_performance(rom, model, deterministic=False, use_sdl=False, 16 | record_screen_dir=None): 17 | env = ale.ALE(rom, treat_life_lost_as_terminal=False, use_sdl=use_sdl, 18 | record_screen_dir=record_screen_dir) 19 | model.reset_state() 20 | test_r = 0 21 | while not env.is_terminal: 22 | s = chainer.Variable(np.expand_dims(dqn_phi(env.state), 0)) 23 | pout = model.pi_and_v(s)[0] 24 | model.unchain_backward() 25 | if deterministic: 26 | a = pout.most_probable_actions[0] 27 | else: 28 | a = pout.action_indices[0] 29 | test_r += env.receive_action(a) 30 | return test_r 31 | 32 | 33 | def main(): 34 | 35 | import logging 36 | logging.basicConfig(level=logging.DEBUG) 37 | 38 | parser = argparse.ArgumentParser() 39 | parser.add_argument('rom', type=str) 40 | parser.add_argument('model', type=str) 41 | parser.add_argument('--seed', type=int, default=0) 42 | parser.add_argument('--use-sdl', action='store_true') 43 | parser.add_argument('--n-runs', type=int, default=10) 44 | parser.add_argument('--deterministic', action='store_true') 45 | parser.add_argument('--record-screen-dir', type=str, default=None) 46 | parser.add_argument('--use-lstm', action='store_true') 47 | parser.set_defaults(use_sdl=False) 48 | parser.set_defaults(use_lstm=False) 49 | parser.set_defaults(deterministic=False) 50 | args = parser.parse_args() 51 | 52 | random_seed.set_random_seed(args.seed) 53 | 54 | n_actions = ale.ALE(args.rom).number_of_actions 55 | 56 | # Load an A3C-DQN model 57 | if args.use_lstm: 58 | model = A3CLSTM(n_actions) 59 | else: 60 | model = A3CFF(n_actions) 61 | serializers.load_hdf5(args.model, model) 62 | 63 | scores = [] 64 | for i in range(args.n_runs): 65 | episode_record_dir = None 66 | if args.record_screen_dir is not None: 67 | episode_record_dir = os.path.join(args.record_screen_dir, str(i)) 68 | os.makedirs(episode_record_dir) 69 | score = eval_performance( 70 | args.rom, model, deterministic=args.deterministic, 71 | use_sdl=args.use_sdl, record_screen_dir=episode_record_dir) 72 | print('Run {}: {}'.format(i, score)) 73 | scores.append(score) 74 | print('Average: {}'.format(sum(scores) / args.n_runs)) 75 | 76 | 77 | if __name__ == '__main__': 78 | main() 79 | -------------------------------------------------------------------------------- /demo_a3c_doom.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import chainer 4 | from chainer import serializers 5 | import numpy as np 6 | 7 | import random_seed 8 | import doom_env 9 | from train_a3c_doom import phi, A3CFF, A3CLSTM 10 | 11 | 12 | def eval_single_run(env, model, phi, deterministic=False): 13 | model.reset_state() 14 | test_r = 0 15 | obs = env.reset() 16 | done = False 17 | while not done: 18 | s = chainer.Variable(np.expand_dims(phi(obs), 0)) 19 | pout = model.pi_and_v(s)[0] 20 | model.unchain_backward() 21 | if deterministic: 22 | a = pout.most_probable_actions[0] 23 | else: 24 | a = pout.action_indices[0] 25 | obs, r, done, info = env.step(a) 26 | test_r += r 27 | return test_r 28 | 29 | 30 | def eval_single_random_run(env): 31 | test_r = 0 32 | obs = env.reset() 33 | done = False 34 | while not done: 35 | a = np.random.randint(env.n_actions) 36 | obs, r, done, info = env.step(a) 37 | test_r += r 38 | return test_r 39 | 40 | 41 | def main(): 42 | import logging 43 | logging.basicConfig(level=logging.DEBUG) 44 | 45 | parser = argparse.ArgumentParser() 46 | parser.add_argument('model', type=str) 47 | parser.add_argument('--seed', type=int, default=0) 48 | parser.add_argument('--sleep', type=float, default=0) 49 | parser.add_argument('--scenario', type=str, default='basic') 50 | parser.add_argument('--n-runs', type=int, default=10) 51 | parser.add_argument('--use-lstm', action='store_true') 52 | parser.add_argument('--window-visible', action='store_true') 53 | parser.add_argument('--deterministic', action='store_true') 54 | parser.add_argument('--random', action='store_true') 55 | parser.set_defaults(window_visible=False) 56 | parser.set_defaults(use_lstm=False) 57 | parser.set_defaults(deterministic=False) 58 | parser.set_defaults(random=False) 59 | args = parser.parse_args() 60 | 61 | random_seed.set_random_seed(args.seed) 62 | 63 | n_actions = doom_env.DoomEnv( 64 | window_visible=False, scenario=args.scenario).n_actions 65 | 66 | if not args.random: 67 | if args.use_lstm: 68 | model = A3CLSTM(n_actions) 69 | else: 70 | model = A3CFF(n_actions) 71 | serializers.load_hdf5(args.model, model) 72 | 73 | scores = [] 74 | env = doom_env.DoomEnv(window_visible=args.window_visible, 75 | scenario=args.scenario, 76 | sleep=args.sleep) 77 | for i in range(args.n_runs): 78 | if args.random: 79 | score = eval_single_random_run(env) 80 | else: 81 | score = eval_single_run( 82 | env, model, phi, deterministic=args.deterministic) 83 | print('Run {}: {}'.format(i, score)) 84 | scores.append(score) 85 | print('Average: {}'.format(sum(scores) / args.n_runs)) 86 | 87 | 88 | if __name__ == '__main__': 89 | main() 90 | -------------------------------------------------------------------------------- /doom_env.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import time 4 | 5 | import numpy as np 6 | 7 | 8 | class DoomEnv(object): 9 | 10 | def __init__(self, vizdoom_dir=os.path.expanduser('~/ViZDoom'), 11 | window_visible=True, scenario='basic', skipcount=10, 12 | resolution_width=640, sleep=0.0, seed=None): 13 | 14 | self.skipcount = skipcount 15 | self.sleep = sleep 16 | 17 | sys.path.append(os.path.join(vizdoom_dir, "examples/python")) 18 | from vizdoom import DoomGame 19 | from vizdoom import ScreenFormat 20 | from vizdoom import ScreenResolution 21 | 22 | game = DoomGame() 23 | 24 | if seed is not None: 25 | assert seed >= 0 and seed < 2 ** 16, \ 26 | "ViZDoom's random seed must be represented by unsigned int" 27 | else: 28 | # Use numpy's random state 29 | seed = np.random.randint(0, 2 ** 16) 30 | game.set_seed(seed) 31 | 32 | # Load a config file 33 | game.load_config(os.path.join( 34 | vizdoom_dir, "examples", 'config', scenario + '.cfg')) 35 | 36 | # Replace default relative paths with actual paths 37 | game.set_vizdoom_path(os.path.join(vizdoom_dir, "bin/vizdoom")) 38 | game.set_doom_game_path( 39 | os.path.join(vizdoom_dir, 'scenarios/freedoom2.wad')) 40 | game.set_doom_scenario_path( 41 | os.path.join(vizdoom_dir, 'scenarios', scenario + '.wad')) 42 | 43 | # Set screen settings 44 | resolutions = {640: ScreenResolution.RES_640X480, 45 | 320: ScreenResolution.RES_320X240, 46 | 160: ScreenResolution.RES_160X120} 47 | game.set_screen_resolution(resolutions[resolution_width]) 48 | game.set_screen_format(ScreenFormat.RGB24) 49 | game.set_window_visible(window_visible) 50 | game.set_sound_enabled(window_visible) 51 | 52 | game.init() 53 | self.game = game 54 | 55 | # Use one-hot actions 56 | self.n_actions = game.get_available_buttons_size() 57 | self.actions = [] 58 | for i in range(self.n_actions): 59 | self.actions.append([i == j for j in range(self.n_actions)]) 60 | 61 | def reset(self): 62 | self.game.new_episode() 63 | return self.game.get_state() 64 | 65 | def step(self, action): 66 | r = self.game.make_action(self.actions[action], self.skipcount) 67 | r /= 100 68 | time.sleep(self.sleep * self.skipcount) 69 | return self.game.get_state(), r, self.game.is_episode_finished(), None 70 | -------------------------------------------------------------------------------- /dqn_head.py: -------------------------------------------------------------------------------- 1 | import chainer 2 | from chainer import functions as F 3 | from chainer import links as L 4 | 5 | 6 | class NatureDQNHead(chainer.ChainList): 7 | """DQN's head (Nature version)""" 8 | 9 | def __init__(self, n_input_channels=4, n_output_channels=512, 10 | activation=F.relu, bias=0.1): 11 | self.n_input_channels = n_input_channels 12 | self.activation = activation 13 | self.n_output_channels = n_output_channels 14 | 15 | layers = [ 16 | L.Convolution2D(n_input_channels, 32, 8, stride=4, bias=bias), 17 | L.Convolution2D(32, 64, 4, stride=2, bias=bias), 18 | L.Convolution2D(64, 64, 3, stride=1, bias=bias), 19 | L.Linear(3136, n_output_channels, bias=bias), 20 | ] 21 | 22 | super(NatureDQNHead, self).__init__(*layers) 23 | 24 | def __call__(self, state): 25 | h = state 26 | for layer in self: 27 | h = self.activation(layer(h)) 28 | return h 29 | 30 | 31 | class NIPSDQNHead(chainer.ChainList): 32 | """DQN's head (NIPS workshop version)""" 33 | 34 | def __init__(self, n_input_channels=4, n_output_channels=256, 35 | activation=F.relu, bias=0.1): 36 | self.n_input_channels = n_input_channels 37 | self.activation = activation 38 | self.n_output_channels = n_output_channels 39 | 40 | layers = [ 41 | L.Convolution2D(n_input_channels, 16, 8, stride=4, bias=bias), 42 | L.Convolution2D(16, 32, 4, stride=2, bias=bias), 43 | L.Linear(2592, n_output_channels, bias=bias), 44 | ] 45 | 46 | super(NIPSDQNHead, self).__init__(*layers) 47 | 48 | def __call__(self, state): 49 | h = state 50 | for layer in self: 51 | h = self.activation(layer(h)) 52 | return h 53 | -------------------------------------------------------------------------------- /dqn_phi.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def dqn_phi(screens): 5 | """Phi (feature extractor) of DQN for ALE 6 | Args: 7 | screens: List of N screen objects. Each screen object must be 8 | numpy.ndarray whose dtype is numpy.uint8. 9 | Returns: 10 | numpy.ndarray 11 | """ 12 | assert len(screens) == 4 13 | assert screens[0].dtype == np.uint8 14 | raw_values = np.asarray(screens, dtype=np.float32) 15 | # [0,255] -> [0, 1] 16 | raw_values /= 255.0 17 | return raw_values 18 | -------------------------------------------------------------------------------- /environment.py: -------------------------------------------------------------------------------- 1 | class Environment(object): 2 | """RL learning environment 3 | """ 4 | 5 | @property 6 | def state(self): 7 | pass 8 | 9 | @property 10 | def reward(self): 11 | pass 12 | 13 | def receive_action(self, action): 14 | pass 15 | 16 | class EpisodicEnvironment(Environment): 17 | 18 | def initialize(self): 19 | """ 20 | Initialize the internal state 21 | """ 22 | pass 23 | 24 | @property 25 | def is_terminal(self): 26 | pass 27 | 28 | -------------------------------------------------------------------------------- /init_like_torch.py: -------------------------------------------------------------------------------- 1 | from chainer import links as L 2 | import numpy as np 3 | 4 | 5 | def init_like_torch(link): 6 | # Mimic torch's default parameter initialization 7 | # TODO(muupan): Use chainer's initializers when it is merged 8 | for l in link.links(): 9 | if isinstance(l, L.Linear): 10 | out_channels, in_channels = l.W.data.shape 11 | stdv = 1 / np.sqrt(in_channels) 12 | l.W.data[:] = np.random.uniform(-stdv, stdv, size=l.W.data.shape) 13 | if l.b is not None: 14 | l.b.data[:] = np.random.uniform(-stdv, stdv, 15 | size=l.b.data.shape) 16 | elif isinstance(l, L.Convolution2D): 17 | out_channels, in_channels, kh, kw = l.W.data.shape 18 | stdv = 1 / np.sqrt(in_channels * kh * kw) 19 | l.W.data[:] = np.random.uniform(-stdv, stdv, size=l.W.data.shape) 20 | if l.b is not None: 21 | l.b.data[:] = np.random.uniform(-stdv, stdv, 22 | size=l.b.data.shape) 23 | -------------------------------------------------------------------------------- /nonbias_weight_decay.py: -------------------------------------------------------------------------------- 1 | from chainer import cuda 2 | 3 | 4 | class NonbiasWeightDecay(object): 5 | 6 | """Optimizer hook function for weight decay regularization. 7 | 8 | """ 9 | name = 'NonbiasWeightDecay' 10 | 11 | def __init__(self, rate): 12 | self.rate = rate 13 | 14 | def __call__(self, opt): 15 | if cuda.available: 16 | kernel = cuda.elementwise( 17 | 'T p, T decay', 'T g', 'g += decay * p', 'weight_decay') 18 | 19 | rate = self.rate 20 | for name, param in opt.target.namedparams(): 21 | if name == 'b' or name.endswith('/b'): 22 | continue 23 | p, g = param.data, param.grad 24 | with cuda.get_device(p) as dev: 25 | if int(dev) == -1: 26 | g += rate * p 27 | else: 28 | kernel(p, rate, g) 29 | -------------------------------------------------------------------------------- /plot_scores.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import matplotlib.pyplot as plt 4 | import pandas as pd 5 | 6 | 7 | def main(): 8 | parser = argparse.ArgumentParser() 9 | parser.add_argument('scores', type=str, help='specify path of scores.txt') 10 | parser.add_argument('--title', type=str, default=None) 11 | args = parser.parse_args() 12 | 13 | scores = pd.read_csv(args.scores, delimiter='\t') 14 | for col in ['mean', 'median']: 15 | plt.plot(scores['steps'], scores[col], label=col) 16 | if args.title is not None: 17 | plt.title(args.title) 18 | plt.xlabel('steps') 19 | plt.ylabel('score') 20 | plt.legend(loc='best') 21 | fig_fname = args.scores + '.png' 22 | plt.savefig(fig_fname) 23 | print('Saved a figure as {}'.format(fig_fname)) 24 | 25 | if __name__ == '__main__': 26 | main() 27 | -------------------------------------------------------------------------------- /policy.py: -------------------------------------------------------------------------------- 1 | from logging import getLogger 2 | logger = getLogger(__name__) 3 | 4 | import chainer 5 | from chainer import functions as F 6 | from chainer import links as L 7 | 8 | import policy_output 9 | 10 | 11 | class Policy(object): 12 | """Abstract policy class.""" 13 | 14 | def __call__(self, state): 15 | raise NotImplementedError 16 | 17 | 18 | class SoftmaxPolicy(Policy): 19 | """Abstract softmax policy class.""" 20 | 21 | def compute_logits(self, state): 22 | """ 23 | Returns: 24 | ~chainer.Variable: logits of actions 25 | """ 26 | raise NotImplementedError 27 | 28 | def __call__(self, state): 29 | return policy_output.SoftmaxPolicyOutput(self.compute_logits(state)) 30 | 31 | 32 | class FCSoftmaxPolicy(chainer.ChainList, SoftmaxPolicy): 33 | """Softmax policy that consists of FC layers and rectifiers""" 34 | 35 | def __init__(self, n_input_channels, n_actions, 36 | n_hidden_layers=0, n_hidden_channels=None): 37 | self.n_input_channels = n_input_channels 38 | self.n_actions = n_actions 39 | self.n_hidden_layers = n_hidden_layers 40 | self.n_hidden_channels = n_hidden_channels 41 | 42 | layers = [] 43 | if n_hidden_layers > 0: 44 | layers.append(L.Linear(n_input_channels, n_hidden_channels)) 45 | for i in range(n_hidden_layers - 1): 46 | layers.append(L.Linear(n_hidden_channels, n_hidden_channels)) 47 | layers.append(L.Linear(n_hidden_channels, n_actions)) 48 | else: 49 | layers.append(L.Linear(n_input_channels, n_actions)) 50 | 51 | super(FCSoftmaxPolicy, self).__init__(*layers) 52 | 53 | def compute_logits(self, state): 54 | h = state 55 | for layer in self[:-1]: 56 | h = F.relu(layer(h)) 57 | h = self[-1](h) 58 | return h 59 | 60 | 61 | class GaussianPolicy(Policy): 62 | """Abstract Gaussian policy class. 63 | """ 64 | pass 65 | -------------------------------------------------------------------------------- /policy_output.py: -------------------------------------------------------------------------------- 1 | import chainer 2 | from chainer import functions as F 3 | from cached_property import cached_property 4 | import numpy as np 5 | 6 | 7 | class PolicyOutput(object): 8 | """Struct that holds policy output and subproducts.""" 9 | pass 10 | 11 | 12 | def _sample_discrete_actions(batch_probs): 13 | """Sample a batch of actions from a batch of action probabilities. 14 | 15 | Args: 16 | batch_probs (ndarray): batch of action probabilities BxA 17 | Returns: 18 | List consisting of sampled actions 19 | """ 20 | action_indices = [] 21 | 22 | # Subtract a tiny value from probabilities in order to avoid 23 | # "ValueError: sum(pvals[:-1]) > 1.0" in numpy.multinomial 24 | batch_probs = batch_probs - np.finfo(np.float32).epsneg 25 | 26 | for i in range(batch_probs.shape[0]): 27 | histogram = np.random.multinomial(1, batch_probs[i]) 28 | action_indices.append(int(np.nonzero(histogram)[0])) 29 | return action_indices 30 | 31 | 32 | class SoftmaxPolicyOutput(PolicyOutput): 33 | 34 | def __init__(self, logits): 35 | self.logits = logits 36 | 37 | @cached_property 38 | def most_probable_actions(self): 39 | return np.argmax(self.probs.data, axis=1) 40 | 41 | @cached_property 42 | def probs(self): 43 | return F.softmax(self.logits) 44 | 45 | @cached_property 46 | def log_probs(self): 47 | return F.log_softmax(self.logits) 48 | 49 | @cached_property 50 | def action_indices(self): 51 | return _sample_discrete_actions(self.probs.data) 52 | 53 | @cached_property 54 | def sampled_actions_log_probs(self): 55 | return F.select_item( 56 | self.log_probs, 57 | chainer.Variable(np.asarray(self.action_indices, dtype=np.int32))) 58 | 59 | @cached_property 60 | def entropy(self): 61 | return - F.sum(self.probs * self.log_probs, axis=1) 62 | -------------------------------------------------------------------------------- /prepare_output_dir.py: -------------------------------------------------------------------------------- 1 | import os 2 | import tempfile 3 | import json 4 | import subprocess 5 | 6 | 7 | def prepare_output_dir(args, user_specified_dir=None): 8 | """Prepare output directory. 9 | 10 | An output directory is created if it does not exist. Then the following 11 | infomation is saved into the directory: 12 | args.txt: command-line arguments 13 | git-status.txt: result of `git status` 14 | git-log.txt: result of `git log` 15 | git-diff.txt: result of `git diff` 16 | 17 | Args: 18 | args: dict that describes command-line arguments 19 | user_specified_dir: directory path 20 | """ 21 | if user_specified_dir is not None: 22 | if os.path.exists(user_specified_dir): 23 | if not os.path.isdir(user_specified_dir): 24 | raise RuntimeError( 25 | '{} is not a directory'.format(user_specified_dir)) 26 | else: 27 | os.makedirs(user_specified_dir) 28 | outdir = user_specified_dir 29 | else: 30 | outdir = tempfile.mkdtemp() 31 | 32 | # Save all the arguments 33 | with open(os.path.join(outdir, 'args.txt'), 'w') as f: 34 | f.write(json.dumps(vars(args))) 35 | 36 | # Save `git status` 37 | with open(os.path.join(outdir, 'git-status.txt'), 'w') as f: 38 | f.write(subprocess.getoutput('git status')) 39 | 40 | # Save `git log` 41 | with open(os.path.join(outdir, 'git-log.txt'), 'w') as f: 42 | f.write(subprocess.getoutput('git log')) 43 | 44 | # Save `git diff` 45 | with open(os.path.join(outdir, 'git-diff.txt'), 'w') as f: 46 | f.write(subprocess.getoutput('git diff')) 47 | 48 | return outdir 49 | -------------------------------------------------------------------------------- /random_seed.py: -------------------------------------------------------------------------------- 1 | import random 2 | import numpy as np 3 | 4 | 5 | def set_random_seed(seed): 6 | random.seed(seed) 7 | np.random.seed(seed) 8 | -------------------------------------------------------------------------------- /rmsprop_async.py: -------------------------------------------------------------------------------- 1 | import numpy 2 | 3 | from chainer import cuda 4 | from chainer import optimizer 5 | 6 | 7 | class RMSpropAsync(optimizer.GradientMethod): 8 | 9 | """RMSprop for asynchronous methods. 10 | 11 | The only difference from chainer.optimizers.RMSprop in that the epsilon is 12 | outside the square root.""" 13 | 14 | def __init__(self, lr=0.01, alpha=0.99, eps=1e-8): 15 | self.lr = lr 16 | self.alpha = alpha 17 | self.eps = eps 18 | 19 | def init_state(self, param, state): 20 | xp = cuda.get_array_module(param.data) 21 | state['ms'] = xp.zeros_like(param.data) 22 | 23 | def update_one_cpu(self, param, state): 24 | ms = state['ms'] 25 | grad = param.grad 26 | 27 | ms *= self.alpha 28 | ms += (1 - self.alpha) * grad * grad 29 | param.data -= self.lr * grad / numpy.sqrt(ms + self.eps) 30 | 31 | def update_one_gpu(self, param, state): 32 | cuda.elementwise( 33 | 'T grad, T lr, T alpha, T eps', 34 | 'T param, T ms', 35 | '''ms = alpha * ms + (1 - alpha) * grad * grad; 36 | param -= lr * grad / sqrt(ms + eps);''', 37 | 'rmsprop')(param.grad, self.lr, self.alpha, self.eps, 38 | param.data, state['ms']) 39 | -------------------------------------------------------------------------------- /run_a3c.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import copy 3 | import multiprocessing as mp 4 | import os 5 | import sys 6 | import statistics 7 | import time 8 | 9 | import chainer 10 | from chainer import links as L 11 | from chainer import functions as F 12 | import cv2 13 | import numpy as np 14 | 15 | import a3c 16 | import random_seed 17 | import async 18 | from prepare_output_dir import prepare_output_dir 19 | 20 | 21 | def eval_performance(process_idx, make_env, model, phi, n_runs): 22 | assert n_runs > 1, 'Computing stdev requires at least two runs' 23 | scores = [] 24 | for i in range(n_runs): 25 | model.reset_state() 26 | env = make_env(process_idx, test=True) 27 | obs = env.reset() 28 | done = False 29 | test_r = 0 30 | while not done: 31 | s = chainer.Variable(np.expand_dims(phi(obs), 0)) 32 | pout, _ = model.pi_and_v(s) 33 | a = pout.action_indices[0] 34 | obs, r, done, info = env.step(a) 35 | test_r += r 36 | scores.append(test_r) 37 | print('test_{}:'.format(i), test_r) 38 | mean = statistics.mean(scores) 39 | median = statistics.median(scores) 40 | stdev = statistics.stdev(scores) 41 | return mean, median, stdev 42 | 43 | 44 | def train_loop(process_idx, counter, make_env, max_score, args, agent, env, 45 | start_time, outdir): 46 | try: 47 | 48 | total_r = 0 49 | episode_r = 0 50 | global_t = 0 51 | local_t = 0 52 | obs = env.reset() 53 | r = 0 54 | done = False 55 | 56 | while True: 57 | 58 | # Get and increment the global counter 59 | with counter.get_lock(): 60 | counter.value += 1 61 | global_t = counter.value 62 | local_t += 1 63 | 64 | if global_t > args.steps: 65 | break 66 | 67 | agent.optimizer.lr = ( 68 | args.steps - global_t - 1) / args.steps * args.lr 69 | 70 | total_r += r 71 | episode_r += r 72 | 73 | a = agent.act(obs, r, done) 74 | 75 | if done: 76 | if process_idx == 0: 77 | print('{} global_t:{} local_t:{} lr:{} r:{}'.format( 78 | outdir, global_t, local_t, agent.optimizer.lr, 79 | episode_r)) 80 | episode_r = 0 81 | obs = env.reset() 82 | r = 0 83 | done = False 84 | else: 85 | obs, r, done, info = env.step(a) 86 | 87 | if global_t % args.eval_frequency == 0: 88 | # Evaluation 89 | 90 | # We must use a copy of the model because test runs can change 91 | # the hidden states of the model 92 | test_model = copy.deepcopy(agent.model) 93 | test_model.reset_state() 94 | 95 | mean, median, stdev = eval_performance( 96 | process_idx, make_env, test_model, agent.phi, 97 | args.eval_n_runs) 98 | with open(os.path.join(outdir, 'scores.txt'), 'a+') as f: 99 | elapsed = time.time() - start_time 100 | record = (global_t, elapsed, mean, median, stdev) 101 | print('\t'.join(str(x) for x in record), file=f) 102 | with max_score.get_lock(): 103 | if mean > max_score.value: 104 | # Save the best model so far 105 | print('The best score is updated {} -> {}'.format( 106 | max_score.value, mean)) 107 | filename = os.path.join( 108 | outdir, '{}.h5'.format(global_t)) 109 | agent.save_model(filename) 110 | print('Saved the current best model to {}'.format( 111 | filename)) 112 | max_score.value = mean 113 | 114 | except KeyboardInterrupt: 115 | if process_idx == 0: 116 | # Save the current model before being killed 117 | agent.save_model(os.path.join( 118 | outdir, '{}_keyboardinterrupt.h5'.format(global_t))) 119 | print('Saved the current model to {}'.format( 120 | outdir), file=sys.stderr) 121 | raise 122 | 123 | if global_t == args.steps + 1: 124 | # Save the final model 125 | agent.save_model( 126 | os.path.join(args.outdir, '{}_finish.h5'.format(args.steps))) 127 | print('Saved the final model to {}'.format(args.outdir)) 128 | 129 | 130 | def train_loop_with_profile(process_idx, counter, make_env, max_score, args, 131 | agent, env, start_time, outdir): 132 | import cProfile 133 | cmd = 'train_loop(process_idx, counter, make_env, max_score, args, ' \ 134 | 'agent, env, start_time)' 135 | cProfile.runctx(cmd, globals(), locals(), 136 | 'profile-{}.out'.format(os.getpid())) 137 | 138 | 139 | def run_a3c(processes, make_env, model_opt, phi, t_max=1, beta=1e-2, 140 | profile=False, steps=8 * 10 ** 7, eval_frequency=10 ** 6, 141 | eval_n_runs=10, args={}): 142 | 143 | # Prevent numpy from using multiple threads 144 | os.environ['OMP_NUM_THREADS'] = '1' 145 | 146 | outdir = prepare_output_dir(args, None) 147 | 148 | print('Output files are saved in {}'.format(outdir)) 149 | 150 | n_actions = 20 * 20 151 | 152 | model, opt = model_opt() 153 | 154 | shared_params = async.share_params_as_shared_arrays(model) 155 | shared_states = async.share_states_as_shared_arrays(opt) 156 | 157 | max_score = mp.Value('f', np.finfo(np.float32).min) 158 | counter = mp.Value('l', 0) 159 | start_time = time.time() 160 | 161 | # Write a header line first 162 | with open(os.path.join(outdir, 'scores.txt'), 'a+') as f: 163 | column_names = ('steps', 'elapsed', 'mean', 'median', 'stdev') 164 | print('\t'.join(column_names), file=f) 165 | 166 | def run_func(process_idx): 167 | env = make_env(process_idx, test=False) 168 | model, opt = model_opt() 169 | async.set_shared_params(model, shared_params) 170 | async.set_shared_states(opt, shared_states) 171 | 172 | agent = a3c.A3C(model, opt, t_max, 0.99, beta=beta, 173 | process_idx=process_idx, phi=phi) 174 | 175 | if profile: 176 | train_loop_with_profile(process_idx, counter, make_env, max_score, 177 | args, agent, env, start_time, 178 | outdir=outdir) 179 | else: 180 | train_loop(process_idx, counter, make_env, max_score, 181 | args, agent, env, start_time, outdir=outdir) 182 | 183 | async.run_async(processes, run_func) 184 | -------------------------------------------------------------------------------- /train_a3c_doom.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import multiprocessing as mp 3 | 4 | import chainer 5 | from chainer import links as L 6 | from chainer import functions as F 7 | import cv2 8 | import numpy as np 9 | 10 | import policy 11 | import v_function 12 | import dqn_head 13 | import a3c 14 | import random_seed 15 | import rmsprop_async 16 | from init_like_torch import init_like_torch 17 | import run_a3c 18 | import doom_env 19 | 20 | 21 | def phi(obs): 22 | resized = cv2.resize(obs.image_buffer, (84, 84)) 23 | return resized.transpose(2, 0, 1).astype(np.float32) / 255 24 | 25 | 26 | class A3CFF(chainer.ChainList, a3c.A3CModel): 27 | 28 | def __init__(self, n_actions): 29 | self.head = dqn_head.NIPSDQNHead(n_input_channels=3) 30 | self.pi = policy.FCSoftmaxPolicy( 31 | self.head.n_output_channels, n_actions) 32 | self.v = v_function.FCVFunction(self.head.n_output_channels) 33 | super().__init__(self.head, self.pi, self.v) 34 | init_like_torch(self) 35 | 36 | def pi_and_v(self, state, keep_same_state=False): 37 | out = self.head(state) 38 | return self.pi(out), self.v(out) 39 | 40 | 41 | class A3CLSTM(chainer.ChainList, a3c.A3CModel): 42 | 43 | def __init__(self, n_actions): 44 | self.head = dqn_head.NIPSDQNHead(n_input_channels=3) 45 | self.pi = policy.FCSoftmaxPolicy( 46 | self.head.n_output_channels, n_actions) 47 | self.v = v_function.FCVFunction(self.head.n_output_channels) 48 | self.lstm = L.LSTM(self.head.n_output_channels, 49 | self.head.n_output_channels) 50 | super().__init__(self.head, self.lstm, self.pi, self.v) 51 | init_like_torch(self) 52 | 53 | def pi_and_v(self, state, keep_same_state=False): 54 | out = self.head(state) 55 | if keep_same_state: 56 | prev_h, prev_c = self.lstm.h, self.lstm.c 57 | out = self.lstm(out) 58 | self.lstm.h, self.lstm.c = prev_h, prev_c 59 | else: 60 | out = self.lstm(out) 61 | return self.pi(out), self.v(out) 62 | 63 | def reset_state(self): 64 | self.lstm.reset_state() 65 | 66 | def unchain_backward(self): 67 | self.lstm.h.unchain_backward() 68 | self.lstm.c.unchain_backward() 69 | 70 | 71 | def main(): 72 | import logging 73 | logging.basicConfig(level=logging.DEBUG) 74 | 75 | parser = argparse.ArgumentParser() 76 | parser.add_argument('processes', type=int) 77 | parser.add_argument('--seed', type=int, default=None) 78 | parser.add_argument('--outdir', type=str, default=None) 79 | parser.add_argument('--scenario', type=str, default='basic') 80 | parser.add_argument('--t-max', type=int, default=5) 81 | parser.add_argument('--beta', type=float, default=1e-2) 82 | parser.add_argument('--profile', action='store_true') 83 | parser.add_argument('--steps', type=int, default=8 * 10 ** 7) 84 | parser.add_argument('--lr', type=float, default=7e-4) 85 | parser.add_argument('--eval-frequency', type=int, default=10 ** 5) 86 | parser.add_argument('--eval-n-runs', type=int, default=10) 87 | parser.add_argument('--use-lstm', action='store_true') 88 | parser.add_argument('--window-visible', action='store_true') 89 | parser.set_defaults(window_visible=False) 90 | parser.set_defaults(use_lstm=False) 91 | args = parser.parse_args() 92 | 93 | if args.seed is not None: 94 | random_seed.set_random_seed(args.seed) 95 | 96 | # Simultaneously launching multiple vizdoom processes makes program stuck, 97 | # so use the global lock 98 | env_lock = mp.Lock() 99 | 100 | def make_env(process_idx, test): 101 | with env_lock: 102 | return doom_env.DoomEnv(window_visible=args.window_visible, 103 | scenario=args.scenario) 104 | 105 | n_actions = 3 106 | 107 | def model_opt(): 108 | if args.use_lstm: 109 | model = A3CLSTM(n_actions) 110 | else: 111 | model = A3CFF(n_actions) 112 | opt = rmsprop_async.RMSpropAsync(lr=args.lr, eps=1e-1, alpha=0.99) 113 | opt.setup(model) 114 | opt.add_hook(chainer.optimizer.GradientClipping(40)) 115 | return model, opt 116 | 117 | run_a3c.run_a3c(args.processes, make_env, model_opt, phi, t_max=args.t_max, 118 | beta=args.beta, profile=args.profile, steps=args.steps, 119 | eval_frequency=args.eval_frequency, 120 | eval_n_runs=args.eval_n_runs, args=args) 121 | 122 | 123 | if __name__ == '__main__': 124 | main() 125 | -------------------------------------------------------------------------------- /trained_model/breakout_ff/80000000_finish.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/muupan/async-rl/5c5844fe7d62055ef911c7abfa6200fb30f6db84/trained_model/breakout_ff/80000000_finish.h5 -------------------------------------------------------------------------------- /trained_model/breakout_ff/animation.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/muupan/async-rl/5c5844fe7d62055ef911c7abfa6200fb30f6db84/trained_model/breakout_ff/animation.gif -------------------------------------------------------------------------------- /trained_model/breakout_ff/scores.txt: -------------------------------------------------------------------------------- 1 | steps elapsed mean median stdev 2 | 1000000 799.1865322589874 1.2 1.5 1.1352924243950933 3 | 2000000 1575.7219505310059 1.5 0.0 2.273030282830976 4 | 3000000 2359.6249346733093 2.1 3.0 1.911950719959998 5 | 4000000 3109.272673845291 1.6 2.0 1.4298407059684812 6 | 5000000 3859.516502380371 2.0 2.0 2.0548046676563256 7 | 6000000 4602.4465317726135 1.8 1.0 1.9888578520235065 8 | 7000000 5348.210222244263 2.5 2.0 0.97182531580755 9 | 8000000 6110.1892149448395 3.5 3.5 2.173067468400883 10 | 9000000 6861.104922533035 5.5 4.5 4.503085362035432 11 | 10000000 7619.673486471176 13.0 12.0 4.2946995755750415 12 | 11000000 8378.648483276367 18.9 19.0 5.586690532964138 13 | 12000000 9167.731533527374 24.5 21.5 11.057928276932246 14 | 13000000 10027.934562683105 73.6 72.5 26.95757984521443 15 | 14000000 10791.431033372879 107.2 97.5 50.78888767349716 16 | 15000000 11552.96664738655 95.8 85.0 71.8374708785201 17 | 16000000 12357.831614017487 251.2 269.0 65.97945808271608 18 | 17000000 13152.097360610962 232.0 249.0 99.43395351242508 19 | 18000000 13918.015005588531 246.4 273.5 100.97216338069506 20 | 19000000 14744.782504558563 362.7 375.5 51.497680638171566 21 | 20000000 15498.439815044403 377.1 388.0 41.420472662145656 22 | 21000000 16318.818217515945 340.7 382.0 113.92204937295209 23 | 22000000 17040.97575569153 248.8 265.0 116.37468224088377 24 | 23000000 17841.953941345215 350.8 361.5 51.100771901106235 25 | 24000000 18606.147523880005 373.8 397.0 61.31847646146759 26 | 25000000 19452.10493350029 375.6 415.0 100.00577761087385 27 | 26000000 20231.817700862885 430.5 412.0 77.02849400636681 28 | 27000000 21134.255215406418 433.8 417.0 100.2505749498614 29 | 28000000 21763.75692296028 299.8 328.0 119.36014596356878 30 | 29000000 22626.56786084175 492.9 412.5 183.22023784384615 31 | 30000000 23399.90938782692 373.9 387.5 211.82091492579292 32 | 31000000 24072.50167989731 363.3 380.5 82.33542507464342 33 | 32000000 24852.809225320816 376.7 387.5 55.63581979664221 34 | 33000000 25615.238379001617 311.6 380.0 137.93412276236154 35 | 34000000 26598.047915697098 435.4 421.5 53.74879843618212 36 | 35000000 27231.892096042633 414.8 389.5 138.989847750754 37 | 36000000 27887.654467582703 324.3 399.5 157.67339661464771 38 | 37000000 28814.492196321487 402.5 419.0 49.80907993975038 39 | 38000000 29534.802098751068 376.4 422.0 109.01804132038573 40 | 39000000 30364.29454421997 435.0 428.0 104.94654724085865 41 | 40000000 31047.65421795845 412.3 419.5 18.1906569425076 42 | 41000000 31737.017671108246 389.5 394.0 65.46457905700694 43 | 42000000 32467.050805568695 409.3 423.5 48.087998040629174 44 | 43000000 33230.65163731575 344.2 414.5 133.11799277332872 45 | 44000000 34009.13779783249 457.3 425.5 115.42294206766502 46 | 45000000 34763.50703406334 410.3 427.5 40.996070272378276 47 | 46000000 35484.54351043701 414.1 414.5 12.269655432995844 48 | 47000000 36283.29621911049 459.1 426.5 145.52010170419754 49 | 48000000 37148.736817359924 536.3 446.5 175.72140197230132 50 | 49000000 37665.218223810196 413.3 417.0 13.216572088774676 51 | 50000000 38502.0075905323 453.7 421.5 133.0806355393434 52 | 51000000 39192.341354846954 468.3 430.5 124.23816733283786 53 | 52000000 40030.84110355377 415.3 422.5 17.888854382299364 54 | 53000000 40686.36572384834 415.9 418.0 73.99316785277469 55 | 54000000 41383.44816946983 448.7 408.0 133.6006071177157 56 | 57000000 43803.362179756165 422.1 427.5 50.61060494929233 57 | 58000000 44451.59215736389 423.0 424.5 17.10100711784088 58 | 59000000 45240.43768501282 409.0 427.5 46.401628323918885 59 | 60000000 46110.308336257935 456.9 421.0 139.6896957943252 60 | 61000000 46937.206632852554 486.3 416.5 196.8372424110844 61 | 62000000 47675.98979473114 419.1 422.5 37.495036708580216 62 | 63000000 48513.08631014824 426.1 412.0 33.69289407443521 63 | 64000000 49318.95145368576 459.1 427.5 151.07352882325583 64 | 65000000 50040.2566986084 466.0 429.5 99.69509070717129 65 | 66000000 50993.864820718765 528.6 472.0 226.0335275032347 66 | 67000000 51737.7411031723 463.6 430.0 127.71774435144954 67 | 68000000 52501.34305858612 422.3 428.0 68.30169999511156 68 | 69000000 53283.210359334946 412.5 422.5 122.31948150460562 69 | 70000000 54116.87993979454 459.0 430.5 146.60832172833847 70 | 71000000 54859.40275526047 423.4 424.5 27.162883172119678 71 | 72000000 55693.9102909565 437.6 424.5 45.792284648544594 72 | 73000000 56525.86289215088 447.0 434.5 49.416596402423345 73 | 74000000 57244.09581017494 469.1 432.0 151.229369575563 74 | 75000000 58063.31164121628 417.6 425.5 31.931175988365975 75 | 76000000 58799.664543151855 416.7 424.0 20.144202364176373 76 | 77000000 59633.22558784485 426.4 427.5 10.490207073477837 77 | 78000000 60394.60213589668 513.1 413.0 195.16229713298168 78 | 79000000 61230.48169326782 489.9 425.0 165.88714369849293 79 | 80000000 61839.64669966698 417.8 422.0 21.627143439052077 80 | -------------------------------------------------------------------------------- /trained_model/breakout_ff/scores.txt.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/muupan/async-rl/5c5844fe7d62055ef911c7abfa6200fb30f6db84/trained_model/breakout_ff/scores.txt.png -------------------------------------------------------------------------------- /trained_model/space_invaders_lstm/80000000_finish.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/muupan/async-rl/5c5844fe7d62055ef911c7abfa6200fb30f6db84/trained_model/space_invaders_lstm/80000000_finish.h5 -------------------------------------------------------------------------------- /trained_model/space_invaders_lstm/animation.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/muupan/async-rl/5c5844fe7d62055ef911c7abfa6200fb30f6db84/trained_model/space_invaders_lstm/animation.gif -------------------------------------------------------------------------------- /trained_model/space_invaders_lstm/scores.txt: -------------------------------------------------------------------------------- 1 | steps elapsed mean median stdev 2 | 1000000 1176.0026223659515 150.5 130.0 71.23708459940106 3 | 2000000 2279.606765270233 218.5 195.0 82.63204920341475 4 | 3000000 3405.7928512096405 203.0 210.0 32.076297929918425 5 | 4000000 4499.463195800781 187.0 180.0 57.4552966323481 6 | 5000000 5575.656061410904 259.0 235.0 136.459192109257 7 | 6000000 6694.54666686058 318.0 287.5 179.53953448877058 8 | 7000000 7828.923171281815 351.0 305.0 181.62537757091596 9 | 8000000 8827.748558998108 333.5 322.5 114.26212554181431 10 | 9000000 9897.704538345337 311.0 322.5 131.90484785969352 11 | 10000000 10987.551244974136 379.0 362.5 88.49984306326324 12 | 11000000 12111.16355252266 490.5 515.0 97.39524286808538 13 | 12000000 13139.165969610214 448.5 482.5 113.38356729849936 14 | 13000000 14198.187840223312 367.0 365.0 54.324130099902305 15 | 14000000 15247.866738319397 386.0 337.5 125.80408048496147 16 | 15000000 16321.860386371613 555.0 532.5 211.18712081942877 17 | 16000000 17356.161630392075 402.5 362.5 84.63155439905378 18 | 17000000 18419.353372573853 433.5 405.0 148.38107845828742 19 | 18000000 19517.881631851196 572.5 587.5 170.18372425117508 20 | 19000000 20565.545799016953 504.0 392.5 223.69125547901461 21 | 20000000 21651.335790395737 603.0 622.5 203.74548393086462 22 | 21000000 22765.650399684906 834.5 835.0 298.63997276542426 23 | 22000000 23780.1443631649 579.0 585.0 249.15189476836548 24 | 23000000 24812.381363153458 553.5 535.0 160.07029705726168 25 | 24000000 25943.478595733643 805.5 820.0 266.4420929374503 26 | 25000000 27022.40669798851 819.0 800.0 356.03370626950476 27 | 26000000 28154.23360300064 775.0 547.5 532.2123636294068 28 | 27000000 29255.90105676651 613.5 572.5 229.7226588736949 29 | 28000000 30400.700807094574 646.5 482.5 406.20910050531035 30 | 29000000 31572.145109176636 772.0 632.5 334.54944826338203 31 | 30000000 32690.86220598221 945.0 972.5 311.4839036326311 32 | 31000000 33758.00896549225 830.5 602.5 442.72797015272886 33 | 32000000 34861.88159298897 757.0 602.5 344.4415770489968 34 | 33000000 35916.9765996933 985.0 950.0 397.50611455814254 35 | 34000000 37010.65085887909 962.0 950.0 381.9118571956973 36 | 35000000 38139.09030032158 1014.5 900.0 490.8521501769482 37 | 36000000 39226.673714637756 976.5 607.5 536.9515289525158 38 | 37000000 40325.832810640335 1031.0 820.0 501.2639579658162 39 | 38000000 41391.046285390854 788.0 680.0 318.21725632376103 40 | 39000000 42567.62489628792 1064.0 1017.5 455.73384045222423 41 | 40000000 43663.877554655075 952.5 862.5 410.55687114508805 42 | 41000000 44881.854682445526 1406.5 1165.0 795.8889160345599 43 | 42000000 45882.31720352173 891.0 622.5 492.4586841102059 44 | 43000000 47044.14443182945 1340.0 1127.5 765.2740975914737 45 | 44000000 48123.50830435753 1205.5 1235.0 508.775490761888 46 | 45000000 49227.012467861176 1044.5 935.0 557.8751851644078 47 | 46000000 50345.83370089531 1240.5 1217.5 541.9278857314258 48 | 47000000 51417.79047751427 1269.0 1165.0 778.919479507064 49 | 48000000 52607.29307818413 1698.0 1797.5 543.8198231032039 50 | 49000000 53671.907301187515 1598.5 1710.0 466.9347444296211 51 | 50000000 54812.021178245544 1863.5 1720.0 510.8873000313605 52 | 51000000 55892.94242429733 1501.0 1432.5 484.23364424854066 53 | 52000000 57041.53133392334 2016.0 1927.5 670.7657315834394 54 | 53000000 58073.72801113129 1643.0 1585.0 496.4887824625156 55 | 54000000 59165.362850666046 1382.0 1422.5 681.1395680116602 56 | 55000000 60335.68020105362 1565.0 1645.0 475.94350738530113 57 | 56000000 61488.63900256157 2096.0 2015.0 365.1468502153924 58 | 57000000 62582.88352417946 1996.0 2062.5 504.8751220736558 59 | 58000000 63649.29801917076 2022.0 1895.0 680.4051080871683 60 | 59000000 64736.771109342575 1568.0 1610.0 177.81388522209895 61 | 60000000 65917.98326063156 2057.5 2150.0 778.325303313324 62 | 61000000 66951.23457431793 2082.5 1997.5 567.936078171557 63 | 62000000 68049.10905337334 1737.0 2025.0 718.3592416054797 64 | 63000000 69085.28632044792 1959.0 1940.0 356.23650071889665 65 | 64000000 70181.8262295723 1822.5 1817.5 756.8144128883616 66 | 65000000 71294.90933918953 2032.0 2122.5 683.3585361205996 67 | 66000000 72485.66170525551 2507.0 2705.0 515.5051460029817 68 | 67000000 73494.54617547989 1775.0 1837.5 858.5064550330029 69 | 68000000 74578.25437927246 2038.0 2462.5 922.8434320078352 70 | 69000000 75629.58755373955 2085.0 2070.0 778.1209560358996 71 | 70000000 76678.21772170067 1960.5 1587.5 915.0818117645123 72 | 71000000 77790.33036136627 2420.5 2567.5 455.76827689322806 73 | 72000000 78800.69010519981 1789.0 1760.0 604.4639498486792 74 | 73000000 79914.5376329422 2436.0 2805.0 706.0327345259724 75 | 74000000 80936.40157341957 1924.5 1935.0 821.6191669840447 76 | 75000000 81944.25549578667 1962.0 2177.5 869.6238011666628 77 | 76000000 83064.10873103142 2335.5 2555.0 695.3314557341221 78 | 77000000 84099.02462172508 2615.0 2905.0 847.9648839689321 79 | 78000000 85200.52754330635 2610.5 2797.5 532.5382093751061 80 | 79000000 86191.62391448021 1999.5 2075.0 919.0981388779377 81 | 80000000 87095.2680015564 2278.0 2462.5 637.1908836622055 82 | -------------------------------------------------------------------------------- /trained_model/space_invaders_lstm/scores.txt.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/muupan/async-rl/5c5844fe7d62055ef911c7abfa6200fb30f6db84/trained_model/space_invaders_lstm/scores.txt.png -------------------------------------------------------------------------------- /v_function.py: -------------------------------------------------------------------------------- 1 | import chainer 2 | from chainer import functions as F 3 | from chainer import links as L 4 | 5 | 6 | class VFunction(object): 7 | pass 8 | 9 | 10 | class FCVFunction(chainer.ChainList, VFunction): 11 | 12 | def __init__(self, n_input_channels, n_hidden_layers=0, 13 | n_hidden_channels=None): 14 | self.n_input_channels = n_input_channels 15 | self.n_hidden_layers = n_hidden_layers 16 | self.n_hidden_channels = n_hidden_channels 17 | 18 | layers = [] 19 | if n_hidden_layers > 0: 20 | layers.append(L.Linear(n_input_channels, n_hidden_channels)) 21 | for i in range(n_hidden_layers - 1): 22 | layers.append(L.Linear(n_hidden_channels, n_hidden_channels)) 23 | layers.append(L.Linear(n_hidden_channels, 1)) 24 | else: 25 | layers.append(L.Linear(n_input_channels, 1)) 26 | 27 | super(FCVFunction, self).__init__(*layers) 28 | 29 | def __call__(self, state): 30 | h = state 31 | for layer in self[:-1]: 32 | h = F.relu(layer(h)) 33 | h = self[-1](h) 34 | return h 35 | --------------------------------------------------------------------------------