├── .gitignore ├── DQN.py ├── README.md ├── ckpt └── readme.txt ├── database.py ├── emulator.py ├── main.py └── roms └── breakout.bin /.gitignore: -------------------------------------------------------------------------------- 1 | .idea 2 | .ipynb_checkpoints 3 | results 4 | MNIST* 5 | *.pyc 6 | -------------------------------------------------------------------------------- /DQN.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | import cv2 4 | 5 | class DQN: 6 | def __init__(self,params): 7 | self.params = params 8 | self.network_name = 'qnet' 9 | self.sess = tf.Session() 10 | self.x = tf.placeholder('float',[None,84,84,4],name=self.network_name + '_x') 11 | self.q_t = tf.placeholder('float',[None],name=self.network_name + '_q_t') 12 | self.actions = tf.placeholder("float", [None, params['num_act']],name=self.network_name + '_actions') 13 | self.rewards = tf.placeholder("float", [None],name=self.network_name + '_rewards') 14 | self.terminals = tf.placeholder("float", [None],name=self.network_name + '_terminals') 15 | 16 | #conv1 17 | layer_name = 'conv1' ; size = 8 ; channels = 4 ; filters = 16 ; stride = 4 18 | self.w1 = tf.Variable(tf.random_normal([size,size,channels,filters], stddev=0.01),name=self.network_name + '_'+layer_name+'_weights') 19 | self.b1 = tf.Variable(tf.constant(0.1, shape=[filters]),name=self.network_name + '_'+layer_name+'_biases') 20 | self.c1 = tf.nn.conv2d(self.x, self.w1, strides=[1, stride, stride, 1], padding='SAME',name=self.network_name + '_'+layer_name+'_convs') 21 | self.o1 = tf.nn.relu(tf.add(self.c1,self.b1),name=self.network_name + '_'+layer_name+'_activations') 22 | #self.n1 = tf.nn.lrn(self.o1, 4, bias=1.0, alpha=0.001 / 9.0, beta=0.75) 23 | 24 | #conv2 25 | layer_name = 'conv2' ; size = 4 ; channels = 16 ; filters = 32 ; stride = 2 26 | self.w2 = tf.Variable(tf.random_normal([size,size,channels,filters], stddev=0.01),name=self.network_name + '_'+layer_name+'_weights') 27 | self.b2 = tf.Variable(tf.constant(0.1, shape=[filters]),name=self.network_name + '_'+layer_name+'_biases') 28 | self.c2 = tf.nn.conv2d(self.o1, self.w2, strides=[1, stride, stride, 1], padding='SAME',name=self.network_name + '_'+layer_name+'_convs') 29 | self.o2 = tf.nn.relu(tf.add(self.c2,self.b2),name=self.network_name + '_'+layer_name+'_activations') 30 | #self.n2 = tf.nn.lrn(self.o2, 4, bias=1.0, alpha=0.001 / 9.0, beta=0.75) 31 | 32 | #flat 33 | o2_shape = self.o2.get_shape().as_list() 34 | 35 | #fc3 36 | layer_name = 'fc3' ; hiddens = 256 ; dim = o2_shape[1]*o2_shape[2]*o2_shape[3] 37 | self.o2_flat = tf.reshape(self.o2, [-1,dim],name=self.network_name + '_'+layer_name+'_input_flat') 38 | self.w3 = tf.Variable(tf.random_normal([dim,hiddens], stddev=0.01),name=self.network_name + '_'+layer_name+'_weights') 39 | self.b3 = tf.Variable(tf.constant(0.1, shape=[hiddens]),name=self.network_name + '_'+layer_name+'_biases') 40 | self.ip3 = tf.add(tf.matmul(self.o2_flat,self.w3),self.b3,name=self.network_name + '_'+layer_name+'_ips') 41 | self.o3 = tf.nn.relu(self.ip3,name=self.network_name + '_'+layer_name+'_activations') 42 | 43 | #fc4 44 | layer_name = 'fc4' ; hiddens = params['num_act'] ; dim = 256 45 | self.w4 = tf.Variable(tf.random_normal([dim,hiddens], stddev=0.01),name=self.network_name + '_'+layer_name+'_weights') 46 | self.b4 = tf.Variable(tf.constant(0.1, shape=[hiddens]),name=self.network_name + '_'+layer_name+'_biases') 47 | self.y = tf.add(tf.matmul(self.o3,self.w4),self.b4,name=self.network_name + '_'+layer_name+'_outputs') 48 | 49 | #Q,Cost,Optimizer 50 | self.discount = tf.constant(self.params['discount']) 51 | self.yj = tf.add(self.rewards, tf.mul(1.0-self.terminals, tf.mul(self.discount, self.q_t))) 52 | self.Q_pred = tf.reduce_sum(tf.mul(self.y,self.actions), reduction_indices=1) 53 | #half = tf.constant(0.5) 54 | self.cost = tf.reduce_sum(tf.pow(tf.sub(self.yj, self.Q_pred), 2)) 55 | if self.params['ckpt_file'] is not None: 56 | self.global_step = tf.Variable(int(self.params['ckpt_file'].split('_')[-1]),name='global_step', trainable=False) 57 | else: 58 | self.global_step = tf.Variable(0, name='global_step', trainable=False) 59 | self.rmsprop = tf.train.RMSPropOptimizer(self.params['lr'],self.params['rms_decay'],0.0,self.params['rms_eps']).minimize(self.cost,global_step=self.global_step) 60 | self.saver = tf.train.Saver() 61 | self.sess.run(tf.initialize_all_variables()) 62 | if self.params['ckpt_file'] is not None: 63 | print 'loading checkpoint...' 64 | self.saver.restore(self.sess,self.params['ckpt_file']) 65 | 66 | 67 | def train(self,bat_s,bat_a,bat_t,bat_n,bat_r): 68 | feed_dict={self.x: bat_n, self.q_t: np.zeros(bat_n.shape[0]), self.actions: bat_a, self.terminals:bat_t, self.rewards: bat_r} 69 | q_t = self.sess.run(self.y,feed_dict=feed_dict) 70 | q_t = np.amax(q_t,axis=1) 71 | feed_dict={self.x: bat_s, self.q_t: q_t, self.actions: bat_a, self.terminals:bat_t, self.rewards: bat_r} 72 | _,cnt,cost = self.sess.run([self.rmsprop,self.global_step,self.cost],feed_dict=feed_dict) 73 | return cnt,cost 74 | 75 | def save_ckpt(self,filename): 76 | self.saver.save(self.sess,filename) 77 | 78 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | #Deep Q Learning for ATARI using Tensorflow 2 | 3 | Usage : 'python main.py (ckpt file)' 4 | 5 | This version is still under test. 6 | 7 | If 'Memory Error' arises, change the value of 'db_size' in main.py 8 | 9 | if anyone does experiment with db size = 1000000, Please let me know the results. 10 | 11 | 12 | Requirements 13 | 14 | 1. Tensorflow 15 | 2. opencv2 16 | 3. Arcade Learning Environment ( https://github.com/mgbellemare/Arcade-Learning-Environment ) 17 | -------------------------------------------------------------------------------- /ckpt/readme.txt: -------------------------------------------------------------------------------- 1 | This directory is where the checkpoint files will be saved 2 | -------------------------------------------------------------------------------- /database.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import gc 3 | import time 4 | 5 | class database: 6 | def __init__(self, size, input_dims): 7 | #create database with input_dims as list of input dimensions 8 | self.size = size 9 | self.states = np.zeros([self.size,84,84],dtype='float') #image dimensions 10 | self.actions = np.zeros(self.size,dtype='float') 11 | self.terminals = np.zeros(self.size,dtype='float') 12 | #self.nextstates = np.zeros([self.size,input_dims[0],input_dims[1],input_dims[2]],dtype='float') 13 | self.rewards = np.zeros(self.size,dtype='float') 14 | 15 | self.counter = 0 #keep track of next empty state 16 | self.batch_counter = 0 17 | self.rand_idxs = np.arange(3,300) 18 | self.flag = False 19 | return 20 | 21 | def get_four(self,idx): 22 | four_s = np.zeros([84,84,4]) 23 | four_n = np.zeros([84,84,4]) 24 | for i in range(0,4): 25 | four_s[:,:,i] = self.states[idx-3+i] 26 | four_n[:,:,i] = self.states[idx-2+i] 27 | 28 | return four_s,self.actions[idx],self.terminals[idx],four_n,self.rewards[idx] 29 | 30 | def get_batches(self, bat_size): 31 | bat_s = np.zeros([bat_size,84,84,4]) 32 | bat_a = np.zeros([bat_size]) 33 | bat_t = np.zeros([bat_size]) 34 | bat_n = np.zeros([bat_size,84,84,4]) 35 | bat_r = np.zeros([bat_size]) 36 | ss = time.time() 37 | for i in range(bat_size): 38 | if self.batch_counter >= len(self.rand_idxs) - bat_size : 39 | self.rand_idxs = np.arange(3,self.get_size()-1) 40 | np.random.shuffle(self.rand_idxs) 41 | self.batch_counter = 0 42 | s,a,t,n,r = self.get_four(self.rand_idxs[self.batch_counter]) 43 | bat_s[i] = s; bat_a[i] = a; bat_t[i] = t; bat_n[i] = n; bat_r[i] = r 44 | self.batch_counter += 1 45 | 46 | e3 = time.time()-ss 47 | return bat_s,bat_a,bat_t,bat_n,bat_r 48 | 49 | def insert(self, prevstate_proc,reward,action,terminal): 50 | self.states[self.counter] = prevstate_proc 51 | #self.nextstates[self.counter] = newstate_proc 52 | self.rewards[self.counter] = reward 53 | self.actions[self.counter] = action 54 | self.terminals[self.counter] = terminal 55 | 56 | #update counter 57 | self.counter += 1 58 | if self.counter >= self.size: 59 | self.flag = True 60 | self.counter = 0 61 | return 62 | 63 | def get_size(self): 64 | if self.flag == False: 65 | return self.counter 66 | else: 67 | return self.size 68 | 69 | -------------------------------------------------------------------------------- /emulator.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import copy 3 | import sys 4 | from ale_python_interface import ALEInterface 5 | import cv2 6 | import time 7 | #import scipy.misc 8 | 9 | class emulator: 10 | def __init__(self, rom_name, vis): 11 | self.ale = ALEInterface() 12 | self.max_frames_per_episode = self.ale.getInt("max_num_frames_per_episode"); 13 | self.ale.setInt("random_seed",123) 14 | self.ale.setInt("frame_skip",4) 15 | self.ale.loadROM('roms/' + rom_name ) 16 | self.legal_actions = self.ale.getMinimalActionSet() 17 | self.action_map = dict() 18 | for i in range(len(self.legal_actions)): 19 | self.action_map[self.legal_actions[i]] = i 20 | 21 | # print(self.legal_actions) 22 | self.screen_width,self.screen_height = self.ale.getScreenDims() 23 | print("width/height: " +str(self.screen_width) + "/" + str(self.screen_height)) 24 | self.vis = vis 25 | if vis: 26 | cv2.startWindowThread() 27 | cv2.namedWindow("preview") 28 | 29 | def get_image(self): 30 | numpy_surface = np.zeros(self.screen_height*self.screen_width*3, dtype=np.uint8) 31 | self.ale.getScreenRGB(numpy_surface) 32 | image = np.reshape(numpy_surface, (self.screen_height, self.screen_width, 3)) 33 | return image 34 | 35 | def newGame(self): 36 | self.ale.reset_game() 37 | return self.get_image() 38 | 39 | def next(self, action_indx): 40 | reward = self.ale.act(action_indx) 41 | nextstate = self.get_image() 42 | # scipy.misc.imsave('test.png',nextstate) 43 | if self.vis: 44 | cv2.imshow('preview',nextstate) 45 | return nextstate, reward, self.ale.game_over() 46 | 47 | 48 | 49 | if __name__ == "__main__": 50 | engine = emulator('breakout.bin',True) 51 | engine.next(0) 52 | time.sleep(5) 53 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | from DQN import * 2 | from database import * 3 | from emulator import * 4 | import tensorflow as tf 5 | import numpy as np 6 | import time 7 | from ale_python_interface import ALEInterface 8 | import cv2 9 | from scipy import misc 10 | import gc #garbage colloector 11 | 12 | gc.enable() 13 | 14 | params = { 15 | 'ckpt_file':None, 16 | 'num_episodes': 250000, 17 | 'rms_decay':0.99, 18 | 'rms_eps':1e-6, 19 | 'db_size': 1000000, 20 | 'batch': 32, 21 | 'num_act': 0, 22 | 'input_dims' : [210, 160, 3], 23 | 'input_dims_proc' : [84, 84, 4], 24 | 'episode_max_length': 100000, 25 | 'learning_interval': 1, 26 | 'eps': 1.0, 27 | 'eps_step':1000000, 28 | 'discount': 0.95, 29 | 'lr': 0.0002, 30 | 'save_interval':20000, 31 | 'train_start':100, 32 | 'eval_mode':False 33 | } 34 | 35 | class deep_atari: 36 | def __init__(self,params): 37 | print 'Initializing Module...' 38 | self.params = params 39 | self.sess = tf.Session() 40 | self.DB = database(self.params['db_size'], self.params['input_dims_proc']) 41 | self.engine = emulator(rom_name='breakout.bin', vis=True) 42 | self.params['num_act'] = len(self.engine.legal_actions) 43 | self.build_nets() 44 | self.Q_global = 0 45 | self.cost_disp = 0 46 | 47 | def build_nets(self): 48 | print 'Building QNet and Targetnet...' 49 | self.qnet = DQN(self.params) 50 | 51 | def start(self): 52 | print 'Start training...' 53 | cnt = self.qnet.sess.run(self.qnet.global_step) 54 | print 'Global step = ' + str(cnt) 55 | local_cnt = 0 56 | s = time.time() 57 | for numeps in range(self.params['num_episodes']): 58 | self.Q_global = 0 59 | state_proc = np.zeros((84,84,4)); state_proc_old = None; action = None; terminal = None; delay = 0; 60 | state = self.engine.newGame() 61 | state_resized = cv2.resize(state,(84,110)) 62 | state_gray = cv2.cvtColor(state_resized, cv2.COLOR_BGR2GRAY) 63 | state_proc[:,:,3] = state_gray[26:110,:]/255.0 64 | total_reward_ep = 0 65 | for maxl in range(self.params['episode_max_length']): 66 | if state_proc_old is not None: 67 | self.DB.insert(state_proc_old[:,:,3],reward,action,terminal) 68 | action = self.perceive(state_proc, terminal) 69 | if action == None: #TODO - check [terminal condition] 70 | break 71 | if local_cnt > self.params['train_start'] and local_cnt % self.params['learning_interval'] == 0: 72 | bat_s,bat_a,bat_t,bat_n,bat_r = self.DB.get_batches(self.params['batch']) 73 | bat_a = self.get_onehot(bat_a) 74 | cnt,self.cost_disp = self.qnet.train(bat_s,bat_a,bat_t,bat_n,bat_r) 75 | if local_cnt > self.params['train_start'] and local_cnt % self.params['save_interval'] == 0: 76 | self.qnet.save_ckpt('ckpt/model_'+str(cnt)) 77 | print 'Model saved' 78 | 79 | state_proc_old = np.copy(state_proc) 80 | state, reward, terminal = self.engine.next(action) #IMP: newstate contains terminal info 81 | state_resized = cv2.resize(state,(84,110)) 82 | state_gray = cv2.cvtColor(state_resized, cv2.COLOR_BGR2GRAY) 83 | state_proc[:,:,0:3] = state_proc[:,:,1:4] 84 | state_proc[:,:,3] = state_gray[26:110,:]/255.0 85 | total_reward_ep = total_reward_ep + reward 86 | local_cnt+=1 87 | #params['eps'] =0.05 88 | self.params['eps'] = max(0.1,1.0 - float(cnt)/float(self.params['eps_step'])) 89 | #self.params['eps'] = 0.00001 90 | 91 | sys.stdout.write("Epi: %d | frame: %d | train_step: %d | time: %f | reward: %f | eps: %f " % (numeps,local_cnt,cnt, time.time()-s, total_reward_ep,self.params['eps'])) 92 | sys.stdout.write("| max_Q: %f\n" % (self.Q_global)) 93 | #sys.stdout.write("%f, %f, %f, %f, %f\n" % (self.t_e[0],self.t_e[1],self.t_e[2],self.t_e[3],self.t_e[4])) 94 | sys.stdout.flush() 95 | 96 | 97 | def select_action(self,state): 98 | if np.random.rand() > self.params['eps']: 99 | #greedy with random tie-breaking 100 | Q_pred = self.qnet.sess.run(self.qnet.y, feed_dict = {self.qnet.x: np.reshape(state, (1,84,84,4)),self.qnet.q_t: np.zeros(1) , self.qnet.actions: np.zeros((1,self.params['num_act'])), self.qnet.terminals:np.zeros(1), self.qnet.rewards: np.zeros(1)})[0] #TODO check 101 | self.Q_global = max(self.Q_global,np.amax(Q_pred)) 102 | a_winner = np.argwhere(Q_pred == np.amax(Q_pred)) 103 | if len(a_winner) > 1: 104 | return self.engine.legal_actions[a_winner[np.random.randint(0, len(a_winner))][0]] 105 | else: 106 | return self.engine.legal_actions[a_winner[0][0]] 107 | else: 108 | #random 109 | return self.engine.legal_actions[np.random.randint(0,len(self.engine.legal_actions))] 110 | 111 | def perceive(self,newstate, terminal): 112 | if not terminal: 113 | action = self.select_action(newstate) 114 | return action 115 | 116 | def get_onehot(self,actions): 117 | actions_onehot = np.zeros((self.params['batch'], self.params['num_act'])) 118 | for i in range(len(actions)): 119 | actions_onehot[i][self.engine.action_map[int(actions[i])]] = 1 120 | return actions_onehot 121 | 122 | 123 | if __name__ == "__main__": 124 | if len(sys.argv) > 1: 125 | params['ckpt_file'] = sys.argv[1] 126 | da = deep_atari(params) 127 | da.start() 128 | -------------------------------------------------------------------------------- /roms/breakout.bin: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mrkulk/deepQN_tensorflow/77a3747d9f5b8c2a0039bdc5c5658f1df2cb8fa5/roms/breakout.bin --------------------------------------------------------------------------------