├── README.md ├── actor.py ├── control.py ├── environment.py ├── img ├── action.png ├── frame.png ├── loss.png ├── maxv.png ├── pika.png ├── polepoint │ └── polepoint.png ├── reward.png ├── reward2.png ├── total_reward.png └── total_reward2.png ├── learner.py ├── model.py └── requirements.txt /README.md: -------------------------------------------------------------------------------- 1 | # Alphachu: Ape-x DQN implementation of Pikachu Volleyball 2 | ### [[Demo]](https://www.youtube.com/watch?v=vSkLegIUD98)[ [Paper]](https://arxiv.org/abs/1803.00933) 3 | Training agents to learn how to play Pikachu Volleyball. Architecture is based on Ape-x DQN from the [paper](https://arxiv.org/abs/1803.00933). The game is in exe file which makes the whole problem much more complicated than other Atari games. I built python environment to take screenshot of the game to provide as state and detect the start and end of game. I used mss to take screen shot, cv2 to preprocess image, pynput to press the keyboard, and tensorboardX to record log. I created a number of virtual monitors with Xvfb for each actor. To provide different key input to each monitor, the architecture had to be multi-process. A learner only trains on GPU and many(Assume 10) actors collected data from virtual monitors. They communicate through files in log directory. 4 | 5 | As it sounds, it is complicated. My method seems pretty primitive but it was the only way to train pikachu volleyball. 6 | 7 | ![img](img/pika.png) 8 | 9 | ## Before start 10 | - I tried this in Ubuntu and Mac. 11 | - Reset log_directory and data_directory in actor.py, and learner.py. 12 | 13 | ## Prerequisites 14 | - Install PyTorch dependencies from http://pytorch.org 15 | - Install requirements.txt (```pip install -r requirements.txt```) 16 | - Install Xvfb(```sudo apt-get install xvfb -y```) 17 | 18 | ## Creating Virtual Monitors with Xvfb 19 | Repeat this for 10 times to create virtual monitors. 20 | ``` 21 | Xvfb :99 -ac -screen 0 1280x1024x24 > /dev/null & 22 | echo "export DISPLAY=:99" >> ~/.bashrc 23 | ``` 24 | 25 | ## Run learner 26 | Run learner and copy the model timestamp with configuration. 27 | ``` 28 | python learner.py --actor-num 10 29 | Learner: Model saved in /home/sungwonlyu/experiment/alphachu/180801225440_256_0.0001_4_84_129_32_1_30000_1500_10/model.pt 30 | ``` 31 | 32 | ## Run actors 33 | Run pika.exe and actor in virtual monitor. Also need to do this 10 times with varying epsilons. 34 | ``` 35 | DISPLAY=:99 wine pika.exe 36 | DISPLAY=:99 python actor.py --load-model 180801225440_256_0.0001_4_84_129_32_1_30000_1500_10 --epsilon 0.9 --wepsilon 0.9 37 | ``` 38 | 39 | ## Test 40 | To see the performance of the agent, reset screen-size in environment.py to set the place for screen shot. Then place the pika.exe to the area and start a actor with trained model. 41 | ``` 42 | wine pika.exe 43 | python actor.py --load-model 180801225440_256_0.0001_4_84_129_32_1_30000_1500_10 --test 44 | ``` 45 | 46 | ## Result 47 | ### Demo 48 | You can find demo on [youtube](https://www.youtube.com/watch?v=vSkLegIUD98). 49 | 50 | ### Graphs 51 | 0.99 smoothed graphs for the first 7 days. 52 | #### Loss 53 | ![img](img/loss.png) 54 | 55 | #### Action 56 | ![img](img/action.png) 57 | 58 | #### Frame 59 | ![img](img/frame.png) 60 | 61 | #### Max Value 62 | ![img](img/maxv.png) 63 | 64 | #### Reward 65 | ![img](img/reward.png) 66 | 67 | #### Total reward 68 | My score - computer score (-15 ~ 15) 69 | ![img](img/total_reward.png) 70 | -------------------------------------------------------------------------------- /actor.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from environment import Env 3 | import control as c 4 | from model import DQN 5 | import time 6 | import torch 7 | from torch import nn 8 | import torch.nn.functional as F 9 | import torch.utils.data 10 | import torch.optim as optim 11 | from collections import deque 12 | import random 13 | # import numpy as np 14 | from tensorboardX import SummaryWriter 15 | import os 16 | import gc 17 | 18 | parser = argparse.ArgumentParser(description='parser') 19 | parser.add_argument('--seed', type=int, default=1, metavar='S', help='random seed (default: 1)') 20 | parser.add_argument('--epochs', type=int, default=1000000, metavar='N', help='number of epochs to train (default: 10)') 21 | parser.add_argument('--simnum', type=int, default=0, metavar='N') 22 | parser.add_argument('--start-epoch', type=int, default=0, metavar='N') 23 | parser.add_argument('--load-model', type=str, default='000000000000', metavar='N', help='load previous model') 24 | parser.add_argument('--test', action='store_true', default=False) 25 | parser.add_argument('--save-data', action='store_true', default=False) 26 | parser.add_argument('--device', type=str, default="cpu", metavar='N') 27 | parser.add_argument('--log-directory', type=str, default='/home/sungwonlyu/experiment/alphachu/', metavar='N', help='log directory') 28 | parser.add_argument('--data-directory', type=str, default='/home/sungwonlyu/data/alphachu/', metavar='N', help='data directory') 29 | # parser.add_argument('--log-directory', type=str, default='/Users/SungwonLyu/experiment/alphachu/', metavar='N', help='log directory') 30 | # parser.add_argument('--data-directory', type=str, default='/Users/SungwonLyu/data/alphachu/', metavar='N', help='data directory') 31 | parser.add_argument('--history_size', type=int, default=4, metavar='N') 32 | parser.add_argument('--width', type=int, default=129, metavar='N') 33 | parser.add_argument('--height', type=int, default=84, metavar='N') 34 | parser.add_argument('--hidden-size', type=int, default=32, metavar='N') 35 | parser.add_argument('--epsilon', type=float, default=0.9, metavar='N') 36 | parser.add_argument('--wepsilon', type=float, default=0.9, metavar='N') 37 | parser.add_argument('--frame-time', type=float, default=0.2, metavar='N') 38 | parser.add_argument('--reward', type=float, default=1, metavar='N') 39 | parser.add_argument('--replay-size', type=int, default=3000, metavar='N') 40 | args = parser.parse_args() 41 | torch.manual_seed(args.seed) 42 | 43 | 44 | class Actor: 45 | def __init__(self): 46 | if args.device != 'cpu': 47 | torch.cuda.set_device(int(args.device)) 48 | self.device = torch.device('cuda:{}'.format(int(args.device))) 49 | else: 50 | self.device = torch.device('cpu') 51 | 52 | self.simnum = args.simnum 53 | self.history_size = args.history_size 54 | self.height = args.height 55 | self.width = args.width 56 | self.hidden_size = args.hidden_size 57 | if args.test: 58 | args.epsilon = 0 59 | args.wepsilon = 0 60 | self.epsilon = args.epsilon 61 | self.log = args.log_directory + args.load_model + '/' 62 | self.writer = SummaryWriter(self.log + str(self.simnum) + '/') 63 | 64 | self.dis = 0.99 65 | self.win = False 66 | self.jump = False 67 | self.ground_key_dict = {0: c.stay, 68 | 1: c.left, 69 | 2: c.right, 70 | 3: c.up, 71 | 4: c.left_p, 72 | 5: c.right_p} 73 | self.jump_key_dict = {0: c.stay, 74 | 1: c.left_p, 75 | 2: c.right_p, 76 | 3: c.up_p, 77 | 4: c.p, 78 | 5: c.down_p} 79 | self.key_dict = self.ground_key_dict 80 | self.action_size = len(self.key_dict) 81 | self.replay_memory = deque(maxlen=args.replay_size) 82 | self.priority = deque(maxlen=args.replay_size) 83 | self.mainDQN = DQN(self.history_size, self.hidden_size, self.action_size).to(self.device) 84 | self.start_epoch = self.load_checkpoint() 85 | 86 | def save_checkpoint(self, idx): 87 | checkpoint = {'simnum': self.simnum, 88 | 'epoch': idx + 1} 89 | torch.save(checkpoint, self.log + 'checkpoint{}.pt'.format(self.simnum)) 90 | print('Actor {}: Checkpoint saved in '.format(self.simnum), self.log + 'checkpoint{}.pt'.format(self.simnum)) 91 | 92 | def load_checkpoint(self): 93 | if os.path.isfile(self.log + 'checkpoint{}.pt'.format(self.simnum)): 94 | checkpoint = torch.load(self.log + 'checkpoint{}.pt'.format(self.simnum)) 95 | self.simnum = checkpoint['simnum'] 96 | print("Actor {}: loaded checkpoint ".format(self.simnum), '(epoch {})'.format(checkpoint['epoch']), self.log + 'checkpoint{}.pt'.format(self.simnum)) 97 | return checkpoint['epoch'] 98 | else: 99 | print("Actor {}: no checkpoint found at ".format(self.simnum), self.log + 'checkpoint{}.pt'.format(self.simnum)) 100 | return args.start_epoch 101 | 102 | def save_memory(self): 103 | if os.path.isfile(self.log + 'memory.pt'): 104 | try: 105 | memory = torch.load(self.log + 'memory{}.pt'.format(self.simnum)) 106 | memory['replay_memory'].extend(self.replay_memory) 107 | memory['priority'].extend(self.priority) 108 | torch.save(memory, self.log + 'memory{}.pt'.format(self.simnum)) 109 | self.replay_memory.clear() 110 | self.priority.clear() 111 | except: 112 | time.sleep(10) 113 | memory = torch.load(self.log + 'memory{}.pt'.format(self.simnum)) 114 | memory['replay_memory'].extend(self.replay_memory) 115 | memory['priority'].extend(self.priority) 116 | torch.save(memory, self.log + 'memory{}.pt'.format(self.simnum)) 117 | self.replay_memory.clear() 118 | self.priority.clear() 119 | else: 120 | memory = {'replay_memory': self.replay_memory, 121 | 'priority': self.priority} 122 | torch.save(memory, self.log + 'memory{}.pt'.format(self.simnum)) 123 | self.replay_memory.clear() 124 | self.priority.clear() 125 | 126 | print('Actor {}: Memory saved in '.format(self.simnum), self.log + 'memory{}.pt'.format(self.simnum)) 127 | 128 | def load_model(self): 129 | if os.path.isfile(self.log + 'model.pt'): 130 | if args.device == 'cpu': 131 | model_dict = torch.load(self.log + 'model.pt', map_location=lambda storage, loc: storage) 132 | else: 133 | model_dict = torch.load(self.log + 'model.pt') 134 | self.mainDQN.load_state_dict(model_dict['state_dict']) 135 | print('Actor {}: Model loaded from '.format(self.simnum), self.log + 'model.pt') 136 | 137 | else: 138 | print("Actor {}: no model found at '{}'".format(self.simnum, self.log + 'model.pt')) 139 | 140 | def history_init(self): 141 | history = torch.zeros([1, self.history_size, self.height, self.width]) 142 | return history 143 | 144 | def update_history(self, history, state): 145 | history = torch.cat([state, history[:, :self.history_size - 1]], 1) 146 | return history 147 | 148 | def select_action(self, history): 149 | self.mainDQN.eval() 150 | history = history.to(self.device) 151 | qval = self.mainDQN(history) 152 | self.maxv, action = torch.max(qval, 1) 153 | sample = random.random() 154 | if not self.win: 155 | self.epsilon = args.epsilon 156 | else: 157 | self.epsilon = args.wepsilon 158 | if sample > self.epsilon: 159 | self.random = False 160 | action = action.item() 161 | else: 162 | self.random = True 163 | action = random.randrange(self.action_size) 164 | return action 165 | 166 | def control(self, jump): 167 | if not jump: 168 | self.key_dict = self.ground_key_dict 169 | elif jump: 170 | self.key_dict = self.jump_key_dict 171 | 172 | def main(self): 173 | c.release() 174 | self.load_model() 175 | env.set_standard() 176 | total_reward = 0 177 | set_end = False 178 | for idx in range(self.start_epoch, args.epochs + 1): 179 | reward = self.round(idx, set_end) 180 | self.writer.add_scalar('reward', reward, idx) 181 | total_reward += reward 182 | set_end = env.restart() 183 | if set_end: 184 | self.writer.add_scalar('total_reward', total_reward, idx) 185 | total_reward = 0 186 | self.win = False 187 | if not args.test: 188 | self.save_memory() 189 | self.load_model() 190 | self.save_checkpoint(idx) 191 | env.restart_set() 192 | self.writer.close() 193 | 194 | def round(self, round_num, set_end): 195 | print("Round {} Start".format(round_num)) 196 | if not set_end: 197 | time.sleep(env.warmup) 198 | else: 199 | time.sleep(env.start_warmup) 200 | history = self.history_init() 201 | action = 0 202 | next_action = 0 203 | frame = 0 204 | reward = 0 205 | estimate = 0 206 | end = False 207 | maxv = torch.zeros(0).to(self.device) 208 | actions = torch.zeros(0).to(self.device) 209 | start_time = time.time() 210 | while not end: 211 | round_time = time.time() - start_time 212 | sleep_time = args.frame_time - (round_time % args.frame_time) 213 | time.sleep(sleep_time) 214 | start_time = time.time() 215 | if round_time + sleep_time > args.frame_time: 216 | raise ValueError('Timing error') 217 | # print(round_time, sleep_time, round_time + sleep_time) 218 | if args.save_data: 219 | save_dir = args.data_directory + str(args.time_stamp) + '-' + str(round_num) + '-' + str(frame) + '-' + str(action) + '.png' 220 | else: 221 | save_dir = None 222 | state = env.preprocess_img(save_dir=save_dir) 223 | next_history = self.update_history(history, state) 224 | end, jump = env.check_end() 225 | if not end: 226 | next_action = self.select_action(next_history) 227 | self.control(jump) 228 | print(self.key_dict[next_action]) 229 | self.key_dict[next_action](args.frame_time) 230 | if not self.random: 231 | maxv = torch.cat([maxv, self.maxv]) 232 | actions = torch.cat([actions, torch.FloatTensor([action]).to(self.device)]) 233 | frame += 1 234 | priority = abs(self.dis * self.maxv.item() - estimate) 235 | estimate = self.maxv.item() 236 | else: 237 | c.release() 238 | if env.win: 239 | reward = args.reward 240 | else: 241 | reward = - args.reward 242 | priority = abs(reward - estimate) 243 | if not args.test: 244 | self.replay_memory.append((history, action, reward, next_history, end)) 245 | self.priority.append(priority) 246 | history = next_history 247 | action = next_action 248 | if frame > 2000: 249 | raise ValueError('Loop bug') 250 | if maxv.size()[0] > 0: 251 | self.writer.add_scalar('maxv', maxv.mean(), round_num) 252 | if actions.size()[0] > 0: 253 | self.writer.add_scalar('action', actions.mean(), round_num) 254 | self.writer.add_scalar('epsilon', self.epsilon, round_num) 255 | self.writer.add_scalar('frame', frame, round_num) 256 | gc.collect() 257 | if env.win: 258 | print("Round {} Win: reward:{}, frame:{}".format(round_num, reward, frame)) 259 | self.win = True 260 | else: 261 | print("Round {} Lose: reward:{}, frame:{}".format(round_num, reward, frame)) 262 | self.win = False 263 | return reward 264 | 265 | 266 | if __name__ == "__main__": 267 | env = Env(args.height, args.width, args.frame_time) 268 | actor = Actor() 269 | actor.main() 270 | -------------------------------------------------------------------------------- /control.py: -------------------------------------------------------------------------------- 1 | from pynput.keyboard import Key, Controller 2 | import time 3 | 4 | keyboard = Controller() 5 | 6 | 7 | def release(): 8 | keyboard.release(Key.left) 9 | keyboard.release(Key.right) 10 | keyboard.release(Key.up) 11 | keyboard.release(Key.down) 12 | keyboard.release(Key.enter) 13 | 14 | def stay(press_time): 15 | keyboard.release(Key.right) 16 | keyboard.release(Key.left) 17 | 18 | def left(press_time): 19 | keyboard.release(Key.right) 20 | keyboard.press(Key.left) 21 | 22 | def right(press_time): 23 | keyboard.release(Key.left) 24 | keyboard.press(Key.right) 25 | 26 | def up(press_time): 27 | keyboard.press(Key.up) 28 | time.sleep(press_time / 10) 29 | keyboard.release(Key.up) 30 | 31 | def p(press_time): 32 | keyboard.release(Key.left) 33 | keyboard.release(Key.right) 34 | keyboard.press(Key.enter) 35 | time.sleep(press_time / 10) 36 | keyboard.release(Key.enter) 37 | 38 | def left_p(press_time): 39 | keyboard.release(Key.right) 40 | keyboard.press(Key.left) 41 | keyboard.press(Key.enter) 42 | time.sleep(press_time / 10) 43 | keyboard.release(Key.enter) 44 | 45 | def right_p(press_time): 46 | keyboard.release(Key.left) 47 | keyboard.press(Key.right) 48 | keyboard.press(Key.enter) 49 | time.sleep(press_time / 10) 50 | keyboard.release(Key.enter) 51 | 52 | def up_p(press_time): 53 | keyboard.release(Key.left) 54 | keyboard.release(Key.right) 55 | keyboard.press(Key.up) 56 | keyboard.press(Key.enter) 57 | time.sleep(press_time / 10) 58 | keyboard.release(Key.enter) 59 | keyboard.release(Key.up) 60 | 61 | def down_p(press_time): 62 | keyboard.release(Key.left) 63 | keyboard.release(Key.right) 64 | keyboard.press(Key.down) 65 | keyboard.press(Key.enter) 66 | time.sleep(press_time / 10) 67 | keyboard.release(Key.enter) 68 | keyboard.release(Key.down) 69 | 70 | 71 | 72 | 73 | # def release(): 74 | # keyboard.release(Key.left) 75 | # keyboard.release(Key.right) 76 | # keyboard.release(Key.up) 77 | # keyboard.release(Key.down) 78 | # keyboard.release(Key.enter) 79 | 80 | # def stay(press_time): 81 | # keyboard.release(Key.left) 82 | # keyboard.release(Key.right) 83 | 84 | # def left(press_time): 85 | # keyboard.release(Key.right) 86 | # keyboard.press(Key.left) 87 | 88 | # def right(press_time): 89 | # keyboard.release(Key.left) 90 | # keyboard.press(Key.right) 91 | 92 | # def up(press_time): 93 | # keyboard.press(Key.up) 94 | # time.sleep(press_time/2) 95 | # keyboard.release(Key.up) 96 | 97 | # def p(press_time): 98 | # keyboard.press(Key.enter) 99 | # time.sleep(press_time/2) 100 | # keyboard.release(Key.enter) 101 | 102 | # def p_down(press_time): 103 | # release() 104 | # keyboard.press(Key.down) 105 | # keyboard.press(Key.enter) 106 | # time.sleep(press_time/2) 107 | # keyboard.release(Key.down) 108 | # keyboard.release(Key.enter) 109 | -------------------------------------------------------------------------------- /environment.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import cv2 3 | from mss import mss 4 | import numpy as np 5 | import control as c 6 | # import os 7 | 8 | 9 | class Env: 10 | def __init__(self, height, width, frame_time): 11 | self.height = height 12 | self.width = width 13 | self.frame_time = frame_time 14 | self.observation_space = (283, 430) 15 | self.sct = mss() 16 | self.sct.compression_level = 9 17 | # self.screen_size = (410, 320, 900, 750) # hidden (1280x1024x24 middle) 18 | # self.screen_size = (100, 100, 570, 430) # one monitor 19 | self.screen_size = (50, 50, 500, 400) # one monitor 20 | self.lower_red = np.array([0, 200, 120]) 21 | self.upper_red = np.array([10, 255, 150]) 22 | self.lower_yellow = np.array([20, 120, 100]) 23 | self.upper_yellow = np.array([30, 255, 255]) 24 | # self.lower_red2 = np.array([0, 200, 60]) # white ball 25 | # self.upper_red2 = np.array([30, 255, 255]) # white ball 26 | self.start_warmup = 2.8 27 | self.warmup = 0.9 28 | self.pole = 0 29 | self.ground = 0 30 | self.threshold = 0.95 31 | self.polepoint = cv2.imread('img/polepoint/polepoint.png', 0) 32 | 33 | def preprocess_img(self, resize=True, save_dir=False): 34 | raw = self.sct.grab(self.screen_size) 35 | img = np.array(raw) 36 | img = cv2.resize(img, (450, 350)) 37 | if resize: 38 | img = img[self.ground - 180: self.ground + 100, self.pole - 215:self.pole + 215] 39 | img_hsv = cv2.cvtColor(img, cv2.COLOR_BGR2HSV) 40 | img_red = cv2.inRange(img_hsv, self.lower_red, self.upper_red) 41 | img_yellow = cv2.inRange(img_hsv, self.lower_yellow, self.upper_yellow) 42 | self.mask = img_red + img_yellow 43 | cv2.imshow('img', self.mask) 44 | cv2.waitKey(1) 45 | del img 46 | del img_hsv 47 | if resize: 48 | state = cv2.resize(self.mask, (self.width, self.height)) 49 | if save_dir: 50 | cv2.imwrite(save_dir, state) 51 | state = state / float(255) * float(2) - 1 52 | state = torch.FloatTensor(state).unsqueeze(0).unsqueeze(0) 53 | return state 54 | 55 | def get_standard(self, first_=False, set_=False): 56 | if first_: 57 | self.preprocess_img(resize=False) 58 | gameset_match = cv2.matchTemplate(self.mask, self.polepoint, eval('cv2.TM_CCOEFF_NORMED')) 59 | if np.max(gameset_match) > self.threshold: 60 | min_val, max_val, min_loc, max_loc = cv2.minMaxLoc(gameset_match) 61 | top_left = max_loc 62 | w, h = self.polepoint.shape[::-1] 63 | x, y = (top_left[0] + w // 2, top_left[1] + h // 2) 64 | self.pole = x 65 | self.ground = y 66 | return True 67 | else: 68 | return False 69 | 70 | else: 71 | self.preprocess_img(resize=True) 72 | gameset_match = cv2.matchTemplate(self.mask, self.polepoint, eval('cv2.TM_CCOEFF_NORMED')) 73 | if np.max(gameset_match) > self.threshold: 74 | return True 75 | else: 76 | return False 77 | 78 | def set_standard(self): 79 | ready = False 80 | print("Set standard") 81 | while not ready: 82 | ready = self.get_standard(first_=True) 83 | c.p(self.frame_time) 84 | print("Ready") 85 | 86 | def restart(self): 87 | start1 = False 88 | start2 = False 89 | restart = False 90 | wait = 0 91 | while not start1 or start2 or not restart: 92 | start1 = start2 93 | start2 = self.check_start() 94 | restart = self.get_standard() 95 | wait += 1 96 | if wait > 500: 97 | return True 98 | return False 99 | 100 | def check_start(self): 101 | return np.sum(self.mask) == 0 102 | 103 | def check_end(self): 104 | jump = self.mask[:220, 215:].sum() > 60000 105 | 106 | if np.sum(self.mask[-2:, :]) > 1000: 107 | if np.sum(self.mask[-2:, :215]) > np.sum(self.mask[-2:, 215:]): 108 | self.win = True 109 | # if save_dir: 110 | # new_dir = save_dir[:-6] + '-1.png' 111 | # os.rename(save_dir, new_dir) 112 | return True, False 113 | else: 114 | self.win = False 115 | # if save_dir: 116 | # new_dir = save_dir[:-6] + '-2.png' 117 | # os.rename(save_dir, new_dir) 118 | return True, False 119 | else: 120 | return False, jump 121 | 122 | def restart_set(self): 123 | start1 = False 124 | start2 = False 125 | restart = False 126 | while not start1 or start2 or not restart: 127 | start1 = start2 128 | start2 = self.check_start() 129 | restart = self.get_standard() 130 | print("Start new set") 131 | c.release() 132 | c.p(self.frame_time) 133 | -------------------------------------------------------------------------------- /img/action.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Lyusungwon/apex_dqn_pytorch/66e8d396b5ae227c54504797764b76359a5c3728/img/action.png -------------------------------------------------------------------------------- /img/frame.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Lyusungwon/apex_dqn_pytorch/66e8d396b5ae227c54504797764b76359a5c3728/img/frame.png -------------------------------------------------------------------------------- /img/loss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Lyusungwon/apex_dqn_pytorch/66e8d396b5ae227c54504797764b76359a5c3728/img/loss.png -------------------------------------------------------------------------------- /img/maxv.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Lyusungwon/apex_dqn_pytorch/66e8d396b5ae227c54504797764b76359a5c3728/img/maxv.png -------------------------------------------------------------------------------- /img/pika.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Lyusungwon/apex_dqn_pytorch/66e8d396b5ae227c54504797764b76359a5c3728/img/pika.png -------------------------------------------------------------------------------- /img/polepoint/polepoint.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Lyusungwon/apex_dqn_pytorch/66e8d396b5ae227c54504797764b76359a5c3728/img/polepoint/polepoint.png -------------------------------------------------------------------------------- /img/reward.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Lyusungwon/apex_dqn_pytorch/66e8d396b5ae227c54504797764b76359a5c3728/img/reward.png -------------------------------------------------------------------------------- /img/reward2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Lyusungwon/apex_dqn_pytorch/66e8d396b5ae227c54504797764b76359a5c3728/img/reward2.png -------------------------------------------------------------------------------- /img/total_reward.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Lyusungwon/apex_dqn_pytorch/66e8d396b5ae227c54504797764b76359a5c3728/img/total_reward.png -------------------------------------------------------------------------------- /img/total_reward2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Lyusungwon/apex_dqn_pytorch/66e8d396b5ae227c54504797764b76359a5c3728/img/total_reward2.png -------------------------------------------------------------------------------- /learner.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import control as c 3 | from model import DQN 4 | import datetime 5 | import time 6 | import torch 7 | from torch import nn 8 | import torch.nn.functional as F 9 | import torch.utils.data 10 | import torch.optim as optim 11 | import torch.cuda 12 | import torch.backends.cudnn as cudnn 13 | from collections import deque 14 | from tensorboardX import SummaryWriter 15 | import numpy as np 16 | import os 17 | import gc 18 | cudnn.benchmark = True 19 | 20 | parser = argparse.ArgumentParser(description='parser') 21 | parser.add_argument('--batch-size', type=int, default=256, metavar='N', 22 | help='input batch size for training (default: 256)') 23 | parser.add_argument('--lr', type=float, default=1e-4, metavar='N', 24 | help='learning rate (default: 1e-4)') 25 | parser.add_argument('--gpu', type=int, default=0, metavar='N', 26 | help='number of cuda') 27 | parser.add_argument('--no-cuda', action='store_true', default=False, 28 | help='enables CUDA training') 29 | parser.add_argument('--seed', type=int, default=1, metavar='S', 30 | help='random seed (default: 1)') 31 | parser.add_argument('--log-interval', type=int, default=100, metavar='N', 32 | help='how many batches to wait before logging training status') 33 | parser.add_argument('--time-stamp', type=str, default=datetime.datetime.now().strftime("%y%m%d%H%M%S"), metavar='N', 34 | help='time of the run(no modify)') 35 | parser.add_argument('--load-model', type=str, default='000000000000', metavar='N', 36 | help='load previous model') 37 | parser.add_argument('--start-epoch', type=int, default=0, metavar='N', 38 | help='start-epoch number') 39 | parser.add_argument('--log-directory', type=str, default='/home/sungwonlyu/experiment/alphachu/', metavar='N', 40 | help='log directory') 41 | parser.add_argument('--history_size', type=int, default=4, metavar='N') 42 | parser.add_argument('--width', type=int, default=129, metavar='N') 43 | parser.add_argument('--height', type=int, default=84, metavar='N') 44 | parser.add_argument('--hidden-size', type=int, default=32, metavar='N') 45 | parser.add_argument('--action-size', type=int, default=6, metavar='N') 46 | parser.add_argument('--reward', type=int, default=1, metavar='N') 47 | parser.add_argument('--replay-size', type=int, default=30000, metavar='N') 48 | parser.add_argument('--update-cycle', type=int, default=1500, metavar='N') 49 | parser.add_argument('--actor-num', type=int, default=10, metavar='N') 50 | args = parser.parse_args() 51 | torch.cuda.set_device(args.gpu) 52 | args.device = torch.device("cuda:{}".format(args.gpu) if not args.no_cuda and torch.cuda.is_available() else "cpu") 53 | torch.manual_seed(args.seed) 54 | 55 | config_list = [args.batch_size, args.lr, args.history_size, 56 | args.height, args.width, args.hidden_size, 57 | args.reward, args.replay_size, 58 | args.update_cycle, args.actor_num] 59 | config = "" 60 | for i in map(str, config_list): 61 | config = config + '_' + i 62 | print("Config:", config) 63 | 64 | 65 | class Learner(): 66 | def __init__(self): 67 | self.device = args.device 68 | self.batch_size = args.batch_size 69 | self.lr = args.lr 70 | self.history_size = args.history_size 71 | self.replay_size = args.replay_size 72 | self.width = args.width 73 | self.height = args.height 74 | self.hidden_size = args.hidden_size 75 | self.action_size = args.action_size 76 | self.update_cycle = args.update_cycle 77 | self.log_interval = args.log_interval 78 | self.actor_num = args.actor_num 79 | self.alpha = 0.7 80 | self.beta_init = 0.4 81 | self.beta = self.beta_init 82 | self.beta_increment = 1e-6 83 | self.e = 1e-6 84 | self.dis = 0.99 85 | self.start_epoch = 0 86 | self.mainDQN = DQN(self.history_size, self.hidden_size, self.action_size).to(self.device) 87 | self.targetDQN = DQN(self.history_size, self.hidden_size, self.action_size).to(self.device) 88 | self.update_target_model() 89 | self.optimizer = optim.Adam(self.mainDQN.parameters(), lr=args.lr) 90 | self.replay_memory = deque(maxlen=self.replay_size) 91 | self.priority = deque(maxlen=self.replay_size) 92 | 93 | if args.load_model != '000000000000': 94 | self.log = args.log_directory + args.load_model + '/' 95 | args.time_stamp = args.load_model[:12] 96 | args.start_epoch = self.load_model() 97 | self.log = args.log_directory + args.time_stamp + config + '/' 98 | self.writer = SummaryWriter(self.log) 99 | 100 | def update_target_model(self): 101 | self.targetDQN.load_state_dict(self.mainDQN.state_dict()) 102 | 103 | def save_model(self, train_epoch): 104 | model_dict = {'state_dict': self.mainDQN.state_dict(), 105 | 'optimizer_dict': self.optimizer.state_dict(), 106 | 'train_epoch': train_epoch} 107 | torch.save(model_dict, self.log + 'model.pt') 108 | print('Learner: Model saved in ', self.log + 'model.pt') 109 | 110 | def load_model(self): 111 | if os.path.isfile(self.log + 'model.pt'): 112 | model_dict = torch.load(self.log + 'model.pt') 113 | self.mainDQN.load_state_dict(model_dict['state_dict']) 114 | self.optimizer.load_state_dict(model_dict['optimizer_dict']) 115 | self.update_target_model() 116 | self.start_epoch = model_dict['train_epoch'] 117 | print("Learner: Model loaded from {}(epoch:{})".format(self.log + 'model.pt', str(self.start_epoch))) 118 | else: 119 | raise "=> Learner: no model found at '{}'".format(self.log + 'model.pt') 120 | 121 | def load_memory(self, simnum): 122 | if os.path.isfile(self.log + 'memory{}.pt'.format(simnum)): 123 | try: 124 | memory_dict = torch.load(self.log + 'memory{}.pt'.format(simnum)) 125 | self.replay_memory.extend(memory_dict['replay_memory']) 126 | self.priority.extend(memory_dict['priority']) 127 | print('Memory loaded from ', self.log + 'memory{}.pt'.format(simnum)) 128 | memory_dict['replay_memory'].clear() 129 | memory_dict['priority'].clear() 130 | torch.save(memory_dict, self.log + 'memory{}.pt'.format(simnum)) 131 | except: 132 | time.sleep(10) 133 | memory_dict = torch.load(self.log + 'memory{}.pt'.format(simnum)) 134 | self.replay_memory.extend(memory_dict['replay_memory']) 135 | self.priority.extend(memory_dict['priority']) 136 | print('Memory loaded from ', self.log + 'memory{}.pt'.format(simnum)) 137 | memory_dict['replay_memory'].clear() 138 | memory_dict['priority'].clear() 139 | torch.save(memory_dict, self.log + 'memory{}.pt'.format(simnum)) 140 | else: 141 | print("=> Learner: no memory found at ", self.log + 'memory{}.pt'.format(simnum)) 142 | 143 | def sample(self): 144 | priority = (np.array(self.priority) + self.e) ** self.alpha 145 | weight = (len(priority) * priority) ** -self.beta 146 | # weight = map(lambda x: x ** -self.beta, (len(priority) * priority)) 147 | weight /= weight.max() 148 | self.weight = torch.tensor(weight, dtype=torch.float) 149 | priority = torch.tensor(priority, dtype=torch.float) 150 | return torch.utils.data.sampler.WeightedRandomSampler(priority, self.batch_size, replacement=True) 151 | 152 | def main(self): 153 | train_epoch = self.start_epoch 154 | self.save_model(train_epoch) 155 | is_memory = False 156 | while len(self.replay_memory) < self.batch_size * 100: 157 | print("Memory not enough") 158 | for i in range(self.actor_num): 159 | is_memory = os.path.isfile(self.log + '/memory{}.pt'.format(i)) 160 | if is_memory: 161 | self.load_memory(i) 162 | time.sleep(1) 163 | while True: 164 | self.optimizer.zero_grad() 165 | self.mainDQN.train() 166 | self.targetDQN.eval() 167 | x_stack = torch.zeros(0, self.history_size, self.height, self.width).to(self.device) 168 | y_stack = torch.zeros(0, self.action_size).to(self.device) 169 | w = [] 170 | self.beta = min(1, self.beta_init + train_epoch * self.beta_increment) 171 | sample_idx = self.sample() 172 | for idx in sample_idx: 173 | history, action, reward, next_history, end = self.replay_memory[idx] 174 | history = history.to(self.device) 175 | next_history = next_history.to(self.device) 176 | Q = self.mainDQN(history) 177 | if end: 178 | tderror = reward - Q[0, action] 179 | Q[0, action] = reward 180 | else: 181 | qval = self.mainDQN(next_history) 182 | tderror = reward + self.dis * self.targetDQN(next_history)[0, torch.argmax(qval, 1)] - Q[0, action] 183 | Q[0, action] = reward + self.dis * self.targetDQN(next_history)[0, torch.argmax(qval, 1)] 184 | x_stack = torch.cat([x_stack, history.data], 0) 185 | y_stack = torch.cat([y_stack, Q.data], 0) 186 | w.append(self.weight[idx]) 187 | self.priority[idx] = tderror.abs().item() 188 | pred = self.mainDQN(x_stack) 189 | w = torch.tensor(w, dtype=torch.float, device=self.device) 190 | loss = torch.dot(F.smooth_l1_loss(pred, y_stack.detach(), reduce=False).sum(1), w.detach()) 191 | loss.backward() 192 | self.optimizer.step() 193 | loss /= self.batch_size 194 | self.writer.add_scalar('loss', loss.item(), train_epoch) 195 | train_epoch += 1 196 | gc.collect() 197 | if train_epoch % self.log_interval == 0: 198 | print('Train Epoch: {} \tLoss: {}'.format(train_epoch, loss.item())) 199 | self.writer.add_scalar('replay size', len(self.replay_memory), train_epoch) 200 | if (train_epoch // self.log_interval) % args.actor_num == 0: 201 | self.save_model(train_epoch) 202 | self.load_memory((train_epoch // self.log_interval) % args.actor_num) 203 | 204 | if train_epoch % self.update_cycle == 0: 205 | self.update_target_model() 206 | 207 | 208 | if __name__ == "__main__": 209 | learner = Learner() 210 | learner.main() 211 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class DQN(nn.Module): 7 | def __init__(self, history_size, hidden_size, action_size): 8 | super(DQN, self).__init__() 9 | self.history_size = history_size 10 | self.hidden_size = hidden_size 11 | self.action_size = action_size 12 | self.conv1 = nn.Conv2d(self.history_size, self.hidden_size, 8, 4) 13 | self.conv2 = nn.Conv2d(self.hidden_size, self.hidden_size * 2, 4, 2) 14 | self.conv3 = nn.Conv2d(self.hidden_size * 2, self.hidden_size * 2, 3, 1) 15 | self.fc = 84 * self.hidden_size * 2 16 | self.vfc1 = nn.Linear(self.fc, 512) 17 | self.vfc2 = nn.Linear(512, 1) 18 | self.afc1 = nn.Linear(self.fc, 512) 19 | self.afc2 = nn.Linear(512, self.action_size) 20 | 21 | torch.nn.init.normal_(self.conv1.weight, 0, 0.02) 22 | torch.nn.init.normal_(self.conv2.weight, 0, 0.02) 23 | torch.nn.init.normal_(self.conv3.weight, 0, 0.02) 24 | torch.nn.init.normal_(self.vfc1.weight, 0, 0.02) 25 | torch.nn.init.normal_(self.vfc2.weight, 0, 0.02) 26 | torch.nn.init.normal_(self.afc1.weight, 0, 0.02) 27 | torch.nn.init.normal_(self.afc2.weight, 0, 0.02) 28 | 29 | def forward(self, x): 30 | x = F.relu(self.conv1(x)) 31 | x = F.relu(self.conv2(x)) 32 | x = F.relu(self.conv3(x)) 33 | x = x.view(-1, self.fc) 34 | a = F.relu(self.afc1(x)) 35 | a = self.afc2(a) 36 | av = torch.mean(a, 1, True) 37 | av = av.expand_as(a) 38 | v = F.relu(self.vfc1(x)) 39 | v = self.vfc2(v) 40 | v = v.expand_as(a) 41 | x = a - av + v 42 | return x 43 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | opencv-contrib-python 2 | mss 3 | numpy 4 | tensorboardX 5 | pynput --------------------------------------------------------------------------------