├── .DS_Store ├── README.md ├── OmokModel.ckpt.meta ├── OmokModel.ckpt.index ├── OmokModelDeep.ckpt.index ├── OmokModelDeep.ckpt.meta ├── OmokModel.ckpt.data-00000-of-00001 ├── OmokModelDeep.ckpt.data-00000-of-00001 ├── OmokPlayDeep.py ├── OmokPlay.py ├── OmokTrain.py └── OmokTrainDeep.py /.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepseasw/OmokQLearning/HEAD/.DS_Store -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 자세한 설명은 아래 링크 참조 2 | 3 | -> http://aidev.co.kr/deeplearning/1056 4 | -------------------------------------------------------------------------------- /OmokModel.ckpt.meta: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepseasw/OmokQLearning/HEAD/OmokModel.ckpt.meta -------------------------------------------------------------------------------- /OmokModel.ckpt.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepseasw/OmokQLearning/HEAD/OmokModel.ckpt.index -------------------------------------------------------------------------------- /OmokModelDeep.ckpt.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepseasw/OmokQLearning/HEAD/OmokModelDeep.ckpt.index -------------------------------------------------------------------------------- /OmokModelDeep.ckpt.meta: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepseasw/OmokQLearning/HEAD/OmokModelDeep.ckpt.meta -------------------------------------------------------------------------------- /OmokModel.ckpt.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepseasw/OmokQLearning/HEAD/OmokModel.ckpt.data-00000-of-00001 -------------------------------------------------------------------------------- /OmokModelDeep.ckpt.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepseasw/OmokQLearning/HEAD/OmokModelDeep.ckpt.data-00000-of-00001 -------------------------------------------------------------------------------- /OmokPlayDeep.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | 4 | 5 | from OmokTrainDeep import OmokEnvironment 6 | import tensorflow as tf 7 | import numpy as np 8 | import random 9 | import math 10 | import os 11 | import sys 12 | import time 13 | 14 | 15 | 16 | #------------------------------------------------------------ 17 | # 변수 설정 18 | #------------------------------------------------------------ 19 | STONE_NONE = 0 20 | STONE_PLAYER1 = 1 21 | STONE_PLAYER2 = 2 22 | 23 | gridSize = 8 24 | #------------------------------------------------------------ 25 | 26 | 27 | 28 | #------------------------------------------------------------ 29 | # 화면 출력 함수 30 | #------------------------------------------------------------ 31 | def showBoard(env): 32 | for y in xrange(gridSize): 33 | for x in xrange(gridSize): 34 | if( env.state[y * gridSize + x] == STONE_PLAYER1 ): 35 | sys.stdout.write('O') 36 | elif( env.state[y * gridSize + x] == STONE_PLAYER2 ): 37 | sys.stdout.write('X') 38 | else: 39 | sys.stdout.write('.') 40 | sys.stdout.write('\n') 41 | sys.stdout.write('\n') 42 | 43 | 44 | #------------------------------------------------------------ 45 | 46 | 47 | 48 | #------------------------------------------------------------ 49 | # 게임 플레이 함수 50 | #------------------------------------------------------------ 51 | def playGame(env, sess): 52 | 53 | env.reset() 54 | 55 | gameOver = False 56 | currentPlayer = STONE_PLAYER1 57 | 58 | while( gameOver != True ): 59 | action = - 9999 60 | 61 | if( currentPlayer == STONE_PLAYER1 ): 62 | currentState = env.getState() 63 | else: 64 | currentState = env.getStateInverse() 65 | 66 | action = env.getAction(sess, currentState) 67 | nextState, reward, gameOver = env.act(currentPlayer, action) 68 | 69 | showBoard(env) 70 | time.sleep(3) 71 | 72 | if( currentPlayer == STONE_PLAYER1 ): 73 | currentPlayer = STONE_PLAYER2 74 | else: 75 | currentPlayer = STONE_PLAYER1 76 | #------------------------------------------------------------ 77 | 78 | 79 | 80 | #------------------------------------------------------------ 81 | # 메인 함수 82 | #------------------------------------------------------------ 83 | def main(_): 84 | 85 | # 텐서플로우 초기화 86 | sess = tf.Session() 87 | sess.run(tf.global_variables_initializer()) 88 | 89 | # 세이브 설정 90 | saver = tf.train.Saver() 91 | 92 | # 모델 로드 93 | if( os.path.isfile(os.getcwd() + "/OmokModelDeep.ckpt.index") == True ): 94 | saver.restore(sess, os.getcwd() + "/OmokModelDeep.ckpt") 95 | print('Saved model is loaded!') 96 | 97 | # 환경 인스턴스 생성 98 | env = OmokEnvironment(gridSize) 99 | 100 | # 게임 플레이 101 | playGame(env, sess) 102 | 103 | # 세션 종료 104 | sess.close() 105 | #------------------------------------------------------------ 106 | 107 | 108 | 109 | #------------------------------------------------------------ 110 | # 메인 함수 실행 111 | #------------------------------------------------------------ 112 | if __name__ == '__main__': 113 | tf.app.run() 114 | #------------------------------------------------------------ 115 | 116 | -------------------------------------------------------------------------------- /OmokPlay.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | 4 | 5 | from OmokTrain import OmokEnvironment, X, W1, b1, input_layer, W2, b2, hidden_layer, W3, b3, output_layer, Y, cost, optimizer 6 | import tensorflow as tf 7 | import numpy as np 8 | import random 9 | import math 10 | import os 11 | import sys 12 | import time 13 | 14 | 15 | 16 | #------------------------------------------------------------ 17 | # 변수 설정 18 | #------------------------------------------------------------ 19 | STONE_NONE = 0 20 | STONE_PLAYER1 = 1 21 | STONE_PLAYER2 = 2 22 | 23 | gridSize = 10 24 | #------------------------------------------------------------ 25 | 26 | 27 | 28 | #------------------------------------------------------------ 29 | # 화면 출력 함수 30 | #------------------------------------------------------------ 31 | def showBoard(env): 32 | for y in xrange(gridSize): 33 | for x in xrange(gridSize): 34 | if( env.state[y * gridSize + x] == STONE_PLAYER1 ): 35 | sys.stdout.write('O') 36 | elif( env.state[y * gridSize + x] == STONE_PLAYER2 ): 37 | sys.stdout.write('X') 38 | else: 39 | sys.stdout.write('.') 40 | sys.stdout.write('\n') 41 | sys.stdout.write('\n') 42 | 43 | 44 | #------------------------------------------------------------ 45 | 46 | 47 | 48 | #------------------------------------------------------------ 49 | # 게임 플레이 함수 50 | #------------------------------------------------------------ 51 | def playGame(env, sess): 52 | 53 | env.reset() 54 | 55 | gameOver = False 56 | currentPlayer = STONE_PLAYER1 57 | 58 | while( gameOver != True ): 59 | action = - 9999 60 | 61 | if( currentPlayer == STONE_PLAYER1 ): 62 | currentState = env.getState() 63 | else: 64 | currentState = env.getStateInverse() 65 | 66 | action = env.getAction(sess, currentState) 67 | nextState, reward, gameOver = env.act(currentPlayer, action) 68 | 69 | showBoard(env) 70 | time.sleep(3) 71 | 72 | if( currentPlayer == STONE_PLAYER1 ): 73 | currentPlayer = STONE_PLAYER2 74 | else: 75 | currentPlayer = STONE_PLAYER1 76 | #------------------------------------------------------------ 77 | 78 | 79 | 80 | #------------------------------------------------------------ 81 | # 메인 함수 82 | #------------------------------------------------------------ 83 | def main(_): 84 | 85 | # 환경 인스턴스 생성 86 | env = OmokEnvironment(gridSize) 87 | 88 | # 텐서플로우 초기화 89 | sess = tf.Session() 90 | sess.run(tf.global_variables_initializer()) 91 | 92 | # 세이브 설정 93 | saver = tf.train.Saver() 94 | 95 | # 모델 로드 96 | if( os.path.isfile(os.getcwd() + "/OmokModel.ckpt.index") == True ): 97 | saver.restore(sess, os.getcwd() + "/OmokModel.ckpt") 98 | print('saved model is loaded!') 99 | 100 | # 게임 플레이 101 | playGame(env, sess) 102 | 103 | # 세션 종료 104 | sess.close() 105 | #------------------------------------------------------------ 106 | 107 | 108 | 109 | #------------------------------------------------------------ 110 | # 메인 함수 실행 111 | #------------------------------------------------------------ 112 | if __name__ == '__main__': 113 | tf.app.run() 114 | #------------------------------------------------------------ 115 | 116 | -------------------------------------------------------------------------------- /OmokTrain.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | 4 | 5 | import tensorflow as tf 6 | import numpy as np 7 | import random 8 | import math 9 | import os 10 | 11 | 12 | 13 | #------------------------------------------------------------ 14 | # 변수 설정 15 | #------------------------------------------------------------ 16 | STONE_NONE = 0 17 | STONE_PLAYER1 = 1 18 | STONE_PLAYER2 = 2 19 | STONE_MAX = 5 20 | 21 | gridSize = 10 22 | nbActions = gridSize * gridSize 23 | nbStates = gridSize * gridSize 24 | hiddenSize = 100 25 | maxMemory = 500 26 | batchSize = 50 27 | epoch = 100 28 | epsilonStart = 1 29 | epsilonDiscount = 0.999 30 | epsilonMinimumValue = 0.1 31 | discount = 0.9 32 | learningRate = 0.2 33 | winReward = 1 34 | #------------------------------------------------------------ 35 | 36 | 37 | 38 | #------------------------------------------------------------ 39 | # 가설 설정 40 | #------------------------------------------------------------ 41 | X = tf.placeholder(tf.float32, [None, nbStates]) 42 | W1 = tf.Variable(tf.truncated_normal([nbStates, hiddenSize], stddev = 1.0 / math.sqrt(float(nbStates)))) 43 | b1 = tf.Variable(tf.truncated_normal([hiddenSize], stddev = 0.01)) 44 | input_layer = tf.nn.relu(tf.matmul(X, W1) + b1) 45 | 46 | W2 = tf.Variable(tf.truncated_normal([hiddenSize, hiddenSize], stddev = 1.0 / math.sqrt(float(hiddenSize)))) 47 | b2 = tf.Variable(tf.truncated_normal([hiddenSize], stddev = 0.01)) 48 | hidden_layer = tf.nn.relu(tf.matmul(input_layer, W2) + b2) 49 | 50 | W3 = tf.Variable(tf.truncated_normal([hiddenSize, nbActions], stddev = 1.0 / math.sqrt(float(hiddenSize)))) 51 | b3 = tf.Variable(tf.truncated_normal([nbActions], stddev = 0.01)) 52 | output_layer = tf.matmul(hidden_layer, W3) + b3 53 | 54 | Y = tf.placeholder(tf.float32, [None, nbActions]) 55 | cost = tf.reduce_sum(tf.square(Y - output_layer)) / (2 * batchSize) 56 | optimizer = tf.train.GradientDescentOptimizer(learningRate).minimize(cost) 57 | #------------------------------------------------------------ 58 | 59 | 60 | 61 | #------------------------------------------------------------ 62 | # 랜덤값 구함 63 | #------------------------------------------------------------ 64 | def randf(s, e): 65 | return (float(random.randrange(0, (e - s) * 9999)) / 10000) + s 66 | #------------------------------------------------------------ 67 | 68 | 69 | 70 | #------------------------------------------------------------ 71 | # 오목 환경 클래스 72 | #------------------------------------------------------------ 73 | class OmokEnvironment(): 74 | 75 | #-------------------------------- 76 | # 초기화 77 | #-------------------------------- 78 | def __init__(self, gridSize): 79 | self.gridSize = gridSize 80 | self.nbStates = self.gridSize * self.gridSize 81 | self.state = np.zeros(self.nbStates, dtype = np.uint8) 82 | 83 | 84 | 85 | #-------------------------------- 86 | # 리셋 87 | #-------------------------------- 88 | def reset(self): 89 | self.state = np.zeros(self.nbStates, dtype = np.uint8) 90 | 91 | 92 | 93 | #-------------------------------- 94 | # 현재 상태 구함 95 | #-------------------------------- 96 | def getState(self): 97 | return np.reshape(self.state, (1, self.nbStates)) 98 | 99 | 100 | 101 | #-------------------------------- 102 | # 플레이어가 바뀐 현재 상태 구함 103 | #-------------------------------- 104 | def getStateInverse(self): 105 | tempState = self.state.copy() 106 | 107 | for i in xrange(self.nbStates): 108 | if( tempState[i] == STONE_PLAYER1 ): 109 | tempState[i] = STONE_PLAYER2 110 | elif( tempState[i] == STONE_PLAYER2 ): 111 | tempState[i] = STONE_PLAYER1 112 | 113 | return np.reshape(tempState, (1, self.nbStates)) 114 | 115 | 116 | 117 | #-------------------------------- 118 | # 리워드 구함 119 | #-------------------------------- 120 | def GetReward(self, player, action): 121 | 122 | # 왼쪽 검사 123 | if( action % self.gridSize > 0 ): 124 | if( self.state[action - 1] == player ): 125 | return 0.05 126 | 127 | # 오른쪽 검사 128 | if( action % self.gridSize < self.gridSize - 1 ): 129 | if( self.state[action + 1] == player ): 130 | return 0.05 131 | 132 | # 위 검사 133 | if( action - self.gridSize >= 0 ): 134 | if( self.state[action - self.gridSize] == player ): 135 | return 0.05 136 | 137 | # 아래 검사 138 | if( action + self.gridSize < self.nbStates ): 139 | if( self.state[action + self.gridSize] == player ): 140 | return 0.05 141 | 142 | # 왼쪽 위 검사 143 | if( (action % self.gridSize > 0) and (action - self.gridSize >= 0) ): 144 | if( self.state[action - 1 - self.gridSize] == player ): 145 | return 0.05 146 | 147 | # 오른쪽 위 검사 148 | if( (action % self.gridSize < self.gridSize - 1) and (action - self.gridSize >= 0) ): 149 | if( self.state[action + 1 - self.gridSize] == player ): 150 | return 0.05 151 | 152 | # 왼쪽 아래 검사 153 | if( (action % self.gridSize > 0) and (action + self.gridSize < self.nbStates) ): 154 | if( self.state[action - 1 + self.gridSize] == player ): 155 | return 0.05 156 | 157 | # 오른쪽 아래 검사 158 | if( (action % self.gridSize < self.gridSize - 1) and (action + self.gridSize < self.nbStates) ): 159 | if( self.state[action + 1 + self.gridSize] == player ): 160 | return 0.05 161 | 162 | return 0 163 | 164 | 165 | 166 | #-------------------------------- 167 | # 매칭 검사 168 | #-------------------------------- 169 | def CheckMatch(self, player): 170 | for y in xrange(self.gridSize): 171 | for x in xrange(self.gridSize): 172 | 173 | #-------------------------------- 174 | # 오른쪽 검사 175 | #-------------------------------- 176 | match = 0 177 | 178 | for i in xrange(STONE_MAX): 179 | if( x + i >= self.gridSize ): 180 | break 181 | 182 | if( self.state[y * self.gridSize + x + i] == player ): 183 | match += 1 184 | else: 185 | break; 186 | 187 | if( match >= STONE_MAX ): 188 | return True 189 | 190 | #-------------------------------- 191 | # 아래쪽 검사 192 | #-------------------------------- 193 | match = 0 194 | 195 | for i in xrange(STONE_MAX): 196 | if( y + i >= self.gridSize ): 197 | break 198 | 199 | if( self.state[(y + i) * self.gridSize + x] == player ): 200 | match += 1 201 | else: 202 | break; 203 | 204 | if( match >= STONE_MAX ): 205 | return True 206 | 207 | #-------------------------------- 208 | # 오른쪽 대각선 검사 209 | #-------------------------------- 210 | match = 0 211 | 212 | for i in xrange(STONE_MAX): 213 | if( (x + i >= self.gridSize) or (y + i >= self.gridSize) ): 214 | break 215 | 216 | if( self.state[(y + i) * self.gridSize + x + i] == player ): 217 | match += 1 218 | else: 219 | break; 220 | 221 | if( match >= STONE_MAX ): 222 | return True 223 | 224 | #-------------------------------- 225 | # 왼쪽 대각선 검사 226 | #-------------------------------- 227 | match = 0 228 | 229 | for i in xrange(STONE_MAX): 230 | if( (x - i < 0) or (y + i >= self.gridSize) ): 231 | break 232 | 233 | if( self.state[(y + i) * self.gridSize + x - i] == player ): 234 | match += 1 235 | else: 236 | break; 237 | 238 | if( match >= STONE_MAX ): 239 | return True 240 | 241 | return False 242 | 243 | 244 | 245 | #-------------------------------- 246 | # 게임오버 검사 247 | #-------------------------------- 248 | def isGameOver(self, player): 249 | if( self.CheckMatch(STONE_PLAYER1) == True ): 250 | if( player == STONE_PLAYER1 ): 251 | return True, winReward 252 | else: 253 | return True, 0 254 | elif( self.CheckMatch(STONE_PLAYER2) == True ): 255 | if( player == STONE_PLAYER1 ): 256 | return True, 0 257 | else: 258 | return True, winReward 259 | else: 260 | for i in xrange(self.nbStates): 261 | if( self.state[i] == STONE_NONE ): 262 | return False, 0 263 | return True, 0 264 | 265 | 266 | 267 | #-------------------------------- 268 | # 상태 업데이트 269 | #-------------------------------- 270 | def updateState(self, player, action): 271 | self.state[action] = player; 272 | 273 | 274 | 275 | #-------------------------------- 276 | # 행동 수행 277 | #-------------------------------- 278 | def act(self, player, action): 279 | self.updateState(player, action) 280 | gameOver, reward = self.isGameOver(player) 281 | 282 | if( reward == 0 ): 283 | reward = self.GetReward(player, action) 284 | 285 | if( player == STONE_PLAYER1 ): 286 | nextState = self.getState() 287 | else: 288 | nextState = self.getStateInverse() 289 | 290 | return nextState, reward, gameOver 291 | 292 | 293 | 294 | #-------------------------------- 295 | # 행동 구함 296 | #-------------------------------- 297 | def getAction(self, sess, currentState): 298 | q = sess.run(output_layer, feed_dict = {X: currentState}) 299 | 300 | while( True ): 301 | action = q.argmax() 302 | 303 | if( self.state[action] == STONE_NONE ): 304 | return action 305 | else: 306 | q[0, action] = -99999 307 | 308 | 309 | 310 | #-------------------------------- 311 | # 랜덤 행동 구함 312 | #-------------------------------- 313 | def getActionRandom(self): 314 | while( True ): 315 | action = random.randrange(0, nbActions) 316 | 317 | if( self.state[action] == STONE_NONE ): 318 | return action 319 | #------------------------------------------------------------ 320 | 321 | 322 | 323 | #------------------------------------------------------------ 324 | # 리플레이 메모리 클래스 325 | #------------------------------------------------------------ 326 | class ReplayMemory: 327 | 328 | #-------------------------------- 329 | # 초기화 330 | #-------------------------------- 331 | def __init__(self, gridSize, maxMemory, discount): 332 | self.maxMemory = maxMemory 333 | self.gridSize = gridSize 334 | self.nbStates = self.gridSize * self.gridSize 335 | self.discount = discount 336 | 337 | self.inputState = np.empty((self.maxMemory, self.nbStates), dtype = np.uint8) 338 | self.actions = np.zeros(self.maxMemory, dtype = np.uint8) 339 | self.nextState = np.empty((self.maxMemory, self.nbStates), dtype = np.uint8) 340 | self.gameOver = np.empty(self.maxMemory, dtype = np.bool) 341 | self.rewards = np.empty(self.maxMemory, dtype = np.int8) 342 | self.count = 0 343 | self.current = 0 344 | 345 | 346 | 347 | #-------------------------------- 348 | # 결과 기억 349 | #-------------------------------- 350 | def remember(self, currentState, action, reward, nextState, gameOver): 351 | self.actions[self.current] = action 352 | self.rewards[self.current] = reward 353 | self.inputState[self.current, ...] = currentState 354 | self.nextState[self.current, ...] = nextState 355 | self.gameOver[self.current] = gameOver 356 | self.count = max(self.count, self.current + 1) 357 | self.current = (self.current + 1) % self.maxMemory 358 | 359 | 360 | 361 | #-------------------------------- 362 | # 배치 구함 363 | #-------------------------------- 364 | def getBatch(self, model, batchSize, nbActions, nbStates, sess, X): 365 | memoryLength = self.count 366 | chosenBatchSize = min(batchSize, memoryLength) 367 | 368 | inputs = np.zeros((chosenBatchSize, nbStates)) 369 | targets = np.zeros((chosenBatchSize, nbActions)) 370 | 371 | for i in xrange(chosenBatchSize): 372 | randomIndex = random.randrange(0, memoryLength) 373 | current_inputState = np.reshape(self.inputState[randomIndex], (1, nbStates)) 374 | 375 | target = sess.run(model, feed_dict = {X: current_inputState}) 376 | 377 | current_nextState = np.reshape(self.nextState[randomIndex], (1, nbStates)) 378 | current_outputs = sess.run(model, feed_dict = {X: current_nextState}) 379 | 380 | nextStateMaxQ = np.amax(current_outputs) 381 | 382 | if( nextStateMaxQ > winReward ): 383 | nextStateMaxQ = winReward 384 | 385 | if( self.gameOver[randomIndex] == True ): 386 | target[0, [self.actions[randomIndex]]] = self.rewards[randomIndex] 387 | else: 388 | target[0, [self.actions[randomIndex]]] = self.rewards[randomIndex] + self.discount * nextStateMaxQ 389 | 390 | inputs[i] = current_inputState 391 | targets[i] = target 392 | 393 | return inputs, targets 394 | #------------------------------------------------------------ 395 | 396 | 397 | 398 | #------------------------------------------------------------ 399 | # 게임 플레이 함수 400 | #------------------------------------------------------------ 401 | def playGame(env, memory, sess, saver, epsilon, iteration): 402 | 403 | #-------------------------------- 404 | # 게임 반복 405 | #-------------------------------- 406 | winCount = 0 407 | 408 | for i in xrange(epoch): 409 | env.reset() 410 | 411 | err = 0 412 | gameOver = False 413 | currentPlayer = STONE_PLAYER1 414 | 415 | while( gameOver != True ): 416 | #-------------------------------- 417 | # 행동 수행 418 | #-------------------------------- 419 | action = - 9999 420 | 421 | if( currentPlayer == STONE_PLAYER1 ): 422 | currentState = env.getState() 423 | else: 424 | currentState = env.getStateInverse() 425 | 426 | if( randf(0, 1) <= epsilon ): 427 | action = env.getActionRandom() 428 | else: 429 | action = env.getAction(sess, currentState) 430 | 431 | if( epsilon > epsilonMinimumValue ): 432 | epsilon = epsilon * epsilonDiscount 433 | 434 | nextState, reward, gameOver = env.act(currentPlayer, action) 435 | 436 | if( reward == 1 and currentPlayer == STONE_PLAYER1 ): 437 | winCount = winCount + 1 438 | 439 | #-------------------------------- 440 | # 학습 수행 441 | #-------------------------------- 442 | memory.remember(currentState, action, reward, nextState, gameOver) 443 | 444 | inputs, targets = memory.getBatch(output_layer, batchSize, nbActions, nbStates, sess, X) 445 | 446 | _, loss = sess.run([optimizer, cost], feed_dict = {X: inputs, Y: targets}) 447 | err = err + loss 448 | 449 | if( currentPlayer == STONE_PLAYER1 ): 450 | currentPlayer = STONE_PLAYER2 451 | else: 452 | currentPlayer = STONE_PLAYER1 453 | 454 | print("Epoch " + str(iteration) + str(i) + ": err = " + str(err) + ": Win count = " + str(winCount) + 455 | " Win ratio = " + str(float(winCount) / float(i + 1) * 100)) 456 | 457 | print(targets) 458 | 459 | if( (i % 10 == 0) and (i != 0) ): 460 | save_path = saver.save(sess, os.getcwd() + "/OmokModel.ckpt") 461 | print("Model saved in file: %s" % save_path) 462 | #------------------------------------------------------------ 463 | 464 | 465 | 466 | #------------------------------------------------------------ 467 | # 메인 함수 468 | #------------------------------------------------------------ 469 | def main(_): 470 | 471 | print("Training new model") 472 | 473 | # 환경 인스턴스 생성 474 | env = OmokEnvironment(gridSize) 475 | 476 | # 리플레이 메모리 인스턴스 생성 477 | memory = ReplayMemory(gridSize, maxMemory, discount) 478 | 479 | # 텐서플로우 초기화 480 | sess = tf.Session() 481 | sess.run(tf.global_variables_initializer()) 482 | 483 | # 세이브 설정 484 | saver = tf.train.Saver() 485 | 486 | # 모델 로드 487 | if( os.path.isfile(os.getcwd() + "/OmokModel.ckpt.index") == True ): 488 | saver.restore(sess, os.getcwd() + "/OmokModel.ckpt") 489 | print('Saved model is loaded!') 490 | 491 | # 게임 플레이 492 | iteration = 0 493 | while( True ): 494 | playGame(env, memory, sess, saver, epsilonStart, iteration); 495 | iteration += 1 496 | 497 | # 세션 종료 498 | sess.close() 499 | #------------------------------------------------------------ 500 | 501 | 502 | 503 | #------------------------------------------------------------ 504 | # 메인 함수 실행 505 | #------------------------------------------------------------ 506 | if __name__ == '__main__': 507 | tf.app.run() 508 | #------------------------------------------------------------ 509 | 510 | 511 | 512 | 513 | 514 | 515 | 516 | 517 | 518 | 519 | 520 | 521 | 522 | 523 | 524 | 525 | -------------------------------------------------------------------------------- /OmokTrainDeep.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | 4 | 5 | import tensorflow as tf 6 | import numpy as np 7 | import random 8 | import math 9 | import os 10 | 11 | 12 | 13 | #------------------------------------------------------------ 14 | # 변수 설정 15 | #------------------------------------------------------------ 16 | STONE_NONE = 0 17 | STONE_PLAYER1 = 1 18 | STONE_PLAYER2 = 2 19 | STONE_MAX = 5 20 | 21 | gridSize = 8 22 | nbActions = gridSize * gridSize 23 | nbStates = gridSize * gridSize 24 | hiddenSize = 100 25 | maxMemory = 500 26 | batchSize = 50 27 | epoch = 100 28 | epsilonStart = 1 29 | epsilonDiscount = 0.999 30 | epsilonMinimumValue = 0.1 31 | discount = 0.9 32 | learningRate = 0.0001 33 | winReward = 1 34 | #------------------------------------------------------------ 35 | 36 | 37 | 38 | #------------------------------------------------------------ 39 | # 모델 설정 40 | #------------------------------------------------------------ 41 | 42 | # 첫번째 컨볼루션 레이어 43 | filter1_num = 32 44 | filter1_size = 5 45 | X = tf.placeholder(tf.float32, [None, nbStates]) 46 | x_image = tf.reshape(X, [-1, gridSize, gridSize, 1]) 47 | W_conv1 = tf.Variable(tf.truncated_normal([filter1_size, filter1_size, 1, filter1_num], stddev = 0.1)) 48 | h_conv1 = tf.nn.conv2d(x_image, W_conv1, strides = [1, 1, 1, 1], padding = 'SAME') 49 | b_conv1 = tf.Variable(tf.constant(0.1, shape = [filter1_num])) 50 | h_conv1_cutoff = tf.nn.relu(h_conv1 + b_conv1) 51 | h_pool1 = tf.nn.max_pool(h_conv1_cutoff, ksize = [1, 2, 2, 1], strides = [1, 2, 2, 1], padding = 'SAME') 52 | 53 | # 두번째 컨볼루션 레이어 54 | filter2_num = 64 55 | filter2_size = 5 56 | W_conv2 = tf.Variable(tf.truncated_normal([filter2_size, filter2_size, filter1_num, filter2_num], stddev = 0.1)) 57 | h_conv2 = tf.nn.conv2d(h_pool1, W_conv2, strides = [1, 1, 1, 1], padding = 'SAME') 58 | b_conv2 = tf.Variable(tf.constant(0.1, shape = [filter2_num])) 59 | h_conv2_cutoff = tf.nn.relu(h_conv2 + b_conv2) 60 | h_pool2 = tf.nn.max_pool(h_conv2_cutoff, ksize = [1, 2, 2, 1], strides = [1, 2, 2, 1], padding = 'SAME') 61 | 62 | # 완전 연결 레이어 63 | h_pool2_size = gridSize / 2 / 2 # 풀링 두번으로 사이즈 축소 64 | h_pool2_flat = tf.reshape(h_pool2, [-1, h_pool2_size * h_pool2_size * filter2_num]) 65 | units1_num = h_pool2_size * h_pool2_size * filter2_num 66 | units2_num = 1024 67 | w2 = tf.Variable(tf.truncated_normal([units1_num, units2_num])) 68 | b2 = tf.Variable(tf.constant(0.1, shape = [units2_num])) 69 | hidden2 = tf.nn.relu(tf.matmul(h_pool2_flat, w2) + b2) 70 | keep_prob = tf.placeholder(tf.float32) 71 | hidden2_drop = tf.nn.dropout(hidden2, keep_prob) 72 | w0 = tf.Variable(tf.zeros([units2_num, nbActions])) 73 | b0 = tf.Variable(tf.zeros([nbActions])) 74 | output_layer = tf.matmul(hidden2_drop, w0) + b0 75 | 76 | # 비용 함수 정의 77 | Y = tf.placeholder(tf.float32, [None, nbActions]) 78 | cost = tf.reduce_sum(tf.square(Y - output_layer)) / (2 * batchSize) 79 | optimizer = tf.train.AdamOptimizer(learningRate).minimize(cost) 80 | #------------------------------------------------------------ 81 | 82 | 83 | 84 | #------------------------------------------------------------ 85 | # 랜덤값 구함 86 | #------------------------------------------------------------ 87 | def randf(s, e): 88 | return (float(random.randrange(0, (e - s) * 9999)) / 10000) + s 89 | #------------------------------------------------------------ 90 | 91 | 92 | 93 | #------------------------------------------------------------ 94 | # 오목 환경 클래스 95 | #------------------------------------------------------------ 96 | class OmokEnvironment(): 97 | 98 | #-------------------------------- 99 | # 초기화 100 | #-------------------------------- 101 | def __init__(self, gridSize): 102 | self.gridSize = gridSize 103 | self.nbStates = self.gridSize * self.gridSize 104 | self.state = np.zeros(self.nbStates, dtype = np.uint8) 105 | 106 | 107 | 108 | #-------------------------------- 109 | # 리셋 110 | #-------------------------------- 111 | def reset(self): 112 | self.state = np.zeros(self.nbStates, dtype = np.uint8) 113 | 114 | 115 | 116 | #-------------------------------- 117 | # 현재 상태 구함 118 | #-------------------------------- 119 | def getState(self): 120 | return np.reshape(self.state, (1, self.nbStates)) 121 | 122 | 123 | 124 | #-------------------------------- 125 | # 플레이어가 바뀐 현재 상태 구함 126 | #-------------------------------- 127 | def getStateInverse(self): 128 | tempState = self.state.copy() 129 | 130 | for i in xrange(self.nbStates): 131 | if( tempState[i] == STONE_PLAYER1 ): 132 | tempState[i] = STONE_PLAYER2 133 | elif( tempState[i] == STONE_PLAYER2 ): 134 | tempState[i] = STONE_PLAYER1 135 | 136 | return np.reshape(tempState, (1, self.nbStates)) 137 | 138 | 139 | 140 | #-------------------------------- 141 | # 리워드 구함 142 | #-------------------------------- 143 | def GetReward(self, player, action): 144 | 145 | # 왼쪽 검사 146 | if( action % self.gridSize > 0 ): 147 | if( self.state[action - 1] == player ): 148 | return 0.05 149 | 150 | # 오른쪽 검사 151 | if( action % self.gridSize < self.gridSize - 1 ): 152 | if( self.state[action + 1] == player ): 153 | return 0.05 154 | 155 | # 위 검사 156 | if( action - self.gridSize >= 0 ): 157 | if( self.state[action - self.gridSize] == player ): 158 | return 0.05 159 | 160 | # 아래 검사 161 | if( action + self.gridSize < self.nbStates ): 162 | if( self.state[action + self.gridSize] == player ): 163 | return 0.05 164 | 165 | # 왼쪽 위 검사 166 | if( (action % self.gridSize > 0) and (action - self.gridSize >= 0) ): 167 | if( self.state[action - 1 - self.gridSize] == player ): 168 | return 0.05 169 | 170 | # 오른쪽 위 검사 171 | if( (action % self.gridSize < self.gridSize - 1) and (action - self.gridSize >= 0) ): 172 | if( self.state[action + 1 - self.gridSize] == player ): 173 | return 0.05 174 | 175 | # 왼쪽 아래 검사 176 | if( (action % self.gridSize > 0) and (action + self.gridSize < self.nbStates) ): 177 | if( self.state[action - 1 + self.gridSize] == player ): 178 | return 0.05 179 | 180 | # 오른쪽 아래 검사 181 | if( (action % self.gridSize < self.gridSize - 1) and (action + self.gridSize < self.nbStates) ): 182 | if( self.state[action + 1 + self.gridSize] == player ): 183 | return 0.05 184 | 185 | return 0 186 | 187 | 188 | 189 | #-------------------------------- 190 | # 매칭 검사 191 | #-------------------------------- 192 | def CheckMatch(self, player): 193 | for y in xrange(self.gridSize): 194 | for x in xrange(self.gridSize): 195 | 196 | #-------------------------------- 197 | # 오른쪽 검사 198 | #-------------------------------- 199 | match = 0 200 | 201 | for i in xrange(STONE_MAX): 202 | if( x + i >= self.gridSize ): 203 | break 204 | 205 | if( self.state[y * self.gridSize + x + i] == player ): 206 | match += 1 207 | else: 208 | break; 209 | 210 | if( match >= STONE_MAX ): 211 | return True 212 | 213 | #-------------------------------- 214 | # 아래쪽 검사 215 | #-------------------------------- 216 | match = 0 217 | 218 | for i in xrange(STONE_MAX): 219 | if( y + i >= self.gridSize ): 220 | break 221 | 222 | if( self.state[(y + i) * self.gridSize + x] == player ): 223 | match += 1 224 | else: 225 | break; 226 | 227 | if( match >= STONE_MAX ): 228 | return True 229 | 230 | #-------------------------------- 231 | # 오른쪽 대각선 검사 232 | #-------------------------------- 233 | match = 0 234 | 235 | for i in xrange(STONE_MAX): 236 | if( (x + i >= self.gridSize) or (y + i >= self.gridSize) ): 237 | break 238 | 239 | if( self.state[(y + i) * self.gridSize + x + i] == player ): 240 | match += 1 241 | else: 242 | break; 243 | 244 | if( match >= STONE_MAX ): 245 | return True 246 | 247 | #-------------------------------- 248 | # 왼쪽 대각선 검사 249 | #-------------------------------- 250 | match = 0 251 | 252 | for i in xrange(STONE_MAX): 253 | if( (x - i < 0) or (y + i >= self.gridSize) ): 254 | break 255 | 256 | if( self.state[(y + i) * self.gridSize + x - i] == player ): 257 | match += 1 258 | else: 259 | break; 260 | 261 | if( match >= STONE_MAX ): 262 | return True 263 | 264 | return False 265 | 266 | 267 | 268 | #-------------------------------- 269 | # 게임오버 검사 270 | #-------------------------------- 271 | def isGameOver(self, player): 272 | if( self.CheckMatch(STONE_PLAYER1) == True ): 273 | if( player == STONE_PLAYER1 ): 274 | return True, winReward 275 | else: 276 | return True, 0 277 | elif( self.CheckMatch(STONE_PLAYER2) == True ): 278 | if( player == STONE_PLAYER1 ): 279 | return True, 0 280 | else: 281 | return True, winReward 282 | else: 283 | for i in xrange(self.nbStates): 284 | if( self.state[i] == STONE_NONE ): 285 | return False, 0 286 | return True, 0 287 | 288 | 289 | 290 | #-------------------------------- 291 | # 상태 업데이트 292 | #-------------------------------- 293 | def updateState(self, player, action): 294 | self.state[action] = player; 295 | 296 | 297 | 298 | #-------------------------------- 299 | # 행동 수행 300 | #-------------------------------- 301 | def act(self, player, action): 302 | self.updateState(player, action) 303 | gameOver, reward = self.isGameOver(player) 304 | 305 | if( reward == 0 ): 306 | reward = self.GetReward(player, action) 307 | 308 | if( player == STONE_PLAYER1 ): 309 | nextState = self.getState() 310 | else: 311 | nextState = self.getStateInverse() 312 | 313 | return nextState, reward, gameOver 314 | 315 | 316 | #-------------------------------- 317 | # 행동 구함 318 | #-------------------------------- 319 | def getAction(self, sess, currentState): 320 | q = sess.run(output_layer, feed_dict = {X: currentState, keep_prob:1.0}) 321 | 322 | while( True ): 323 | action = q.argmax() 324 | 325 | if( self.state[action] == STONE_NONE ): 326 | return action 327 | else: 328 | q[0, action] = -99999 329 | 330 | 331 | 332 | #-------------------------------- 333 | # 랜덤 행동 구함 334 | #-------------------------------- 335 | def getActionRandom(self): 336 | while( True ): 337 | action = random.randrange(0, nbActions) 338 | 339 | if( self.state[action] == STONE_NONE ): 340 | return action 341 | #------------------------------------------------------------ 342 | 343 | 344 | 345 | #------------------------------------------------------------ 346 | # 리플레이 메모리 클래스 347 | #------------------------------------------------------------ 348 | class ReplayMemory: 349 | 350 | #-------------------------------- 351 | # 초기화 352 | #-------------------------------- 353 | def __init__(self, gridSize, maxMemory, discount): 354 | self.maxMemory = maxMemory 355 | self.gridSize = gridSize 356 | self.nbStates = self.gridSize * self.gridSize 357 | self.discount = discount 358 | 359 | self.inputState = np.empty((self.maxMemory, self.nbStates), dtype = np.uint8) 360 | self.actions = np.zeros(self.maxMemory, dtype = np.uint8) 361 | self.nextState = np.empty((self.maxMemory, self.nbStates), dtype = np.uint8) 362 | self.gameOver = np.empty(self.maxMemory, dtype = np.bool) 363 | self.rewards = np.empty(self.maxMemory, dtype = np.int8) 364 | self.count = 0 365 | self.current = 0 366 | 367 | 368 | 369 | #-------------------------------- 370 | # 결과 기억 371 | #-------------------------------- 372 | def remember(self, currentState, action, reward, nextState, gameOver): 373 | self.actions[self.current] = action 374 | self.rewards[self.current] = reward 375 | self.inputState[self.current, ...] = currentState 376 | self.nextState[self.current, ...] = nextState 377 | self.gameOver[self.current] = gameOver 378 | self.count = max(self.count, self.current + 1) 379 | self.current = (self.current + 1) % self.maxMemory 380 | 381 | 382 | 383 | #-------------------------------- 384 | # 배치 구함 385 | #-------------------------------- 386 | def getBatch(self, model, batchSize, nbActions, nbStates, sess): 387 | memoryLength = self.count 388 | chosenBatchSize = min(batchSize, memoryLength) 389 | 390 | inputs = np.zeros((chosenBatchSize, nbStates)) 391 | targets = np.zeros((chosenBatchSize, nbActions)) 392 | 393 | for i in xrange(chosenBatchSize): 394 | randomIndex = random.randrange(0, memoryLength) 395 | current_inputState = np.reshape(self.inputState[randomIndex], (1, nbStates)) 396 | 397 | target = sess.run(model, feed_dict = {X: current_inputState, keep_prob:1.0}) 398 | 399 | current_nextState = np.reshape(self.nextState[randomIndex], (1, nbStates)) 400 | current_outputs = sess.run(model, feed_dict = {X: current_nextState, keep_prob:1.0}) 401 | 402 | nextStateMaxQ = np.amax(current_outputs) 403 | 404 | if( nextStateMaxQ > winReward ): 405 | nextStateMaxQ = winReward 406 | 407 | if( self.gameOver[randomIndex] == True ): 408 | target[0, [self.actions[randomIndex]]] = self.rewards[randomIndex] 409 | else: 410 | target[0, [self.actions[randomIndex]]] = self.rewards[randomIndex] + self.discount * nextStateMaxQ 411 | 412 | inputs[i] = current_inputState 413 | targets[i] = target 414 | 415 | return inputs, targets 416 | #------------------------------------------------------------ 417 | 418 | 419 | 420 | #------------------------------------------------------------ 421 | # 게임 플레이 함수 422 | #------------------------------------------------------------ 423 | def playGame(env, memory, sess, saver, epsilon, iteration): 424 | 425 | #-------------------------------- 426 | # 게임 반복 427 | #-------------------------------- 428 | winCount = 0 429 | 430 | for i in xrange(epoch): 431 | env.reset() 432 | 433 | err = 0 434 | gameOver = False 435 | currentPlayer = STONE_PLAYER1 436 | 437 | while( gameOver != True ): 438 | #-------------------------------- 439 | # 행동 수행 440 | #-------------------------------- 441 | action = - 9999 442 | 443 | if( currentPlayer == STONE_PLAYER1 ): 444 | currentState = env.getState() 445 | else: 446 | currentState = env.getStateInverse() 447 | 448 | if( randf(0, 1) <= epsilon ): 449 | action = env.getActionRandom() 450 | else: 451 | action = env.getAction(sess, currentState) 452 | 453 | if( epsilon > epsilonMinimumValue ): 454 | epsilon = epsilon * epsilonDiscount 455 | 456 | nextState, reward, gameOver = env.act(currentPlayer, action) 457 | 458 | if( reward == 1 and currentPlayer == STONE_PLAYER1 ): 459 | winCount = winCount + 1 460 | 461 | #-------------------------------- 462 | # 학습 수행 463 | #-------------------------------- 464 | memory.remember(currentState, action, reward, nextState, gameOver) 465 | 466 | inputs, targets = memory.getBatch(output_layer, batchSize, nbActions, nbStates, sess) 467 | 468 | _, loss = sess.run([optimizer, cost], feed_dict = {X: inputs, Y: targets, keep_prob:1.0}) 469 | err = err + loss 470 | 471 | if( currentPlayer == STONE_PLAYER1 ): 472 | currentPlayer = STONE_PLAYER2 473 | else: 474 | currentPlayer = STONE_PLAYER1 475 | 476 | print("Epoch " + str(iteration) + str(i) + ": err = " + str(err) + ": Win count = " + str(winCount) + 477 | " Win ratio = " + str(float(winCount) / float(i + 1) * 100)) 478 | 479 | print(targets) 480 | 481 | if( (i % 10 == 0) and (i != 0) ): 482 | save_path = saver.save(sess, os.getcwd() + "/OmokModelDeep.ckpt") 483 | print("Model saved in file: %s" % save_path) 484 | #------------------------------------------------------------ 485 | 486 | 487 | 488 | #------------------------------------------------------------ 489 | # 메인 함수 490 | #------------------------------------------------------------ 491 | def main(_): 492 | 493 | # 환경 인스턴스 생성 494 | env = OmokEnvironment(gridSize) 495 | 496 | # 리플레이 메모리 인스턴스 생성 497 | memory = ReplayMemory(gridSize, maxMemory, discount) 498 | 499 | # 텐서플로우 초기화 500 | sess = tf.Session() 501 | sess.run(tf.global_variables_initializer()) 502 | 503 | # 세이브 설정 504 | saver = tf.train.Saver() 505 | 506 | # 모델 로드 507 | if( os.path.isfile(os.getcwd() + "/OmokModelDeep.ckpt.index") == True ): 508 | saver.restore(sess, os.getcwd() + "/OmokModelDeep.ckpt") 509 | print('Saved model is loaded!') 510 | 511 | # 게임 플레이 512 | iteration = 0 513 | while( True ): 514 | playGame(env, memory, sess, saver, epsilonStart, iteration); 515 | iteration += 1 516 | 517 | # 세션 종료 518 | sess.close() 519 | #------------------------------------------------------------ 520 | 521 | 522 | 523 | #------------------------------------------------------------ 524 | # 메인 함수 실행 525 | #------------------------------------------------------------ 526 | if __name__ == '__main__': 527 | tf.app.run() 528 | #------------------------------------------------------------ 529 | 530 | 531 | 532 | 533 | 534 | 535 | 536 | 537 | 538 | 539 | 540 | 541 | 542 | 543 | 544 | 545 | --------------------------------------------------------------------------------