├── .gitattributes ├── model.h5 ├── objects ├── D.pkl ├── time.pkl └── epsilon.pkl ├── unit1_test_gym.py ├── clear_objects.py ├── .gitignore ├── unit2_test_selenium.py ├── unit4_test_DQN.py ├── LICENSE ├── environment.yaml ├── README.md ├── Visualize training Progress.ipynb ├── model.json ├── unit3_test_environment.py ├── unit5_dino.py └── Reinforcement Learning Dino Run.ipynb /.gitattributes: -------------------------------------------------------------------------------- 1 | *.pkl filter=lfs diff=lfs merge=lfs -text 2 | -------------------------------------------------------------------------------- /model.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leeivan1007/DinoRunTutorial/HEAD/model.h5 -------------------------------------------------------------------------------- /objects/D.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leeivan1007/DinoRunTutorial/HEAD/objects/D.pkl -------------------------------------------------------------------------------- /objects/time.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leeivan1007/DinoRunTutorial/HEAD/objects/time.pkl -------------------------------------------------------------------------------- /objects/epsilon.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leeivan1007/DinoRunTutorial/HEAD/objects/epsilon.pkl -------------------------------------------------------------------------------- /unit1_test_gym.py: -------------------------------------------------------------------------------- 1 | import random 2 | import gym 3 | env = gym.make('MountainCar-v0') 4 | env.reset() 5 | # 6 | print('開始進行遊戲') 7 | print('終機端按ctrl-c則可結束遊戲') 8 | random_number = lambda:random.randint(0,2) 9 | 10 | while True: 11 | env.step(random_number()) 12 | env.render() 13 | -------------------------------------------------------------------------------- /clear_objects.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | from collections import deque 3 | 4 | def save_obj(obj, name ): 5 | with open('objects/'+ name + '.pkl', 'wb') as f: #dump files into objects folder 6 | pickle.dump(obj, f, pickle.HIGHEST_PROTOCOL) 7 | 8 | # save the value of epsilon, time, and data 9 | save_obj(0.1,"epsilon") 10 | t = 0 11 | save_obj(t,"time") 12 | D = deque() 13 | save_obj(D,"D") 14 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | MANIFEST 2 | build 3 | dist 4 | _build 5 | docs/man/*.gz 6 | docs/source/api/generated 7 | docs/source/config.rst 8 | docs/gh-pages 9 | notebook/i18n/*/LC_MESSAGES/*.mo 10 | notebook/i18n/*/LC_MESSAGES/nbjs.json 11 | notebook/static/components 12 | notebook/static/style/*.min.css* 13 | notebook/static/*/js/built/ 14 | notebook/static/*/built/ 15 | notebook/static/built/ 16 | notebook/static/*/js/main.min.js* 17 | notebook/static/lab/*bundle.js 18 | node_modules 19 | *.py[co] 20 | __pycache__ 21 | *.egg-info 22 | *~ 23 | *.bak 24 | .ipynb_checkpoints 25 | .tox 26 | .DS_Store 27 | \#*# 28 | .#* 29 | .coverage 30 | src 31 | 32 | *.swp 33 | *.map 34 | .idea/ 35 | Read the Docs 36 | config.rst 37 | 38 | /.project 39 | /.pydevproject -------------------------------------------------------------------------------- /unit2_test_selenium.py: -------------------------------------------------------------------------------- 1 | from selenium import webdriver 2 | from selenium.webdriver.chrome.options import Options 3 | from selenium.webdriver.common.keys import Keys 4 | 5 | # TODO the different envs of system 6 | chrome_driver_path = '../chromedriver' # 設定chromedriver路徑 7 | init_script = "document.getElementsByClassName('runner-canvas')[0].id = 'runner-canvas'" # 元素設id 8 | chrome_options = Options() 9 | chrome_options.add_argument("--mute-audio") # 取消音效 10 | driver = webdriver.Chrome(executable_path = chrome_driver_path,chrome_options=chrome_options) # 載入設定 11 | driver.set_window_position(x=-10,y=0) # 設置視窗大小 12 | 13 | driver.get('chrome://dino') # 進入網頁 14 | driver.execute_script("Runner.config.ACCELERATION=0") # 禁止加速度 15 | driver.execute_script(init_script) # 執行javascript語法 16 | 17 | print('成功運行dino!') 18 | -------------------------------------------------------------------------------- /unit4_test_DQN.py: -------------------------------------------------------------------------------- 1 | from keras.models import Sequential 2 | from keras.layers.core import Dense, Dropout, Activation, Flatten 3 | from keras.layers.convolutional import Conv2D, MaxPooling2D 4 | from keras.optimizers import SGD , Adam 5 | from keras.callbacks import TensorBoard 6 | 7 | 8 | 9 | model = Sequential() 10 | model.add(Conv2D(32, (8, 8), padding='same',strides=(4, 4),input_shape=(80,80,4))) #80*80*4 11 | model.add(MaxPooling2D(pool_size=(2,2))) 12 | model.add(Activation('relu')) 13 | model.add(Conv2D(64, (4, 4),strides=(2, 2), padding='same')) 14 | model.add(MaxPooling2D(pool_size=(2,2))) 15 | model.add(Activation('relu')) 16 | model.add(Conv2D(64, (3, 3),strides=(1, 1), padding='same')) 17 | model.add(MaxPooling2D(pool_size=(2,2))) 18 | model.add(Activation('relu')) 19 | model.add(Flatten()) 20 | model.add(Dense(512)) 21 | model.add(Activation('relu')) 22 | model.add(Dense(2)) 23 | adam = Adam(lr=1e-4) 24 | model.compile(loss='mse',optimizer=adam) 25 | 26 | print(model.summary()) -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Ravi Munde 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /environment.yaml: -------------------------------------------------------------------------------- 1 | name: dino_rl 2 | channels: 3 | - defaults 4 | dependencies: 5 | - ca-certificates=2019.8.28=0 6 | - certifi=2019.9.11=py36_0 7 | - libcxx=4.0.1=hcfea43d_1 8 | - libcxxabi=4.0.1=hcfea43d_1 9 | - libedit=3.1.20181209=hb402a30_0 10 | - libffi=3.2.1=h475c297_4 11 | - ncurses=6.1=h0a44026_1 12 | - openssl=1.0.2t=h1de35cc_1 13 | - pip=19.2.3=py36_0 14 | - python=3.6.5=hc167b69_1 15 | - readline=7.0=h1de35cc_5 16 | - setuptools=41.4.0=py36_0 17 | - sqlite=3.30.0=ha441bb4_0 18 | - tk=8.6.8=ha441bb4_0 19 | - wheel=0.33.6=py36_0 20 | - xz=5.2.4=h1de35cc_4 21 | - zlib=1.2.11=h1de35cc_3 22 | - pip: 23 | - absl-py==0.8.1 24 | - astor==0.8.0 25 | - click==7.0 26 | - cloudpickle==1.2.2 27 | - future==0.18.1 28 | - gast==0.3.2 29 | - grpcio==1.24.1 30 | - gym==0.15.3 31 | - h5py==2.10.0 32 | - joblib==0.14.0 33 | - keras==2.3.1 34 | - keras-applications==1.0.8 35 | - keras-preprocessing==1.1.0 36 | - markdown==3.1.1 37 | - mock==3.0.5 38 | - numpy==1.17.3 39 | - opencv-python==4.1.1.26 40 | - pandas==0.25.2 41 | - pillow==6.2.0 42 | - protobuf==3.10.0 43 | - pyglet==1.3.2 44 | - python-dateutil==2.8.0 45 | - pytz==2019.3 46 | - pyyaml==5.1.2 47 | - scipy==1.3.1 48 | - selenium==3.141.0 49 | - six==1.12.0 50 | - tensorboard==1.13.1 51 | - tensorflow==1.13.1 52 | - tensorflow-estimator==1.13.0 53 | - termcolor==1.1.0 54 | - tqdm==4.36.1 55 | - urllib3==1.25.6 56 | - werkzeug==0.16.0 57 | prefix: /anaconda3/envs/dino_rl 58 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Dino Run Tutorial 2 | 3 | 4 | If you want train the model, please follow steps firstly [Download and build the env](https://github.com/leeivan1007/DinoRunTutorial#%E4%B8%8B%E8%BC%89%E8%88%87%E7%92%B0%E5%A2%83%E5%BB%BA%E7%BD%AE)、[Install baselines](https://github.com/leeivan1007/DinoRunTutorial#baselines%E5%AE%89%E8%A3%9D)、[Download chromedriver](https://github.com/leeivan1007/DinoRunTutorial#chromedriver%E5%AE%89%E8%A3%9D)、[Training&Predicting](https://github.com/leeivan1007/DinoRunTutorial#%E5%9F%B7%E8%A1%8C%E8%A8%93%E7%B7%B4%E9%A0%90%E6%B8%AC)。 5 | 6 | You can button the key that ctrl+c to stop the training. It will still train the model at the next training time. 7 | 8 | If it gets a object_error at pickle, please follow step at **object_error** 9 | 10 | [![Video Sample](https://media.giphy.com/media/Ahh7X6z7jZSSl4veLf/giphy.gif)](http://www.youtube.com/watch?v=w1Rqf2oxcPU) 11 | 12 | # Download and build the env 13 | ``` 14 | git clone https://github.com/leeivan1007/DinoRunTutorial.git 15 | cd DinoRunTutorial 16 | conda env create --file environment.yaml 17 | source activate dino_rl 18 | ``` 19 | ## Install baselines 20 | Open_AI reinforcement learning 21 | ``` 22 | git clone hhttps://github.com/openai/baselines.git 23 | cd baselines 24 | pip install -e . 25 | ``` 26 | ## Download chromedriver 27 | 28 | You can access this link [https://chromedriver.chromium.org/](https://chromedriver.chromium.org/),select the version(Latest stable release), and get the version of your OS 29 | 30 | After download the file, decompress the file as the specified path. 31 | 32 | If you follow the step. The chromedriver will be place at the same level with the repo's folder. 33 | ``` 34 | root_home/ 35 | chromedriver 36 | DinoRunTutorial 37 | ``` 38 | Or you can revision the var of the path (chrome_driver_path). 39 | 40 | ## Training&Predicting 41 | When the dino is running, it will train the model at a cycle time. 42 | ``` 43 | python unit5_dino.py 44 | ``` 45 | 46 | ## object_error 47 | 48 | If it has a error when pickle file load the object. Or you want to remove the trained file. Please run the code below. 49 | ``` 50 | python clear_objects.py 51 | ``` 52 | 53 | ## Original reference 54 | 55 | Accompanying code for Paperspace tutorial ["Build an AI to play Dino Run"](https://blog.paperspace.com/dino-run/) 56 | -------------------------------------------------------------------------------- /Visualize training Progress.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "## Visualize the progress of training\n", 8 | "All paths are relative." 9 | ] 10 | }, 11 | { 12 | "cell_type": "code", 13 | "execution_count": null, 14 | "metadata": { 15 | "collapsed": true 16 | }, 17 | "outputs": [], 18 | "source": [ 19 | "import time\n", 20 | "%matplotlib inline \n", 21 | "from matplotlib import pyplot as plt\n", 22 | "plt.rcParams['figure.figsize'] = (15, 9)\n", 23 | "import seaborn as sns\n", 24 | "import pandas as pd\n", 25 | "import numpy as np" 26 | ] 27 | }, 28 | { 29 | "cell_type": "code", 30 | "execution_count": null, 31 | "metadata": {}, 32 | "outputs": [], 33 | "source": [ 34 | "import pandas as pd\n", 35 | "start = 0\n", 36 | "interval = 10\n", 37 | "scores_df = pd.read_csv(\"./objects/scores_df.csv\")\n", 38 | "mean_scores = pd.DataFrame(columns =['score'])\n", 39 | "actions_df = pd.read_csv(\"./objects/actions_df.csv\")\n", 40 | "max_scores = pd.DataFrame(columns =['max_score'])\n", 41 | "q_max_scores = pd.DataFrame(columns =['q_max'])\n", 42 | "while interval <= len(scores_df):\n", 43 | " mean_scores.loc[len(mean_scores)] = (scores_df.loc[start:interval].mean()['scores'])\n", 44 | " max_scores.loc[len(max_scores)] = (scores_df.loc[start:interval].max()['scores'])\n", 45 | " start = interval\n", 46 | " interval = interval + 10\n", 47 | "\n", 48 | "q_max_df = pd.read_csv(\"./objects/q_values.csv\")\n", 49 | "\n", 50 | "start = 0\n", 51 | "interval = 1000\n", 52 | "while interval <=len(q_max_df):\n", 53 | " q_max_scores.loc[len(q_max_scores)] = (q_max_df.loc[start:interval].mean()['actions'])\n", 54 | " start = interval\n", 55 | " interval = interval + 1000\n", 56 | " \n", 57 | "mean_scores.plot()\n", 58 | "max_scores.plot()\n", 59 | "q_max_scores.plot()" 60 | ] 61 | } 62 | ], 63 | "metadata": { 64 | "kernelspec": { 65 | "display_name": "Python 3", 66 | "language": "python", 67 | "name": "python3" 68 | }, 69 | "language_info": { 70 | "codemirror_mode": { 71 | "name": "ipython", 72 | "version": 3 73 | }, 74 | "file_extension": ".py", 75 | "mimetype": "text/x-python", 76 | "name": "python", 77 | "nbconvert_exporter": "python", 78 | "pygments_lexer": "ipython3", 79 | "version": "3.6.3" 80 | } 81 | }, 82 | "nbformat": 4, 83 | "nbformat_minor": 2 84 | } 85 | -------------------------------------------------------------------------------- /model.json: -------------------------------------------------------------------------------- 1 | "{\"class_name\": \"Sequential\", \"config\": [{\"class_name\": \"Conv2D\", \"config\": {\"name\": \"conv2d_10\", \"trainable\": true, \"batch_input_shape\": [null, 80, 80, 4], \"dtype\": \"float32\", \"filters\": 32, \"kernel_size\": [8, 8], \"strides\": [4, 4], \"padding\": \"same\", \"data_format\": \"channels_last\", \"dilation_rate\": [1, 1], \"activation\": \"linear\", \"use_bias\": true, \"kernel_initializer\": {\"class_name\": \"VarianceScaling\", \"config\": {\"scale\": 1.0, \"mode\": \"fan_avg\", \"distribution\": \"uniform\", \"seed\": null}}, \"bias_initializer\": {\"class_name\": \"Zeros\", \"config\": {}}, \"kernel_regularizer\": null, \"bias_regularizer\": null, \"activity_regularizer\": null, \"kernel_constraint\": null, \"bias_constraint\": null}}, {\"class_name\": \"MaxPooling2D\", \"config\": {\"name\": \"max_pooling2d_10\", \"trainable\": true, \"pool_size\": [2, 2], \"padding\": \"valid\", \"strides\": [2, 2], \"data_format\": \"channels_last\"}}, {\"class_name\": \"Activation\", \"config\": {\"name\": \"activation_13\", \"trainable\": true, \"activation\": \"relu\"}}, {\"class_name\": \"Conv2D\", \"config\": {\"name\": \"conv2d_11\", \"trainable\": true, \"filters\": 64, \"kernel_size\": [4, 4], \"strides\": [2, 2], \"padding\": \"same\", \"data_format\": \"channels_last\", \"dilation_rate\": [1, 1], \"activation\": \"linear\", \"use_bias\": true, \"kernel_initializer\": {\"class_name\": \"VarianceScaling\", \"config\": {\"scale\": 1.0, \"mode\": \"fan_avg\", \"distribution\": \"uniform\", \"seed\": null}}, \"bias_initializer\": {\"class_name\": \"Zeros\", \"config\": {}}, \"kernel_regularizer\": null, \"bias_regularizer\": null, \"activity_regularizer\": null, \"kernel_constraint\": null, \"bias_constraint\": null}}, {\"class_name\": \"MaxPooling2D\", \"config\": {\"name\": \"max_pooling2d_11\", \"trainable\": true, \"pool_size\": [2, 2], \"padding\": \"valid\", \"strides\": [2, 2], \"data_format\": \"channels_last\"}}, {\"class_name\": \"Activation\", \"config\": {\"name\": \"activation_14\", \"trainable\": true, \"activation\": \"relu\"}}, {\"class_name\": \"Conv2D\", \"config\": {\"name\": \"conv2d_12\", \"trainable\": true, \"filters\": 64, \"kernel_size\": [3, 3], \"strides\": [1, 1], \"padding\": \"same\", \"data_format\": \"channels_last\", \"dilation_rate\": [1, 1], \"activation\": \"linear\", \"use_bias\": true, \"kernel_initializer\": {\"class_name\": \"VarianceScaling\", \"config\": {\"scale\": 1.0, \"mode\": \"fan_avg\", \"distribution\": \"uniform\", \"seed\": null}}, \"bias_initializer\": {\"class_name\": \"Zeros\", \"config\": {}}, \"kernel_regularizer\": null, \"bias_regularizer\": null, \"activity_regularizer\": null, \"kernel_constraint\": null, \"bias_constraint\": null}}, {\"class_name\": \"MaxPooling2D\", \"config\": {\"name\": \"max_pooling2d_12\", \"trainable\": true, \"pool_size\": [2, 2], \"padding\": \"valid\", \"strides\": [2, 2], \"data_format\": \"channels_last\"}}, {\"class_name\": \"Activation\", \"config\": {\"name\": \"activation_15\", \"trainable\": true, \"activation\": \"relu\"}}, {\"class_name\": \"Flatten\", \"config\": {\"name\": \"flatten_4\", \"trainable\": true, \"data_format\": \"channels_last\"}}, {\"class_name\": \"Dense\", \"config\": {\"name\": \"dense_7\", \"trainable\": true, \"units\": 512, \"activation\": \"linear\", \"use_bias\": true, \"kernel_initializer\": {\"class_name\": \"VarianceScaling\", \"config\": {\"scale\": 1.0, \"mode\": \"fan_avg\", \"distribution\": \"uniform\", \"seed\": null}}, \"bias_initializer\": {\"class_name\": \"Zeros\", \"config\": {}}, \"kernel_regularizer\": null, \"bias_regularizer\": null, \"activity_regularizer\": null, \"kernel_constraint\": null, \"bias_constraint\": null}}, {\"class_name\": \"Activation\", \"config\": {\"name\": \"activation_16\", \"trainable\": true, \"activation\": \"relu\"}}, {\"class_name\": \"Dense\", \"config\": {\"name\": \"dense_8\", \"trainable\": true, \"units\": 2, \"activation\": \"linear\", \"use_bias\": true, \"kernel_initializer\": {\"class_name\": \"VarianceScaling\", \"config\": {\"scale\": 1.0, \"mode\": \"fan_avg\", \"distribution\": \"uniform\", \"seed\": null}}, \"bias_initializer\": {\"class_name\": \"Zeros\", \"config\": {}}, \"kernel_regularizer\": null, \"bias_regularizer\": null, \"activity_regularizer\": null, \"kernel_constraint\": null, \"bias_constraint\": null}}], \"keras_version\": \"2.1.6\", \"backend\": \"tensorflow\"}" -------------------------------------------------------------------------------- /unit3_test_environment.py: -------------------------------------------------------------------------------- 1 | from selenium import webdriver 2 | from selenium.webdriver.chrome.options import Options 3 | from selenium.webdriver.common.keys import Keys 4 | 5 | chrome_driver_path = '../chromedriver' 6 | init_script = "document.getElementsByClassName('runner-canvas')[0].id = 'runner-canvas'" 7 | 8 | def save_obj(obj, name ): 9 | with open('objects/'+ name + '.pkl', 'wb') as f: #dump files into objects folder 10 | pickle.dump(obj, f, pickle.HIGHEST_PROTOCOL) 11 | def load_obj(name ): 12 | with open('objects/' + name + '.pkl', 'rb') as f: 13 | return pickle.load(f) 14 | def grab_screen(_driver): 15 | image_b64 = _driver.execute_script(getbase64Script) 16 | screen = np.array(Image.open(BytesIO(base64.b64decode(image_b64)))) 17 | image = process_img(screen)#processing image as required 18 | return image 19 | def process_img(image): 20 | 21 | image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) #RGB to Grey Scale 22 | image = image[:300, :500] #Crop Region of Interest(ROI) 23 | image = cv2.resize(image, (80,80)) 24 | return image 25 | def show_img(graphs = False): 26 | """ 27 | Show images in new window 28 | """ 29 | while True: 30 | screen = (yield) 31 | window_title = "logs" if graphs else "game_play" 32 | cv2.namedWindow(window_title, cv2.WINDOW_NORMAL) 33 | imS = cv2.resize(screen, (800, 400)) 34 | cv2.imshow(window_title, screen) 35 | if (cv2.waitKey(1) & 0xFF == ord('q')): 36 | cv2.destroyAllWindows() 37 | break 38 | 39 | class Game: 40 | def __init__(self,custom_config=True): 41 | chrome_options = Options() 42 | chrome_options.add_argument("disable-infobars") 43 | chrome_options.add_argument("--mute-audio") 44 | self._driver = webdriver.Chrome(executable_path = chrome_driver_path,chrome_options=chrome_options) 45 | self._driver.set_window_position(x=-10,y=0) 46 | self._driver.get('chrome://dino') 47 | self._driver.execute_script("Runner.config.ACCELERATION=0") 48 | self._driver.execute_script(init_script) 49 | # 這邊用class function: 50 | def get_crashed(self): # 確定小恐龍有無因撞到障礙物結束遊戲 51 | return self._driver.execute_script("return Runner.instance_.crashed") 52 | def get_playing(self): # 確認在遊戲過程狀態中有無結束遊戲 53 | return self._driver.execute_script("return Runner.instance_.playing") 54 | def restart(self): # 重啟遊戲 55 | self._driver.execute_script("Runner.instance_.restart()") 56 | def press_up(self): # 往上跳躍 57 | self._driver.find_element_by_tag_name("body").send_keys(Keys.ARROW_UP) 58 | def get_score(self): # 取得分數 59 | score_array = self._driver.execute_script("return Runner.instance_.distanceMeter.digits") 60 | score = ''.join(score_array) # the javascript object is of type array with score in the formate[1,0,0] which is 100. 61 | return int(score) 62 | def pause(self): # 暫停 63 | return self._driver.execute_script("return Runner.instance_.stop()") 64 | def resume(self): # 重啟暫停狀態 65 | return self._driver.execute_script("return Runner.instance_.play()") 66 | def end(self): # 結束selenium 67 | self._driver.close() 68 | 69 | class DinoAgent: 70 | def __init__(self,game): #takes game as input for taking actions 71 | self._game = game 72 | self.jump(); # 要先跳第一步,遊戲才能開始 73 | def is_running(self): 74 | return self._game.get_playing() 75 | def is_crashed(self): 76 | return self._game.get_crashed() 77 | def jump(self): 78 | self._game.press_up() 79 | def duck(self): # 這邊有實作往下,但實際上我們不會用到,進階版如果有直撲而來的鳥就有需要了。 80 | self._game.press_down() 81 | 82 | class Game_state: 83 | def __init__(self,agent,game): 84 | self._agent = agent 85 | self._game = game 86 | self._display = show_img() # 顯示小畫面 87 | self._display.__next__() # python語法,產生iter效果 88 | # 這邊先介紹要輸入的動作action,action[0]為不跳,action[1]為跳 89 | def get_state(self,actions): 90 | score = self._game.get_score() # 跟selenium要最新畫面 91 | reward = 0.1 # 先設定reward=0.1,只要能持續在場上皆會有0.1獎勵 92 | is_over = False # 遊戲是否結束 93 | if actions[1] == 1: # 決定是否跳躍 94 | self._agent.jump() 95 | image = grab_screen(self._game._driver) # 圖像前處理,下章介紹 96 | self._display.send(image) # 傳送畫面到小視窗 97 | if self._agent.is_crashed(): # 確定是否結束 98 | self._game.restart() # 重新開啟 99 | reward = -1 # 給負獎懲 100 | is_over = True # 確認結束 101 | return image, reward, is_over #return the Experience tuple 102 | # get_state對照gym就是step,這裡也會回傳下個state、獎勵跟terminate。 103 | 104 | game = Game() 105 | dino = DinoAgent(game) 106 | game_state = Game_state(dino,game) 107 | 108 | print('成功運行Game_state!') -------------------------------------------------------------------------------- /unit5_dino.py: -------------------------------------------------------------------------------- 1 | 2 | # import package 3 | import numpy as np 4 | from PIL import Image 5 | import cv2 #opencv 6 | import io 7 | import time 8 | import pandas as pd 9 | import numpy as np 10 | # from IPython.display import clear_output 11 | from random import randint 12 | import os 13 | 14 | from selenium import webdriver 15 | from selenium.webdriver.chrome.options import Options 16 | from selenium.webdriver.common.keys import Keys 17 | 18 | #keras imports 19 | from keras.models import model_from_json 20 | from keras.models import Sequential 21 | from keras.layers.core import Dense, Dropout, Activation, Flatten 22 | from keras.layers.convolutional import Conv2D, MaxPooling2D 23 | from keras.optimizers import SGD , Adam 24 | from keras.callbacks import TensorBoard 25 | from collections import deque 26 | import random 27 | import pickle 28 | from io import BytesIO 29 | import base64 30 | import json 31 | 32 | #path variables 33 | game_url = "chrome://dino" 34 | chrome_driver_path = "chromedriver" 35 | loss_file_path = "./objects/loss_df.csv" 36 | actions_file_path = "./objects/actions_df.csv" 37 | q_value_file_path = "./objects/q_values.csv" 38 | scores_file_path = "./objects/scores_df.csv" 39 | 40 | #scripts 41 | #create id for canvas for faster selection from DOM 42 | init_script = "document.getElementsByClassName('runner-canvas')[0].id = 'runner-canvas'" 43 | 44 | #get image from canvas 45 | getbase64Script = "canvasRunner = document.getElementById('runner-canvas'); \ 46 | return canvasRunner.toDataURL().substring(22)" 47 | 48 | ''' 49 | * Game class: Selenium interfacing between the python and browser 50 | * __init__(): Launch the broswer window using the attributes in chrome_options 51 | * get_crashed() : return true if the agent as crashed on an obstacles. Gets javascript variable from game decribing the state 52 | * get_playing(): true if game in progress, false is crashed or paused 53 | * restart() : sends a signal to browser-javascript to restart the game 54 | * press_up(): sends a single to press up get to the browser 55 | * get_score(): gets current game score from javascript variables. 56 | * pause(): pause the game 57 | * resume(): resume a paused game if not crashed 58 | * end(): close the browser and end the game 59 | ''' 60 | class Game: 61 | def __init__(self,custom_config=True): 62 | chrome_options = Options() 63 | chrome_options.add_argument("disable-infobars") 64 | chrome_options.add_argument("--mute-audio") 65 | self._driver = webdriver.Chrome(executable_path = chrome_driver_path,chrome_options=chrome_options) 66 | self._driver.set_window_position(x=-10,y=0) 67 | time.sleep(5) 68 | self._driver.get('chrome://dino') 69 | self._driver.execute_script("Runner.config.ACCELERATION=0") 70 | self._driver.execute_script(init_script) 71 | def get_crashed(self): 72 | return self._driver.execute_script("return Runner.instance_.crashed") 73 | def get_playing(self): 74 | return self._driver.execute_script("return Runner.instance_.playing") 75 | def restart(self): 76 | self._driver.execute_script("Runner.instance_.restart()") 77 | def press_up(self): 78 | self._driver.find_element_by_tag_name("body").send_keys(Keys.ARROW_UP) 79 | def get_score(self): 80 | score_array = self._driver.execute_script("return Runner.instance_.distanceMeter.digits") 81 | score = ''.join(score_array) # the javascript object is of type array with score in the formate[1,0,0] which is 100. 82 | return int(score) 83 | def pause(self): 84 | return self._driver.execute_script("return Runner.instance_.stop()") 85 | def resume(self): 86 | return self._driver.execute_script("return Runner.instance_.play()") 87 | def end(self): 88 | self._driver.close() 89 | class DinoAgent: 90 | def __init__(self,game): #takes game as input for taking actions 91 | self._game = game; 92 | self.jump(); #to start the game, we need to jump once 93 | def is_running(self): 94 | return self._game.get_playing() 95 | def is_crashed(self): 96 | return self._game.get_crashed() 97 | def jump(self): 98 | self._game.press_up() 99 | def duck(self): 100 | self._game.press_down() 101 | class Game_sate: 102 | def __init__(self,agent,game): 103 | self._agent = agent 104 | self._game = game 105 | self._display = show_img() #display the processed image on screen using openCV, implemented using python coroutine 106 | self._display.__next__() # initiliaze the display coroutine 107 | def get_state(self,actions): 108 | actions_df.loc[len(actions_df)] = actions[1] # storing actions in a dataframe 109 | score = self._game.get_score() 110 | reward = 0.1 111 | is_over = False #game over 112 | if actions[1] == 1: 113 | self._agent.jump() 114 | image = grab_screen(self._game._driver) 115 | self._display.send(image) #display the image on screen 116 | if self._agent.is_crashed(): 117 | scores_df.loc[len(loss_df)] = score # log the score when game is over 118 | self._game.restart() 119 | reward = -1 120 | is_over = True 121 | return image, reward, is_over #return the Experience tuple 122 | 123 | def save_obj(obj, name ): 124 | with open('objects/'+ name + '.pkl', 'wb') as f: #dump files into objects folder 125 | pickle.dump(obj, f, pickle.HIGHEST_PROTOCOL) 126 | def load_obj(name ): 127 | with open('objects/' + name + '.pkl', 'rb') as f: 128 | return pickle.load(f) 129 | def grab_screen(_driver): 130 | image_b64 = _driver.execute_script(getbase64Script) 131 | screen = np.array(Image.open(BytesIO(base64.b64decode(image_b64)))) 132 | image = process_img(screen)#processing image as required 133 | return image 134 | def process_img(image): 135 | 136 | image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) #RGB to Grey Scale 137 | image = image[:300, :500] #Crop Region of Interest(ROI) 138 | image = cv2.resize(image, (80,80)) 139 | return image 140 | def show_img(graphs = False): 141 | """ 142 | Show images in new window 143 | """ 144 | while True: 145 | screen = (yield) 146 | window_title = "logs" if graphs else "game_play" 147 | cv2.namedWindow(window_title, cv2.WINDOW_NORMAL) 148 | imS = cv2.resize(screen, (800, 400)) 149 | cv2.imshow(window_title, screen) 150 | if (cv2.waitKey(1) & 0xFF == ord('q')): 151 | cv2.destroyAllWindows() 152 | break 153 | 154 | #Intialize log structures from file if exists else create new 155 | loss_df = pd.read_csv(loss_file_path) if os.path.isfile(loss_file_path) else pd.DataFrame(columns =['loss']) 156 | scores_df = pd.read_csv(scores_file_path) if os.path.isfile(loss_file_path) else pd.DataFrame(columns = ['scores']) 157 | actions_df = pd.read_csv(actions_file_path) if os.path.isfile(actions_file_path) else pd.DataFrame(columns = ['actions']) 158 | q_values_df =pd.read_csv(actions_file_path) if os.path.isfile(q_value_file_path) else pd.DataFrame(columns = ['qvalues']) 159 | 160 | #game parameters 161 | ACTIONS = 2 # possible actions: jump, do nothing 162 | GAMMA = 0.99 # decay rate of past observations original 0.99 163 | OBSERVATION = 100. # timesteps to observe before training 164 | EXPLORE = 100000 # frames over which to anneal epsilon 165 | FINAL_EPSILON = 0.0001 # final value of epsilon 166 | INITIAL_EPSILON = 0.1 # starting value of epsilon 167 | REPLAY_MEMORY = 50000 # number of previous transitions to remember 168 | BATCH = 16 # size of minibatch 169 | FRAME_PER_ACTION = 1 170 | LEARNING_RATE = 1e-4 171 | img_rows , img_cols = 80,80 172 | img_channels = 4 #We stack 4 frames 173 | 174 | # training variables saved as checkpoints to filesystem to resume training from the same step 175 | def init_cache(): 176 | """initial variable caching, done only once""" 177 | save_obj(INITIAL_EPSILON,"epsilon") 178 | t = 0 179 | save_obj(t,"time") 180 | D = deque() 181 | save_obj(D,"D") 182 | 183 | '''Call only once to init file structure 184 | ''' 185 | #init_cache() 186 | 187 | def buildmodel(): 188 | print("Now we build the model") 189 | model = Sequential() 190 | model.add(Conv2D(32, (8, 8), padding='same',strides=(4, 4),input_shape=(img_cols,img_rows,img_channels))) #80*80*4 191 | model.add(MaxPooling2D(pool_size=(2,2))) 192 | model.add(Activation('relu')) 193 | model.add(Conv2D(64, (4, 4),strides=(2, 2), padding='same')) 194 | model.add(MaxPooling2D(pool_size=(2,2))) 195 | model.add(Activation('relu')) 196 | model.add(Conv2D(64, (3, 3),strides=(1, 1), padding='same')) 197 | model.add(MaxPooling2D(pool_size=(2,2))) 198 | model.add(Activation('relu')) 199 | model.add(Flatten()) 200 | model.add(Dense(512)) 201 | model.add(Activation('relu')) 202 | model.add(Dense(ACTIONS)) 203 | adam = Adam(lr=LEARNING_RATE) 204 | model.compile(loss='mse',optimizer=adam) 205 | 206 | #create model file if not present 207 | if not os.path.isfile(loss_file_path): 208 | model.save_weights('model.h5') 209 | print("We finish building the model") 210 | return model 211 | 212 | ''' 213 | main training module 214 | Parameters: 215 | * model => Keras Model to be trained 216 | * game_state => Game State module with access to game environment and dino 217 | * observe => flag to indicate wherther the model is to be trained(weight updates), else just play 218 | ''' 219 | def trainNetwork(model,game_state,observe=False): 220 | last_time = time.time() 221 | # store the previous observations in replay memory 222 | D = load_obj("D") #load from file system 223 | # get the first state by doing nothing 224 | do_nothing = np.zeros(ACTIONS) 225 | do_nothing[0] =1 #0 => do nothing, 226 | #1=> jump 227 | 228 | x_t, r_0, terminal = game_state.get_state(do_nothing) # get next step after performing the action 229 | 230 | 231 | s_t = np.stack((x_t, x_t, x_t, x_t), axis=2) # stack 4 images to create placeholder input 232 | s_t = s_t.reshape(1, s_t.shape[0], s_t.shape[1], s_t.shape[2]) #1*20*40*4 233 | initial_state = s_t 234 | 235 | if observe : 236 | OBSERVE = 999999999 #We keep observe, never train 237 | epsilon = FINAL_EPSILON 238 | print ("Now we load weight") 239 | model.load_weights("model.h5") 240 | adam = Adam(lr=LEARNING_RATE) 241 | model.compile(loss='mse',optimizer=adam) 242 | print ("Weight load successfully") 243 | else: #We go to training mode 244 | OBSERVE = OBSERVATION 245 | epsilon = load_obj("epsilon") 246 | model.load_weights("model.h5") 247 | adam = Adam(lr=LEARNING_RATE) 248 | model.compile(loss='mse',optimizer=adam) 249 | t = load_obj("time") # resume from the previous time step stored in file system 250 | while (True): #endless running 251 | 252 | loss = 0 253 | Q_sa = 0 254 | action_index = 0 255 | r_t = 0 #reward at 4 256 | a_t = np.zeros([ACTIONS]) # action at t 257 | 258 | #choose an action epsilon greedy 259 | if t % FRAME_PER_ACTION == 0: #parameter to skip frames for actions 260 | if random.random() <= epsilon: #randomly explore an action 261 | print("----------Random Action----------") 262 | action_index = random.randrange(ACTIONS) 263 | a_t[action_index] = 1 264 | else: # predict the output 265 | q = model.predict(s_t) #input a stack of 4 images, get the prediction 266 | max_Q = np.argmax(q) # chosing index with maximum q value 267 | action_index = max_Q 268 | a_t[action_index] = 1 # o=> do nothing, 1=> jump 269 | 270 | #We reduced the epsilon (exploration parameter) gradually 271 | if epsilon > FINAL_EPSILON and t > OBSERVE: 272 | epsilon -= (INITIAL_EPSILON - FINAL_EPSILON) / EXPLORE 273 | 274 | #run the selected action and observed next state and reward 275 | x_t1, r_t, terminal = game_state.get_state(a_t) 276 | print('fps: {0}'.format(1 / (time.time()-last_time))) # helpful for measuring frame rate 277 | last_time = time.time() 278 | x_t1 = x_t1.reshape(1, x_t1.shape[0], x_t1.shape[1], 1) #1x20x40x1 279 | s_t1 = np.append(x_t1, s_t[:, :, :, :3], axis=3) # append the new image to input stack and remove the first one 280 | 281 | 282 | # store the transition in D 283 | D.append((s_t, action_index, r_t, s_t1, terminal)) 284 | if len(D) > REPLAY_MEMORY: 285 | D.popleft() 286 | 287 | #only train if done observing 288 | if t > OBSERVE: 289 | 290 | #sample a minibatch to train on 291 | minibatch = random.sample(D, BATCH) 292 | inputs = np.zeros((BATCH, s_t.shape[1], s_t.shape[2], s_t.shape[3])) #32, 20, 40, 4 293 | targets = np.zeros((inputs.shape[0], ACTIONS)) #32, 2 294 | 295 | #Now we do the experience replay 296 | for i in range(0, len(minibatch)): 297 | state_t = minibatch[i][0] # 4D stack of images 298 | action_t = minibatch[i][1] #This is action index 299 | reward_t = minibatch[i][2] #reward at state_t due to action_t 300 | state_t1 = minibatch[i][3] #next state 301 | terminal = minibatch[i][4] #wheather the agent died or survided due the action 302 | 303 | 304 | inputs[i:i + 1] = state_t 305 | 306 | targets[i] = model.predict(state_t) # predicted q values 307 | Q_sa = model.predict(state_t1) #predict q values for next step 308 | 309 | if terminal: 310 | targets[i, action_t] = reward_t # if terminated, only equals reward 311 | else: 312 | targets[i, action_t] = reward_t + GAMMA * np.max(Q_sa) 313 | 314 | loss += model.train_on_batch(inputs, targets) 315 | loss_df.loc[len(loss_df)] = loss 316 | q_values_df.loc[len(q_values_df)] = np.max(Q_sa) 317 | s_t = initial_state if terminal else s_t1 #reset game to initial frame if terminate 318 | t = t + 1 319 | 320 | # save progress every 1000 iterations 321 | if t % 1000 == 0: 322 | print("Now we save model") 323 | game_state._game.pause() #pause game while saving to filesystem 324 | model.save_weights("model.h5", overwrite=True) 325 | save_obj(D,"D") #saving episodes 326 | save_obj(t,"time") #caching time steps 327 | save_obj(epsilon,"epsilon") #cache epsilon to avoid repeated randomness in actions 328 | loss_df.to_csv("./objects/loss_df.csv",index=False) 329 | scores_df.to_csv("./objects/scores_df.csv",index=False) 330 | actions_df.to_csv("./objects/actions_df.csv",index=False) 331 | q_values_df.to_csv(q_value_file_path,index=False) 332 | with open("model.json", "w") as outfile: 333 | json.dump(model.to_json(), outfile) 334 | # clear_output() 335 | game_state._game.resume() 336 | # print info 337 | state = "" 338 | if t <= OBSERVE: 339 | state = "observe" 340 | elif t > OBSERVE and t <= OBSERVE + EXPLORE: 341 | state = "explore" 342 | else: 343 | state = "train" 344 | 345 | print("TIMESTEP", t, "/ STATE", state, "/ EPSILON", epsilon, "/ ACTION", action_index, "/ REWARD", r_t, "/ Q_MAX " , np.max(Q_sa), "/ Loss ", loss) 346 | 347 | print("Episode finished!") 348 | print("************************") 349 | 350 | 351 | #main function 352 | def playGame(observe=False): 353 | game = Game() 354 | dino = DinoAgent(game) 355 | game_state = Game_sate(dino,game) 356 | model = buildmodel() 357 | try: 358 | trainNetwork(model,game_state,observe=observe) 359 | except StopIteration: 360 | game.end() 361 | 362 | playGame(observe=False) -------------------------------------------------------------------------------- /Reinforcement Learning Dino Run.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import numpy as np\n", 10 | "from PIL import Image\n", 11 | "import cv2 #opencv\n", 12 | "import io\n", 13 | "import time\n", 14 | "import pandas as pd\n", 15 | "import numpy as np\n", 16 | "from IPython.display import clear_output\n", 17 | "from random import randint\n", 18 | "import os\n", 19 | "\n", 20 | "from selenium import webdriver\n", 21 | "from selenium.webdriver.chrome.options import Options\n", 22 | "from selenium.webdriver.common.keys import Keys\n", 23 | "\n", 24 | "#keras imports\n", 25 | "from keras.models import model_from_json\n", 26 | "from keras.models import Sequential\n", 27 | "from keras.layers.core import Dense, Dropout, Activation, Flatten\n", 28 | "from keras.layers.convolutional import Conv2D, MaxPooling2D\n", 29 | "from keras.optimizers import SGD , Adam\n", 30 | "from keras.callbacks import TensorBoard\n", 31 | "from collections import deque\n", 32 | "import random\n", 33 | "import pickle\n", 34 | "from io import BytesIO\n", 35 | "import base64\n", 36 | "import json" 37 | ] 38 | }, 39 | { 40 | "cell_type": "code", 41 | "execution_count": null, 42 | "metadata": { 43 | "collapsed": true 44 | }, 45 | "outputs": [], 46 | "source": [ 47 | "#path variables\n", 48 | "game_url = \"chrome://dino\"\n", 49 | "chrome_driver_path = \"../chromedriver\"\n", 50 | "loss_file_path = \"./objects/loss_df.csv\"\n", 51 | "actions_file_path = \"./objects/actions_df.csv\"\n", 52 | "q_value_file_path = \"./objects/q_values.csv\"\n", 53 | "scores_file_path = \"./objects/scores_df.csv\"\n", 54 | "\n", 55 | "#scripts\n", 56 | "#create id for canvas for faster selection from DOM\n", 57 | "init_script = \"document.getElementsByClassName('runner-canvas')[0].id = 'runner-canvas'\"\n", 58 | "\n", 59 | "#get image from canvas\n", 60 | "getbase64Script = \"canvasRunner = document.getElementById('runner-canvas'); \\\n", 61 | "return canvasRunner.toDataURL().substring(22)\"" 62 | ] 63 | }, 64 | { 65 | "cell_type": "code", 66 | "execution_count": null, 67 | "metadata": { 68 | "collapsed": true 69 | }, 70 | "outputs": [], 71 | "source": [ 72 | "'''\n", 73 | "* Game class: Selenium interfacing between the python and browser\n", 74 | "* __init__(): Launch the broswer window using the attributes in chrome_options\n", 75 | "* get_crashed() : return true if the agent as crashed on an obstacles. Gets javascript variable from game decribing the state\n", 76 | "* get_playing(): true if game in progress, false is crashed or paused\n", 77 | "* restart() : sends a signal to browser-javascript to restart the game\n", 78 | "* press_up(): sends a single to press up get to the browser\n", 79 | "* get_score(): gets current game score from javascript variables.\n", 80 | "* pause(): pause the game\n", 81 | "* resume(): resume a paused game if not crashed\n", 82 | "* end(): close the browser and end the game\n", 83 | "'''\n", 84 | "class Game:\n", 85 | " def __init__(self,custom_config=True):\n", 86 | " chrome_options = Options()\n", 87 | " chrome_options.add_argument(\"disable-infobars\")\n", 88 | " chrome_options.add_argument(\"--mute-audio\")\n", 89 | " self._driver = webdriver.Chrome(executable_path = chrome_driver_path,chrome_options=chrome_options)\n", 90 | " self._driver.set_window_position(x=-10,y=0)\n", 91 | " self._driver.get('chrome://dino')\n", 92 | " self._driver.execute_script(\"Runner.config.ACCELERATION=0\")\n", 93 | " self._driver.execute_script(init_script)\n", 94 | " def get_crashed(self):\n", 95 | " return self._driver.execute_script(\"return Runner.instance_.crashed\")\n", 96 | " def get_playing(self):\n", 97 | " return self._driver.execute_script(\"return Runner.instance_.playing\")\n", 98 | " def restart(self):\n", 99 | " self._driver.execute_script(\"Runner.instance_.restart()\")\n", 100 | " def press_up(self):\n", 101 | " self._driver.find_element_by_tag_name(\"body\").send_keys(Keys.ARROW_UP)\n", 102 | " def get_score(self):\n", 103 | " score_array = self._driver.execute_script(\"return Runner.instance_.distanceMeter.digits\")\n", 104 | " score = ''.join(score_array) # the javascript object is of type array with score in the formate[1,0,0] which is 100.\n", 105 | " return int(score)\n", 106 | " def pause(self):\n", 107 | " return self._driver.execute_script(\"return Runner.instance_.stop()\")\n", 108 | " def resume(self):\n", 109 | " return self._driver.execute_script(\"return Runner.instance_.play()\")\n", 110 | " def end(self):\n", 111 | " self._driver.close()" 112 | ] 113 | }, 114 | { 115 | "cell_type": "code", 116 | "execution_count": null, 117 | "metadata": { 118 | "collapsed": true 119 | }, 120 | "outputs": [], 121 | "source": [ 122 | "class DinoAgent:\n", 123 | " def __init__(self,game): #takes game as input for taking actions\n", 124 | " self._game = game; \n", 125 | " self.jump(); #to start the game, we need to jump once\n", 126 | " def is_running(self):\n", 127 | " return self._game.get_playing()\n", 128 | " def is_crashed(self):\n", 129 | " return self._game.get_crashed()\n", 130 | " def jump(self):\n", 131 | " self._game.press_up()\n", 132 | " def duck(self):\n", 133 | " self._game.press_down()" 134 | ] 135 | }, 136 | { 137 | "cell_type": "code", 138 | "execution_count": null, 139 | "metadata": { 140 | "collapsed": true 141 | }, 142 | "outputs": [], 143 | "source": [ 144 | "class Game_sate:\n", 145 | " def __init__(self,agent,game):\n", 146 | " self._agent = agent\n", 147 | " self._game = game\n", 148 | " self._display = show_img() #display the processed image on screen using openCV, implemented using python coroutine \n", 149 | " self._display.__next__() # initiliaze the display coroutine \n", 150 | " def get_state(self,actions):\n", 151 | " actions_df.loc[len(actions_df)] = actions[1] # storing actions in a dataframe\n", 152 | " score = self._game.get_score() \n", 153 | " reward = 0.1\n", 154 | " is_over = False #game over\n", 155 | " if actions[1] == 1:\n", 156 | " self._agent.jump()\n", 157 | " image = grab_screen(self._game._driver) \n", 158 | " self._display.send(image) #display the image on screen\n", 159 | " if self._agent.is_crashed():\n", 160 | " scores_df.loc[len(loss_df)] = score # log the score when game is over\n", 161 | " self._game.restart()\n", 162 | " reward = -1\n", 163 | " is_over = True\n", 164 | " return image, reward, is_over #return the Experience tuple" 165 | ] 166 | }, 167 | { 168 | "cell_type": "code", 169 | "execution_count": null, 170 | "metadata": { 171 | "collapsed": true 172 | }, 173 | "outputs": [], 174 | "source": [ 175 | "def save_obj(obj, name ):\n", 176 | " with open('objects/'+ name + '.pkl', 'wb') as f: #dump files into objects folder\n", 177 | " pickle.dump(obj, f, pickle.HIGHEST_PROTOCOL)\n", 178 | "def load_obj(name ):\n", 179 | " with open('objects/' + name + '.pkl', 'rb') as f:\n", 180 | " return pickle.load(f)\n", 181 | "\n", 182 | "def grab_screen(_driver):\n", 183 | " image_b64 = _driver.execute_script(getbase64Script)\n", 184 | " screen = np.array(Image.open(BytesIO(base64.b64decode(image_b64))))\n", 185 | " image = process_img(screen)#processing image as required\n", 186 | " return image\n", 187 | "\n", 188 | "def process_img(image):\n", 189 | " \n", 190 | " image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) #RGB to Grey Scale\n", 191 | " image = image[:300, :500] #Crop Region of Interest(ROI)\n", 192 | " image = cv2.resize(image, (80,80))\n", 193 | " return image\n", 194 | "\n", 195 | "def show_img(graphs = False):\n", 196 | " \"\"\"\n", 197 | " Show images in new window\n", 198 | " \"\"\"\n", 199 | " while True:\n", 200 | " screen = (yield)\n", 201 | " window_title = \"logs\" if graphs else \"game_play\"\n", 202 | " cv2.namedWindow(window_title, cv2.WINDOW_NORMAL) \n", 203 | " imS = cv2.resize(screen, (800, 400)) \n", 204 | " cv2.imshow(window_title, screen)\n", 205 | " if (cv2.waitKey(1) & 0xFF == ord('q')):\n", 206 | " cv2.destroyAllWindows()\n", 207 | " break" 208 | ] 209 | }, 210 | { 211 | "cell_type": "code", 212 | "execution_count": null, 213 | "metadata": { 214 | "collapsed": true 215 | }, 216 | "outputs": [], 217 | "source": [ 218 | "#Intialize log structures from file if exists else create new\n", 219 | "loss_df = pd.read_csv(loss_file_path) if os.path.isfile(loss_file_path) else pd.DataFrame(columns =['loss'])\n", 220 | "scores_df = pd.read_csv(scores_file_path) if os.path.isfile(loss_file_path) else pd.DataFrame(columns = ['scores'])\n", 221 | "actions_df = pd.read_csv(actions_file_path) if os.path.isfile(actions_file_path) else pd.DataFrame(columns = ['actions'])\n", 222 | "q_values_df =pd.read_csv(actions_file_path) if os.path.isfile(q_value_file_path) else pd.DataFrame(columns = ['qvalues'])" 223 | ] 224 | }, 225 | { 226 | "cell_type": "code", 227 | "execution_count": null, 228 | "metadata": { 229 | "collapsed": true 230 | }, 231 | "outputs": [], 232 | "source": [ 233 | "#game parameters\n", 234 | "ACTIONS = 2 # possible actions: jump, do nothing\n", 235 | "GAMMA = 0.99 # decay rate of past observations original 0.99\n", 236 | "OBSERVATION = 100. # timesteps to observe before training\n", 237 | "EXPLORE = 100000 # frames over which to anneal epsilon\n", 238 | "FINAL_EPSILON = 0.0001 # final value of epsilon\n", 239 | "INITIAL_EPSILON = 0.1 # starting value of epsilon\n", 240 | "REPLAY_MEMORY = 50000 # number of previous transitions to remember\n", 241 | "BATCH = 16 # size of minibatch\n", 242 | "FRAME_PER_ACTION = 1\n", 243 | "LEARNING_RATE = 1e-4\n", 244 | "img_rows , img_cols = 80,80\n", 245 | "img_channels = 4 #We stack 4 frames" 246 | ] 247 | }, 248 | { 249 | "cell_type": "code", 250 | "execution_count": null, 251 | "metadata": { 252 | "collapsed": true 253 | }, 254 | "outputs": [], 255 | "source": [ 256 | "# training variables saved as checkpoints to filesystem to resume training from the same step\n", 257 | "def init_cache():\n", 258 | " \"\"\"initial variable caching, done only once\"\"\"\n", 259 | " save_obj(INITIAL_EPSILON,\"epsilon\")\n", 260 | " t = 0\n", 261 | " save_obj(t,\"time\")\n", 262 | " D = deque()\n", 263 | " save_obj(D,\"D\")" 264 | ] 265 | }, 266 | { 267 | "cell_type": "code", 268 | "execution_count": null, 269 | "metadata": {}, 270 | "outputs": [], 271 | "source": [ 272 | "'''Call only once to init file structure\n", 273 | "'''\n", 274 | "#init_cache()" 275 | ] 276 | }, 277 | { 278 | "cell_type": "code", 279 | "execution_count": null, 280 | "metadata": {}, 281 | "outputs": [], 282 | "source": [ 283 | "def buildmodel():\n", 284 | " print(\"Now we build the model\")\n", 285 | " model = Sequential()\n", 286 | " model.add(Conv2D(32, (8, 8), padding='same',strides=(4, 4),input_shape=(img_cols,img_rows,img_channels))) #80*80*4\n", 287 | " model.add(MaxPooling2D(pool_size=(2,2)))\n", 288 | " model.add(Activation('relu'))\n", 289 | " model.add(Conv2D(64, (4, 4),strides=(2, 2), padding='same'))\n", 290 | " model.add(MaxPooling2D(pool_size=(2,2)))\n", 291 | " model.add(Activation('relu'))\n", 292 | " model.add(Conv2D(64, (3, 3),strides=(1, 1), padding='same'))\n", 293 | " model.add(MaxPooling2D(pool_size=(2,2)))\n", 294 | " model.add(Activation('relu'))\n", 295 | " model.add(Flatten())\n", 296 | " model.add(Dense(512))\n", 297 | " model.add(Activation('relu'))\n", 298 | " model.add(Dense(ACTIONS))\n", 299 | " adam = Adam(lr=LEARNING_RATE)\n", 300 | " model.compile(loss='mse',optimizer=adam)\n", 301 | " \n", 302 | " #create model file if not present\n", 303 | " if not os.path.isfile(loss_file_path):\n", 304 | " model.save_weights('model.h5')\n", 305 | " print(\"We finish building the model\")\n", 306 | " return model" 307 | ] 308 | }, 309 | { 310 | "cell_type": "code", 311 | "execution_count": null, 312 | "metadata": { 313 | "collapsed": true 314 | }, 315 | "outputs": [], 316 | "source": [ 317 | "''' \n", 318 | "main training module\n", 319 | "Parameters:\n", 320 | "* model => Keras Model to be trained\n", 321 | "* game_state => Game State module with access to game environment and dino\n", 322 | "* observe => flag to indicate wherther the model is to be trained(weight updates), else just play\n", 323 | "'''\n", 324 | "def trainNetwork(model,game_state,observe=False):\n", 325 | " last_time = time.time()\n", 326 | " # store the previous observations in replay memory\n", 327 | " D = load_obj(\"D\") #load from file system\n", 328 | " # get the first state by doing nothing\n", 329 | " do_nothing = np.zeros(ACTIONS)\n", 330 | " do_nothing[0] =1 #0 => do nothing,\n", 331 | " #1=> jump\n", 332 | " \n", 333 | " x_t, r_0, terminal = game_state.get_state(do_nothing) # get next step after performing the action\n", 334 | " \n", 335 | "\n", 336 | " s_t = np.stack((x_t, x_t, x_t, x_t), axis=2) # stack 4 images to create placeholder input\n", 337 | " \n", 338 | "\n", 339 | " \n", 340 | " s_t = s_t.reshape(1, s_t.shape[0], s_t.shape[1], s_t.shape[2]) #1*20*40*4\n", 341 | " \n", 342 | " initial_state = s_t \n", 343 | "\n", 344 | " if observe :\n", 345 | " OBSERVE = 999999999 #We keep observe, never train\n", 346 | " epsilon = FINAL_EPSILON\n", 347 | " print (\"Now we load weight\")\n", 348 | " model.load_weights(\"model.h5\")\n", 349 | " adam = Adam(lr=LEARNING_RATE)\n", 350 | " model.compile(loss='mse',optimizer=adam)\n", 351 | " print (\"Weight load successfully\") \n", 352 | " else: #We go to training mode\n", 353 | " OBSERVE = OBSERVATION\n", 354 | " epsilon = load_obj(\"epsilon\") \n", 355 | " model.load_weights(\"model.h5\")\n", 356 | " adam = Adam(lr=LEARNING_RATE)\n", 357 | " model.compile(loss='mse',optimizer=adam)\n", 358 | "\n", 359 | " t = load_obj(\"time\") # resume from the previous time step stored in file system\n", 360 | " while (True): #endless running\n", 361 | " \n", 362 | " loss = 0\n", 363 | " Q_sa = 0\n", 364 | " action_index = 0\n", 365 | " r_t = 0 #reward at 4\n", 366 | " a_t = np.zeros([ACTIONS]) # action at t\n", 367 | " \n", 368 | " #choose an action epsilon greedy\n", 369 | " if t % FRAME_PER_ACTION == 0: #parameter to skip frames for actions\n", 370 | " if random.random() <= epsilon: #randomly explore an action\n", 371 | " print(\"----------Random Action----------\")\n", 372 | " action_index = random.randrange(ACTIONS)\n", 373 | " a_t[action_index] = 1\n", 374 | " else: # predict the output\n", 375 | " q = model.predict(s_t) #input a stack of 4 images, get the prediction\n", 376 | " max_Q = np.argmax(q) # chosing index with maximum q value\n", 377 | " action_index = max_Q \n", 378 | " a_t[action_index] = 1 # o=> do nothing, 1=> jump\n", 379 | " \n", 380 | " #We reduced the epsilon (exploration parameter) gradually\n", 381 | " if epsilon > FINAL_EPSILON and t > OBSERVE:\n", 382 | " epsilon -= (INITIAL_EPSILON - FINAL_EPSILON) / EXPLORE \n", 383 | "\n", 384 | " #run the selected action and observed next state and reward\n", 385 | " x_t1, r_t, terminal = game_state.get_state(a_t)\n", 386 | " print('fps: {0}'.format(1 / (time.time()-last_time))) # helpful for measuring frame rate\n", 387 | " last_time = time.time()\n", 388 | " x_t1 = x_t1.reshape(1, x_t1.shape[0], x_t1.shape[1], 1) #1x20x40x1\n", 389 | " s_t1 = np.append(x_t1, s_t[:, :, :, :3], axis=3) # append the new image to input stack and remove the first one\n", 390 | " \n", 391 | " \n", 392 | " # store the transition in D\n", 393 | " D.append((s_t, action_index, r_t, s_t1, terminal))\n", 394 | " if len(D) > REPLAY_MEMORY:\n", 395 | " D.popleft()\n", 396 | "\n", 397 | " #only train if done observing\n", 398 | " if t > OBSERVE: \n", 399 | " \n", 400 | " #sample a minibatch to train on\n", 401 | " minibatch = random.sample(D, BATCH)\n", 402 | " inputs = np.zeros((BATCH, s_t.shape[1], s_t.shape[2], s_t.shape[3])) #32, 20, 40, 4\n", 403 | " targets = np.zeros((inputs.shape[0], ACTIONS)) #32, 2\n", 404 | "\n", 405 | " #Now we do the experience replay\n", 406 | " for i in range(0, len(minibatch)):\n", 407 | " state_t = minibatch[i][0] # 4D stack of images\n", 408 | " action_t = minibatch[i][1] #This is action index\n", 409 | " reward_t = minibatch[i][2] #reward at state_t due to action_t\n", 410 | " state_t1 = minibatch[i][3] #next state\n", 411 | " terminal = minibatch[i][4] #wheather the agent died or survided due the action\n", 412 | " \n", 413 | "\n", 414 | " inputs[i:i + 1] = state_t \n", 415 | "\n", 416 | " targets[i] = model.predict(state_t) # predicted q values\n", 417 | " Q_sa = model.predict(state_t1) #predict q values for next step\n", 418 | " \n", 419 | " if terminal:\n", 420 | " targets[i, action_t] = reward_t # if terminated, only equals reward\n", 421 | " else:\n", 422 | " targets[i, action_t] = reward_t + GAMMA * np.max(Q_sa)\n", 423 | "\n", 424 | " loss += model.train_on_batch(inputs, targets)\n", 425 | " loss_df.loc[len(loss_df)] = loss\n", 426 | " q_values_df.loc[len(q_values_df)] = np.max(Q_sa)\n", 427 | " s_t = initial_state if terminal else s_t1 #reset game to initial frame if terminate\n", 428 | " t = t + 1\n", 429 | " \n", 430 | " # save progress every 1000 iterations\n", 431 | " if t % 1000 == 0:\n", 432 | " print(\"Now we save model\")\n", 433 | " game_state._game.pause() #pause game while saving to filesystem\n", 434 | " model.save_weights(\"model.h5\", overwrite=True)\n", 435 | " save_obj(D,\"D\") #saving episodes\n", 436 | " save_obj(t,\"time\") #caching time steps\n", 437 | " save_obj(epsilon,\"epsilon\") #cache epsilon to avoid repeated randomness in actions\n", 438 | " loss_df.to_csv(\"./objects/loss_df.csv\",index=False)\n", 439 | " scores_df.to_csv(\"./objects/scores_df.csv\",index=False)\n", 440 | " actions_df.to_csv(\"./objects/actions_df.csv\",index=False)\n", 441 | " q_values_df.to_csv(q_value_file_path,index=False)\n", 442 | " with open(\"model.json\", \"w\") as outfile:\n", 443 | " json.dump(model.to_json(), outfile)\n", 444 | " clear_output()\n", 445 | " game_state._game.resume()\n", 446 | " # print info\n", 447 | " state = \"\"\n", 448 | " if t <= OBSERVE:\n", 449 | " state = \"observe\"\n", 450 | " elif t > OBSERVE and t <= OBSERVE + EXPLORE:\n", 451 | " state = \"explore\"\n", 452 | " else:\n", 453 | " state = \"train\"\n", 454 | "\n", 455 | " print(\"TIMESTEP\", t, \"/ STATE\", state, \"/ EPSILON\", epsilon, \"/ ACTION\", action_index, \"/ REWARD\", r_t, \"/ Q_MAX \" , np.max(Q_sa), \"/ Loss \", loss)\n", 456 | "\n", 457 | " print(\"Episode finished!\")\n", 458 | " print(\"************************\")\n" 459 | ] 460 | }, 461 | { 462 | "cell_type": "code", 463 | "execution_count": null, 464 | "metadata": { 465 | "collapsed": true 466 | }, 467 | "outputs": [], 468 | "source": [ 469 | "#main function\n", 470 | "def playGame(observe=False):\n", 471 | " game = Game()\n", 472 | " dino = DinoAgent(game)\n", 473 | " game_state = Game_sate(dino,game) \n", 474 | " model = buildmodel()\n", 475 | " try:\n", 476 | " trainNetwork(model,game_state,observe=observe)\n", 477 | " except StopIteration:\n", 478 | " game.end()" 479 | ] 480 | }, 481 | { 482 | "cell_type": "code", 483 | "execution_count": null, 484 | "metadata": { 485 | "scrolled": true 486 | }, 487 | "outputs": [], 488 | "source": [ 489 | "playGame(observe=False);" 490 | ] 491 | } 492 | ], 493 | "metadata": { 494 | "kernelspec": { 495 | "display_name": "Python 3", 496 | "language": "python", 497 | "name": "python3" 498 | }, 499 | "language_info": { 500 | "codemirror_mode": { 501 | "name": "ipython", 502 | "version": 3 503 | }, 504 | "file_extension": ".py", 505 | "mimetype": "text/x-python", 506 | "name": "python", 507 | "nbconvert_exporter": "python", 508 | "pygments_lexer": "ipython3", 509 | "version": "3.6.3" 510 | } 511 | }, 512 | "nbformat": 4, 513 | "nbformat_minor": 2 514 | } 515 | --------------------------------------------------------------------------------