├── .gitignore ├── LICENSE ├── README.md ├── dqn.py ├── minerva_agent.py ├── trainingRL.py └── trainingSL.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/* 2 | *.pyc 3 | saved/* 4 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2017 Min ByeongUk 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # MinervaSc2 2 | 3 | A machine learning project using DeepMind's [PySC2](https://github.com/deepmind/pysc2) and [Tensorflow](https://github.com/tensorflow/tensorflow). 4 | 5 | I refered to [Sunghun Kim's repository](https://github.com/hunkim/ReinforcementZeroToAll/) for DQN class, etc. 6 | 7 | 8 | ## Usage 9 | 10 | At first, please download my git and some prerequisites. 11 | 12 | Here's my example. 13 | 14 | ```shell 15 | git clone https://github.com/phraust1612/MinervaSc2.git 16 | sudo pip3 install pysc2 17 | sudo pip3 install numpy 18 | sudo pip3 install tensorflow 19 | ``` 20 | 21 | Specify your saving directory. For default, DQN structure will be saved at 'saved/', 22 | 23 | so be sure that 'saved/' directory exists or prepare for your own. 24 | 25 | ### Reinforcement Learning 26 | 27 | 28 | ```shell 29 | python3 trainingRL.py --start_episode (##) --num_episodes (##) 30 | ``` 31 | 32 | For default, starting episode and total number would be 0 and 100 33 | 34 | 35 | ### Supervised Learning (with replays) 36 | 37 | ```shell 38 | python3 trainingSL.py --replay (your replay directory) --repeat (no of repetition) 39 | ``` 40 | 41 | For default option, this will refer ~/StarCraftII/Replays/ for replay files 42 | 43 | 44 | ## Composition 45 | 46 | * trainingRL.py : run reinforcement learning loops via DQN. 47 | * trainingSL.py : run supervised learning loops with your replay files. 48 | * minerva_agent.py : contains an agent class which decides actions for every step. 49 | * dqn.py : DQN network class, in order to devide target and learning networks. 50 | -------------------------------------------------------------------------------- /dqn.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | 4 | class DQN: 5 | 6 | def __init__(self, session: tf.Session, screen_size: int, minimap_size: int, output_size: int, learning_rate:int, name: str="main") -> None: 7 | """DQN Agent can 8 | 1) Build network 9 | 2) Predict Q_value given state 10 | 3) Train parameters 11 | Args: 12 | session (tf.Session): Tensorflow session 13 | screen_size : screen width pixel size, default=64 14 | minimap_size : minimap width pixel size, default=64 15 | output_size (int): Number of discrete actions 16 | learning_rate (int): do I need any more explanation? 17 | name (str, optional): TF Graph will be built under this name scope 18 | """ 19 | self.session = session 20 | self.output_size = output_size 21 | self.screen_size = screen_size 22 | self.minimap_size = minimap_size 23 | self.net_name = name 24 | self.l_rate = learning_rate 25 | self._build_network() 26 | 27 | def _build_network(self) -> None: 28 | """DQN Network architecture (FullyConv : check out for DeepMind's sc2le paper) 29 | """ 30 | with tf.variable_scope(self.net_name): 31 | 32 | self._X_minimap = tf.placeholder(tf.float32, [None, 7, self.minimap_size, self.minimap_size], name="x_minimap") 33 | self._X_screen = tf.placeholder(tf.float32, [None, 13, self.screen_size, self.screen_size], name="x_screen") 34 | self._X_select = tf.placeholder(tf.float32, [None, 1, 7], name="x_select") 35 | self._X_player = tf.placeholder(tf.float32, [None, 11], name="x_player") 36 | self._X_control_group = tf.placeholder(tf.float32, [None, 10, 2], name="x_control_group") 37 | self._X_score = tf.placeholder(tf.float32, [None, 13], name="x_score") 38 | _X_minimap = tf.transpose(self._X_minimap, perm=[0,2,3,1]) 39 | _X_screen = tf.transpose(self._X_screen, perm=[0,2,3,1]) 40 | 41 | W1_minimap = tf.Variable(tf.random_normal([3,3,7,12],stddev=0.1),name='W1_minimap') 42 | L1_minimap = tf.nn.conv2d(_X_minimap, W1_minimap, strides=[1,1,1,1], padding="SAME") 43 | L1_minimap = tf.nn.relu(L1_minimap) 44 | L1_minimap = tf.nn.max_pool(L1_minimap, ksize=[1,2,2,1], strides=[1,2,2,1], padding="SAME") 45 | 46 | W2_minimap = tf.Variable(tf.random_normal([3,3,12,12],stddev=0.1),name='W2_minimap') 47 | L2_minimap = tf.nn.conv2d(L1_minimap, W2_minimap, strides=[1,1,1,1], padding="SAME") 48 | L2_minimap = tf.nn.relu(L2_minimap) 49 | L2_minimap = tf.nn.max_pool(L2_minimap, ksize=[1,2,2,1], strides=[1,2,2,1], padding="SAME") 50 | # for default, L2_minimap shape : [-1, 16, 16, 12] 51 | 52 | W1_screen = tf.Variable(tf.random_normal([3,3,13,16],stddev=0.1),name='W1_screen') 53 | L1_screen = tf.nn.conv2d(_X_screen, W1_screen, strides=[1,1,1,1], padding="SAME") 54 | L1_screen = tf.nn.relu(L1_screen) 55 | L1_screen = tf.nn.max_pool(L1_screen, ksize=[1,2,2,1], strides=[1,2,2,1], padding="SAME") 56 | 57 | W2_screen = tf.Variable(tf.random_normal([3,3,16,16],stddev=0.1),name='W2_screen') 58 | L2_screen = tf.nn.conv2d(L1_screen, W2_screen, strides=[1,1,1,1], padding="SAME") 59 | L2_screen = tf.nn.relu(L2_screen) 60 | L2_screen = tf.nn.max_pool(L2_screen, ksize=[1,2,2,1], strides=[1,2,2,1], padding="SAME") 61 | # for default, L2_screen shape : [-1, 16, 16, 16] 62 | 63 | W1_player = tf.Variable(tf.random_normal([11, 256],stddev=0.1), name="W1_player") 64 | L1_player = tf.matmul(self._X_player, W1_player) 65 | L1_player = tf.nn.relu(L1_player) 66 | L1_player = tf.reshape(L1_player,[-1, 16,16,1]) 67 | 68 | _X_select = tf.reshape(self._X_select,[-1,7]) 69 | W1_select = tf.Variable(tf.random_normal([7, 256],stddev=0.1),name='W1_select') 70 | L1_select = tf.matmul(_X_select, W1_select) 71 | L1_select = tf.nn.relu(L1_select) 72 | L1_select = tf.reshape(L1_select,[-1, 16,16,1]) 73 | 74 | _X_control = tf.reshape(self._X_control_group,[-1,20]) 75 | W1_control = tf.Variable(tf.random_normal([20, 256],stddev=0.1),name='W1_control') 76 | L1_control = tf.matmul(_X_control, W1_control) 77 | L1_control = tf.nn.relu(L1_control) 78 | L1_control = tf.reshape(L1_control,[-1, 16,16,1]) 79 | 80 | W1_score = tf.Variable(tf.random_normal([13, 256],stddev=0.1),name='W1_score') 81 | L1_score = tf.matmul(self._X_score, W1_score) 82 | L1_score = tf.nn.relu(L1_score) 83 | L1_score = tf.reshape(L1_score,[-1, 16,16,1]) 84 | 85 | # for default, _X_State shape : [-1, 16, 16, 32] 86 | _X_State = tf.concat([L2_minimap, L2_screen, L1_player, L1_select, L1_control, L1_score], axis=-1) 87 | 88 | # *_ID : nets for classifying action_id (e.g. move_camera etc) 89 | W1_ID = tf.Variable(tf.random_normal([3,3,32,32], stddev=0.1), name="W1_ID") 90 | L1_ID = tf.nn.conv2d(_X_State, W1_ID, strides=[1,1,1,1], padding="SAME") 91 | L1_ID = tf.nn.relu(L1_ID) 92 | L1_ID = tf.nn.max_pool(L1_ID, ksize=[1,2,2,1], strides=[1,2,2,1], padding="SAME") 93 | # L1_ID shape : [-1, 8, 8, 32] 94 | 95 | W2_ID = tf.Variable(tf.random_normal([3,3,32,64], stddev=0.1), name="W2_ID") 96 | L2_ID = tf.nn.conv2d(L1_ID, W2_ID, strides=[1,1,1,1], padding="SAME") 97 | L2_ID = tf.nn.relu(L2_ID) 98 | L2_ID = tf.nn.max_pool(L2_ID, ksize=[1,2,2,1], strides=[1,2,2,1], padding="SAME") 99 | # L2_ID shape : [-1, 4, 4, 64] 100 | L2_ID = tf.reshape(L2_ID,[-1, 1024]) 101 | 102 | W3_ID = tf.Variable(tf.random_normal([1024, self.output_size]),name='W3_ID') 103 | self._Qpred = tf.matmul(L2_ID, W3_ID) 104 | 105 | W_screen_policy = tf.Variable(tf.random_normal([1024, self.screen_size*self.screen_size]),name='W_screen_policy') 106 | W_minimap_policy = tf.Variable(tf.random_normal([1024, self.minimap_size*self.minimap_size]),name="W_minimap_policy") 107 | W_screen2_policy = tf.Variable(tf.random_normal([1024, self.screen_size*self.screen_size]),name="W_screen2_policy") 108 | self._screen_policy_Qpred = tf.matmul(L2_ID, W_screen_policy) 109 | self._minimap_policy_Qpred = tf.matmul(L2_ID, W_minimap_policy) 110 | self._screen2_policy_Qpred = tf.matmul(L2_ID, W_screen2_policy) 111 | 112 | W_nonspatial3 = tf.Variable(tf.random_normal([1024, 2]),name='W_nonspatial3') 113 | W_nonspatial4 = tf.Variable(tf.random_normal([1024, 5]),name='W_nonspatial4') 114 | W_nonspatial5 = tf.Variable(tf.random_normal([1024, 10]),name='W_nonspatial5') 115 | W_nonspatial6 = tf.Variable(tf.random_normal([1024, 4]),name='W_nonspatial6') 116 | W_nonspatial7 = tf.Variable(tf.random_normal([1024, 2]),name='W_nonspatial7') 117 | W_nonspatial8 = tf.Variable(tf.random_normal([1024, 4]),name='W_nonspatial8') 118 | W_nonspatial9 = tf.Variable(tf.random_normal([1024, 500]),name='W_nonspatial9') 119 | W_nonspatial10 = tf.Variable(tf.random_normal([1024, 4]),name='W_nonspatial10') 120 | W_nonspatial11 = tf.Variable(tf.random_normal([1024, 10]),name='W_nonspatial11') 121 | W_nonspatial12 = tf.Variable(tf.random_normal([1024, 500]),name='W_nonspatial12') 122 | self._nonspatial3_Qpred = tf.matmul(L2_ID, W_nonspatial3) 123 | self._nonspatial4_Qpred = tf.matmul(L2_ID, W_nonspatial4) 124 | self._nonspatial5_Qpred = tf.matmul(L2_ID, W_nonspatial5) 125 | self._nonspatial6_Qpred = tf.matmul(L2_ID, W_nonspatial6) 126 | self._nonspatial7_Qpred = tf.matmul(L2_ID, W_nonspatial7) 127 | self._nonspatial8_Qpred = tf.matmul(L2_ID, W_nonspatial8) 128 | self._nonspatial9_Qpred = tf.matmul(L2_ID, W_nonspatial9) 129 | self._nonspatial10_Qpred = tf.matmul(L2_ID, W_nonspatial10) 130 | self._nonspatial11_Qpred = tf.matmul(L2_ID, W_nonspatial11) 131 | self._nonspatial12_Qpred = tf.matmul(L2_ID, W_nonspatial12) 132 | 133 | self._Y = tf.placeholder(tf.float32, shape=[None, self.output_size]) 134 | self._Y_screen = tf.placeholder(tf.float32, shape=[None, self.screen_size*self.screen_size]) 135 | self._Y_minimap = tf.placeholder(tf.float32, shape=[None, self.minimap_size*self.minimap_size]) 136 | self._Y_screen2 = tf.placeholder(tf.float32, shape=[None, self.screen_size*self.screen_size]) 137 | self._Y_nonspatial3 = tf.placeholder(tf.float32, shape=[None, 2]) 138 | self._Y_nonspatial4 = tf.placeholder(tf.float32, shape=[None, 5]) 139 | self._Y_nonspatial5 = tf.placeholder(tf.float32, shape=[None, 10]) 140 | self._Y_nonspatial6 = tf.placeholder(tf.float32, shape=[None, 4]) 141 | self._Y_nonspatial7 = tf.placeholder(tf.float32, shape=[None, 2]) 142 | self._Y_nonspatial8 = tf.placeholder(tf.float32, shape=[None, 4]) 143 | self._Y_nonspatial9 = tf.placeholder(tf.float32, shape=[None, 500]) 144 | self._Y_nonspatial10 = tf.placeholder(tf.float32, shape=[None, 4]) 145 | self._Y_nonspatial11 = tf.placeholder(tf.float32, shape=[None, 10]) 146 | self._Y_nonspatial12 = tf.placeholder(tf.float32, shape=[None, 500]) 147 | 148 | _loss = tf.losses.mean_squared_error(self._Y, self._Qpred) * 10 149 | _loss += tf.losses.mean_squared_error(self._Y_screen, self._screen_policy_Qpred) 150 | _loss += tf.losses.mean_squared_error(self._Y_minimap, self._minimap_policy_Qpred) 151 | _loss += tf.losses.mean_squared_error(self._Y_screen2, self._screen2_policy_Qpred) 152 | _loss += tf.losses.mean_squared_error(self._Y_nonspatial3, self._nonspatial3_Qpred) 153 | _loss += tf.losses.mean_squared_error(self._Y_nonspatial4, self._nonspatial4_Qpred) 154 | _loss += tf.losses.mean_squared_error(self._Y_nonspatial5, self._nonspatial5_Qpred) 155 | _loss += tf.losses.mean_squared_error(self._Y_nonspatial6, self._nonspatial6_Qpred) 156 | _loss += tf.losses.mean_squared_error(self._Y_nonspatial7, self._nonspatial7_Qpred) 157 | _loss += tf.losses.mean_squared_error(self._Y_nonspatial8, self._nonspatial8_Qpred) 158 | _loss += tf.losses.mean_squared_error(self._Y_nonspatial9, self._nonspatial9_Qpred) 159 | _loss += tf.losses.mean_squared_error(self._Y_nonspatial10, self._nonspatial10_Qpred) 160 | _loss += tf.losses.mean_squared_error(self._Y_nonspatial11, self._nonspatial11_Qpred) 161 | _loss += tf.losses.mean_squared_error(self._Y_nonspatial12, self._nonspatial12_Qpred) 162 | self._loss = _loss 163 | 164 | optimizer = tf.train.GradientDescentOptimizer(learning_rate=self.l_rate) 165 | self._train = optimizer.minimize(self._loss) 166 | 167 | self.session.run(tf.global_variables_initializer()) 168 | self.saver = tf.train.Saver({ 169 | 'W1_minimap':W1_minimap, 170 | 'W2_minimap':W2_minimap, 171 | 'W1_screen':W1_screen, 172 | 'W2_screen':W2_screen, 173 | 'W1_player':W1_player, 174 | 'W1_select':W1_select, 175 | 'W1_control':W1_control, 176 | 'W1_score':W1_score, 177 | 'W1_ID':W1_ID, 178 | 'W2_ID':W2_ID, 179 | 'W3_ID':W3_ID, 180 | 'W_screen_policy':W_screen_policy, 181 | 'W_minimap_policy':W_minimap_policy, 182 | 'W_screen2_policy':W_screen2_policy, 183 | 'W_nonspatial3':W_nonspatial3, 184 | 'W_nonspatial4':W_nonspatial4, 185 | 'W_nonspatial5':W_nonspatial5, 186 | 'W_nonspatial6':W_nonspatial6, 187 | 'W_nonspatial7':W_nonspatial7, 188 | 'W_nonspatial8':W_nonspatial8, 189 | 'W_nonspatial9':W_nonspatial9, 190 | 'W_nonspatial10':W_nonspatial10, 191 | 'W_nonspatial11':W_nonspatial11, 192 | 'W_nonspatial12':W_nonspatial12 193 | }) 194 | 195 | try: 196 | self.saver.restore(self.session, "saved/model") 197 | print("DQN : weight params are restored") 198 | except: 199 | print("DQN : no params were restored") 200 | 201 | def predict(self, state) -> np.ndarray: 202 | """Returns Q(s, a) <- here the Q only predicts for action id, not arguments 203 | Args: 204 | state (array): State array, shape (n, ) 205 | Returns: 206 | np.ndarray: Q value array, shape (n, output_dim) 207 | """ 208 | _minimap = np.vstack([x[0]['minimap'] for x in state]) 209 | _screen = np.vstack([x[0]['screen'] for x in state]) 210 | _control = np.vstack([x[0]['control_groups'] for x in state]) 211 | _player = np.array([x[0]['player'] for x in state]) 212 | _score = np.array([x[0]['score_cumulative'] for x in state]) 213 | _select = np.array([x[0]['single_select'] for x in state]) 214 | _multiselect = np.array([x[0]['multi_select'] for x in state]) 215 | 216 | _minimap = np.reshape(_minimap, [-1, 7, self.minimap_size, self.minimap_size]) 217 | _screen = np.reshape(_screen, [-1, 13, self.screen_size, self.screen_size]) 218 | _control = np.reshape(_control, [-1, 10, 2]) 219 | _player = np.reshape(_player, [-1, 11]) 220 | _score = np.reshape(_score, [-1, 13]) 221 | _select = np.reshape(_select, [-1, 1, 7]) 222 | for i in range(len(_multiselect)): 223 | if _multiselect[i].shape[0] > 0: 224 | _select[i][0] = _multiselect[i][0] 225 | 226 | feed = { 227 | self._X_minimap: _minimap, 228 | self._X_screen: _screen, 229 | self._X_control_group: _control, 230 | self._X_player: _player, 231 | self._X_score: _score, 232 | self._X_select: _select 233 | } 234 | return self.session.run(self._Qpred, feed_dict=feed) 235 | 236 | def predictSpatial(self, state): 237 | """Returns spatial/nonspatial argument Q values 238 | """ 239 | _minimap = np.vstack([x[0]['minimap'] for x in state]) 240 | _screen = np.vstack([x[0]['screen'] for x in state]) 241 | _control = np.vstack([x[0]['control_groups'] for x in state]) 242 | _player = np.array([x[0]['player'] for x in state]) 243 | _score = np.array([x[0]['score_cumulative'] for x in state]) 244 | _select = np.array([x[0]['single_select'] for x in state]) 245 | _multiselect = np.array([x[0]['multi_select'] for x in state]) 246 | 247 | _minimap = np.reshape(_minimap, [-1, 7, self.minimap_size, self.minimap_size]) 248 | _screen = np.reshape(_screen, [-1, 13, self.screen_size, self.screen_size]) 249 | _control = np.reshape(_control, [-1, 10, 2]) 250 | _player = np.reshape(_player, [-1, 11]) 251 | _score = np.reshape(_score, [-1, 13]) 252 | _select = np.reshape(_select, [-1, 1, 7]) 253 | for i in range(len(_multiselect)): 254 | if _multiselect[i].shape[0] > 0: 255 | _select[i][0] = _multiselect[i][0] 256 | 257 | 258 | feed = { 259 | self._X_minimap: _minimap, 260 | self._X_screen: _screen, 261 | self._X_control_group: _control, 262 | self._X_player: _player, 263 | self._X_score: _score, 264 | self._X_select: _select 265 | } 266 | return self.session.run([ 267 | self._screen_policy_Qpred, 268 | self._minimap_policy_Qpred, 269 | self._screen2_policy_Qpred, 270 | self._nonspatial3_Qpred, 271 | self._nonspatial4_Qpred, 272 | self._nonspatial5_Qpred, 273 | self._nonspatial6_Qpred, 274 | self._nonspatial7_Qpred, 275 | self._nonspatial8_Qpred, 276 | self._nonspatial9_Qpred, 277 | self._nonspatial10_Qpred, 278 | self._nonspatial11_Qpred, 279 | self._nonspatial12_Qpred], feed_dict=feed) 280 | 281 | def update(self, state, y_stack, y_spatial) -> list: 282 | """Performs updates on given X and y and returns a result 283 | Args: 284 | x_stack (array): State array, shape (n, ) 285 | y_stack (array): Target action id Q array, shape (n, output_dim) 286 | y_spatial (array) : Target action argument Q array (13, n, ) 287 | Returns: 288 | list: First element is loss, second element is a result from train step 289 | 290 | """ 291 | _minimap = np.vstack([x[0]['minimap'] for x in state]) 292 | _screen = np.vstack([x[0]['screen'] for x in state]) 293 | _control = np.vstack([x[0]['control_groups'] for x in state]) 294 | _player = np.array([x[0]['player'] for x in state]) 295 | _score = np.array([x[0]['score_cumulative'] for x in state]) 296 | _select = np.array([x[0]['single_select'] for x in state]) 297 | _multiselect = np.array([x[0]['multi_select'] for x in state]) 298 | 299 | _minimap = np.reshape(_minimap, [-1, 7, self.minimap_size, self.minimap_size]) 300 | _screen = np.reshape(_screen, [-1, 13, self.screen_size, self.screen_size]) 301 | _control = np.reshape(_control, [-1, 10, 2]) 302 | _player = np.reshape(_player, [-1, 11]) 303 | _score = np.reshape(_score, [-1, 13]) 304 | _select = np.reshape(_select, [-1, 1, 7]) 305 | for i in range(len(_multiselect)): 306 | if _multiselect[i].shape[0] > 0: 307 | _select[i][0] = _multiselect[i][0] 308 | 309 | 310 | feed = { 311 | self._X_minimap: _minimap, 312 | self._X_screen: _screen, 313 | self._X_control_group: _control, 314 | self._X_player: _player, 315 | self._X_score: _score, 316 | self._X_select: _select, 317 | self._Y: y_stack, 318 | self._Y_screen: y_spatial[0], 319 | self._Y_minimap: y_spatial[1], 320 | self._Y_screen2: y_spatial[2], 321 | self._Y_nonspatial3: y_spatial[3], 322 | self._Y_nonspatial4: y_spatial[4], 323 | self._Y_nonspatial5: y_spatial[5], 324 | self._Y_nonspatial6: y_spatial[6], 325 | self._Y_nonspatial7: y_spatial[7], 326 | self._Y_nonspatial8: y_spatial[8], 327 | self._Y_nonspatial9: y_spatial[9], 328 | self._Y_nonspatial10: y_spatial[10], 329 | self._Y_nonspatial11: y_spatial[11], 330 | self._Y_nonspatial12: y_spatial[12], 331 | } 332 | return self.session.run([self._loss, self._train], feed) 333 | 334 | def saveWeight(self): 335 | self.saver.save(self.session, 'saved/model') 336 | -------------------------------------------------------------------------------- /minerva_agent.py: -------------------------------------------------------------------------------- 1 | from pysc2.agents import base_agent 2 | from pysc2.lib import actions 3 | import numpy as np 4 | 5 | """ 6 | ====================================================== 7 | self.step : 24 8 | obs type : 9 | ------------------------------------------------------ 10 | obs.step_type : 11 | StepType.MID [FIRST, MID, LAST] 12 | ------------------------------------------------------ 13 | obs.reward : 14 | 0 15 | ------------------------------------------------------ 16 | obs.discount : 17 | 1.0 18 | ------------------------------------------------------ 19 | obs.observation : - {str : numpy.ndarray} 20 | obs.observation['build_queue'] : (n, 7) 21 | ..['build_queue'][i][j] : same as single_select 22 | obs.observation['game_loop'] : (1,) 23 | obs.observation['cargo_slots_available'] : (1,) 24 | obs.observation['player'] : (11,) 25 | ..['player'][0] : player_id 26 | ..['player'][1] : mineral 27 | ..['player'][2] : vespine 28 | ..['player'][3] : food used 29 | ..['player'][4] : food cap 30 | ..['player'][5] : food used by army 31 | ..['player'][6] : food used by workers 32 | ..['player'][7] : idle worker count 33 | ..['player'][8] : army count 34 | ..['player'][9] : warp gate count 35 | ..['player'][10] : larva count 36 | obs.observation['available_actions'] : (n) 37 | ..['available_actions'][i] : available action id 38 | obs.observation['minimap'] : (7, 64, 64) 39 | ..['minimap'][0] : height_map 40 | ..['minimap'][1] : visibility 41 | ..['minimap'][2] : creep 42 | ..['minimap'][3] : camera 43 | ..['minimap'][4] : player_id 44 | ..['minimap'][5] : player_relative < [0,4] < [background, self, ally, neutral, enemy] 45 | ..['minimap'][6] : selected < 0 for not selected, 1 for selected 46 | obs.observation['cargo'] : (n, 7) - n is the number of all units in a transport 47 | ..['cargo'][i][j] : same as single_select[0][j] 48 | obs.observation['multi_select'] : (n, 7) 49 | ..['multi_select'][i][j] : same as single_select[0][j] 50 | -> single_select 과 양존하지 않음. 51 | single_select시엔 multi_select=[] 52 | multi_select 시엔 single_select = [[0,0,0,0,0,0,0]] 53 | obs.observation['score_cumulative'] : (13,) 54 | obs.observation['control_groups'] : (10, 2) 55 | ..['control_groups'][i][0] : i'th unit leader type 56 | ..['control_groups'][i][1] : count 57 | obs.observation['single_select'] : (1, 7) 58 | ..['single_select'][0][0] : unit_type 59 | ..['single_select'][0][1] : player_relative < [0,4] < [background, self, ally, neutral, enemy] 60 | ..['single_select'][0][2] : health 61 | ..['single_select'][0][3] : shields 62 | ..['single_select'][0][4] : energy 63 | ..['single_select'][0][5] : transport slot 64 | ..['single_select'][0][6] : build progress as percentage 65 | obs.observation['screen'] : (13, 84, 84) 66 | ..['screen'][0] : height_map 67 | ..['screen'][1] : visibility 68 | ..['screen'][2] : creep 69 | ..['screen'][3] : power < protoss power 70 | ..['screen'][4] : player_id 71 | ..['screen'][5] : player_relative < [0,4] < [background, self, ally, neutral, enemy] 72 | ..['screen'][6] : unit_type 73 | ..['screen'][7] : selected < 0 for not selected, 1 for selected 74 | ..['screen'][8] : hit_points 75 | ..['screen'][9] : energy 76 | ..['screen'][10] : shields 77 | ..['screen'][11] : unit_density 78 | ..['screen'][12] : unit_density_aa 79 | ====================================================== 80 | """ 81 | 82 | def intToCoordinate(num, size=64): 83 | if size!=64: 84 | num = num * size * size // 4096 85 | y = num // size 86 | x = num - size * y 87 | return [x, y] 88 | 89 | class MinervaAgent(base_agent.BaseAgent): 90 | def __init__(self, mainDQN=None): 91 | super(MinervaAgent, self).__init__() 92 | self.mainDQN = mainDQN 93 | 94 | def close(self): 95 | self.mainDQN = None 96 | self.obs_spec = None 97 | self.action_spec = None 98 | 99 | def setup(self, obs_spec, action_spec): 100 | super(MinervaAgent, self).setup(obs_spec, action_spec) 101 | 102 | def step(self, obs, exploit): 103 | super(MinervaAgent, self).step(obs) 104 | 105 | # if exploit == 0, choose an action for exploration 106 | if exploit == 0: 107 | ans_id = np.random.choice(obs.observation["available_actions"]) 108 | 109 | ################## find action id ###################### 110 | 111 | # otherwise choose an action for exploit 112 | # Qs[0] : ndarray([584]) -> Qs[0][i] score function of action whose id=i 113 | else: 114 | Qs = self.mainDQN.predict([[obs.observation]]) 115 | for i in range(len(Qs[0])): 116 | if i not in obs.observation["available_actions"]: 117 | Qs[0][i] = -100 118 | 119 | ans_id = np.argmax(Qs[0]) 120 | if Qs[0][ans_id] <= -100: 121 | ans_id = 0 122 | 123 | ############# find minimap/screen coordinate etc. ################# 124 | 125 | spatialQs = self.mainDQN.predictSpatial([[obs.observation]]) 126 | spatialInt = [] 127 | for i in range(13): 128 | spatialInt.append(np.argmax(spatialQs[i][0])) 129 | 130 | ans_arg = [] 131 | for arg in self.action_spec.functions[ans_id].args: 132 | if arg.id in range(3): 133 | ans_arg.append(intToCoordinate(spatialInt[arg.id], arg.sizes[0])) 134 | else: 135 | ans_arg.append([spatialInt[arg.id]]) 136 | 137 | print("step :", self.steps, "exploit :",exploit,"action id :", ans_id, "arg :",ans_arg) 138 | """ 139 | 173 Attributes: 140 | 174 0 screen: A point on the screen. 141 | 175 1 minimap: A point on the minimap. 142 | 176 2 screen2: The second point for a rectangle. This is needed so that no 143 | 177 function takes the same type twice. 144 | 178 3 queued: Whether the action should be done now or later. size<2 145 | 179 4 control_group_act: What to do with the control group. size<5 146 | 180 5 control_group_id: Which control group to do it with. size<10 147 | 181 6 select_point_act: What to do with the unit at the point. size<4 148 | 182 7 select_add: Whether to add the unit to the selection or replace it. size<2 149 | 183 8 select_unit_act: What to do when selecting a unit by id. size<4 150 | 184 9 select_unit_id: Which unit to select by id. size<500 151 | 185 10 select_worker: What to do when selecting a worker. size<4 152 | 186 11 build_queue_id: Which build queue index to target. size<10 153 | 187 12 unload_id: Which unit to target in a transport/nydus/command center. size<500 154 | """ 155 | return actions.FunctionCall(ans_id, ans_arg) 156 | -------------------------------------------------------------------------------- /trainingRL.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | from pysc2.env import sc2_env 4 | from pysc2.lib import actions as actlib 5 | from pysc2.lib import app 6 | from collections import deque 7 | from typing import List 8 | import random 9 | import minerva_agent 10 | import dqn 11 | import sys 12 | import time 13 | import gflags as flags 14 | import psutil 15 | import resource 16 | FLAGS = flags.FLAGS 17 | 18 | output_size = len(actlib.FUNCTIONS) # no of possible actions 19 | flags.DEFINE_integer("start_episode", 0, "starting episode number") 20 | flags.DEFINE_integer("num_episodes", 100, "total episodes number") 21 | flags.DEFINE_integer("screen_size", 64, "screen width pixels") 22 | flags.DEFINE_integer("minimap_size", 64, "minimap width pixels") 23 | 24 | flags.DEFINE_integer("learning_rate", 0.001, "learning rate") 25 | flags.DEFINE_integer("discount", 0.99, "discount factor") 26 | flags.DEFINE_integer("batch_size", 16, "size of mini-batch") 27 | flags.DEFINE_integer("max_buffer_size", 50000, "maximum deque size") 28 | flags.DEFINE_integer("update_frequency", 16, "update target frequency") 29 | 30 | flags.DEFINE_bool("visualize", False, "visualize") 31 | flags.DEFINE_string("agent_race", "T", "agent race") 32 | flags.DEFINE_string("bot_race", "R", "bot race") 33 | flags.DEFINE_string("map_name","AscensiontoAiur", "map name") 34 | flags.DEFINE_string("difficulty","1", "bot difficulty") 35 | 36 | # below is a list of possible map_name 37 | # AbyssalReef 38 | # Acolyte 39 | # AscensiontoAiur 40 | # BelShirVestige 41 | # BloodBoil 42 | # CactusValley 43 | # DefendersLanding 44 | # Frost 45 | # Honorgrounds 46 | # Interloper 47 | # MechDepot 48 | # NewkirkPrecinct 49 | # Odyssey 50 | # PaladinoTerminal 51 | # ProximaTerminal 52 | # Sequencer 53 | 54 | def coordinateToInt(coor, size=64): 55 | return coor[0] + size*coor[1] 56 | 57 | def batch_train(env, mainDQN, targetDQN, train_batch: list) -> float: 58 | """Trains `mainDQN` with target Q values given by `targetDQN` 59 | Args: 60 | mainDQN (dqn.DQN): Main DQN that will be trained 61 | targetDQN (dqn.DQN): Target DQN that will predict Q_target 62 | train_batch (list): Minibatch of stored buffer 63 | Each element is (s, a, r, s', done) 64 | [(state, action, reward, next_state, done), ...] 65 | Returns: 66 | float: After updating `mainDQN`, it returns a `loss` 67 | """ 68 | states = np.vstack([x[0] for x in train_batch]) 69 | actions_id = np.array([x[1] for x in train_batch]) 70 | rewards = np.array([x[3] for x in train_batch]) 71 | next_states = np.vstack([x[4] for x in train_batch]) 72 | done = np.array([x[5] for x in train_batch]) 73 | 74 | # actions_arg[i] : arguments whose id=i 75 | actions_arg = np.ones([13,FLAGS.batch_size],dtype=np.int32) 76 | actions_arg *= -1 77 | 78 | batch_index = 0 79 | for x in train_batch: 80 | action_id = x[1] 81 | arg_index = 0 82 | 83 | for arg in env.action_spec().functions[action_id].args: 84 | if arg.id in range(3): 85 | actions_arg[arg.id][batch_index] = coordinateToInt(x[2][arg_index]) 86 | else: 87 | actions_arg[arg.id][batch_index] = (int) (x[2][arg_index][0]) 88 | arg_index += 1 89 | batch_index += 1 90 | 91 | X = states 92 | 93 | Q_target = rewards + FLAGS.discount * np.max(targetDQN.predict(next_states), axis=1) * ~done 94 | spatial_Q_target = [] 95 | spatial_predict = targetDQN.predictSpatial(next_states) 96 | for i in range(13): 97 | spatial_Q_target.append( rewards + FLAGS.discount * np.max(spatial_predict[i], axis=1) *~done ) 98 | 99 | # y shape : [batch_size, output_size] 100 | y = mainDQN.predict(states) 101 | y[np.arange(len(X)), actions_id] = Q_target 102 | 103 | # ySpatial shape : [13, batch_size, arg_size(id)] 104 | ySpatial = mainDQN.predictSpatial(states) 105 | for j in range(13): 106 | for i in range(len(X)): 107 | if actions_arg[j][i] >= 0: 108 | ySpatial[j][i][actions_arg[j][i]] = spatial_Q_target[j][i] 109 | 110 | # Train our network using target and predicted Q values on each episode 111 | return mainDQN.update(X, y, ySpatial) 112 | 113 | 114 | def get_copy_var_ops(*, dest_scope_name: str, src_scope_name: str) -> List[tf.Operation]: 115 | """Creates TF operations that copy weights from `src_scope` to `dest_scope` 116 | Args: 117 | dest_scope_name (str): Destination weights (copy to) 118 | src_scope_name (str): Source weight (copy from) 119 | Returns: 120 | List[tf.Operation]: Update operations are created and returned 121 | """ 122 | # Copy variables src_scope to dest_scope 123 | op_holder = [] 124 | 125 | src_vars = tf.get_collection( 126 | tf.GraphKeys.TRAINABLE_VARIABLES, scope=src_scope_name) 127 | dest_vars = tf.get_collection( 128 | tf.GraphKeys.TRAINABLE_VARIABLES, scope=dest_scope_name) 129 | 130 | for src_var, dest_var in zip(src_vars, dest_vars): 131 | op_holder.append(dest_var.assign(src_var.value())) 132 | 133 | return op_holder 134 | 135 | # returns pysc2.env.environment.TimeStep after end of the game 136 | def run_loop(agents, env, sess, e, mainDQN, targetDQN, copy_ops, max_frames=0): 137 | total_frames = 0 138 | stored_buffer = deque(maxlen=FLAGS.max_buffer_size) 139 | start_time = time.time() 140 | 141 | action_spec = env.action_spec() 142 | observation_spec = env.observation_spec() 143 | for agent in agents: 144 | agent.setup(observation_spec, action_spec) 145 | 146 | timesteps = env.reset() 147 | state = timesteps[0].observation 148 | step_count = 0 149 | 150 | for a in agents: 151 | a.reset() 152 | try: 153 | while True: 154 | total_frames += 1 155 | if np.random.rand(1) < e: 156 | # choose a random action and explore 157 | actions = [agent.step(timestep, 0) 158 | for agent, timestep in zip(agents, timesteps)] 159 | else: 160 | # choose an action by 'exploit' 161 | actions = [agent.step(timestep, 1) 162 | for agent, timestep in zip(agents, timesteps)] 163 | 164 | if max_frames and total_frames >= max_frames: 165 | return timesteps 166 | 167 | timesteps = env.step(actions) 168 | next_state = timesteps[0].observation 169 | reward = timesteps[0].reward 170 | done = timesteps[0].last() 171 | 172 | if done: 173 | break 174 | 175 | stored_buffer.append( (state, actions[0].function, actions[0].arguments, reward, next_state, done) ) 176 | 177 | if len(stored_buffer) > FLAGS.batch_size: 178 | minibatch = random.sample(stored_buffer, FLAGS.batch_size) 179 | loss, _ = batch_train(env, mainDQN, targetDQN, minibatch) 180 | 181 | if step_count % FLAGS.update_frequency == 0: 182 | sess.run(copy_ops) 183 | 184 | state = next_state 185 | step_count += 1 186 | 187 | except KeyboardInterrupt: 188 | return timesteps 189 | finally: 190 | elapsed_time = time.time() - start_time 191 | print("Took %.3f seconds for %s steps: %.3f fps" % ( 192 | elapsed_time, total_frames, total_frames / elapsed_time)) 193 | return timesteps 194 | 195 | def main(unusued_argv): 196 | parent_proc = psutil.Process() 197 | with tf.Session() as sess: 198 | mainDQN = dqn.DQN(sess, FLAGS.screen_size, FLAGS.minimap_size, output_size, FLAGS.learning_rate, name="main") 199 | targetDQN = dqn.DQN(sess, FLAGS.screen_size, FLAGS.minimap_size, output_size, FLAGS.learning_rate, name="target") 200 | 201 | copy_ops = get_copy_var_ops(dest_scope_name="target", src_scope_name="main") 202 | sess.run(copy_ops) 203 | print("memory before starting the iteration : %s (kb)"%(resource.getrusage(resource.RUSAGE_SELF).ru_maxrss)) 204 | 205 | for episode in range(FLAGS.start_episode, FLAGS.num_episodes): 206 | e = 1.0 / ((episode / 50) + 2.0) # decaying exploration rate 207 | with sc2_env.SC2Env( 208 | FLAGS.map_name, 209 | screen_size_px=(FLAGS.screen_size, FLAGS.screen_size), 210 | minimap_size_px=(FLAGS.minimap_size, FLAGS.minimap_size), 211 | agent_race=FLAGS.agent_race, 212 | bot_race=FLAGS.bot_race, 213 | difficulty=FLAGS.difficulty, 214 | visualize=FLAGS.visualize) as env: 215 | 216 | agent = minerva_agent.MinervaAgent(mainDQN) 217 | run_result = run_loop([agent], env, sess, e, mainDQN, targetDQN, copy_ops, 5000) 218 | agent.close() 219 | reward = run_result[0].reward 220 | if reward > 0: 221 | env.save_replay("victory/") 222 | #else: 223 | # env.save_replay("defeat/") 224 | 225 | children = parent_proc.children(recursive=True) 226 | for child in children: 227 | print("remaining child proc :", child) 228 | print("memory after exit %d'th sc2env : %s (kb)"%(episode, resource.getrusage(resource.RUSAGE_SELF).ru_maxrss)) 229 | 230 | mainDQN.saveWeight() 231 | print("networks were saved, %d'th game result :"%episode,reward) 232 | 233 | def _main(): 234 | argv = FLAGS(sys.argv) 235 | app.really_start(main) 236 | 237 | if __name__ == "__main__": 238 | sys.exit(_main()) 239 | -------------------------------------------------------------------------------- /trainingSL.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # Copyright 2017 Google Inc. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS-IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """train network via supervised learning with replay files""" 16 | 17 | import tensorflow as tf 18 | import numpy as np 19 | import dqn 20 | 21 | import os 22 | import platform 23 | import sys 24 | import time 25 | 26 | from pysc2 import maps 27 | from pysc2 import run_configs 28 | from pysc2.lib import actions as actlib 29 | from pysc2.lib import stopwatch 30 | from pysc2.lib import features 31 | from pysc2.lib import app 32 | 33 | import gflags as flags 34 | from s2clientprotocol import sc2api_pb2 as sc_pb 35 | 36 | REPLAY_HOME = os.path.expanduser("~") + "/StarCraftII/Replays/" 37 | FLAGS = flags.FLAGS 38 | 39 | output_size = len(actlib.FUNCTIONS) 40 | flags.DEFINE_string("replay",None,"replay path relative to REPLAY_HOME") 41 | flags.DEFINE_integer("repeat",1,"number of iteration") 42 | flags.DEFINE_bool("win_only", True, "learn only for the player who won if this flag is True") 43 | flags.DEFINE_integer("screen_size", 64, "screen width pixels") 44 | flags.DEFINE_integer("minimap_size", 64, "minimap width pixels") 45 | flags.DEFINE_integer("learning_rate", 0.001, "learning rate") 46 | flags.DEFINE_string("agent_race", "T", "agent race") 47 | flags.DEFINE_string("map_name","AscensiontoAiur", "map name") 48 | 49 | def coordinateToInt(coor, size=64): 50 | return coor[0] + size*coor[1] 51 | 52 | def raceToCode(race): 53 | if race == "R": 54 | return sc_pb.Random 55 | elif race == "P": 56 | return sc_pb.Protoss 57 | elif race == "T": 58 | return sc_pb.Terran 59 | else: 60 | return sc_pb.Zerg 61 | 62 | def mapNameMatch(name:str): 63 | name = name.replace(' ','') 64 | name = name.replace('LE','') 65 | name = name.replace('TE','') 66 | name = name.lower() 67 | name2 = FLAGS.map_name.lower() 68 | return name == name2 69 | 70 | def train(mainDQN, obs, action, action_spec): 71 | states = [[obs]] 72 | if len(action) > 0: 73 | actions_id = action[0].function 74 | actions_arg = np.zeros([13],dtype=np.int32) 75 | 76 | arg_index = 0 77 | for arg in action_spec.functions[actions_id].args: 78 | if arg.id in range(3): 79 | actions_arg[arg.id] = coordinateToInt(action[0].arguments[arg_index]) 80 | else: 81 | actions_arg[arg.id] = (int) (action[0].arguments[arg_index][0]) 82 | arg_index += 1 83 | 84 | else: 85 | # in case of doing nothing 86 | actions_id = 0 87 | actions_arg = np.zeros([13],dtype=np.int32) 88 | 89 | X = states 90 | 91 | Q_target = np.array([actions_id]) 92 | spatial_Q_target = actions_arg 93 | 94 | # y shape : [1, output_size] 95 | y = mainDQN.predict(states) 96 | y[np.arange(len(X)), actions_id] = Q_target 97 | 98 | # ySpatial shape : [13, 1, arg_size(id)] 99 | ySpatial = mainDQN.predictSpatial(states) 100 | for j in range(13): 101 | if actions_arg[j] >= 0: 102 | ySpatial[j][0,actions_arg[j]] = spatial_Q_target[j] 103 | 104 | # Train our network using target and predicted Q values on each episode 105 | return mainDQN.update(X, y, ySpatial) 106 | 107 | def run_loop(replay, player_id, mainDQN): 108 | """Run SC2 to play a game or a replay.""" 109 | stopwatch.sw.enabled = False 110 | stopwatch.sw.trace = False 111 | 112 | if not replay: 113 | sys.exit("Must supply a replay.") 114 | 115 | if replay and not replay.lower().endswith("sc2replay"): 116 | sys.exit("Replay must end in .SC2Replay.") 117 | 118 | run_config = run_configs.get() 119 | 120 | interface = sc_pb.InterfaceOptions() 121 | interface.raw = False 122 | interface.score = True 123 | interface.feature_layer.width = 24 124 | interface.feature_layer.resolution.x = FLAGS.screen_size 125 | interface.feature_layer.resolution.y = FLAGS.screen_size 126 | interface.feature_layer.minimap_resolution.x = FLAGS.minimap_size 127 | interface.feature_layer.minimap_resolution.y = FLAGS.minimap_size 128 | 129 | max_episode_steps = 0 130 | 131 | replay_data = run_config.replay_data(replay) 132 | start_replay = sc_pb.RequestStartReplay( 133 | replay_data=replay_data, 134 | options=interface, 135 | disable_fog=False, 136 | observed_player_id=player_id) 137 | 138 | with run_config.start(full_screen=False) as controller: 139 | info = controller.replay_info(replay_data) 140 | infomap = info.map_name 141 | inforace = info.player_info[player_id-1].player_info.race_actual 142 | inforesult = info.player_info[player_id-1].player_result.result 143 | if FLAGS.map_name and not mapNameMatch(infomap): 144 | print("map doesn't match, continue...") 145 | print("map_name:",FLAGS.map_name,"infomap:",infomap) 146 | return 147 | if FLAGS.agent_race and raceToCode(FLAGS.agent_race) != inforace: 148 | print("agent race doesn't match, continue...") 149 | print("agent_race:",raceToCode(FLAGS.agent_race),"inforace:",inforace) 150 | return 151 | if FLAGS.win_only and not inforesult: 152 | print("this player was defeated, continue...") 153 | print("result:",inforesult) 154 | return 155 | else: 156 | print("condition's satisfied, training starts :",replay) 157 | print("map :",infomap) 158 | print("player id :", player_id) 159 | print("race :", inforace) 160 | print("result :", inforesult) 161 | 162 | map_path = info.local_map_path 163 | if map_path: 164 | start_replay.map_data = run_config.map_data(map_path) 165 | controller.start_replay(start_replay) 166 | 167 | game_info = controller.game_info() 168 | _features = features.Features(game_info) 169 | action_spec = _features.action_spec() 170 | 171 | try: 172 | while True: 173 | frame_start_time = time.time() 174 | controller.step(1) 175 | obs = controller.observe() 176 | actions = obs.actions 177 | real_obs = _features.transform_obs(obs.observation) 178 | real_actions = [] 179 | for action in actions: 180 | try: 181 | real_actions.append(_features.reverse_action(action)) 182 | except ValueError: 183 | real_actions.append(actlib.FunctionCall(function=0,arguments=[])) 184 | train(mainDQN, real_obs, real_actions, action_spec) 185 | 186 | if obs.player_result: 187 | break 188 | #time.sleep(max(0, frame_start_time + 1 / FLAGS.fps - time.time())) 189 | 190 | except KeyboardInterrupt: 191 | pass 192 | 193 | print("Score: ", obs.observation.score.score) 194 | print("Result: ", obs.player_result) 195 | 196 | def main(unused_argv): 197 | replay_list = [] 198 | if FLAGS.replay: 199 | REPLAY_PATH = REPLAY_HOME + FLAGS.replay 200 | else: 201 | REPLAY_PATH = REPLAY_HOME 202 | 203 | for root, dirs, files in os.walk(REPLAY_PATH): 204 | for subdir in dirs: 205 | tmp = os.path.join(root, subdir) 206 | if tmp[-10:] == '.SC2Replay': 207 | replay_list.append(tmp) 208 | for file1 in files: 209 | tmp = os.path.join(root, file1) 210 | if tmp[-10:] == '.SC2Replay': 211 | replay_list.append(tmp) 212 | 213 | with tf.Session() as sess: 214 | mainDQN = dqn.DQN(sess, FLAGS.screen_size, FLAGS.minimap_size, output_size, FLAGS.learning_rate, name="main") 215 | 216 | for iter in range(FLAGS.repeat): 217 | for replay in replay_list: 218 | start_time = time.time() 219 | run_loop(replay, 1, mainDQN) 220 | run_loop(replay, 2, mainDQN) 221 | mainDQN.saveWeight() 222 | print("networks were updated / replay :",replay) 223 | elapsed_time = time.time() - start_time 224 | print("Took %.3f seconds... " % (elapsed_time)) 225 | 226 | def _main(): 227 | argv = FLAGS(sys.argv) 228 | app.really_start(main) 229 | 230 | if __name__ == "__main__": 231 | sys.exit(_main()) 232 | --------------------------------------------------------------------------------