├── DQN_nature.py ├── DQN_nips.py ├── README.md ├── ckpt ├── readme.txt └── readme~ ├── database.py ├── emulator.py ├── main.py ├── main_multithread.py ├── pretrained ├── nature_pretrained └── nips_pretrained └── roms └── breakout.bin /DQN_nature.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | import cv2 4 | 5 | class DQN: 6 | def __init__(self,params,name): 7 | self.network_type = 'nature' 8 | self.params = params 9 | self.network_name = name 10 | self.x = tf.placeholder('float32',[None,84,84,4],name=self.network_name + '_x') 11 | self.q_t = tf.placeholder('float32',[None],name=self.network_name + '_q_t') 12 | self.actions = tf.placeholder("float32", [None, params['num_act']],name=self.network_name + '_actions') 13 | self.rewards = tf.placeholder("float32", [None],name=self.network_name + '_rewards') 14 | self.terminals = tf.placeholder("float32", [None],name=self.network_name + '_terminals') 15 | 16 | #conv1 17 | layer_name = 'conv1' ; size = 8 ; channels = 4 ; filters = 32 ; stride = 4 18 | self.w1 = tf.Variable(tf.random_normal([size,size,channels,filters], stddev=0.01),name=self.network_name + '_'+layer_name+'_weights') 19 | self.b1 = tf.Variable(tf.constant(0.1, shape=[filters]),name=self.network_name + '_'+layer_name+'_biases') 20 | self.c1 = tf.nn.conv2d(self.x, self.w1, strides=[1, stride, stride, 1], padding='VALID',name=self.network_name + '_'+layer_name+'_convs') 21 | self.o1 = tf.nn.relu(tf.add(self.c1,self.b1),name=self.network_name + '_'+layer_name+'_activations') 22 | #self.n1 = tf.nn.lrn(self.o1, 4, bias=1.0, alpha=0.001 / 9.0, beta=0.75) 23 | 24 | #conv2 25 | layer_name = 'conv2' ; size = 4 ; channels = 32 ; filters = 64 ; stride = 2 26 | self.w2 = tf.Variable(tf.random_normal([size,size,channels,filters], stddev=0.01),name=self.network_name + '_'+layer_name+'_weights') 27 | self.b2 = tf.Variable(tf.constant(0.1, shape=[filters]),name=self.network_name + '_'+layer_name+'_biases') 28 | self.c2 = tf.nn.conv2d(self.o1, self.w2, strides=[1, stride, stride, 1], padding='VALID',name=self.network_name + '_'+layer_name+'_convs') 29 | self.o2 = tf.nn.relu(tf.add(self.c2,self.b2),name=self.network_name + '_'+layer_name+'_activations') 30 | #self.n2 = tf.nn.lrn(self.o2, 4, bias=1.0, alpha=0.001 / 9.0, beta=0.75) 31 | 32 | #conv3 33 | layer_name = 'conv3' ; size = 3 ; channels = 64 ; filters = 64 ; stride = 1 34 | self.w3 = tf.Variable(tf.random_normal([size,size,channels,filters], stddev=0.01),name=self.network_name + '_'+layer_name+'_weights') 35 | self.b3 = tf.Variable(tf.constant(0.1, shape=[filters]),name=self.network_name + '_'+layer_name+'_biases') 36 | self.c3 = tf.nn.conv2d(self.o2, self.w3, strides=[1, stride, stride, 1], padding='VALID',name=self.network_name + '_'+layer_name+'_convs') 37 | self.o3 = tf.nn.relu(tf.add(self.c3,self.b3),name=self.network_name + '_'+layer_name+'_activations') 38 | #self.n2 = tf.nn.lrn(self.o2, 4, bias=1.0, alpha=0.001 / 9.0, beta=0.75) 39 | 40 | #flat 41 | o3_shape = self.o3.get_shape().as_list() 42 | 43 | #fc3 44 | layer_name = 'fc4' ; hiddens = 512 ; dim = o3_shape[1]*o3_shape[2]*o3_shape[3] 45 | self.o3_flat = tf.reshape(self.o3, [-1,dim],name=self.network_name + '_'+layer_name+'_input_flat') 46 | self.w4 = tf.Variable(tf.random_normal([dim,hiddens], stddev=0.01),name=self.network_name + '_'+layer_name+'_weights') 47 | self.b4 = tf.Variable(tf.constant(0.1, shape=[hiddens]),name=self.network_name + '_'+layer_name+'_biases') 48 | self.ip4 = tf.add(tf.matmul(self.o3_flat,self.w4),self.b4,name=self.network_name + '_'+layer_name+'_ips') 49 | self.o4 = tf.nn.relu(self.ip4,name=self.network_name + '_'+layer_name+'_activations') 50 | 51 | #fc4 52 | layer_name = 'fc5' ; hiddens = params['num_act'] ; dim = 512 53 | self.w5 = tf.Variable(tf.random_normal([dim,hiddens], stddev=0.01),name=self.network_name + '_'+layer_name+'_weights') 54 | self.b5 = tf.Variable(tf.constant(0.1, shape=[hiddens]),name=self.network_name + '_'+layer_name+'_biases') 55 | self.y = tf.add(tf.matmul(self.o4,self.w5),self.b5,name=self.network_name + '_'+layer_name+'_outputs') 56 | 57 | #Q,Cost,Optimizer 58 | self.discount = tf.constant(self.params['discount']) 59 | self.yj = tf.add(self.rewards, tf.mul(1.0-self.terminals, tf.mul(self.discount, self.q_t))) 60 | self.Qxa = tf.mul(self.y,self.actions) 61 | self.Q_pred = tf.reduce_max(self.Qxa, reduction_indices=1) 62 | #self.yjr = tf.reshape(self.yj,(-1,1)) 63 | #self.yjtile = tf.concat(1,[self.yjr,self.yjr,self.yjr,self.yjr]) 64 | #self.yjax = tf.mul(self.yjtile,self.actions) 65 | 66 | #half = tf.constant(0.5) 67 | self.diff = tf.sub(self.yj, self.Q_pred) 68 | if self.params['clip_delta'] > 0 : 69 | self.quadratic_part = tf.minimum(tf.abs(self.diff), tf.constant(self.params['clip_delta'])) 70 | self.linear_part = tf.sub(tf.abs(self.diff),self.quadratic_part) 71 | self.diff_square = 0.5 * tf.pow(self.quadratic_part,2) + self.params['clip_delta']*self.linear_part 72 | 73 | 74 | else: 75 | self.diff_square = tf.mul(tf.constant(0.5),tf.pow(self.diff, 2)) 76 | 77 | if self.params['batch_accumulator'] == 'sum': 78 | self.cost = tf.reduce_sum(self.diff_square) 79 | else: 80 | self.cost = tf.reduce_mean(self.diff_square) 81 | 82 | self.global_step = tf.Variable(0, name='global_step', trainable=False) 83 | self.rmsprop = tf.train.RMSPropOptimizer(self.params['lr'],self.params['rms_decay'],0.0,self.params['rms_eps']).minimize(self.cost,global_step=self.global_step) 84 | 85 | 86 | 87 | 88 | -------------------------------------------------------------------------------- /DQN_nips.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | import cv2 4 | 5 | class DQN: 6 | def __init__(self,params,name): 7 | self.network_type = 'nips' 8 | self.params = params 9 | self.network_name = name 10 | self.x = tf.placeholder('float32',[None,84,84,4],name=self.network_name + '_x') 11 | self.q_t = tf.placeholder('float32',[None],name=self.network_name + '_q_t') 12 | self.actions = tf.placeholder("float32", [None, params['num_act']],name=self.network_name + '_actions') 13 | self.rewards = tf.placeholder("float32", [None],name=self.network_name + '_rewards') 14 | self.terminals = tf.placeholder("float32", [None],name=self.network_name + '_terminals') 15 | 16 | #conv1 17 | layer_name = 'conv1' ; size = 8 ; channels = 4 ; filters = 16 ; stride = 4 18 | self.w1 = tf.Variable(tf.random_normal([size,size,channels,filters], stddev=0.01),name=self.network_name + '_'+layer_name+'_weights') 19 | self.b1 = tf.Variable(tf.constant(0.1, shape=[filters]),name=self.network_name + '_'+layer_name+'_biases') 20 | self.c1 = tf.nn.conv2d(self.x, self.w1, strides=[1, stride, stride, 1], padding='VALID',name=self.network_name + '_'+layer_name+'_convs') 21 | self.o1 = tf.nn.relu(tf.add(self.c1,self.b1),name=self.network_name + '_'+layer_name+'_activations') 22 | #self.n1 = tf.nn.lrn(self.o1, 4, bias=1.0, alpha=0.001 / 9.0, beta=0.75) 23 | 24 | #conv2 25 | layer_name = 'conv2' ; size = 4 ; channels = 16 ; filters = 32 ; stride = 2 26 | self.w2 = tf.Variable(tf.random_normal([size,size,channels,filters], stddev=0.01),name=self.network_name + '_'+layer_name+'_weights') 27 | self.b2 = tf.Variable(tf.constant(0.1, shape=[filters]),name=self.network_name + '_'+layer_name+'_biases') 28 | self.c2 = tf.nn.conv2d(self.o1, self.w2, strides=[1, stride, stride, 1], padding='VALID',name=self.network_name + '_'+layer_name+'_convs') 29 | self.o2 = tf.nn.relu(tf.add(self.c2,self.b2),name=self.network_name + '_'+layer_name+'_activations') 30 | #self.n2 = tf.nn.lrn(self.o2, 4, bias=1.0, alpha=0.001 / 9.0, beta=0.75) 31 | 32 | #flat 33 | o2_shape = self.o2.get_shape().as_list() 34 | 35 | #fc3 36 | layer_name = 'fc3' ; hiddens = 256 ; dim = o2_shape[1]*o2_shape[2]*o2_shape[3] 37 | self.o2_flat = tf.reshape(self.o2, [-1,dim],name=self.network_name + '_'+layer_name+'_input_flat') 38 | self.w3 = tf.Variable(tf.random_normal([dim,hiddens], stddev=0.01),name=self.network_name + '_'+layer_name+'_weights') 39 | self.b3 = tf.Variable(tf.constant(0.1, shape=[hiddens]),name=self.network_name + '_'+layer_name+'_biases') 40 | self.ip3 = tf.add(tf.matmul(self.o2_flat,self.w3),self.b3,name=self.network_name + '_'+layer_name+'_ips') 41 | self.o3 = tf.nn.relu(self.ip3,name=self.network_name + '_'+layer_name+'_activations') 42 | 43 | #fc4 44 | layer_name = 'fc4' ; hiddens = params['num_act'] ; dim = 256 45 | self.w4 = tf.Variable(tf.random_normal([dim,hiddens], stddev=0.01),name=self.network_name + '_'+layer_name+'_weights') 46 | self.b4 = tf.Variable(tf.constant(0.1, shape=[hiddens]),name=self.network_name + '_'+layer_name+'_biases') 47 | self.y = tf.add(tf.matmul(self.o3,self.w4),self.b4,name=self.network_name + '_'+layer_name+'_outputs') 48 | #dummy for nature network 49 | self.w5 = tf.Variable(tf.constant(1.0)) 50 | self.b5 = tf.Variable(tf.constant(1.0)) 51 | #Q,Cost,Optimizer 52 | self.discount = tf.constant(self.params['discount']) 53 | self.yj = tf.add(self.rewards, tf.mul(1.0-self.terminals, tf.mul(self.discount, self.q_t))) 54 | self.Qxa = tf.mul(self.y,self.actions) 55 | self.Q_pred = tf.reduce_max(self.Qxa, reduction_indices=1) 56 | #self.yjr = tf.reshape(self.yj,(-1,1)) 57 | #self.yjtile = tf.concat(1,[self.yjr,self.yjr,self.yjr,self.yjr]) 58 | #self.yjax = tf.mul(self.yjtile,self.actions) 59 | 60 | #half = tf.constant(0.5) 61 | self.diff = tf.sub(self.yj, self.Q_pred) 62 | 63 | if self.params['clip_delta'] > 0 : 64 | self.quadratic_part = tf.minimum(tf.abs(self.diff), tf.constant(self.params['clip_delta'])) 65 | self.linear_part = tf.abs(self.diff) - self.quadratic_part 66 | self.diff_square = 0.5 * tf.pow(self.quadratic_part,2) + self.params['clip_delta']*self.linear_part 67 | 68 | 69 | else: 70 | self.diff_square = tf.mul(tf.constant(0.5),tf.pow(self.diff, 2)) 71 | 72 | if self.params['batch_accumulator'] == 'sum': 73 | self.cost = tf.reduce_sum(self.diff_square) 74 | else: 75 | self.cost = tf.reduce_mean(self.diff_square) 76 | 77 | self.global_step = tf.Variable(0, name='global_step', trainable=False) 78 | self.rmsprop = tf.train.RMSPropOptimizer(self.params['lr'],self.params['rms_decay'],0.0,self.params['rms_eps']).minimize(self.cost,global_step=self.global_step) 79 | 80 | 81 | 82 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | #Deep Q Learning for ATARI using Tensorflow 2 | 3 | (Version 1.0, Last updated :2016.03.17) 4 | 5 | ###1. Introduction 6 | 7 | This is tensorflow implementation of 'Playing Atari with Deep Reinforcement Learning'. 8 | 9 | This is renewal of (https://github.com/mrkulk/deepQN_tensorflow) 10 | 11 | I used mrkulk's emulator interface and replay memory code and I made networks and main module 12 | 13 | It needs 1~3 days to training. 14 | 15 | I'm working on multiprocessing version for fast training. 16 | 17 | You can also check A3C and batch-A3C implementations 18 | 19 | A3C : https://github.com/gliese581gg/A3C_tensorflow 20 | 21 | Batch-A3C : https://github.com/gliese581gg/batch-A3C_tensorflow 22 | 23 | ###2. Usage 24 | 25 | python main_multithread.py (args) 26 | 27 | where args : 28 | 29 | -weight (checkpoint file) : for test trained network or continue training (default : None) 30 | -network_type (nips or nature) : nature version is more complex, need more time for training but has better performance.(default : nips) 31 | -visualize (y or n) : show opencv window for game screen or not (default : y) 32 | -gpu_fraction (0.0~1.0) : fraction of gpu memory to use. Needs roughly 1~1.5 Gb. (default : 0.9) 33 | -db_size (integer) : size of replay memory. Take 8Gb for size 1,000,000 (default : 1,000,000) 34 | -only_eval (y or n) : doing only evaluation without training if set to y (default : n) 35 | 36 | ###3. Testing with pretrained networks 37 | 38 | python main_multithread.py -network_type (nips or nature) -weight pretrained/(nips or nature)_pretrained -only_eval y 39 | 40 | ###4. Requirements: 41 | 42 | - Tensorflow 43 | - opencv2 44 | - Arcade Learning Environment ( https://github.com/mgbellemare/Arcade-Learning-Environment ) 45 | 46 | ###5. Video 47 | 48 | https://www.youtube.com/watch?v=GACcbfUaHwc 49 | 50 | ###6. Changelog 51 | 52 | -2016.03.17 : First upload! 53 | -------------------------------------------------------------------------------- /ckpt/readme.txt: -------------------------------------------------------------------------------- 1 | checkpoint files will be saved here 2 | -------------------------------------------------------------------------------- /ckpt/readme~: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gliese581gg/DQN_tensorflow/d7552c14a5d81712ecbe6365bb4289994e63a6a3/ckpt/readme~ -------------------------------------------------------------------------------- /database.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import gc 3 | import time 4 | import cv2 5 | 6 | class database: 7 | def __init__(self, params): 8 | self.size = params['db_size'] 9 | self.img_scale = params['img_scale'] 10 | self.states = np.zeros([self.size,84,84],dtype='uint8') #image dimensions 11 | self.actions = np.zeros(self.size,dtype='float32') 12 | self.terminals = np.zeros(self.size,dtype='float32') 13 | self.rewards = np.zeros(self.size,dtype='float32') 14 | self.bat_size = params['batch'] 15 | self.bat_s = np.zeros([self.bat_size,84,84,4]) 16 | self.bat_a = np.zeros([self.bat_size]) 17 | self.bat_t = np.zeros([self.bat_size]) 18 | self.bat_n = np.zeros([self.bat_size,84,84,4]) 19 | self.bat_r = np.zeros([self.bat_size]) 20 | 21 | self.counter = 0 #keep track of next empty state 22 | self.flag = False 23 | return 24 | 25 | def get_batches(self): 26 | for i in range(self.bat_size): 27 | idx = 0 28 | while idx < 3 or (idx > self.counter-2 and idx < self.counter+3): 29 | idx = np.random.randint(3,self.get_size()-1) 30 | self.bat_s[i] = np.transpose(self.states[idx-3:idx+1,:,:],(1,2,0))/self.img_scale 31 | self.bat_n[i] = np.transpose(self.states[idx-2:idx+2,:,:],(1,2,0))/self.img_scale 32 | self.bat_a[i] = self.actions[idx] 33 | self.bat_t[i] = self.terminals[idx] 34 | self.bat_r[i] = self.rewards[idx] 35 | #self.bat_s[0] = np.transpose(self.states[10:14,:,:],(1,2,0))/self.img_scale 36 | #self.bat_n[0] = np.transpose(self.states[11:15,:,:],(1,2,0))/self.img_scale 37 | #self.bat_a[0] = self.actions[13] 38 | #self.bat_t[0] = self.terminals[13] 39 | #self.bat_r[0] = self.rewards[13] 40 | 41 | return self.bat_s,self.bat_a,self.bat_t,self.bat_n,self.bat_r 42 | 43 | def insert(self, prevstate_proc,reward,action,terminal): 44 | self.states[self.counter] = prevstate_proc 45 | self.rewards[self.counter] = reward 46 | self.actions[self.counter] = action 47 | self.terminals[self.counter] = terminal 48 | #update counter 49 | self.counter += 1 50 | if self.counter >= self.size: 51 | self.flag = True 52 | self.counter = 0 53 | return 54 | 55 | def get_size(self): 56 | if self.flag == False: 57 | return self.counter 58 | else: 59 | return self.size 60 | 61 | -------------------------------------------------------------------------------- /emulator.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import copy 3 | import sys 4 | from ale_python_interface import ALEInterface 5 | import cv2 6 | import time 7 | #import scipy.misc 8 | 9 | class emulator: 10 | def __init__(self, rom_name, vis,windowname='preview'): 11 | self.ale = ALEInterface() 12 | self.max_frames_per_episode = self.ale.getInt("max_num_frames_per_episode"); 13 | self.ale.setInt("random_seed",123) 14 | self.ale.setInt("frame_skip",4) 15 | self.ale.loadROM('roms/' + rom_name ) 16 | self.legal_actions = self.ale.getMinimalActionSet() 17 | self.action_map = dict() 18 | self.windowname = windowname 19 | for i in range(len(self.legal_actions)): 20 | self.action_map[self.legal_actions[i]] = i 21 | 22 | # print(self.legal_actions) 23 | self.screen_width,self.screen_height = self.ale.getScreenDims() 24 | print("width/height: " +str(self.screen_width) + "/" + str(self.screen_height)) 25 | self.vis = vis 26 | if vis: 27 | cv2.startWindowThread() 28 | cv2.namedWindow(self.windowname) 29 | 30 | def get_image(self): 31 | numpy_surface = np.zeros(self.screen_height*self.screen_width*3, dtype=np.uint8) 32 | self.ale.getScreenRGB(numpy_surface) 33 | image = np.reshape(numpy_surface, (self.screen_height, self.screen_width, 3)) 34 | return image 35 | 36 | def newGame(self): 37 | self.ale.reset_game() 38 | return self.get_image() 39 | 40 | def next(self, action_indx): 41 | reward = self.ale.act(action_indx) 42 | nextstate = self.get_image() 43 | # scipy.misc.imsave('test.png',nextstate) 44 | if self.vis: 45 | cv2.imshow(self.windowname,nextstate) 46 | return nextstate, reward, self.ale.game_over() 47 | 48 | 49 | 50 | if __name__ == "__main__": 51 | engine = emulator('breakout.bin',True) 52 | engine.next(0) 53 | time.sleep(5) 54 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | from database import * 2 | from emulator import * 3 | import tensorflow as tf 4 | import numpy as np 5 | import time 6 | from ale_python_interface import ALEInterface 7 | import cv2 8 | from scipy import misc 9 | import gc #garbage colloector 10 | import thread 11 | 12 | gc.enable() 13 | 14 | params = { 15 | 'visualize' : True, 16 | 'network_type':'nips', 17 | 'ckpt_file':None, 18 | 'steps_per_epoch': 50000, 19 | 'num_epochs': 100, 20 | 'eval_freq':50000, 21 | 'steps_per_eval':10000, 22 | 'copy_freq' : 10000, 23 | 'disp_freq':10000, 24 | 'save_interval':10000, 25 | 'db_size': 1000000, 26 | 'batch': 32, 27 | 'num_act': 0, 28 | 'input_dims' : [210, 160, 3], 29 | 'input_dims_proc' : [84, 84, 4], 30 | 'learning_interval': 1, 31 | 'eps': 1.0, 32 | 'eps_step':1000000, 33 | 'eps_min' : 0.1, 34 | 'eps_eval' : 0.05, 35 | 'discount': 0.95, 36 | 'lr': 0.0002, 37 | 'rms_decay':0.99, 38 | 'rms_eps':1e-6, 39 | 'train_start':100, 40 | 'img_scale':255.0, 41 | 'clip_delta' : 0, #nature : 1 42 | 'gpu_fraction' : 0.25, 43 | 'batch_accumulator':'mean', 44 | 'record_eval' : True, 45 | 'only_eval' : 'n' 46 | } 47 | 48 | class deep_atari: 49 | def __init__(self,params): 50 | print 'Initializing Module...' 51 | self.params = params 52 | 53 | self.gpu_config = tf.ConfigProto(gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=self.params['gpu_fraction'])) 54 | 55 | self.sess = tf.Session(config=self.gpu_config) 56 | self.DB = database(self.params) 57 | self.engine = emulator(rom_name='breakout.bin', vis=self.params['visualize'],windowname=self.params['network_type']+'_preview') 58 | self.params['num_act'] = len(self.engine.legal_actions) 59 | self.build_net() 60 | self.training = True 61 | 62 | def build_net(self): 63 | print 'Building QNet and targetnet...' 64 | self.qnet = DQN(self.params,'qnet') 65 | self.targetnet = DQN(self.params,'targetnet') 66 | self.sess.run(tf.initialize_all_variables()) 67 | saver_dict = {'qw1':self.qnet.w1,'qb1':self.qnet.b1, 68 | 'qw2':self.qnet.w2,'qb2':self.qnet.b2, 69 | 'qw3':self.qnet.w3,'qb3':self.qnet.b3, 70 | 'qw4':self.qnet.w4,'qb4':self.qnet.b4, 71 | 'qw5':self.qnet.w5,'qb5':self.qnet.b5, 72 | 'tw1':self.targetnet.w1,'tb1':self.targetnet.b1, 73 | 'tw2':self.targetnet.w2,'tb2':self.targetnet.b2, 74 | 'tw3':self.targetnet.w3,'tb3':self.targetnet.b3, 75 | 'tw4':self.targetnet.w4,'tb4':self.targetnet.b4, 76 | 'tw5':self.targetnet.w5,'tb5':self.targetnet.b5, 77 | 'step':self.qnet.global_step} 78 | self.saver = tf.train.Saver(saver_dict) 79 | #self.saver = tf.train.Saver() 80 | self.cp_ops = [ 81 | self.targetnet.w1.assign(self.qnet.w1),self.targetnet.b1.assign(self.qnet.b1), 82 | self.targetnet.w2.assign(self.qnet.w2),self.targetnet.b2.assign(self.qnet.b2), 83 | self.targetnet.w3.assign(self.qnet.w3),self.targetnet.b3.assign(self.qnet.b3), 84 | self.targetnet.w4.assign(self.qnet.w4),self.targetnet.b4.assign(self.qnet.b4), 85 | self.targetnet.w5.assign(self.qnet.w5),self.targetnet.b5.assign(self.qnet.b5)] 86 | 87 | self.sess.run(self.cp_ops) 88 | 89 | if self.params['ckpt_file'] is not None: 90 | print 'loading checkpoint : ' + self.params['ckpt_file'] 91 | self.saver.restore(self.sess,self.params['ckpt_file']) 92 | temp_train_cnt = self.sess.run(self.qnet.global_step) 93 | temp_step = temp_train_cnt * self.params['learning_interval'] 94 | print 'Continue from' 95 | print ' -> Steps : ' + str(temp_step) 96 | print ' -> Minibatch update : ' + str(temp_train_cnt) 97 | 98 | 99 | def start(self): 100 | self.reset_game() 101 | self.step = 0 102 | self.reset_statistics('all') 103 | self.train_cnt = self.sess.run(self.qnet.global_step) 104 | 105 | if self.train_cnt > 0 : 106 | self.step = self.train_cnt * self.params['learning_interval'] 107 | try: 108 | self.log_train = open('log_training_'+self.params['network_type']+'.csv','a') 109 | except: 110 | self.log_train = open('log_training_'+self.params['network_type']+'.csv','w') 111 | self.log_train.write('step,epoch,train_cnt,avg_reward,avg_q,epsilon,time\n') 112 | 113 | try: 114 | self.log_eval = open('log_eval_'+self.params['network_type']+'.csv','a') 115 | except: 116 | self.log_eval = open('log_eval_'+self.params['network_type']+'.csv','w') 117 | self.log_eval.write('step,epoch,train_cnt,avg_reward,avg_q,epsilon,time\n') 118 | else: 119 | self.log_train = open('log_training_'+self.params['network_type']+'.csv','w') 120 | self.log_train.write('step,epoch,train_cnt,avg_reward,avg_q,epsilon,time\n') 121 | self.log_eval = open('log_eval_'+self.params['network_type']+'.csv','w') 122 | self.log_eval.write('step,epoch,train_cnt,avg_reward,avg_q,epsilon,time\n') 123 | 124 | self.s = time.time() 125 | print self.params 126 | print 'Start training!' 127 | print 'Collecting replay memory for ' + str(self.params['train_start']) + ' steps' 128 | 129 | while self.step < (self.params['steps_per_epoch'] * self.params['num_epochs'] * self.params['learning_interval'] + self.params['train_start']): 130 | if self.training : 131 | if self.DB.get_size() >= self.params['train_start'] : self.step += 1 ; self.steps_train += 1 132 | else : self.step_eval += 1 133 | if self.state_gray_old is not None and self.training: 134 | self.DB.insert(self.state_gray_old[26:110,:],self.reward_scaled,self.action_idx,self.terminal) 135 | 136 | if self.training and self.params['copy_freq'] > 0 and self.step % self.params['copy_freq'] == 0 and self.DB.get_size() > self.params['train_start']: 137 | print '&&& Copying Qnet to targetnet\n' 138 | self.sess.run(self.cp_ops) 139 | 140 | 141 | if self.training and self.step % self.params['learning_interval'] == 0 and self.DB.get_size() > self.params['train_start'] : 142 | bat_s,bat_a,bat_t,bat_n,bat_r = self.DB.get_batches() 143 | bat_a = self.get_onehot(bat_a) 144 | 145 | if self.params['copy_freq'] > 0 : 146 | feed_dict={self.targetnet.x: bat_n} 147 | q_t = self.sess.run(self.targetnet.y,feed_dict=feed_dict) 148 | else: 149 | feed_dict={self.qnet.x: bat_n} 150 | q_t = self.sess.run(self.qnet.y,feed_dict=feed_dict) 151 | 152 | q_t = np.amax(q_t,axis=1) 153 | 154 | feed_dict={self.qnet.x: bat_s, self.qnet.q_t: q_t, self.qnet.actions: bat_a, self.qnet.terminals:bat_t, self.qnet.rewards: bat_r} 155 | 156 | 157 | _,self.train_cnt,self.cost = self.sess.run([self.qnet.rmsprop,self.qnet.global_step,self.qnet.cost],feed_dict=feed_dict) 158 | 159 | 160 | self.total_cost_train += np.sqrt(self.cost) 161 | self.train_cnt_for_disp += 1 162 | 163 | if self.training : 164 | self.params['eps'] = max(self.params['eps_min'],1.0 - float(self.train_cnt * self.params['learning_interval'])/float(self.params['eps_step'])) 165 | else: 166 | self.params['eps'] = 0.05 167 | 168 | if self.DB.get_size() > self.params['train_start'] and self.step % self.params['save_interval'] == 0 and self.training: 169 | save_idx = self.train_cnt 170 | self.saver.save(self.sess,'ckpt/model_'+self.params['network_type']+'_'+str(save_idx)) 171 | sys.stdout.write('$$$ Model saved : %s\n\n' % ('ckpt/model_'+self.params['network_type']+'_'+str(save_idx))) 172 | sys.stdout.flush() 173 | 174 | if self.training and self.step > 0 and self.step % self.params['disp_freq'] == 0 and self.DB.get_size() > self.params['train_start'] : 175 | self.write_log_train() 176 | 177 | if self.training and self.step > 0 and self.step % self.params['eval_freq'] == 0 and self.DB.get_size() > self.params['train_start'] : 178 | 179 | self.reset_game() 180 | if self.step % self.params['steps_per_epoch'] == 0 : self.reset_statistics('all') 181 | else: self.reset_statistics('eval') 182 | self.training = False 183 | #TODO : add video recording 184 | continue 185 | if self.training and self.step > 0 and self.step % self.params['steps_per_epoch'] == 0 and self.DB.get_size() > self.params['train_start']: 186 | self.reset_game() 187 | self.reset_statistics('all') 188 | #self.training = False 189 | continue 190 | 191 | if not self.training and self.step_eval >= self.params['steps_per_eval'] : 192 | self.write_log_eval() 193 | self.reset_game() 194 | self.reset_statistics('eval') 195 | self.training = True 196 | continue 197 | 198 | 199 | if self.terminal : 200 | self.reset_game() 201 | if self.training : 202 | self.num_epi_train += 1 203 | self.total_reward_train += self.epi_reward_train 204 | self.epi_reward_train = 0 205 | else : 206 | self.num_epi_eval += 1 207 | self.total_reward_eval += self.epi_reward_eval 208 | self.epi_reward_eval = 0 209 | continue 210 | 211 | self.action_idx,self.action, self.maxQ = self.select_action(self.state_proc) 212 | self.state, self.reward, self.terminal = self.engine.next(self.action) 213 | self.reward_scaled = self.reward // max(1,abs(self.reward)) 214 | if self.training : self.epi_reward_train += self.reward ; self.total_Q_train += self.maxQ 215 | else : self.epi_reward_eval += self.reward ; self.total_Q_eval += self.maxQ 216 | 217 | self.state_gray_old = np.copy(self.state_gray) 218 | self.state_proc[:,:,0:3] = self.state_proc[:,:,1:4] 219 | self.state_resized = cv2.resize(self.state,(84,110)) 220 | self.state_gray = cv2.cvtColor(self.state_resized, cv2.COLOR_BGR2GRAY) 221 | self.state_proc[:,:,3] = self.state_gray[26:110,:]/self.params['img_scale'] 222 | 223 | #TODO : add video recording 224 | 225 | def reset_game(self): 226 | self.state_proc = np.zeros((84,84,4)); self.action = -1; self.terminal = False; self.reward = 0 227 | self.state = self.engine.newGame() 228 | self.state_resized = cv2.resize(self.state,(84,110)) 229 | self.state_gray = cv2.cvtColor(self.state_resized, cv2.COLOR_BGR2GRAY) 230 | self.state_gray_old = None 231 | self.state_proc[:,:,3] = self.state_gray[26:110,:]/self.params['img_scale'] 232 | 233 | def reset_statistics(self,mode): 234 | if mode == 'all': 235 | self.epi_reward_train = 0 236 | self.epi_Q_train = 0 237 | self.num_epi_train = 0 238 | self.total_reward_train = 0 239 | self.total_Q_train = 0 240 | self.total_cost_train = 0 241 | self.steps_train = 0 242 | self.train_cnt_for_disp = 0 243 | self.step_eval = 0 244 | self.epi_reward_eval = 0 245 | self.epi_Q_eval = 0 246 | self.num_epi_eval = 0 247 | self.total_reward_eval = 0 248 | self.total_Q_eval = 0 249 | 250 | 251 | def write_log_train(self): 252 | sys.stdout.write('### Training (Step : %d , Minibatch update : %d , Epoch %d)\n' % (self.step,self.train_cnt,self.step//self.params['steps_per_epoch'] )) 253 | 254 | sys.stdout.write(' Num.Episodes : %d , Avg.reward : %.3f , Avg.Q : %.3f, Avg.loss : %.3f\n' % (self.num_epi_train,float(self.total_reward_train)/max(1,self.num_epi_train),float(self.total_Q_train)/max(1,self.steps_train),self.total_cost_train/max(1,self.train_cnt_for_disp))) 255 | sys.stdout.write(' Epsilon : %.3f , Elapsed time : %.1f\n\n' % (self.params['eps'],time.time()-self.s)) 256 | sys.stdout.flush() 257 | self.log_train.write(str(self.step) + ',' + str(self.step//self.params['steps_per_epoch']) + ',' + str(self.train_cnt) + ',') 258 | self.log_train.write(str(float(self.total_reward_train)/max(1,self.num_epi_train)) +','+ str(float(self.total_Q_train)/max(1,self.steps_train)) +',') 259 | self.log_train.write(str(self.params['eps']) +','+ str(time.time()-self.s) + '\n') 260 | self.log_train.flush() 261 | 262 | def write_log_eval(self): 263 | sys.stdout.write('@@@ Evaluation (Step : %d , Minibatch update : %d , Epoch %d)\n' % (self.step,self.train_cnt,self.step//self.params['steps_per_epoch'] )) 264 | sys.stdout.write(' Num.Episodes : %d , Avg.reward : %.3f , Avg.Q : %.3f\n' % (self.num_epi_eval,float(self.total_reward_eval)/max(1,self.num_epi_eval),float(self.total_Q_eval)/max(1,self.params['steps_per_eval']))) 265 | sys.stdout.write(' Epsilon : %.3f , Elapsed time : %.1f\n\n' % (self.params['eps'],time.time()-self.s)) 266 | sys.stdout.flush() 267 | self.log_eval.write(str(self.step) + ',' + str(self.step//self.params['steps_per_epoch']) + ',' + str(self.train_cnt) + ',') 268 | self.log_eval.write(str(float(self.total_reward_eval)/max(1,self.num_epi_eval)) +','+ str(float(self.total_Q_eval)/max(1,self.params['steps_per_eval'])) +',') 269 | self.log_eval.write(str(self.params['eps']) +','+ str(time.time()-self.s) + '\n') 270 | self.log_eval.flush() 271 | 272 | def select_action(self,st): 273 | if np.random.rand() > self.params['eps']: 274 | #greedy with random tie-breaking 275 | Q_pred = self.sess.run(self.qnet.y, feed_dict = {self.qnet.x: np.reshape(st, (1,84,84,4))})[0] 276 | a_winner = np.argwhere(Q_pred == np.amax(Q_pred)) 277 | if len(a_winner) > 1: 278 | act_idx = a_winner[np.random.randint(0, len(a_winner))][0] 279 | return act_idx,self.engine.legal_actions[act_idx], np.amax(Q_pred) 280 | else: 281 | act_idx = a_winner[0][0] 282 | return act_idx,self.engine.legal_actions[act_idx], np.amax(Q_pred) 283 | else: 284 | #random 285 | act_idx = np.random.randint(0,len(self.engine.legal_actions)) 286 | Q_pred = self.sess.run(self.qnet.y, feed_dict = {self.qnet.x: np.reshape(st, (1,84,84,4))})[0] 287 | return act_idx,self.engine.legal_actions[act_idx], Q_pred[act_idx] 288 | 289 | def get_onehot(self,actions): 290 | actions_onehot = np.zeros((self.params['batch'], self.params['num_act'])) 291 | 292 | for i in range(self.params['batch']): 293 | actions_onehot[i,actions[i]] = 1 294 | return actions_onehot 295 | 296 | 297 | if __name__ == "__main__": 298 | dict_items = params.items() 299 | for i in range(1,len(sys.argv),2): 300 | if sys.argv[i] == '-weight' :params['ckpt_file'] = sys.argv[i+1] 301 | elif sys.argv[i] == '-network_type' :params['network_type'] = sys.argv[i+1] 302 | elif sys.argv[i] == '-visualize' : 303 | if sys.argv[i+1] == 'y' : params['visualize'] = True 304 | elif sys.argv[i+1] == 'n' : params['visualize'] = False 305 | else: 306 | print 'Invalid visualization argument!!! Available arguments are' 307 | print ' y or n' 308 | raise ValueError() 309 | elif sys.argv[i] == '-gpu_fraction' : params['gpu_fraction'] = float(sys.argv[i+1]) 310 | elif sys.argv[i] == '-db_size' : params['db_size'] = int(sys.argv[i+1]) 311 | elif sys.argv[i] == '-only_eval' : params['only_eval'] = sys.argv[i+1] 312 | else : 313 | print 'Invalid arguments!!! Available arguments are' 314 | print ' -weight (filename)' 315 | print ' -network_type (nips or nature)' 316 | print ' -visualize (y or n)' 317 | print ' -gpu_fraction (0.1~0.9)' 318 | print ' -db_size (integer)' 319 | raise ValueError() 320 | if params['network_type'] == 'nips': 321 | from DQN_nips import * 322 | elif params['network_type'] == 'nature': 323 | from DQN_nature import * 324 | params['steps_per_epoch']= 200000 325 | params['eval_freq'] = 100000 326 | params['steps_per_eval'] = 10000 327 | params['copy_freq'] = 10000 328 | params['disp_freq'] = 20000 329 | params['save_interval'] = 20000 330 | params['learning_interval'] = 1 331 | params['discount'] = 0.99 332 | params['lr'] = 0.00025 333 | params['rms_decay'] = 0.95 334 | params['rms_eps']=0.01 335 | params['clip_delta'] = 1.0 336 | params['train_start']=50000 337 | params['batch_accumulator'] = 'sum' 338 | params['eps_step'] = 1000000 339 | params['num_epochs'] = 250 340 | params['batch'] = 32 341 | else : 342 | print 'Invalid network type! Available network types are' 343 | print ' nips or nature' 344 | raise ValueError() 345 | 346 | if params['only_eval'] == 'y' : only_eval = True 347 | elif params['only_eval'] == 'n' : only_eval = False 348 | else : 349 | print 'Invalid only_eval option! Available options are' 350 | print ' y or n' 351 | raise ValueError() 352 | 353 | if only_eval: 354 | params['eval_freq'] = 1 355 | params['train_start'] = 100 356 | 357 | da = deep_atari(params) 358 | da.start() 359 | -------------------------------------------------------------------------------- /main_multithread.py: -------------------------------------------------------------------------------- 1 | from database import * 2 | from emulator import * 3 | import tensorflow as tf 4 | import numpy as np 5 | import time 6 | from ale_python_interface import ALEInterface 7 | import cv2 8 | from scipy import misc 9 | import gc #garbage colloector 10 | import thread 11 | 12 | gc.enable() 13 | 14 | params = { 15 | 'visualize' : True, 16 | 'network_type':'nips', 17 | 'ckpt_file':None, 18 | 'steps_per_epoch': 50000, 19 | 'num_epochs': 250, 20 | 'eval_freq':50000, 21 | 'steps_per_eval':10000, 22 | 'copy_freq' : 10000, 23 | 'disp_freq':10000, 24 | 'save_interval':10000, 25 | 'db_size': 1000000, 26 | 'batch': 32, 27 | 'num_act': 0, 28 | 'input_dims' : [210, 160, 3], 29 | 'input_dims_proc' : [84, 84, 4], 30 | 'learning_interval': 1, 31 | 'eps': 1.0, 32 | 'eps_step':1000000, 33 | 'eps_min' : 0.1, 34 | 'eps_eval' : 0.00, 35 | 'discount': 0.95, 36 | 'lr': 0.0002, 37 | 'rms_decay':0.99, 38 | 'rms_eps':1e-6, 39 | 'train_start':100, 40 | 'img_scale':255.0, 41 | 'clip_delta' : 0, #nature : 1 42 | 'gpu_fraction' : 0.9, 43 | 'batch_accumulator':'mean', 44 | #'num_threads' : 4, 45 | 'record_eval' : True, 46 | 'only_eval' : 'n' 47 | } 48 | 49 | class deep_atari: 50 | def __init__(self,params): 51 | print 'Initializing Module...' 52 | self.params = params 53 | 54 | self.gpu_config = tf.ConfigProto(gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=self.params['gpu_fraction'])) 55 | 56 | self.sess = tf.Session(config=self.gpu_config) 57 | self.DB = database(self.params) 58 | self.engine = emulator(rom_name='breakout.bin', vis=self.params['visualize'],windowname=self.params['network_type']+'_preview') 59 | self.params['num_act'] = len(self.engine.legal_actions) 60 | self.build_net() 61 | self.training = True 62 | self.lock = thread.allocate_lock() 63 | 64 | def build_net(self): 65 | print 'Building QNet and targetnet...' 66 | self.qnet = DQN(self.params,'qnet') 67 | self.targetnet = DQN(self.params,'targetnet') 68 | self.sess.run(tf.initialize_all_variables()) 69 | saver_dict = {'qw1':self.qnet.w1,'qb1':self.qnet.b1, 70 | 'qw2':self.qnet.w2,'qb2':self.qnet.b2, 71 | 'qw3':self.qnet.w3,'qb3':self.qnet.b3, 72 | 'qw4':self.qnet.w4,'qb4':self.qnet.b4, 73 | 'qw5':self.qnet.w5,'qb5':self.qnet.b5, 74 | 'tw1':self.targetnet.w1,'tb1':self.targetnet.b1, 75 | 'tw2':self.targetnet.w2,'tb2':self.targetnet.b2, 76 | 'tw3':self.targetnet.w3,'tb3':self.targetnet.b3, 77 | 'tw4':self.targetnet.w4,'tb4':self.targetnet.b4, 78 | 'tw5':self.targetnet.w5,'tb5':self.targetnet.b5, 79 | 'step':self.qnet.global_step} 80 | self.saver = tf.train.Saver(saver_dict) 81 | #self.saver = tf.train.Saver() 82 | self.cp_ops = [ 83 | self.targetnet.w1.assign(self.qnet.w1),self.targetnet.b1.assign(self.qnet.b1), 84 | self.targetnet.w2.assign(self.qnet.w2),self.targetnet.b2.assign(self.qnet.b2), 85 | self.targetnet.w3.assign(self.qnet.w3),self.targetnet.b3.assign(self.qnet.b3), 86 | self.targetnet.w4.assign(self.qnet.w4),self.targetnet.b4.assign(self.qnet.b4), 87 | self.targetnet.w5.assign(self.qnet.w5),self.targetnet.b5.assign(self.qnet.b5)] 88 | 89 | self.sess.run(self.cp_ops) 90 | 91 | if self.params['ckpt_file'] is not None: 92 | print 'loading checkpoint : ' + self.params['ckpt_file'] 93 | self.saver.restore(self.sess,self.params['ckpt_file']) 94 | temp_train_cnt = self.sess.run(self.qnet.global_step) 95 | temp_step = temp_train_cnt * self.params['learning_interval'] 96 | print 'Continue from' 97 | print ' -> Steps : ' + str(temp_step) 98 | print ' -> Minibatch update : ' + str(temp_train_cnt) 99 | 100 | def do_training(self,th_idx): 101 | #print 'Training thread ' + str(th_idx) + ' initiated' 102 | print 'Training thread initiated' 103 | while True: 104 | if self.training and self.step % self.params['learning_interval'] == 0 and self.DB.get_size() > self.params['train_start'] : 105 | bat_s,bat_a,bat_t,bat_n,bat_r = self.DB.get_batches() 106 | bat_a = self.get_onehot(bat_a) 107 | 108 | if self.params['copy_freq'] > 0 : 109 | feed_dict={self.targetnet.x: bat_n} 110 | self.lock.acquire() 111 | q_t = self.sess.run(self.targetnet.y,feed_dict=feed_dict) 112 | self.lock.release() 113 | else: 114 | feed_dict={self.qnet.x: bat_n} 115 | self.lock.acquire() 116 | q_t = self.sess.run(self.qnet.y,feed_dict=feed_dict) 117 | self.lock.release() 118 | q_t = np.amax(q_t,axis=1) 119 | #print str(th_idx) + '_qt_old : ' 120 | #print q_t 121 | 122 | feed_dict={self.qnet.x: bat_s, self.qnet.q_t: q_t, self.qnet.actions: bat_a, self.qnet.terminals:bat_t, self.qnet.rewards: bat_r} 123 | #print str(th_idx) + '_old : ' 124 | #print self.sess.run(self.qnet.b4) 125 | self.lock.acquire() 126 | _,self.train_cnt,self.cost = self.sess.run([self.qnet.rmsprop,self.qnet.global_step,self.qnet.cost],feed_dict=feed_dict) 127 | #print str(th_idx) + '_new : ' 128 | #print self.sess.run(self.qnet.b4) 129 | self.lock.release() 130 | self.total_cost_train += np.sqrt(self.cost) 131 | self.train_cnt_for_disp += 1 132 | 133 | 134 | 135 | def start(self): 136 | self.reset_game() 137 | self.step = 0 138 | self.reset_statistics('all') 139 | self.train_cnt = self.sess.run(self.qnet.global_step) 140 | 141 | if self.train_cnt > 0 : 142 | self.step = self.train_cnt * self.params['learning_interval'] 143 | try: 144 | self.log_train = open('log_training_'+self.params['network_type']+'.csv','a') 145 | except: 146 | self.log_train = open('log_training_'+self.params['network_type']+'.csv','w') 147 | self.log_train.write('step,epoch,train_cnt,avg_reward,avg_q,epsilon,time\n') 148 | 149 | try: 150 | self.log_eval = open('log_eval_'+self.params['network_type']+'.csv','a') 151 | except: 152 | self.log_eval = open('log_eval_'+self.params['network_type']+'.csv','w') 153 | self.log_eval.write('step,epoch,train_cnt,avg_reward,avg_q,epsilon,time\n') 154 | else: 155 | self.log_train = open('log_training_'+self.params['network_type']+'.csv','w') 156 | self.log_train.write('step,epoch,train_cnt,avg_reward,avg_q,epsilon,time\n') 157 | self.log_eval = open('log_eval_'+self.params['network_type']+'.csv','w') 158 | self.log_eval.write('step,epoch,train_cnt,avg_reward,avg_q,epsilon,time\n') 159 | 160 | #for ii in range(self.params['num_threads']): 161 | thread.start_new_thread(self.do_training,(0,)) 162 | time.sleep(1.5) 163 | self.s = time.time() 164 | print self.params 165 | print 'Start training!' 166 | print 'Collecting replay memory for ' + str(self.params['train_start']) + ' steps' 167 | 168 | while self.step < (self.params['steps_per_epoch'] * self.params['num_epochs'] * self.params['learning_interval'] + self.params['train_start']): 169 | if self.training : 170 | if self.DB.get_size() >= self.params['train_start'] : self.step += 1 ; self.steps_train += 1 171 | else : self.step_eval += 1 172 | if self.state_gray_old is not None and self.training: 173 | self.DB.insert(self.state_gray_old[26:110,:],self.reward_scaled,self.action_idx,self.terminal) 174 | 175 | if self.training and self.params['copy_freq'] > 0 and self.step % self.params['copy_freq'] == 0 and self.DB.get_size() > self.params['train_start']: 176 | print '&&& Copying Qnet to targetnet\n' 177 | self.lock.acquire() 178 | self.sess.run(self.cp_ops) 179 | self.lock.release() 180 | 181 | if self.training : 182 | self.params['eps'] = max(self.params['eps_min'],1.0 - float(self.step)/float(self.params['eps_step'])) 183 | else: 184 | self.params['eps'] = 0.05 185 | 186 | if self.DB.get_size() > self.params['train_start'] and self.step % self.params['save_interval'] == 0 and self.training: 187 | save_idx = self.train_cnt 188 | self.lock.acquire() 189 | self.saver.save(self.sess,'ckpt/model_'+self.params['network_type']+'_'+str(save_idx)) 190 | self.lock.release() 191 | sys.stdout.write('$$$ Model saved : %s\n\n' % ('ckpt/model_'+self.params['network_type']+'_'+str(save_idx))) 192 | sys.stdout.flush() 193 | 194 | if self.training and self.step > 0 and self.step % self.params['disp_freq'] == 0 and self.DB.get_size() > self.params['train_start'] : 195 | self.write_log_train() 196 | 197 | if self.training and self.step > 0 and self.step % self.params['eval_freq'] == 0 and self.DB.get_size() > self.params['train_start'] : 198 | 199 | self.reset_game() 200 | if self.step % self.params['steps_per_epoch'] == 0 : self.reset_statistics('all') 201 | else: self.reset_statistics('eval') 202 | self.training = False 203 | #TODO : add video recording 204 | continue 205 | if self.training and self.step > 0 and self.step % self.params['steps_per_epoch'] == 0 and self.DB.get_size() > self.params['train_start']: 206 | self.reset_game() 207 | self.reset_statistics('all') 208 | #self.training = False 209 | continue 210 | 211 | if not self.training and self.step_eval >= self.params['steps_per_eval'] : 212 | self.write_log_eval() 213 | self.reset_game() 214 | self.reset_statistics('eval') 215 | self.training = True 216 | continue 217 | 218 | 219 | if self.terminal : 220 | self.reset_game() 221 | if self.training : 222 | self.num_epi_train += 1 223 | self.total_reward_train += self.epi_reward_train 224 | self.epi_reward_train = 0 225 | else : 226 | self.num_epi_eval += 1 227 | self.total_reward_eval += self.epi_reward_eval 228 | self.epi_reward_eval = 0 229 | continue 230 | 231 | self.action_idx,self.action, self.maxQ = self.select_action(self.state_proc) 232 | self.state, self.reward, self.terminal = self.engine.next(self.action) 233 | self.reward_scaled = self.reward // max(1,abs(self.reward)) 234 | if self.training : self.epi_reward_train += self.reward ; self.total_Q_train += self.maxQ 235 | else : self.epi_reward_eval += self.reward ; self.total_Q_eval += self.maxQ 236 | 237 | self.state_gray_old = np.copy(self.state_gray) 238 | self.state_proc[:,:,0:3] = self.state_proc[:,:,1:4] 239 | self.state_resized = cv2.resize(self.state,(84,110)) 240 | self.state_gray = cv2.cvtColor(self.state_resized, cv2.COLOR_BGR2GRAY) 241 | self.state_proc[:,:,3] = self.state_gray[26:110,:]/self.params['img_scale'] 242 | 243 | #TODO : add video recording 244 | if self.params['only_eval'] == 'y': 245 | time.sleep(0.01) 246 | 247 | def reset_game(self): 248 | self.state_proc = np.zeros((84,84,4)); self.action = -1; self.terminal = False; self.reward = 0 249 | self.state = self.engine.newGame() 250 | self.state_resized = cv2.resize(self.state,(84,110)) 251 | self.state_gray = cv2.cvtColor(self.state_resized, cv2.COLOR_BGR2GRAY) 252 | self.state_gray_old = None 253 | self.state_proc[:,:,3] = self.state_gray[26:110,:]/self.params['img_scale'] 254 | 255 | def reset_statistics(self,mode): 256 | if mode == 'all': 257 | self.epi_reward_train = 0 258 | self.epi_Q_train = 0 259 | self.num_epi_train = 0 260 | self.total_reward_train = 0 261 | self.total_Q_train = 0 262 | self.total_cost_train = 0 263 | self.steps_train = 0 264 | self.train_cnt_for_disp = 0 265 | self.step_eval = 0 266 | self.epi_reward_eval = 0 267 | self.epi_Q_eval = 0 268 | self.num_epi_eval = 0 269 | self.total_reward_eval = 0 270 | self.total_Q_eval = 0 271 | 272 | 273 | def write_log_train(self): 274 | sys.stdout.write('### Training (Step : %d , Minibatch update : %d , Epoch %d)\n' % (self.step,self.train_cnt,self.step//self.params['steps_per_epoch'] )) 275 | 276 | sys.stdout.write(' Num.Episodes : %d , Avg.reward : %.3f , Avg.Q : %.3f, Avg.loss : %.3f\n' % (self.num_epi_train,float(self.total_reward_train)/max(1,self.num_epi_train),float(self.total_Q_train)/max(1,self.steps_train),self.total_cost_train/max(1,self.train_cnt_for_disp))) 277 | sys.stdout.write(' Epsilon : %.3f , Elapsed time : %.1f\n\n' % (self.params['eps'],time.time()-self.s)) 278 | sys.stdout.flush() 279 | self.log_train.write(str(self.step) + ',' + str(self.step//self.params['steps_per_epoch']) + ',' + str(self.train_cnt) + ',') 280 | self.log_train.write(str(float(self.total_reward_train)/max(1,self.num_epi_train)) +','+ str(float(self.total_Q_train)/max(1,self.steps_train)) +',') 281 | self.log_train.write(str(self.params['eps']) +','+ str(time.time()-self.s) + '\n') 282 | self.log_train.flush() 283 | 284 | def write_log_eval(self): 285 | sys.stdout.write('@@@ Evaluation (Step : %d , Minibatch update : %d , Epoch %d)\n' % (self.step,self.train_cnt,self.step//self.params['steps_per_epoch'] )) 286 | sys.stdout.write(' Num.Episodes : %d , Avg.reward : %.3f , Avg.Q : %.3f\n' % (self.num_epi_eval,float(self.total_reward_eval)/max(1,self.num_epi_eval),float(self.total_Q_eval)/max(1,self.params['steps_per_eval']))) 287 | sys.stdout.write(' Epsilon : %.3f , Elapsed time : %.1f\n\n' % (self.params['eps'],time.time()-self.s)) 288 | sys.stdout.flush() 289 | self.log_eval.write(str(self.step) + ',' + str(self.step//self.params['steps_per_epoch']) + ',' + str(self.train_cnt) + ',') 290 | self.log_eval.write(str(float(self.total_reward_eval)/max(1,self.num_epi_eval)) +','+ str(float(self.total_Q_eval)/max(1,self.params['steps_per_eval'])) +',') 291 | self.log_eval.write(str(self.params['eps']) +','+ str(time.time()-self.s) + '\n') 292 | self.log_eval.flush() 293 | 294 | def select_action(self,st): 295 | if np.random.rand() > self.params['eps']: 296 | #greedy with random tie-breaking 297 | self.lock.acquire() 298 | Q_pred = self.sess.run(self.qnet.y, feed_dict = {self.qnet.x: np.reshape(st, (1,84,84,4))})[0] 299 | self.lock.release() 300 | a_winner = np.argwhere(Q_pred == np.amax(Q_pred)) 301 | if len(a_winner) > 1: 302 | act_idx = a_winner[np.random.randint(0, len(a_winner))][0] 303 | return act_idx,self.engine.legal_actions[act_idx], np.amax(Q_pred) 304 | else: 305 | act_idx = a_winner[0][0] 306 | return act_idx,self.engine.legal_actions[act_idx], np.amax(Q_pred) 307 | else: 308 | #random 309 | act_idx = np.random.randint(0,len(self.engine.legal_actions)) 310 | self.lock.acquire() 311 | Q_pred = self.sess.run(self.qnet.y, feed_dict = {self.qnet.x: np.reshape(st, (1,84,84,4))})[0] 312 | self.lock.release() 313 | return act_idx,self.engine.legal_actions[act_idx], Q_pred[act_idx] 314 | 315 | def get_onehot(self,actions): 316 | actions_onehot = np.zeros((self.params['batch'], self.params['num_act'])) 317 | 318 | for i in range(self.params['batch']): 319 | actions_onehot[i,actions[i]] = 1 320 | return actions_onehot 321 | 322 | 323 | if __name__ == "__main__": 324 | dict_items = params.items() 325 | for i in range(1,len(sys.argv),2): 326 | if sys.argv[i] == '-weight' :params['ckpt_file'] = sys.argv[i+1] 327 | elif sys.argv[i] == '-network_type' :params['network_type'] = sys.argv[i+1] 328 | elif sys.argv[i] == '-visualize' : 329 | if sys.argv[i+1] == 'y' : params['visualize'] = True 330 | elif sys.argv[i+1] == 'n' : params['visualize'] = False 331 | else: 332 | print 'Invalid visualization argument!!! Available arguments are' 333 | print ' y or n' 334 | raise ValueError() 335 | elif sys.argv[i] == '-gpu_fraction' : params['gpu_fraction'] = float(sys.argv[i+1]) 336 | elif sys.argv[i] == '-db_size' : params['db_size'] = int(sys.argv[i+1]) 337 | #elif sys.argv[i] == '-num_threads' : params['num_threads'] = int(sys.argv[i+1]) 338 | elif sys.argv[i] == '-only_eval' : params['only_eval'] = sys.argv[i+1] 339 | else : 340 | print 'Invalid arguments!!! Available arguments are' 341 | print ' -weight (filename)' 342 | print ' -network_type (nips or nature)' 343 | print ' -visualize (y or n)' 344 | print ' -gpu_fraction (0.1~0.9)' 345 | print ' -db_size (integer)' 346 | #print ' -num_threads (integer)' 347 | raise ValueError() 348 | if params['network_type'] == 'nips': 349 | from DQN_nips import * 350 | elif params['network_type'] == 'nature': 351 | from DQN_nature import * 352 | params['steps_per_epoch']= 200000 353 | params['eval_freq'] = 200000 354 | params['steps_per_eval'] = 10000 355 | params['copy_freq'] = 40000 356 | params['disp_freq'] = 20000 357 | params['save_interval'] = 20000 358 | params['learning_interval'] = 1 359 | params['discount'] = 0.99 360 | params['lr'] = 0.00025 361 | params['rms_decay'] = 0.95 362 | params['rms_eps']=0.01 363 | params['clip_delta'] = 1.0 364 | params['train_start']=10000 365 | params['batch_accumulator'] = 'sum' 366 | params['eps_step'] = 4000000 367 | params['num_epochs'] = 1000 368 | params['batch'] = 32 369 | else : 370 | print 'Invalid network type! Available network types are' 371 | print ' nips or nature' 372 | raise ValueError() 373 | 374 | if params['only_eval'] == 'y' : only_eval = True 375 | elif params['only_eval'] == 'n' : only_eval = False 376 | else : 377 | print 'Invalid only_eval option! Available options are' 378 | print ' y or n' 379 | raise ValueError() 380 | 381 | if only_eval: 382 | params['eval_freq'] = 1 383 | params['train_start'] = 100 384 | 385 | da = deep_atari(params) 386 | da.start() 387 | -------------------------------------------------------------------------------- /pretrained/nature_pretrained: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gliese581gg/DQN_tensorflow/d7552c14a5d81712ecbe6365bb4289994e63a6a3/pretrained/nature_pretrained -------------------------------------------------------------------------------- /pretrained/nips_pretrained: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gliese581gg/DQN_tensorflow/d7552c14a5d81712ecbe6365bb4289994e63a6a3/pretrained/nips_pretrained -------------------------------------------------------------------------------- /roms/breakout.bin: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gliese581gg/DQN_tensorflow/d7552c14a5d81712ecbe6365bb4289994e63a6a3/roms/breakout.bin --------------------------------------------------------------------------------