├── .gitignore ├── Pong-v0_monitor ├── openaigym.episode_batch.0.3438.stats.json ├── openaigym.manifest.0.3438.manifest.json ├── openaigym.video.0.3438.video000000.meta.json ├── openaigym.video.0.3438.video000000.mp4 ├── openaigym.video.0.3438.video000001.meta.json └── openaigym.video.0.3438.video000001.mp4 ├── README.md ├── checkpoints └── Pong-v0.model ├── config.json ├── environment.py ├── gym_eval.py ├── logs ├── Pong-v0_log └── Pong-v0_mon_log ├── main.py ├── model.py ├── player_util.py ├── run_play.sh ├── run_train.sh ├── shared_optim.py ├── test.py ├── train.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | .idea/ 3 | -------------------------------------------------------------------------------- /Pong-v0_monitor/openaigym.episode_batch.0.3438.stats.json: -------------------------------------------------------------------------------- 1 | {"initial_reset_timestamp": 1501897257.560941, "timestamps": [1501897419.795865, 1501897524.366334, 1501897667.15831], "episode_lengths": [9679, 6266, 8561], "episode_rewards": [-1.0, 11.0, -6.0], "episode_types": ["t", "t", "t", "t"]} -------------------------------------------------------------------------------- /Pong-v0_monitor/openaigym.manifest.0.3438.manifest.json: -------------------------------------------------------------------------------- 1 | {"stats": "openaigym.episode_batch.0.3438.stats.json", "videos": [["openaigym.video.0.3438.video000000.mp4", "openaigym.video.0.3438.video000000.meta.json"], ["openaigym.video.0.3438.video000001.mp4", "openaigym.video.0.3438.video000001.meta.json"]], "env_info": {"gym_version": "0.9.2", "env_id": "Pong-v0"}} -------------------------------------------------------------------------------- /Pong-v0_monitor/openaigym.video.0.3438.video000000.meta.json: -------------------------------------------------------------------------------- 1 | {"episode_id": 0, "content_type": "video/mp4", "encoder_version": {"backend": "ffmpeg", "version": "b'ffmpeg version 3.2.4 Copyright (c) 2000-2017 the FFmpeg developers\\nbuilt with Apple LLVM version 8.0.0 (clang-800.0.42.1)\\nconfiguration: --prefix=/usr/local/Cellar/ffmpeg/3.2.4 --enable-shared --enable-pthreads --enable-gpl --enable-version3 --enable-hardcoded-tables --enable-avresample --cc=clang --host-cflags= --host-ldflags= --enable-libmp3lame --enable-libx264 --enable-libxvid --enable-opencl --disable-lzma --enable-vda\\nlibavutil 55. 34.101 / 55. 34.101\\nlibavcodec 57. 64.101 / 57. 64.101\\nlibavformat 57. 56.101 / 57. 56.101\\nlibavdevice 57. 1.100 / 57. 1.100\\nlibavfilter 6. 65.100 / 6. 65.100\\nlibavresample 3. 1. 0 / 3. 1. 0\\nlibswscale 4. 2.100 / 4. 2.100\\nlibswresample 2. 3.100 / 2. 3.100\\nlibpostproc 54. 1.100 / 54. 1.100\\n'", "cmdline": ["ffmpeg", "-nostats", "-loglevel", "error", "-y", "-r", "30", "-f", "rawvideo", "-s:v", "160x210", "-pix_fmt", "rgb24", "-i", "-", "-vcodec", "libx264", "-pix_fmt", "yuv420p", "/Volumes/xs/CodeSpace/AISpace/rl_space/rl_atari_pytorch/Pong-v0_monitor/openaigym.video.0.3438.video000000.mp4"]}} -------------------------------------------------------------------------------- /Pong-v0_monitor/openaigym.video.0.3438.video000000.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lineCode/rl_atari_pytorch/48d19f414ec1641cd425679c034273af9d5e7199/Pong-v0_monitor/openaigym.video.0.3438.video000000.mp4 -------------------------------------------------------------------------------- /Pong-v0_monitor/openaigym.video.0.3438.video000001.meta.json: -------------------------------------------------------------------------------- 1 | {"episode_id": 1, "content_type": "video/mp4", "encoder_version": {"backend": "ffmpeg", "version": "b'ffmpeg version 3.2.4 Copyright (c) 2000-2017 the FFmpeg developers\\nbuilt with Apple LLVM version 8.0.0 (clang-800.0.42.1)\\nconfiguration: --prefix=/usr/local/Cellar/ffmpeg/3.2.4 --enable-shared --enable-pthreads --enable-gpl --enable-version3 --enable-hardcoded-tables --enable-avresample --cc=clang --host-cflags= --host-ldflags= --enable-libmp3lame --enable-libx264 --enable-libxvid --enable-opencl --disable-lzma --enable-vda\\nlibavutil 55. 34.101 / 55. 34.101\\nlibavcodec 57. 64.101 / 57. 64.101\\nlibavformat 57. 56.101 / 57. 56.101\\nlibavdevice 57. 1.100 / 57. 1.100\\nlibavfilter 6. 65.100 / 6. 65.100\\nlibavresample 3. 1. 0 / 3. 1. 0\\nlibswscale 4. 2.100 / 4. 2.100\\nlibswresample 2. 3.100 / 2. 3.100\\nlibpostproc 54. 1.100 / 54. 1.100\\n'", "cmdline": ["ffmpeg", "-nostats", "-loglevel", "error", "-y", "-r", "30", "-f", "rawvideo", "-s:v", "160x210", "-pix_fmt", "rgb24", "-i", "-", "-vcodec", "libx264", "-pix_fmt", "yuv420p", "/Volumes/xs/CodeSpace/AISpace/rl_space/rl_atari_pytorch/Pong-v0_monitor/openaigym.video.0.3438.video000001.mp4"]}} -------------------------------------------------------------------------------- /Pong-v0_monitor/openaigym.video.0.3438.video000001.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lineCode/rl_atari_pytorch/48d19f414ec1641cd425679c034273af9d5e7199/Pong-v0_monitor/openaigym.video.0.3438.video000001.mp4 -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Atari Pytorch 2 | 3 | > A atari AI Player implement by pytorch play games 4 | 5 | ## Synopsis 6 | 7 | Reinforcement learning shows the most potential of AI in many area, however, to use reinforcement learning you must specific your environment which is somethings hard to build a environment for your problem. But gym let us has a very convenient way to explore rl algorithms. 8 | 9 | So here it is, using DDPG and LSTM to play atari, **and it is really effective!!**, as I can show in Pong-V0-moniter you can find the play progress in mp4. Our AI can really beat computer!! 10 | 11 | ## How to Play With 12 | 13 | OK, to play with it, simply run: 14 | ``` 15 | ./run_train.sh 16 | ``` 17 | This will train on Pong-V0 env, and save your model into `checkpoints/`. If you interrupted, next time it will continue train on last saved model. 18 | 19 | And, to play with your model, simply run: 20 | ``` 21 | ./run_play.sh 22 | ``` 23 | 24 | You can change env in .sh command, many atari env are supported. 25 | 26 | ## Future 27 | 28 | This is a very good exploration but not the end, later on I will explore on reinforcement learning on autonamous-car driving problem and train a AI to fucking drive!! 29 | 30 | ## Contribute 31 | 32 | Well, very welcome to send PR to add more game env train models to this repo!! 33 | If you have any question about this you can find me via wechat: `jintianiloveu`. 34 | -------------------------------------------------------------------------------- /checkpoints/Pong-v0.model: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lineCode/rl_atari_pytorch/48d19f414ec1641cd425679c034273af9d5e7199/checkpoints/Pong-v0.model -------------------------------------------------------------------------------- /config.json: -------------------------------------------------------------------------------- 1 | { 2 | "Default": { 3 | "crop1": 34, 4 | "crop2": 34, 5 | "dimension2": 80 6 | }, 7 | "Asteroids": { 8 | "crop1": 16, 9 | "crop2": 34, 10 | "dimension2": 94 11 | }, 12 | "BeamRider": { 13 | "crop1": 20, 14 | "crop2": 20, 15 | "dimension2": 80 16 | }, 17 | "Breakout": { 18 | "crop1": 34, 19 | "crop2": 34, 20 | "dimension2": 80 21 | }, 22 | "Centipede": { 23 | "crop1": 36, 24 | "crop2": 56, 25 | "dimension2": 90 26 | }, 27 | "MsPacman": { 28 | "crop1": 2, 29 | "crop2": 10, 30 | "dimension2": 84 31 | }, 32 | "Pong": { 33 | "crop1": 34, 34 | "crop2": 34, 35 | "dimension2": 80 36 | }, 37 | "Seaquest": { 38 | "crop1": 30, 39 | "crop2": 30, 40 | "dimension2": 80 41 | }, 42 | "SpaceInvaders": { 43 | "crop1": 8, 44 | "crop2": 36, 45 | "dimension2": 94 46 | }, 47 | "VideoPinball": { 48 | "crop1": 42, 49 | "crop2": 60, 50 | "dimension2": 89 51 | } 52 | } 53 | -------------------------------------------------------------------------------- /environment.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | import gym 3 | import numpy as np 4 | from gym.spaces.box import Box 5 | from universe import vectorized 6 | from universe.wrappers import Unvectorize, Vectorize 7 | from skimage.color import rgb2gray 8 | import cv2 9 | 10 | 11 | def atari_env(env_id, env_conf): 12 | env = gym.make(env_id) 13 | if len(env.observation_space.shape) > 1: 14 | env = Vectorize(env) 15 | env = AtariRescale(env, env_conf) 16 | env = NormalizedEnv(env) 17 | env = Unvectorize(env) 18 | return env 19 | 20 | 21 | def _process_frame(frame, conf): 22 | frame = frame[conf["crop1"]:conf["crop2"] + 160, :160] 23 | frame = cv2.resize(rgb2gray(frame), (80, conf["dimension2"])) 24 | frame = cv2.resize(frame, (80, 80)) 25 | frame = np.reshape(frame, [1, 80, 80]) 26 | return frame 27 | 28 | 29 | class AtariRescale(vectorized.ObservationWrapper): 30 | def __init__(self, env, env_conf): 31 | super(AtariRescale, self).__init__(env) 32 | self.observation_space = Box(0.0, 1.0, [1, 80, 80]) 33 | self.conf = env_conf 34 | 35 | def _observation(self, observation_n): 36 | return [ 37 | _process_frame(observation, self.conf) 38 | for observation in observation_n 39 | ] 40 | 41 | 42 | class NormalizedEnv(vectorized.ObservationWrapper): 43 | def __init__(self, env=None): 44 | super(NormalizedEnv, self).__init__(env) 45 | self.state_mean = 0 46 | self.state_std = 0 47 | self.alpha = 0.9999 48 | self.num_steps = 0 49 | 50 | def _observation(self, observation_n): 51 | for observation in observation_n: 52 | self.num_steps += 1 53 | self.state_mean = self.state_mean * self.alpha + \ 54 | observation.mean() * (1 - self.alpha) 55 | self.state_std = self.state_std * self.alpha + \ 56 | observation.std() * (1 - self.alpha) 57 | 58 | unbiased_mean = self.state_mean / (1 - pow(self.alpha, self.num_steps)) 59 | unbiased_std = self.state_std / (1 - pow(self.alpha, self.num_steps)) 60 | 61 | return [(observation - unbiased_mean) / (unbiased_std + 1e-8) 62 | for observation in observation_n] 63 | -------------------------------------------------------------------------------- /gym_eval.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | import os 3 | os.environ["OMP_NUM_THREADS"] = "1" 4 | import argparse 5 | import torch 6 | from environment import atari_env 7 | from utils import read_config, setup_logger 8 | from model import A3Clstm 9 | from player_util import Agent, player_act, player_start 10 | from torch.autograd import Variable 11 | import gym 12 | import logging 13 | 14 | parser = argparse.ArgumentParser(description='A3C_EVAL') 15 | parser.add_argument( 16 | '--env', 17 | default='Pong-v0', 18 | metavar='ENV', 19 | help='environment to train on (default: Pong-v0)') 20 | parser.add_argument( 21 | '--env-config', 22 | default='config.json', 23 | metavar='EC', 24 | help='environment to crop and resize info (default: config.json)') 25 | parser.add_argument( 26 | '--num-episodes', 27 | type=int, 28 | default=100, 29 | metavar='NE', 30 | help='how many episodes in evaluation (default: 100)') 31 | parser.add_argument( 32 | '--load-model-dir', 33 | default='checkpoints/', 34 | metavar='LMD', 35 | help='folder to load trained models from') 36 | parser.add_argument( 37 | '--log-dir', 38 | default='logs/', 39 | metavar='LG', 40 | help='folder to save logs') 41 | parser.add_argument( 42 | '--render', 43 | default=True, 44 | metavar='R', 45 | help='Watch game as it being played') 46 | parser.add_argument( 47 | '--render-freq', 48 | type=int, 49 | default=1, 50 | metavar='RF', 51 | help='Frequency to watch rendered game play') 52 | parser.add_argument( 53 | '--max-episode-length', 54 | type=int, 55 | default=100000, 56 | metavar='M', 57 | help='maximum length of an episode (default: 100000)') 58 | args = parser.parse_args() 59 | 60 | setup_json = read_config(args.env_config) 61 | env_conf = setup_json["Default"] 62 | for i in setup_json.keys(): 63 | if i in args.env: 64 | env_conf = setup_json[i] 65 | torch.set_default_tensor_type('torch.FloatTensor') 66 | 67 | saved_state_path = os.path.join(args.load_model_dir, args.env + '.model') 68 | saved_state = torch.load(saved_state_path, map_location=lambda storage, loc: storage) 69 | print('Loaded trained model from: {}'.format(saved_state_path)) 70 | 71 | log = {} 72 | setup_logger('{}_mon_log'.format(args.env), r'{0}{1}_mon_log'.format( 73 | args.log_dir, args.env)) 74 | log['{}_mon_log'.format(args.env)] = logging.getLogger( 75 | '{}_mon_log'.format(args.env)) 76 | 77 | env = atari_env("{}".format(args.env), env_conf) 78 | model = A3Clstm(env.observation_space.shape[0], env.action_space) 79 | 80 | num_tests = 0 81 | reward_total_sum = 0 82 | player = Agent(model, env, args, state=None) 83 | player.env = gym.wrappers.Monitor(player.env, "{}_monitor".format(args.env), force=True) 84 | player.model.eval() 85 | for i_episode in range(args.num_episodes): 86 | state = player.env.reset() 87 | player.state = torch.from_numpy(state).float() 88 | player.eps_len = 0 89 | reward_sum = 0 90 | while True: 91 | if args.render: 92 | if i_episode % args.render_freq == 0: 93 | player.env.render() 94 | if player.done: 95 | player.model.load_state_dict(saved_state) 96 | player.cx = Variable(torch.zeros(1, 512), volatile=True) 97 | player.hx = Variable(torch.zeros(1, 512), volatile=True) 98 | if player.starter: 99 | player = player_start(player, train=False) 100 | else: 101 | player.cx = Variable(player.cx.data, volatile=True) 102 | player.hx = Variable(player.hx.data, volatile=True) 103 | 104 | player, reward = player_act(player, train=False) 105 | reward_sum += reward 106 | 107 | if not player.done: 108 | if player.current_life > player.info['ale.lives']: 109 | player.flag = True 110 | player.current_life = player.info['ale.lives'] 111 | else: 112 | player.current_life = player.info['ale.lives'] 113 | player.flag = False 114 | if player.starter and player.flag: 115 | player = player_start(player, train=False) 116 | 117 | if player.done: 118 | num_tests += 1 119 | reward_total_sum += reward_sum 120 | reward_mean = reward_total_sum / num_tests 121 | log['{}_mon_log'.format(args.env)].info( 122 | "reward sum: {0}, reward mean: {1:.4f}".format( 123 | reward_sum, reward_mean)) 124 | 125 | break 126 | -------------------------------------------------------------------------------- /logs/Pong-v0_log: -------------------------------------------------------------------------------- 1 | 2017-08-04 22:16:30,378 : lr: 0.0001 2 | 2017-08-04 22:16:30,379 : gamma: 0.99 3 | 2017-08-04 22:16:30,380 : tau: 1.0 4 | 2017-08-04 22:16:30,380 : seed: 1 5 | 2017-08-04 22:16:30,380 : workers: 4 6 | 2017-08-04 22:16:30,380 : num_steps: 20 7 | 2017-08-04 22:16:30,381 : max_episode_length: 10000 8 | 2017-08-04 22:16:30,381 : env: Pong-v0 9 | 2017-08-04 22:16:30,381 : env_config: config.json 10 | 2017-08-04 22:16:30,382 : shared_optimizer: True 11 | 2017-08-04 22:16:30,382 : load: True 12 | 2017-08-04 22:16:30,382 : save_score_level: 20 13 | 2017-08-04 22:16:30,383 : optimizer: Adam 14 | 2017-08-04 22:16:30,383 : count_lives: False 15 | 2017-08-04 22:16:30,383 : load_model_dir: checkpoints/ 16 | 2017-08-04 22:16:30,384 : save_model_dir: checkpoints/ 17 | 2017-08-04 22:16:30,384 : log_dir: logs/ 18 | 2017-08-04 22:17:13,467 : Time 00h 00m 42s, episode reward -19.0, episode length 1328, reward mean -19.0000 19 | 2017-08-04 22:18:55,497 : Time 00h 02m 24s, episode reward -20.0, episode length 1246, reward mean -19.5000 20 | 2017-08-04 22:20:34,143 : Time 00h 04m 03s, episode reward -20.0, episode length 1153, reward mean -19.6667 21 | 2017-08-04 22:22:11,553 : Time 00h 05m 40s, episode reward -20.0, episode length 1120, reward mean -19.7500 22 | 2017-08-04 22:23:48,687 : Time 00h 07m 17s, episode reward -21.0, episode length 1109, reward mean -20.0000 23 | 2017-08-04 22:25:23,346 : Time 00h 08m 52s, episode reward -21.0, episode length 1045, reward mean -20.1667 24 | 2017-08-04 22:27:00,540 : Time 00h 10m 29s, episode reward -20.0, episode length 1116, reward mean -20.1429 25 | 2017-08-04 22:28:49,386 : Time 00h 12m 18s, episode reward -19.0, episode length 1463, reward mean -20.0000 26 | 2017-08-04 22:30:36,282 : Time 00h 14m 05s, episode reward -20.0, episode length 1408, reward mean -20.0000 27 | 2017-08-04 22:32:22,145 : Time 00h 15m 51s, episode reward -19.0, episode length 1373, reward mean -19.9000 28 | 2017-08-04 22:34:25,169 : Time 00h 17m 54s, episode reward -17.0, episode length 1902, reward mean -19.6364 29 | 2017-08-04 22:36:53,405 : Time 00h 20m 22s, episode reward -17.0, episode length 2635, reward mean -19.4167 30 | 2017-08-04 22:38:30,021 : Time 00h 21m 59s, episode reward -20.0, episode length 1112, reward mean -19.4615 31 | 2017-08-04 22:40:21,303 : Time 00h 23m 50s, episode reward -18.0, episode length 1567, reward mean -19.3571 32 | 2017-08-04 22:42:30,722 : Time 00h 25m 59s, episode reward -17.0, episode length 2100, reward mean -19.2000 33 | 2017-08-04 22:44:33,288 : Time 00h 28m 02s, episode reward -17.0, episode length 1906, reward mean -19.0625 34 | 2017-08-04 22:46:14,205 : Time 00h 29m 43s, episode reward -20.0, episode length 1244, reward mean -19.1176 35 | 2017-08-04 22:48:49,481 : Time 00h 32m 18s, episode reward -14.0, episode length 2911, reward mean -18.8333 36 | 2017-08-04 22:50:23,561 : Time 00h 33m 52s, episode reward -21.0, episode length 1047, reward mean -18.9474 37 | 2017-08-04 22:51:58,133 : Time 00h 35m 27s, episode reward -21.0, episode length 1057, reward mean -19.0500 38 | 2017-08-04 22:53:42,657 : Time 00h 37m 11s, episode reward -21.0, episode length 1364, reward mean -19.1429 39 | 2017-08-04 22:55:31,743 : Time 00h 39m 01s, episode reward -19.0, episode length 1507, reward mean -19.1364 40 | 2017-08-04 22:58:03,003 : Time 00h 41m 32s, episode reward -16.0, episode length 2766, reward mean -19.0000 41 | 2017-08-04 23:00:23,326 : Time 00h 43m 52s, episode reward -16.0, episode length 2435, reward mean -18.8750 42 | 2017-08-04 23:02:17,178 : Time 00h 45m 46s, episode reward -19.0, episode length 1635, reward mean -18.8800 43 | 2017-08-04 23:05:12,007 : Time 00h 48m 41s, episode reward -6.0, episode length 3527, reward mean -18.3846 44 | 2017-08-04 23:08:23,186 : Time 00h 51m 52s, episode reward 3.0, episode length 4028, reward mean -17.5926 45 | 2017-08-04 23:10:59,810 : Time 00h 54m 29s, episode reward -11.0, episode length 2955, reward mean -17.3571 46 | 2017-08-04 23:13:44,007 : Time 00h 57m 13s, episode reward -10.0, episode length 3197, reward mean -17.1034 47 | 2017-08-04 23:16:09,393 : Time 00h 59m 38s, episode reward -16.0, episode length 2607, reward mean -17.0667 48 | 2017-08-04 23:18:59,264 : Time 01h 02m 28s, episode reward -8.0, episode length 3359, reward mean -16.7742 49 | 2017-08-04 23:22:17,316 : Time 01h 05m 46s, episode reward 1.0, episode length 4219, reward mean -16.2188 50 | 2017-08-04 23:25:17,142 : Time 01h 08m 46s, episode reward -6.0, episode length 3691, reward mean -15.9091 51 | 2017-08-04 23:27:59,966 : Time 01h 11m 29s, episode reward -10.0, episode length 3150, reward mean -15.7353 52 | 2017-08-04 23:30:52,462 : Time 01h 14m 21s, episode reward -7.0, episode length 3435, reward mean -15.4857 53 | 2017-08-04 23:33:18,452 : Time 01h 16m 47s, episode reward -17.0, episode length 2617, reward mean -15.5278 54 | 2017-08-04 23:36:02,643 : Time 01h 19m 31s, episode reward -11.0, episode length 3187, reward mean -15.4054 55 | 2017-08-04 23:38:40,216 : Time 01h 22m 09s, episode reward -10.0, episode length 2985, reward mean -15.2632 56 | 2017-08-04 23:41:22,135 : Time 01h 24m 51s, episode reward -15.0, episode length 3131, reward mean -15.2564 57 | 2017-08-04 23:43:55,587 : Time 01h 27m 24s, episode reward -16.0, episode length 2867, reward mean -15.2750 58 | 2017-08-04 23:46:38,680 : Time 01h 30m 07s, episode reward -11.0, episode length 3156, reward mean -15.1707 59 | 2017-08-04 23:49:23,909 : Time 01h 32m 53s, episode reward -10.0, episode length 3221, reward mean -15.0476 60 | 2017-08-04 23:51:46,405 : Time 01h 35m 15s, episode reward -16.0, episode length 2506, reward mean -15.0698 61 | 2017-08-04 23:54:32,115 : Time 01h 38m 01s, episode reward -12.0, episode length 3239, reward mean -15.0000 62 | 2017-08-04 23:57:15,573 : Time 01h 40m 44s, episode reward -12.0, episode length 3152, reward mean -14.9333 63 | 2017-08-05 00:00:22,136 : Time 01h 43m 51s, episode reward -5.0, episode length 3881, reward mean -14.7174 64 | 2017-08-05 00:03:18,217 : Time 01h 46m 47s, episode reward -7.0, episode length 3554, reward mean -14.5532 65 | 2017-08-05 00:06:22,353 : Time 01h 49m 51s, episode reward -7.0, episode length 3780, reward mean -14.3958 66 | 2017-08-05 00:09:14,818 : Time 01h 52m 44s, episode reward -10.0, episode length 3427, reward mean -14.3061 67 | 2017-08-05 00:12:10,796 : Time 01h 55m 40s, episode reward -10.0, episode length 3561, reward mean -14.2200 68 | 2017-08-05 00:14:48,435 : Time 01h 58m 17s, episode reward -15.0, episode length 3012, reward mean -14.2353 69 | 2017-08-05 00:17:26,465 : Time 02h 00m 55s, episode reward -13.0, episode length 2985, reward mean -14.2115 70 | 2017-08-05 00:20:28,180 : Time 02h 03m 57s, episode reward -11.0, episode length 3729, reward mean -14.1509 71 | 2017-08-05 00:23:12,182 : Time 02h 06m 41s, episode reward -15.0, episode length 3195, reward mean -14.1667 72 | 2017-08-05 00:26:01,467 : Time 02h 09m 30s, episode reward -13.0, episode length 3352, reward mean -14.1455 73 | 2017-08-05 00:28:48,939 : Time 02h 12m 18s, episode reward -12.0, episode length 3276, reward mean -14.1071 74 | 2017-08-05 00:31:32,845 : Time 02h 15m 02s, episode reward -13.0, episode length 3178, reward mean -14.0877 75 | 2017-08-05 00:34:23,542 : Time 02h 17m 52s, episode reward -11.0, episode length 3408, reward mean -14.0345 76 | 2017-08-05 00:37:05,093 : Time 02h 20m 34s, episode reward -10.0, episode length 3094, reward mean -13.9661 77 | 2017-08-05 00:40:11,115 : Time 02h 23m 40s, episode reward -11.0, episode length 3848, reward mean -13.9167 78 | 2017-08-05 00:43:21,023 : Time 02h 26m 50s, episode reward -10.0, episode length 3963, reward mean -13.8525 79 | 2017-08-05 00:46:20,667 : Time 02h 29m 49s, episode reward -11.0, episode length 3660, reward mean -13.8065 80 | 2017-08-05 00:48:56,210 : Time 02h 32m 25s, episode reward -13.0, episode length 2928, reward mean -13.7937 81 | 2017-08-05 00:51:33,898 : Time 02h 35m 03s, episode reward -14.0, episode length 3011, reward mean -13.7969 82 | 2017-08-05 00:54:09,268 : Time 02h 37m 38s, episode reward -15.0, episode length 2925, reward mean -13.8154 83 | 2017-08-05 00:56:45,414 : Time 02h 40m 14s, episode reward -17.0, episode length 2930, reward mean -13.8636 84 | 2017-08-05 00:59:39,650 : Time 02h 43m 08s, episode reward -12.0, episode length 3482, reward mean -13.8358 85 | 2017-08-05 01:02:30,749 : Time 02h 46m 00s, episode reward -15.0, episode length 3398, reward mean -13.8529 86 | 2017-08-05 01:05:16,680 : Time 02h 48m 45s, episode reward -14.0, episode length 3264, reward mean -13.8551 87 | 2017-08-05 01:06:50,189 : Time 02h 50m 19s, episode reward -21.0, episode length 1019, reward mean -13.9571 88 | 2017-08-05 01:08:53,583 : Time 02h 52m 22s, episode reward -20.0, episode length 1920, reward mean -14.0423 89 | 2017-08-05 01:11:58,258 : Time 02h 55m 27s, episode reward -10.0, episode length 3780, reward mean -13.9861 90 | 2017-08-05 01:14:51,324 : Time 02h 58m 20s, episode reward -16.0, episode length 3447, reward mean -14.0137 91 | 2017-08-05 01:19:05,539 : Time 03h 02m 34s, episode reward -4.0, episode length 5929, reward mean -13.8784 92 | 2017-08-05 01:22:09,102 : Time 03h 05m 38s, episode reward -9.0, episode length 3783, reward mean -13.8133 93 | 2017-08-05 01:25:29,813 : Time 03h 08m 59s, episode reward -10.0, episode length 4328, reward mean -13.7632 94 | 2017-08-05 01:28:42,407 : Time 03h 12m 11s, episode reward -14.0, episode length 4052, reward mean -13.7662 95 | 2017-08-05 01:32:30,236 : Time 03h 15m 59s, episode reward -4.0, episode length 5125, reward mean -13.6410 96 | 2017-08-05 01:36:15,706 : Time 03h 19m 44s, episode reward -11.0, episode length 5067, reward mean -13.6076 97 | 2017-08-05 01:40:42,787 : Time 03h 24m 12s, episode reward -4.0, episode length 6357, reward mean -13.4875 98 | 2017-08-05 01:45:02,533 : Time 03h 28m 31s, episode reward -10.0, episode length 6096, reward mean -13.4444 99 | 2017-08-05 01:49:25,798 : Time 03h 32m 55s, episode reward -8.0, episode length 6225, reward mean -13.3780 100 | 2017-08-05 01:53:34,599 : Time 03h 37m 03s, episode reward -6.0, episode length 5782, reward mean -13.2892 101 | 2017-08-05 01:56:42,718 : Time 03h 40m 11s, episode reward -9.0, episode length 3930, reward mean -13.2381 102 | 2017-08-05 02:00:18,960 : Time 03h 43m 48s, episode reward -7.0, episode length 4764, reward mean -13.1647 103 | 2017-08-05 02:05:09,321 : Time 03h 48m 38s, episode reward -4.0, episode length 7067, reward mean -13.0581 104 | 2017-08-05 02:09:24,925 : Time 03h 52m 54s, episode reward -8.0, episode length 6003, reward mean -13.0000 105 | 2017-08-05 02:12:48,831 : Time 03h 56m 18s, episode reward -11.0, episode length 4399, reward mean -12.9773 106 | 2017-08-05 02:17:10,378 : Time 04h 00m 39s, episode reward -5.0, episode length 6171, reward mean -12.8876 107 | 2017-08-05 02:22:16,744 : Time 04h 05m 46s, episode reward -3.0, episode length 7533, reward mean -12.7778 108 | 2017-08-05 02:25:46,461 : Time 04h 09m 15s, episode reward -11.0, episode length 4572, reward mean -12.7582 109 | 2017-08-05 02:29:52,744 : Time 04h 13m 22s, episode reward -10.0, episode length 5665, reward mean -12.7283 110 | 2017-08-05 02:34:54,984 : Time 04h 18m 24s, episode reward -8.0, episode length 7371, reward mean -12.6774 111 | 2017-08-05 02:39:48,979 : Time 04h 23m 18s, episode reward -2.0, episode length 7128, reward mean -12.5638 112 | 2017-08-05 02:44:32,198 : Time 04h 28m 01s, episode reward -3.0, episode length 6854, reward mean -12.4632 113 | 2017-08-05 02:48:23,915 : Time 04h 31m 53s, episode reward -9.0, episode length 5263, reward mean -12.4271 114 | 2017-08-05 02:52:21,670 : Time 04h 35m 50s, episode reward -5.0, episode length 5458, reward mean -12.3505 115 | 2017-08-05 02:55:22,708 : Time 04h 38m 51s, episode reward -15.0, episode length 3692, reward mean -12.3776 116 | 2017-08-05 02:59:08,839 : Time 04h 42m 38s, episode reward -8.0, episode length 5083, reward mean -12.3333 117 | 2017-08-05 03:02:52,292 : Time 04h 46m 21s, episode reward -8.0, episode length 4972, reward mean -12.2900 118 | 2017-08-05 03:06:31,009 : Time 04h 50m 00s, episode reward -10.0, episode length 4801, reward mean -12.2673 119 | 2017-08-05 03:10:08,417 : Time 04h 53m 37s, episode reward -8.0, episode length 4802, reward mean -12.2255 120 | 2017-08-05 03:13:23,989 : Time 04h 56m 53s, episode reward -9.0, episode length 4113, reward mean -12.1942 121 | 2017-08-05 03:16:17,974 : Time 04h 59m 47s, episode reward -15.0, episode length 3474, reward mean -12.2212 122 | 2017-08-05 03:19:29,056 : Time 05h 02m 58s, episode reward -11.0, episode length 4009, reward mean -12.2095 123 | 2017-08-05 03:23:21,472 : Time 05h 06m 50s, episode reward -12.0, episode length 5269, reward mean -12.2075 124 | 2017-08-05 03:26:50,186 : Time 05h 10m 19s, episode reward -9.0, episode length 4525, reward mean -12.1776 125 | 2017-08-05 03:30:27,750 : Time 05h 13m 57s, episode reward -9.0, episode length 4781, reward mean -12.1481 126 | 2017-08-05 03:34:24,580 : Time 05h 17m 53s, episode reward -8.0, episode length 5419, reward mean -12.1101 127 | 2017-08-05 03:38:46,427 : Time 05h 22m 15s, episode reward -6.0, episode length 6165, reward mean -12.0545 128 | 2017-08-05 03:43:26,437 : Time 05h 26m 55s, episode reward 1.0, episode length 6703, reward mean -11.9369 129 | 2017-08-05 03:46:56,593 : Time 05h 30m 25s, episode reward -13.0, episode length 4600, reward mean -11.9464 130 | 2017-08-05 03:51:38,240 : Time 05h 35m 07s, episode reward -5.0, episode length 6779, reward mean -11.8850 131 | 2017-08-05 03:56:04,307 : Time 05h 39m 33s, episode reward -8.0, episode length 6263, reward mean -11.8509 132 | 2017-08-05 04:00:01,409 : Time 05h 43m 30s, episode reward -8.0, episode length 5404, reward mean -11.8174 133 | 2017-08-05 04:04:06,379 : Time 05h 47m 35s, episode reward -10.0, episode length 5613, reward mean -11.8017 134 | 2017-08-05 04:08:28,298 : Time 05h 51m 57s, episode reward -6.0, episode length 6110, reward mean -11.7521 135 | 2017-08-05 04:13:02,877 : Time 05h 56m 32s, episode reward -2.0, episode length 6484, reward mean -11.6695 136 | 2017-08-05 04:16:52,429 : Time 06h 00m 21s, episode reward -11.0, episode length 5175, reward mean -11.6639 137 | 2017-08-05 04:21:02,661 : Time 06h 04m 31s, episode reward -5.0, episode length 5802, reward mean -11.6083 138 | 2017-08-05 04:25:41,400 : Time 06h 09m 10s, episode reward -5.0, episode length 6645, reward mean -11.5537 139 | 2017-08-05 04:30:04,770 : Time 06h 13m 34s, episode reward -6.0, episode length 6168, reward mean -11.5082 140 | 2017-08-05 04:33:28,869 : Time 06h 16m 58s, episode reward -13.0, episode length 4382, reward mean -11.5203 141 | 2017-08-05 04:37:24,693 : Time 06h 20m 53s, episode reward -6.0, episode length 5359, reward mean -11.4758 142 | 2017-08-05 04:41:41,587 : Time 06h 25m 10s, episode reward -5.0, episode length 5952, reward mean -11.4240 143 | 2017-08-05 04:45:34,202 : Time 06h 29m 03s, episode reward -12.0, episode length 5226, reward mean -11.4286 144 | 2017-08-05 04:49:28,531 : Time 06h 32m 57s, episode reward -5.0, episode length 5295, reward mean -11.3780 145 | 2017-08-05 04:53:06,019 : Time 06h 36m 35s, episode reward -7.0, episode length 4762, reward mean -11.3438 146 | 2017-08-05 04:58:00,891 : Time 06h 41m 30s, episode reward 2.0, episode length 7133, reward mean -11.2403 147 | 2017-08-05 05:01:06,666 : Time 06h 44m 35s, episode reward -13.0, episode length 3818, reward mean -11.2538 148 | 2017-08-05 05:04:22,998 : Time 06h 47m 52s, episode reward -8.0, episode length 4143, reward mean -11.2290 149 | 2017-08-05 05:07:40,211 : Time 06h 51m 09s, episode reward -13.0, episode length 4145, reward mean -11.2424 150 | 2017-08-05 05:11:19,746 : Time 06h 54m 49s, episode reward -10.0, episode length 4824, reward mean -11.2331 151 | 2017-08-05 05:15:31,895 : Time 06h 59m 01s, episode reward -8.0, episode length 5825, reward mean -11.2090 152 | 2017-08-05 05:20:06,141 : Time 07h 03m 35s, episode reward 4.0, episode length 6501, reward mean -11.0963 153 | 2017-08-05 05:24:28,465 : Time 07h 07m 57s, episode reward -6.0, episode length 6141, reward mean -11.0588 154 | 2017-08-05 05:28:27,651 : Time 07h 11m 56s, episode reward -12.0, episode length 5445, reward mean -11.0657 155 | 2017-08-05 05:32:16,330 : Time 07h 15m 45s, episode reward -13.0, episode length 5085, reward mean -11.0797 156 | 2017-08-05 05:37:33,693 : Time 07h 21m 02s, episode reward 2.0, episode length 7786, reward mean -10.9856 157 | 2017-08-05 05:42:30,586 : Time 07h 25m 59s, episode reward 1.0, episode length 7190, reward mean -10.9000 158 | 2017-08-05 05:47:30,995 : Time 07h 31m 00s, episode reward -6.0, episode length 7245, reward mean -10.8652 159 | 2017-08-05 05:52:53,262 : Time 07h 36m 22s, episode reward 2.0, episode length 7972, reward mean -10.7746 160 | 2017-08-05 05:58:21,840 : Time 07h 41m 51s, episode reward 4.0, episode length 8112, reward mean -10.6713 161 | 2017-08-05 06:03:33,958 : Time 07h 47m 03s, episode reward -3.0, episode length 7626, reward mean -10.6181 162 | 2017-08-05 06:08:42,584 : Time 07h 52m 11s, episode reward -3.0, episode length 7531, reward mean -10.5655 163 | 2017-08-05 06:13:39,658 : Time 07h 57m 08s, episode reward -8.0, episode length 7160, reward mean -10.5479 164 | 2017-08-05 06:17:24,632 : Time 08h 00m 53s, episode reward -12.0, episode length 5000, reward mean -10.5578 165 | 2017-08-05 06:21:57,856 : Time 08h 05m 27s, episode reward -7.0, episode length 6426, reward mean -10.5338 166 | 2017-08-05 06:27:09,943 : Time 08h 10m 39s, episode reward -3.0, episode length 7624, reward mean -10.4832 167 | 2017-08-05 06:32:55,370 : Time 08h 16m 24s, episode reward 2.0, episode length 8635, reward mean -10.4000 168 | 2017-08-05 06:37:49,997 : Time 08h 21m 19s, episode reward -4.0, episode length 7127, reward mean -10.3576 169 | 2017-08-05 06:42:52,330 : Time 08h 26m 21s, episode reward -7.0, episode length 7331, reward mean -10.3355 170 | 2017-08-05 06:48:12,248 : Time 08h 31m 41s, episode reward -3.0, episode length 7883, reward mean -10.2876 171 | 2017-08-05 06:52:36,033 : Time 08h 36m 05s, episode reward -4.0, episode length 6188, reward mean -10.2468 172 | 2017-08-05 06:57:36,452 : Time 08h 41m 05s, episode reward 1.0, episode length 7314, reward mean -10.1742 173 | 2017-08-05 07:02:06,855 : Time 08h 45m 36s, episode reward -9.0, episode length 6380, reward mean -10.1667 174 | 2017-08-05 07:07:46,873 : Time 08h 51m 16s, episode reward -2.0, episode length 8516, reward mean -10.1146 175 | 2017-08-05 07:11:30,092 : Time 08h 54m 59s, episode reward -11.0, episode length 4929, reward mean -10.1203 176 | 2017-08-05 07:17:10,139 : Time 09h 00m 39s, episode reward 7.0, episode length 8417, reward mean -10.0126 177 | 2017-08-05 07:21:22,470 : Time 09h 04m 51s, episode reward 9.0, episode length 5833, reward mean -9.8938 178 | 2017-08-05 07:27:46,974 : Time 09h 11m 16s, episode reward -1.0, episode length 9790, reward mean -9.8385 179 | 2017-08-05 07:33:14,888 : Time 09h 16m 44s, episode reward -5.0, episode length 8072, reward mean -9.8086 180 | 2017-08-05 07:37:42,323 : Time 09h 21m 11s, episode reward -7.0, episode length 6283, reward mean -9.7914 181 | 2017-08-05 07:42:26,771 : Time 09h 25m 56s, episode reward -7.0, episode length 6810, reward mean -9.7744 182 | 2017-08-05 07:47:38,362 : Time 09h 31m 07s, episode reward -12.0, episode length 7598, reward mean -9.7879 183 | 2017-08-05 07:51:46,946 : Time 09h 35m 16s, episode reward -7.0, episode length 5663, reward mean -9.7711 184 | 2017-08-05 07:56:53,748 : Time 09h 40m 23s, episode reward -5.0, episode length 7447, reward mean -9.7425 185 | 2017-08-05 08:01:33,106 : Time 09h 45m 02s, episode reward -4.0, episode length 6599, reward mean -9.7083 186 | 2017-08-05 08:06:48,642 : Time 09h 50m 17s, episode reward 3.0, episode length 7708, reward mean -9.6331 187 | 2017-08-05 08:11:34,266 : Time 09h 55m 03s, episode reward -9.0, episode length 6809, reward mean -9.6294 188 | 2017-08-05 08:17:05,950 : Time 10h 00m 35s, episode reward -2.0, episode length 8125, reward mean -9.5848 189 | 2017-08-05 08:22:31,641 : Time 10h 06m 00s, episode reward 5.0, episode length 8004, reward mean -9.5000 190 | 2017-08-05 08:27:29,026 : Time 10h 10m 58s, episode reward -2.0, episode length 7176, reward mean -9.4566 191 | 2017-08-05 08:34:00,667 : Time 10h 17m 29s, episode reward 3.0, episode length 10000, reward mean -9.3851 192 | 2017-08-05 08:39:42,507 : Time 10h 23m 11s, episode reward -5.0, episode length 8457, reward mean -9.3600 193 | 2017-08-05 08:45:42,279 : Time 10h 29m 11s, episode reward -3.0, episode length 9039, reward mean -9.3239 194 | 2017-08-05 08:52:15,048 : Time 10h 35m 44s, episode reward -1.0, episode length 10000, reward mean -9.2768 195 | 2017-08-05 08:58:36,432 : Time 10h 42m 05s, episode reward 1.0, episode length 9695, reward mean -9.2191 196 | 2017-08-05 09:04:15,345 : Time 10h 47m 44s, episode reward -4.0, episode length 8385, reward mean -9.1899 197 | 2017-08-05 09:09:10,069 : Time 10h 52m 39s, episode reward 5.0, episode length 7005, reward mean -9.1111 198 | 2017-08-05 09:15:28,119 : Time 10h 58m 57s, episode reward -2.0, episode length 9578, reward mean -9.0718 199 | 2017-08-05 09:20:58,122 : Time 11h 04m 27s, episode reward -5.0, episode length 8133, reward mean -9.0495 200 | 2017-08-05 09:27:28,734 : Time 11h 10m 58s, episode reward 1.0, episode length 10000, reward mean -8.9945 201 | 2017-08-05 09:33:27,430 : Time 11h 16m 56s, episode reward 3.0, episode length 9041, reward mean -8.9293 202 | 2017-08-05 09:38:36,598 : Time 11h 22m 05s, episode reward 7.0, episode length 7560, reward mean -8.8432 203 | -------------------------------------------------------------------------------- /logs/Pong-v0_mon_log: -------------------------------------------------------------------------------- 1 | 2017-08-05 09:43:39,796 : reward sum: -1.0, reward mean: -1.0000 2 | 2017-08-05 09:45:24,366 : reward sum: 11.0, reward mean: 5.0000 3 | 2017-08-05 09:47:47,158 : reward sum: -6.0, reward mean: 1.3333 4 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, division 2 | import os 3 | os.environ["OMP_NUM_THREADS"] = "1" 4 | import argparse 5 | import torch 6 | from torch.multiprocessing import Process 7 | from environment import atari_env 8 | from utils import read_config 9 | from model import A3Clstm 10 | from train import train 11 | from test import test 12 | from shared_optim import SharedRMSprop, SharedAdam, SharedLrSchedAdam 13 | import time 14 | 15 | parser = argparse.ArgumentParser(description='A3C') 16 | parser.add_argument( 17 | '--lr', 18 | type=float, 19 | default=0.0001, 20 | metavar='LR', 21 | help='learning rate (default: 0.0001)') 22 | parser.add_argument( 23 | '--gamma', 24 | type=float, 25 | default=0.99, 26 | metavar='G', 27 | help='discount factor for rewards (default: 0.99)') 28 | parser.add_argument( 29 | '--tau', 30 | type=float, 31 | default=1.00, 32 | metavar='T', 33 | help='parameter for GAE (default: 1.00)') 34 | parser.add_argument( 35 | '--seed', 36 | type=int, 37 | default=1, 38 | metavar='S', 39 | help='random seed (default: 1)') 40 | parser.add_argument( 41 | '--workers', 42 | type=int, 43 | default=32, 44 | metavar='W', 45 | help='how many training processes to use (default: 32)') 46 | parser.add_argument( 47 | '--num-steps', 48 | type=int, 49 | default=20, 50 | metavar='NS', 51 | help='number of forward steps in A3C (default: 20)') 52 | parser.add_argument( 53 | '--max-episode-length', 54 | type=int, 55 | default=10000, 56 | metavar='M', 57 | help='maximum length of an episode (default: 10000)') 58 | parser.add_argument( 59 | '--env', 60 | default='Pong-v0', 61 | metavar='ENV', 62 | help='environment to train on (default: Pong-v0)') 63 | parser.add_argument( 64 | '--env-config', 65 | default='config.json', 66 | metavar='EC', 67 | help='environment to crop and resize info (default: config.json)') 68 | parser.add_argument( 69 | '--shared-optimizer', 70 | default=True, 71 | metavar='SO', 72 | help='use an optimizer without shared statistics.') 73 | parser.add_argument( 74 | '--load', 75 | default=True, 76 | metavar='L', 77 | help='load a trained model') 78 | parser.add_argument( 79 | '--save-score-level', 80 | type=int, 81 | default=20, 82 | metavar='SSL', 83 | help='reward score test evaluation must get higher than to save model') 84 | parser.add_argument( 85 | '--optimizer', 86 | default='Adam', 87 | metavar='OPT', 88 | help='shares optimizer choice of Adam or RMSprop') 89 | parser.add_argument( 90 | '--count-lives', 91 | default=False, 92 | metavar='CL', 93 | help='end of life is end of training episode.') 94 | parser.add_argument( 95 | '--load-model-dir', 96 | default='checkpoints/', 97 | metavar='LMD', 98 | help='folder to load trained models from') 99 | parser.add_argument( 100 | '--save-model-dir', 101 | default='checkpoints/', 102 | metavar='SMD', 103 | help='folder to save trained models') 104 | parser.add_argument( 105 | '--log-dir', 106 | default='logs/', 107 | metavar='LG', 108 | help='folder to save logs') 109 | 110 | # Based on 111 | # https://github.com/pytorch/examples/tree/master/mnist_hogwild 112 | # Training settings 113 | # Implemented multiprocessing using locks but was not beneficial. Hogwild 114 | # training was far superior 115 | 116 | if __name__ == '__main__': 117 | args = parser.parse_args() 118 | torch.set_default_tensor_type('torch.FloatTensor') 119 | torch.manual_seed(args.seed) 120 | 121 | setup_json = read_config(args.env_config) 122 | env_conf = setup_json["Default"] 123 | for i in setup_json.keys(): 124 | if i in args.env: 125 | env_conf = setup_json[i] 126 | env = atari_env(args.env, env_conf) 127 | 128 | if not os.path.exists(args.load_model_dir): 129 | os.makedirs(args.load_model_dir) 130 | if not os.path.exists(args.log_dir): 131 | os.makedirs(args.log_dir) 132 | 133 | saved_state_path = os.path.join(args.load_model_dir, args.env + '.model') 134 | shared_model = A3Clstm(env.observation_space.shape[0], env.action_space) 135 | if os.path.exists(saved_state_path) and args.load: 136 | saved_state = torch.load(saved_state_path, map_location=lambda storage, loc: storage) 137 | print('Loading previous model from: {}'.format(saved_state_path)) 138 | shared_model.load_state_dict(saved_state) 139 | shared_model.share_memory() 140 | 141 | if args.shared_optimizer: 142 | if args.optimizer == 'RMSprop': 143 | optimizer = SharedRMSprop(shared_model.parameters(), lr=args.lr) 144 | if args.optimizer == 'Adam': 145 | optimizer = SharedAdam(shared_model.parameters(), lr=args.lr) 146 | if args.optimizer == 'LrSchedAdam': 147 | optimizer = SharedLrSchedAdam(shared_model.parameters(), lr=args.lr) 148 | optimizer.share_memory() 149 | else: 150 | optimizer = None 151 | 152 | processes = [] 153 | 154 | p = Process(target=test, args=(args, shared_model, env_conf)) 155 | p.start() 156 | processes.append(p) 157 | time.sleep(0.1) 158 | for rank in range(0, args.workers): 159 | p = Process(target=train, args=(rank, args, shared_model, optimizer, env_conf)) 160 | p.start() 161 | processes.append(p) 162 | time.sleep(0.1) 163 | for p in processes: 164 | time.sleep(0.1) 165 | p.join() 166 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.init as init 5 | import torch.nn.functional as F 6 | from torch.autograd import Variable 7 | from utils import normalized_columns_initializer, weights_init 8 | 9 | 10 | class A3Clstm(torch.nn.Module): 11 | def __init__(self, num_inputs, action_space): 12 | super(A3Clstm, self).__init__() 13 | self.conv1 = nn.Conv2d(num_inputs, 32, 5, stride=1, padding=2) 14 | self.maxp1 = nn.MaxPool2d(2, 2) 15 | self.conv2 = nn.Conv2d(32, 32, 5, stride=1, padding=1) 16 | self.maxp2 = nn.MaxPool2d(2, 2) 17 | self.conv3 = nn.Conv2d(32, 64, 4, stride=1, padding=1) 18 | self.maxp3 = nn.MaxPool2d(2, 2) 19 | self.conv4 = nn.Conv2d(64, 64, 3, stride=1, padding=1) 20 | self.maxp4 = nn.MaxPool2d(2, 2) 21 | 22 | self.lstm = nn.LSTMCell(1024, 512) 23 | num_outputs = action_space.n 24 | self.critic_linear = nn.Linear(512, 1) 25 | self.actor_linear = nn.Linear(512, num_outputs) 26 | 27 | self.apply(weights_init) 28 | self.actor_linear.reset_parameters() 29 | self.critic_linear.reset_parameters() 30 | 31 | self.lstm.bias_ih.data.fill_(0) 32 | self.lstm.bias_hh.data.fill_(0) 33 | 34 | self.train() 35 | 36 | def forward(self, inputs): 37 | inputs, (hx, cx) = inputs 38 | x = F.relu(self.maxp1(self.conv1(inputs))) 39 | x = F.relu(self.maxp2(self.conv2(x))) 40 | x = F.relu(self.maxp3(self.conv3(x))) 41 | x = F.relu(self.maxp4(self.conv4(x))) 42 | 43 | x = x.view(x.size(0), -1) 44 | 45 | hx, cx = self.lstm(x, (hx, cx)) 46 | 47 | x = hx 48 | 49 | return self.critic_linear(x), self.actor_linear(x), (hx, cx) 50 | -------------------------------------------------------------------------------- /player_util.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | import torch 3 | import torch.nn.functional as F 4 | from torch.autograd import Variable 5 | 6 | 7 | class Agent(object): 8 | def __init__(self, model, env, args, state): 9 | self.model = model 10 | self.env = env 11 | self.current_life = 0 12 | self.state = state 13 | self.hx = None 14 | self.cx = None 15 | self.eps_len = 0 16 | self.args = args 17 | self.values = [] 18 | self.log_probs = [] 19 | self.rewards = [] 20 | self.entropies = [] 21 | self.done = True 22 | self.flag = False 23 | self.info = None 24 | self.starter = False or args.env[:8] == 'Breakout' 25 | 26 | 27 | def player_act(player, train): 28 | if train: 29 | value, logit, (player.hx, player.cx) = player.model( 30 | (Variable(player.state.unsqueeze(0)), (player.hx, player.cx))) 31 | else: 32 | value, logit, (player.hx, player.cx) = player.model((Variable( 33 | player.state.unsqueeze(0), volatile=True), (player.hx, player.cx))) 34 | prob = F.softmax(logit) 35 | action = prob.max(1)[1].data.numpy() 36 | state, reward, player.done, player.info = player.env.step(action[0]) 37 | player.state = torch.from_numpy(state).float() 38 | player.eps_len += 1 39 | player.done = player.done or player.eps_len >= player.args.max_episode_length 40 | return player, reward 41 | prob = F.softmax(logit) 42 | log_prob = F.log_softmax(logit) 43 | entropy = -(log_prob * prob).sum(1) 44 | player.entropies.append(entropy) 45 | action = prob.multinomial().data 46 | log_prob = log_prob.gather(1, Variable(action)) 47 | state, reward, player.done, player.info = player.env.step(action.numpy()) 48 | player.state = torch.from_numpy(state).float() 49 | player.eps_len += 1 50 | player.done = player.done or player.eps_len >= player.args.max_episode_length 51 | reward = max(min(reward, 1), -1) 52 | player.values.append(value) 53 | player.log_probs.append(log_prob) 54 | player.rewards.append(reward) 55 | return player 56 | 57 | 58 | def player_start(player, train): 59 | for i in range(3): 60 | player.flag = False 61 | if train: 62 | value, logit, (player.hx, player.cx) = player.model( 63 | (Variable(player.state.unsqueeze(0)), (player.hx, player.cx))) 64 | else: 65 | value, logit, (player.hx, player.cx) = player.model((Variable( 66 | player.state.unsqueeze(0), volatile=True), (player.hx, 67 | player.cx))) 68 | prob = F.softmax(logit) 69 | log_prob = F.log_softmax(logit) 70 | entropy = -(log_prob * prob).sum(1) 71 | player.entropies.append(entropy) 72 | action = prob.multinomial().data 73 | log_prob = log_prob.gather(1, Variable(action)) 74 | state, reward, player.done, player.info = player.env.step(1) 75 | player.state = torch.from_numpy(state).float() 76 | player.eps_len += 1 77 | player.done = player.done or player.eps_len >= player.args.max_episode_length 78 | if train: 79 | reward = max(min(reward, 1), -1) 80 | player.values.append(value) 81 | player.log_probs.append(log_prob) 82 | player.rewards.append(reward) 83 | if player.done: 84 | return player 85 | return player 86 | -------------------------------------------------------------------------------- /run_play.sh: -------------------------------------------------------------------------------- 1 | python3 gym_eval.py --env Pong-v0 --render True -------------------------------------------------------------------------------- /run_train.sh: -------------------------------------------------------------------------------- 1 | # Support Python3 Only 2 | python3 main.py --env Pong-v0 --workers 4 --save-model-dir 'checkpoints/' -------------------------------------------------------------------------------- /shared_optim.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | import math 3 | import torch 4 | import torch.optim as optim 5 | 6 | 7 | class SharedRMSprop(optim.RMSprop): 8 | """Implements RMSprop algorithm with shared states. 9 | """ 10 | 11 | def __init__(self, 12 | params, 13 | lr=7e-4, 14 | alpha=0.99, 15 | eps=0.1, 16 | weight_decay=0, 17 | momentum=0, 18 | centered=False): 19 | super(SharedRMSprop, self).__init__(params, lr, alpha, eps, 20 | weight_decay, momentum, centered) 21 | 22 | for group in self.param_groups: 23 | for p in group['params']: 24 | state = self.state[p] 25 | state['step'] = torch.zeros(1) 26 | state['grad_avg'] = p.data.new().resize_as_(p.data).zero_() 27 | state['square_avg'] = p.data.new().resize_as_(p.data).zero_() 28 | state['momentum_buffer'] = p.data.new().resize_as_( 29 | p.data).zero_() 30 | 31 | def share_memory(self): 32 | for group in self.param_groups: 33 | for p in group['params']: 34 | state = self.state[p] 35 | state['square_avg'].share_memory_() 36 | state['step'].share_memory_() 37 | state['grad_avg'].share_memory_() 38 | 39 | def step(self, closure=None): 40 | """Performs a single optimization step. 41 | Arguments: 42 | closure (callable, optional): A closure that reevaluates the model 43 | and returns the loss. 44 | """ 45 | loss = None 46 | if closure is not None: 47 | loss = closure() 48 | 49 | for group in self.param_groups: 50 | for p in group['params']: 51 | if p.grad is None: 52 | continue 53 | grad = p.grad.data 54 | state = self.state[p] 55 | 56 | square_avg = state['square_avg'] 57 | alpha = group['alpha'] 58 | 59 | state['step'] += 1 60 | 61 | if group['weight_decay'] != 0: 62 | grad = grad.add(group['weight_decay'], p.data) 63 | 64 | square_avg.mul_(alpha).addcmul_(1 - alpha, grad, grad) 65 | 66 | if group['centered']: 67 | grad_avg = state['grad_avg'] 68 | grad_avg.mul_(alpha).add_(1 - alpha, grad) 69 | avg = square_avg.addcmul( 70 | -1, grad_avg, grad_avg).sqrt().add_(group['eps']) 71 | else: 72 | avg = square_avg.sqrt().add_(group['eps']) 73 | 74 | if group['momentum'] > 0: 75 | buf = state['momentum_buffer'] 76 | buf.mul_(group['momentum']).addcdiv_(grad, avg) 77 | p.data.add_(-group['lr'], buf) 78 | else: 79 | p.data.addcdiv_(-group['lr'], grad, avg) 80 | 81 | return loss 82 | 83 | 84 | class SharedAdam(optim.Adam): 85 | """Implements Adam algorithm with shared states. 86 | """ 87 | 88 | def __init__(self, 89 | params, 90 | lr=1e-3, 91 | betas=(0.9, 0.999), 92 | eps=1e-3, 93 | weight_decay=0): 94 | super(SharedAdam, self).__init__(params, lr, betas, eps, weight_decay) 95 | 96 | for group in self.param_groups: 97 | for p in group['params']: 98 | state = self.state[p] 99 | state['step'] = torch.zeros(1) 100 | state['exp_avg'] = p.data.new().resize_as_(p.data).zero_() 101 | state['exp_avg_sq'] = p.data.new().resize_as_(p.data).zero_() 102 | 103 | def share_memory(self): 104 | for group in self.param_groups: 105 | for p in group['params']: 106 | state = self.state[p] 107 | state['step'].share_memory_() 108 | state['exp_avg'].share_memory_() 109 | state['exp_avg_sq'].share_memory_() 110 | 111 | def step(self, closure=None): 112 | """Performs a single optimization step. 113 | Arguments: 114 | closure (callable, optional): A closure that reevaluates the model 115 | and returns the loss. 116 | """ 117 | loss = None 118 | if closure is not None: 119 | loss = closure() 120 | 121 | for group in self.param_groups: 122 | for p in group['params']: 123 | if p.grad is None: 124 | continue 125 | grad = p.grad.data 126 | state = self.state[p] 127 | 128 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 129 | beta1, beta2 = group['betas'] 130 | 131 | state['step'] += 1 132 | 133 | if group['weight_decay'] != 0: 134 | grad = grad.add(group['weight_decay'], p.data) 135 | 136 | # Decay the first and second moment running average coefficient 137 | exp_avg.mul_(beta1).add_(1 - beta1, grad) 138 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) 139 | 140 | denom = exp_avg_sq.sqrt().add_(group['eps']) 141 | 142 | bias_correction1 = 1 - beta1**state['step'][0] 143 | bias_correction2 = 1 - beta2**state['step'][0] 144 | step_size = group['lr'] * \ 145 | math.sqrt(bias_correction2) / bias_correction1 146 | 147 | p.data.addcdiv_(-step_size, exp_avg, denom) 148 | 149 | return loss 150 | 151 | 152 | sample_lr = [ 153 | 0.0001, 0.00009, 0.00008, 0.00007, 0.00006, 0.00005, 0.00004, 0.00003, 154 | 0.00002, 0.00001, 0.000009, 0.000008, 0.000007, 0.000006, 0.000005, 155 | 0.000004, 0.000003, 0.000002, 0.000001 156 | ] 157 | 158 | 159 | class SharedLrSchedAdam(optim.Adam): 160 | """Implements Adam algorithm with shared states. 161 | """ 162 | 163 | def __init__(self, 164 | params, 165 | lr=1e-3, 166 | betas=(0.9, 0.999), 167 | eps=1e-3, 168 | weight_decay=0): 169 | super(SharedLrSchedAdam, self).__init__(params, lr, betas, eps, 170 | weight_decay) 171 | 172 | for group in self.param_groups: 173 | for p in group['params']: 174 | state = self.state[p] 175 | state['step'] = torch.zeros(1) 176 | state['exp_avg'] = p.data.new().resize_as_(p.data).zero_() 177 | state['exp_avg_sq'] = p.data.new().resize_as_(p.data).zero_() 178 | 179 | def share_memory(self): 180 | for group in self.param_groups: 181 | for p in group['params']: 182 | state = self.state[p] 183 | state['step'].share_memory_() 184 | state['exp_avg'].share_memory_() 185 | state['exp_avg_sq'].share_memory_() 186 | 187 | def step(self, closure=None): 188 | """Performs a single optimization step. 189 | Arguments: 190 | closure (callable, optional): A closure that reevaluates the model 191 | and returns the loss. 192 | """ 193 | loss = None 194 | if closure is not None: 195 | loss = closure() 196 | 197 | lr = sample_lr[int(state['step'][0] // 40000000)] 198 | group['lr'] = lr 199 | 200 | for group in self.param_groups: 201 | for p in group['params']: 202 | if p.grad is None: 203 | continue 204 | grad = p.grad.data 205 | state = self.state[p] 206 | 207 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 208 | beta1, beta2 = group['betas'] 209 | 210 | state['step'] += 1 211 | 212 | if group['weight_decay'] != 0: 213 | grad = grad.add(group['weight_decay'], p.data) 214 | 215 | # Decay the first and second moment running average coefficient 216 | exp_avg.mul_(beta1).add_(1 - beta1, grad) 217 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) 218 | 219 | denom = exp_avg_sq.sqrt().add_(group['eps']) 220 | 221 | bias_correction1 = 1 - beta1**state['step'][0] 222 | bias_correction2 = 1 - beta2**state['step'][0] 223 | step_size = group['lr'] * \ 224 | math.sqrt(bias_correction2) / bias_correction1 225 | p.data.addcdiv_(-step_size, exp_avg, denom) 226 | 227 | return loss 228 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | import torch 3 | from environment import atari_env 4 | from utils import setup_logger 5 | from model import A3Clstm 6 | from player_util import Agent, player_act, player_start 7 | from torch.autograd import Variable 8 | import time 9 | import logging 10 | import os 11 | 12 | 13 | def test(args, shared_model, env_conf): 14 | log = {} 15 | setup_logger('{}_log'.format(args.env), r'{0}{1}_log'.format( 16 | args.log_dir, args.env)) 17 | log['{}_log'.format(args.env)] = logging.getLogger( 18 | '{}_log'.format(args.env)) 19 | d_args = vars(args) 20 | for k in d_args.keys(): 21 | log['{}_log'.format(args.env)].info('{0}: {1}'.format(k, d_args[k])) 22 | 23 | torch.manual_seed(args.seed) 24 | env = atari_env(args.env, env_conf) 25 | model = A3Clstm(env.observation_space.shape[0], env.action_space) 26 | 27 | state = env.reset() 28 | reward_sum = 0 29 | start_time = time.time() 30 | num_tests = 0 31 | reward_total_sum = 0 32 | player = Agent(model, env, args, state) 33 | player.state = torch.from_numpy(state).float() 34 | player.model.eval() 35 | while True: 36 | 37 | if player.done: 38 | player.model.load_state_dict(shared_model.state_dict()) 39 | player.cx = Variable(torch.zeros(1, 512), volatile=True) 40 | player.hx = Variable(torch.zeros(1, 512), volatile=True) 41 | if player.starter: 42 | player = player_start(player, train=False) 43 | else: 44 | player.cx = Variable(player.cx.data, volatile=True) 45 | player.hx = Variable(player.hx.data, volatile=True) 46 | 47 | player, reward = player_act(player, train=False) 48 | reward_sum += reward 49 | 50 | if not player.done: 51 | if player.current_life > player.info['ale.lives']: 52 | player.flag = True 53 | player.current_life = player.info['ale.lives'] 54 | else: 55 | player.current_life = player.info['ale.lives'] 56 | player.flag = False 57 | 58 | if player.starter and player.flag: 59 | player = player_start(player, train=False) 60 | 61 | if player.done: 62 | num_tests += 1 63 | player.current_life = 0 64 | player.flag = False 65 | reward_total_sum += reward_sum 66 | reward_mean = reward_total_sum / num_tests 67 | log['{}_log'.format(args.env)].info( 68 | "Time {0}, episode reward {1}, episode length {2}, reward mean {3:.4f}".format( 69 | time.strftime("%Hh %Mm %Ss", time.gmtime(time.time() - start_time)), 70 | reward_sum, player.eps_len, reward_mean)) 71 | 72 | player.model.load_state_dict(shared_model.state_dict()) 73 | 74 | state_to_save = player.model.state_dict() 75 | saved_state_path = os.path.join(args.load_model_dir, args.env + '.model') 76 | print('Model has been saved into {}'.format(saved_state_path)) 77 | torch.save(state_to_save, saved_state_path) 78 | 79 | reward_sum = 0 80 | player.eps_len = 0 81 | state = player.env.reset() 82 | time.sleep(60) 83 | player.state = torch.from_numpy(state).float() 84 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | import torch 3 | import torch.optim as optim 4 | from environment import atari_env 5 | from utils import ensure_shared_grads 6 | from model import A3Clstm 7 | from player_util import Agent, player_act, player_start 8 | from torch.autograd import Variable 9 | 10 | 11 | def train(rank, args, shared_model, optimizer, env_conf): 12 | torch.manual_seed(args.seed + rank) 13 | 14 | env = atari_env(args.env, env_conf) 15 | model = A3Clstm(env.observation_space.shape[0], env.action_space) 16 | 17 | if optimizer is None: 18 | if args.optimizer == 'RMSprop': 19 | optimizer = optim.RMSprop(shared_model.parameters(), lr=args.lr) 20 | if args.optimizer == 'Adam': 21 | optimizer = optim.Adam(shared_model.parameters(), lr=args.lr) 22 | 23 | env.seed(args.seed + rank) 24 | state = env.reset() 25 | player = Agent(model, env, args, state) 26 | player.state = torch.from_numpy(state).float() 27 | player.model.train() 28 | epoch = 0 29 | while True: 30 | 31 | player.model.load_state_dict(shared_model.state_dict()) 32 | if player.done: 33 | player.cx = Variable(torch.zeros(1, 512)) 34 | player.hx = Variable(torch.zeros(1, 512)) 35 | if player.starter: 36 | player = player_start(player, train=True) 37 | else: 38 | player.cx = Variable(player.cx.data) 39 | player.hx = Variable(player.hx.data) 40 | 41 | for step in range(args.num_steps): 42 | 43 | player = player_act(player, train=True) 44 | 45 | if player.done: 46 | break 47 | 48 | if player.current_life > player.info['ale.lives']: 49 | player.flag = True 50 | player.current_life = player.info['ale.lives'] 51 | else: 52 | player.current_life = player.info['ale.lives'] 53 | player.flag = False 54 | if args.count_lives: 55 | if player.flag: 56 | player.done = True 57 | break 58 | 59 | if player.starter and player.flag: 60 | player = player_start(player, train=True) 61 | if player.done: 62 | break 63 | 64 | if player.done: 65 | player.eps_len = 0 66 | player.current_life = 0 67 | state = player.env.reset() 68 | player.state = torch.from_numpy(state).float() 69 | player.flag = False 70 | 71 | R = torch.zeros(1, 1) 72 | if not player.done: 73 | value, _, _ = player.model( 74 | (Variable(player.state.unsqueeze(0)), (player.hx, player.cx))) 75 | R = value.data 76 | 77 | player.values.append(Variable(R)) 78 | policy_loss = 0 79 | value_loss = 0 80 | R = Variable(R) 81 | gae = torch.zeros(1, 1) 82 | for i in reversed(range(len(player.rewards))): 83 | R = args.gamma * R + player.rewards[i] 84 | advantage = R - player.values[i] 85 | value_loss += 0.5 * advantage.pow(2) 86 | 87 | # Generalized Advantage Estimataion 88 | delta_t = player.rewards[i] + args.gamma * player.values[i + 1].data - player.values[i].data 89 | gae = gae * args.gamma * args.tau + delta_t 90 | 91 | policy_loss = policy_loss - player.log_probs[i] * Variable(gae) - 0.01 * player.entropies[i] 92 | 93 | optimizer.zero_grad() 94 | 95 | (policy_loss + value_loss).backward() 96 | 97 | ensure_shared_grads(player.model, shared_model) 98 | optimizer.step() 99 | player.values = [] 100 | player.log_probs = [] 101 | player.rewards = [] 102 | player.entropies = [] 103 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | import numpy as np 3 | import torch 4 | import json 5 | import logging 6 | 7 | 8 | def setup_logger(logger_name, log_file, level=logging.INFO): 9 | l = logging.getLogger(logger_name) 10 | formatter = logging.Formatter('%(asctime)s : %(message)s') 11 | file_handler = logging.FileHandler(log_file, mode='w') 12 | file_handler.setFormatter(formatter) 13 | stream_handler = logging.StreamHandler() 14 | stream_handler.setFormatter(formatter) 15 | 16 | l.setLevel(level) 17 | l.addHandler(file_handler) 18 | l.addHandler(stream_handler) 19 | 20 | 21 | def read_config(file_path): 22 | """Read JSON config.""" 23 | json_object = json.load(open(file_path, 'r')) 24 | return json_object 25 | 26 | 27 | def normalized_columns_initializer(weights, std=1.0): 28 | out = torch.randn(weights.size()) 29 | out *= std / torch.sqrt(out.pow(2).sum(1, keepdim=True)) 30 | return out 31 | 32 | 33 | def ensure_shared_grads(model, shared_model): 34 | for param, shared_param in zip(model.parameters(), 35 | shared_model.parameters()): 36 | if shared_param.grad is not None: 37 | return 38 | shared_param._grad = param.grad 39 | 40 | 41 | def weights_init(m): 42 | class_name = m.__class__.__name__ 43 | if class_name.find('Conv') != -1: 44 | weight_shape = list(m.weight.data.size()) 45 | fan_in = np.prod(weight_shape[1:4]) 46 | fan_out = np.prod(weight_shape[2:4]) * weight_shape[0] 47 | w_bound = np.sqrt(6. / (fan_in + fan_out)) 48 | m.weight.data.uniform_(-w_bound, w_bound) 49 | m.bias.data.fill_(0) 50 | elif class_name.find('Linear') != -1: 51 | weight_shape = list(m.weight.data.size()) 52 | fan_in = weight_shape[1] 53 | fan_out = weight_shape[0] 54 | w_bound = np.sqrt(6. / (fan_in + fan_out)) 55 | m.weight.data.uniform_(-w_bound, w_bound) 56 | m.bias.data.fill_(0) 57 | --------------------------------------------------------------------------------