├── dqn_agent ├── __init__.py ├── train_mineral_shards.py └── deepq_mineral_shards.py ├── scripted_agent ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-36.pyc │ └── simple_agent_protoss.cpython-36.pyc └── simple_agent_protoss.py └── README.md /dqn_agent/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /scripted_agent/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /scripted_agent/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nailo2c/pysc2-tutorial/HEAD/scripted_agent/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /scripted_agent/__pycache__/simple_agent_protoss.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nailo2c/pysc2-tutorial/HEAD/scripted_agent/__pycache__/simple_agent_protoss.cpython-36.pyc -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # pysc2-tutorial 2 | 3 | 這個repo實作了簡單的rule-based agent以及dqn agent。 4 | 5 | # Dependencies 6 | 7 | * Python 3.6 8 | * Anaconda 9 | * TensorFlow 10 | * PySC2 11 | * Baselines 12 | 13 | # Getting Started 14 | 15 | 首先必須先安裝星海爭霸2並申請帳號(免費),可參考以下slide進行安裝: 16 | https://goo.gl/d5L4yD 17 | 18 | 接下來安裝需要的套件,以下以MacOSX Sierra 環境為準,安裝Anaconda時請一路Enter與Yes到底。 19 | 20 | ``` 21 | wget https://repo.continuum.io/archive/Anaconda3-5.0.0-MacOSX-x86_64.sh 22 | bash Anaconda3-5.0.0-MacOSX-x86_64.sh 23 | source .bash_profile 24 | pip install tensorflow 25 | pip install baselines 26 | pip install pysc2 27 | pip install absl-py 28 | ``` 29 | 30 | # How to run 31 | 32 | * scripted agent 33 | ``` 34 | python -m pysc2.bin.agent --map Simple64 --agent scripted_agent.simple_agent_protoss.RuleBaseAgent --agent_race protoss 35 | ``` 36 | 37 | * dqn agent 38 | ``` 39 | python dqn_agent/train_mineral_shards.py 40 | ``` 41 | 42 | # Result 43 | 44 | * scripted agent 45 | 46 | 穩定打贏難度級別最簡單的電腦。 47 | 48 | * dqn agent 49 | 50 | 卡在13~14分左右就上不去了。 51 | 52 | # Slide 53 | 54 | 2017.10.02於Taiwan R User Group / MLDM 分享的投影片: 55 | https://goo.gl/oeEFvr 56 | 57 | 58 | # References 59 | 60 | [deepmind/pysc2](https://github.com/deepmind/pysc2) 61 | [openai/baselines](https://github.com/openai/baselines) 62 | [Building a Basic PySC2 Agent](https://medium.com/@skjb/building-a-basic-pysc2-agent-b109cde1477c) 63 | [chris-chris/pysc2-examples](https://github.com/chris-chris/pysc2-examples) 64 | [xhujoy/pysc2-agents](https://github.com/xhujoy/pysc2-agents) 65 | -------------------------------------------------------------------------------- /dqn_agent/train_mineral_shards.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import sys 3 | 4 | from absl import flags 5 | from baselines import deepq 6 | 7 | from pysc2.env import sc2_env 8 | from pysc2.lib import actions 9 | 10 | import deepq_mineral_shards 11 | 12 | # Define the constant 13 | _MOVE_SCREEN = actions.FUNCTIONS.Move_screen.id 14 | _SELECE_ARMY = actions.FUNCTIONS.select_army.id 15 | _SELECT_ALL = [0] 16 | _NOT_QUEUED = [0] 17 | 18 | step_mul = 8 19 | 20 | FLAGS = flags.FLAGS 21 | 22 | # main function, create env, define model, learn from observation and save model 23 | def main(): 24 | FLAGS(sys.argv) 25 | with sc2_env.SC2Env(map_name="CollectMineralShards", step_mul=step_mul) as env: 26 | # CNN 27 | model = deepq.models.cnn_to_mlp( 28 | convs=[(32, 8, 4), (64, 4, 2), (64, 3, 1)], 29 | hiddens=[256], 30 | dueling=True 31 | ) 32 | 33 | act = deepq_mineral_shards.learn( 34 | env, 35 | q_func=model, 36 | num_actions=4, # 輸出的動作只有up, down, left, right 37 | lr=1e-5, 38 | max_timesteps=2000000, 39 | buffer_size=100000, 40 | exploration_fraction=0.5, 41 | exploration_final_eps=0.01, 42 | train_freq=4, 43 | learning_starts=100000, 44 | target_network_update_freq=1000, 45 | gamma=0.99, 46 | prioritized_replay=True 47 | ) 48 | act.save("mineral_shards.pkl") 49 | 50 | 51 | if __name__ == '__main__': 52 | main() -------------------------------------------------------------------------------- /scripted_agent/simple_agent_protoss.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from pysc2.agents import base_agent 3 | from pysc2.lib import actions 4 | from pysc2.lib import features 5 | 6 | import time 7 | 8 | # 取得Function id,使用actions.FunctionCall來呼叫Function 9 | _BUILD_PYLON = actions.FUNCTIONS.Build_Pylon_screen.id 10 | _BUILD_GATEWAY = actions.FUNCTIONS.Build_Gateway_screen.id 11 | _NOOP = actions.FUNCTIONS.no_op.id 12 | _SELECT_POINT = actions.FUNCTIONS.select_point.id 13 | _TRAIN_ZEALOT = actions.FUNCTIONS.Train_Zealot_quick.id 14 | _RALLY_UNITS_MINIMAP = actions.FUNCTIONS.Rally_Units_minimap.id 15 | _SELECT_ARMY = actions.FUNCTIONS.select_army.id 16 | _ATTACK_MINIMAP = actions.FUNCTIONS.Attack_minimap.id 17 | 18 | # 取得Screen Feature中的敵我資訊,以及單位的type id 19 | _PLAYER_RELATIVE = features.SCREEN_FEATURES.player_relative.index # value in [0, 4], denoting [background, self, ally, neutral, enemy] 20 | _UNIT_TYPE = features.SCREEN_FEATURES.unit_type.index # 超長的list,將所有單位的id都列了出來 21 | 22 | # Unit IDs 23 | # https://github.com/Blizzard/s2client-api/blob/master/include/sc2api/sc2_typeenums.h 24 | _PROTOSS_GATEWAY = 62 25 | _PROTOSS_NEXUS = 59 26 | _PROTOSS_PYLON = 60 27 | _PROTOSS_PROBE = 84 28 | 29 | # Parameters 30 | _PLAYER_SELF = 1 # 用於_PLAYER_RELATIVE中來取得自己的座標 31 | _SUPPLY_USED = 3 # https://github.com/deepmind/pysc2/blob/master/docs/environment.md 32 | _SUPPLY_MAX = 4 # 在Structured的General player information中,3為目前人口數,4為人口上限,用於還沒滿人口時就持續造兵 33 | _SCREEN = [0] 34 | _MINIMAP = [1] # 大地圖編號0,小地圖編號1 35 | _QUEUED = [1] # add to queue https://github.com/deepmind/pysc2/blob/master/pysc2/lib/actions.py#L204 36 | _SELECT_ALL = [0] 37 | 38 | class RuleBaseAgent(base_agent.BaseAgent): 39 | nexus_top_left = None 40 | pylon_built = False 41 | probe_selected = False 42 | gateway_built = False 43 | gateway_selected = False 44 | gateway_rallied = False 45 | army_selected = False 46 | army_rallied = False 47 | 48 | # 由於Simple64出生點只會是左上或右下,因此寫一個簡單的function輔助我們擺放建築物 49 | # 若不是在左上角,星核的位置就會是比較大的數字,因此使用減法萊表示建造的建築物是 50 | # 放在星核的左方或上方。反之若在左上,就建在星核的右方與下方。 51 | def transformLocation(self, x, x_distance, y, y_distance): 52 | if not self.nexus_top_left: 53 | return [x - x_distance, y - y_distance] 54 | else: 55 | return [x + x_distance, y + y_distance] 56 | 57 | # 改寫step函數,由一連串條件式組成來操控agent 58 | def step(self, obs): 59 | super(RuleBaseAgent, self).step(obs) 60 | 61 | # 觀察速度,若不設置會跑很快,想快速做實驗的話此行可以註解掉 62 | time.sleep(0.01) 63 | 64 | # 取得星核在左上或是右下的資訊 65 | if self.nexus_top_left is None: 66 | # 從observation的minimap的其中一個feature來取得自己基地的座標, 67 | # 由於知道地圖是Simple64,因此若在左上則y座標會小於31 68 | nexus_y, nexus_x = (obs.observation["feature_minimap"][_PLAYER_RELATIVE] == _PLAYER_SELF).nonzero() 69 | self.nexus_top_left = nexus_y.mean() <= 31 70 | 71 | # rule 1: 如果水晶塔還沒建造且探測機還沒被圈選,就圈選探測機 72 | # 如果水晶塔還沒建造但探測機已經被圈選了,就建造水晶塔 73 | if not self.pylon_built: 74 | if not self.probe_selected: 75 | unit_type = obs.observation["feature_screen"][_UNIT_TYPE] # 列出screen上所有單位 76 | probe_y, probe_x = (unit_type == _PROTOSS_PROBE).nonzero() 77 | 78 | target = [probe_x[0], probe_y[0]] # 選擇第一隻探測機 79 | 80 | self.probe_selected = True 81 | # select_point的arg需要的形式像是[[0], [23, 46]],[0]的意思是screen,後面的list代表座標 82 | return actions.FunctionCall(_SELECT_POINT, [_SCREEN, target]) 83 | 84 | # 看建造水晶塔在這個observation中是否為合法的action 85 | elif _BUILD_PYLON in obs.observation["available_actions"]: 86 | unit_type = obs.observation["feature_screen"][_UNIT_TYPE] 87 | nexus_y, nexus_x = (unit_type == _PROTOSS_NEXUS).nonzero() # 找出星核的位置 88 | 89 | # 找出星核上方或下方的位置並建造 90 | target = self.transformLocation(int(nexus_x.mean()), 0, int(nexus_y.mean()), 20) 91 | 92 | self.pylon_built = True 93 | return actions.FunctionCall(_BUILD_PYLON, [_SCREEN, target]) 94 | 95 | # rule 2: 如果水晶建造了但星門(兵營)還沒建造,則建造星門 96 | elif not self.gateway_built: 97 | if _BUILD_GATEWAY in obs.observation["available_actions"]: 98 | unit_type = obs.observation["feature_screen"][_UNIT_TYPE] 99 | pylon_y, pylon_x = (unit_type == _PROTOSS_PYLON).nonzero() 100 | 101 | target = self.transformLocation(int(pylon_x.mean()), 10, int(pylon_y.mean()), 0) 102 | 103 | # 確認星門有被建造,才停止這個rule 104 | if (unit_type == _PROTOSS_GATEWAY).any(): 105 | self.gateway_built = True 106 | 107 | return actions.FunctionCall(_BUILD_GATEWAY, [_SCREEN, target]) 108 | 109 | # rule 3: 如果水晶、星門都建造了,則派兵駐守斜坡(斜坡座標用hardcode的形式寫下) 110 | elif not self.gateway_rallied: 111 | # 必須先選擇gateway,才能設置集合點 112 | if not self.gateway_selected: 113 | unit_type = obs.observation["feature_screen"][_UNIT_TYPE] 114 | gateway_y, gateway_x = (unit_type == _PROTOSS_GATEWAY).nonzero() 115 | 116 | # 確認有選擇到星門 117 | if gateway_y.any(): 118 | target = [int(gateway_x.mean()), int(gateway_y.mean())] 119 | self.gateway_selected = True 120 | return actions.FunctionCall(_SELECT_POINT, [_SCREEN, target]) 121 | else: 122 | self.gateway_rallied = True 123 | if self.nexus_top_left: 124 | return actions.FunctionCall(_RALLY_UNITS_MINIMAP, [_MINIMAP, [29, 21]]) 125 | else: 126 | return actions.FunctionCall(_RALLY_UNITS_MINIMAP, [_MINIMAP, [29, 46]]) 127 | 128 | # rule 4: 如果人口還沒達到上限,就一直訓練狂戰士 129 | elif obs.observation["player"][_SUPPLY_USED] < obs.observation["player"][_SUPPLY_MAX] and \ 130 | _TRAIN_ZEALOT in obs.observation["available_actions"]: 131 | return actions.FunctionCall(_TRAIN_ZEALOT, [_QUEUED]) # [1]代表True,也就是選擇訓練狂戰士 132 | 133 | # rule 5: 如果人口滿了,則進攻對方基地 134 | elif not self.army_rallied: # 軍隊集結 135 | if not self.army_selected: # 圈選軍隊 136 | if _SELECT_ARMY in obs.observation["available_actions"]: 137 | self.army_selected = True 138 | 139 | return actions.FunctionCall(_SELECT_ARMY, [_SELECT_ALL]) 140 | elif _ATTACK_MINIMAP in obs.observation["available_actions"]: 141 | self.army_rallied = True 142 | 143 | # 進攻與自己星核相對的位置(hardcode) 144 | if self.nexus_top_left: 145 | self.army_selected = False 146 | self.army_rallied = False 147 | return actions.FunctionCall(_ATTACK_MINIMAP, [_MINIMAP, [39, 45]]) 148 | else: 149 | self.army_selected = False 150 | self.army_rallied = False 151 | return actions.FunctionCall(_ATTACK_MINIMAP, [_MINIMAP, [21, 24]]) 152 | 153 | # 如果現在的observation不符合任一條規則,則什麼都不做 154 | return actions.FunctionCall(_NOOP, []) 155 | -------------------------------------------------------------------------------- /dqn_agent/deepq_mineral_shards.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import numpy as np 3 | import os 4 | import dill 5 | import tempfile 6 | import tensorflow as tf 7 | import zipfile 8 | 9 | import baselines.common.tf_util as U 10 | 11 | from baselines import logger 12 | from baselines.common.schedules import LinearSchedule 13 | from baselines import deepq 14 | from baselines.deepq.replay_buffer import ReplayBuffer, PrioritizedReplayBuffer 15 | 16 | from pysc2.lib import actions as sc2_actions 17 | from pysc2.env import environment 18 | from pysc2.lib import features 19 | from pysc2.lib import actions 20 | 21 | from absl import flags 22 | 23 | # 這段code的整體架構參考: 24 | # https://github.com/openai/baselines/blob/master/baselines/deepq/simple.py 25 | 26 | # Define constant and function id 27 | _PLAYER_RELATIVE = features.SCREEN_FEATURES.player_relative.index 28 | _PLAYER_FRIENDLY = 1 29 | 30 | _NO_OP = actions.FUNCTIONS.no_op.id 31 | _MOVE_SCREEN = actions.FUNCTIONS.Move_screen.id 32 | _ATTACK_SCREEN = actions.FUNCTIONS.Attack_screen.id 33 | _SELECT_ARMY = actions.FUNCTIONS.select_army.id 34 | 35 | _NOT_QUEUED = [0] 36 | _SELECT_ALL = [0] 37 | 38 | FLAGS = flags.FLAGS 39 | 40 | 41 | 42 | # 參考baselines裡的ActWrapper進行簡單改寫 43 | class ActWrapper(object): 44 | def __init__(self, act): 45 | self._act = act 46 | 47 | @staticmethod 48 | def load(path, act_params, num_cpu=16): 49 | with open(path, "rb") as f: 50 | model_data = dill.load(f) 51 | act = deepq.build_act(**act_params) 52 | sess = U.make_session(num_cpu=num_cpu) 53 | sess.__enter__() 54 | with tempfile.TemporaryDirectory() as td: 55 | arc_path = os.path.join(td, "packed.zip") 56 | with open(arc_path, "wb") as f: 57 | f.write(model_data) 58 | 59 | zipfile.ZipFile(arc_path, 'r', zipfile.ZIP_DEFLATED).extractall(td) 60 | U.load_state(os.path.join(td, "model")) 61 | 62 | return ActWrapper(act) 63 | 64 | def __call__(self, *args, **kwargs): 65 | return self._act(*args, **kwargs) 66 | 67 | def save(self, path): 68 | """Save model to a pickle located at `path`""" 69 | with tempfile.TemporaryDirectory() as td: 70 | U.save_state(os.path.join(td, "model")) 71 | arc_name = os.path.join(td, "packed.zip") 72 | with zipfile.ZipFile(arc_name, 'w') as zipf: 73 | for root, dirs, files in os.walk(td): 74 | for fname in files: 75 | file_path = os.path.join(root, fname) 76 | if file_path != arc_name: 77 | zipf.write(file_path, os.path.relpath(file_path, td)) 78 | with open(arc_name, "rb") as f: 79 | model_data = f.read() 80 | with open(path, "wb") as f: 81 | dill.dump((model_data), f) 82 | 83 | 84 | 85 | def load(path, act_params, num_cpu=16): 86 | return ActWrapper.load(path, act_params=act_params, num_cpu=num_cpu) 87 | 88 | 89 | 90 | def learn(env, 91 | q_func, 92 | num_actions=4, 93 | lr=5e-4, 94 | max_timesteps=100000, 95 | buffer_size=50000, 96 | exploration_fraction=0.1, 97 | exploration_final_eps=0.02, 98 | train_freq=1, 99 | batch_size=32, 100 | print_freq=1, 101 | checkpoint_freq=10000, 102 | learning_starts=1000, 103 | gamma=1.0, 104 | target_network_update_freq=500, 105 | prioritized_replay=False, 106 | prioritized_replay_alpha=0.6, 107 | prioritized_replay_beta0=0.4, 108 | prioritized_replay_beta_iters=None, 109 | prioritized_replay_eps=1e-6, 110 | num_cpu=16, 111 | param_noise=False, 112 | param_noise_threshold=0.05, 113 | callback=None): 114 | 115 | # Create all the functions necessary to train the model 116 | 117 | # Returns a session that will use CPU's only 118 | sess = U.make_session(num_cpu=num_cpu) 119 | sess.__enter__() 120 | 121 | # Creates a placeholder for a batch of tensors of a given shape and dtyp 122 | def make_obs_ph(name): 123 | return U.BatchInput((64,64), name=name) 124 | 125 | # act, train, update_target are function, debug is dict 126 | act, train, update_target, debug = deepq.build_train( 127 | make_obs_ph=make_obs_ph, 128 | q_func=q_func, 129 | num_actions=num_actions, 130 | optimizer=tf.train.AdamOptimizer(learning_rate=lr), 131 | gamma=gamma, 132 | grad_norm_clipping=10 133 | ) 134 | 135 | # Choose use prioritized replay buffer or normal replay buffer 136 | if prioritized_replay: 137 | replay_buffer = PrioritizedReplayBuffer(buffer_size, alpha=prioritized_replay_alpha) 138 | if prioritized_replay_beta_iters is None: 139 | prioritized_replay_beta_iters = max_timesteps 140 | beta_schedule = LinearSchedule(prioritized_replay_beta_iters, 141 | initial_p=prioritized_replay_beta0, 142 | final_p=1.0) 143 | else: 144 | replay_buffer = ReplayBuffer(buffer_size) 145 | beta_schedule = None 146 | 147 | # Create the schedule for exploration starting from 1 148 | exploration = LinearSchedule(schedule_timesteps=int(exploration_fraction * max_timesteps), 149 | initial_p=1.0, 150 | final_p=exploration_final_eps) 151 | 152 | # SC2的部分開始 153 | 154 | # Initialize the parameters and copy them to the target network. 155 | U.initialize() 156 | update_target() 157 | 158 | episode_rewards = [0.0] 159 | saved_mean_reward = None 160 | 161 | path_memory = np.zeros((64, 64)) 162 | 163 | obs = env.reset() 164 | 165 | # Select all marines 166 | obs = env.step(actions=[sc2_actions.FunctionCall(_SELECT_ARMY, [_SELECT_ALL])]) 167 | 168 | # obs is tuple, obs[0] is 'pysc2.env.environment.TimeStep', obs[0].observation is dictionary. 169 | player_relative = obs[0].observation["screen"][_PLAYER_RELATIVE] 170 | 171 | # 利用path memory記憶曾經走過的軌跡 172 | screen = player_relative + path_memory 173 | 174 | # 取得兩個陸戰隊的中心位置 175 | player_y, player_x = (player_relative == _PLAYER_FRIENDLY).nonzero() 176 | player = [int(player_x.mean()), int(player_y.mean())] 177 | 178 | 179 | reset = True 180 | with tempfile.TemporaryDirectory() as td: 181 | model_saved = False 182 | model_file = os.path.join(td, "model") 183 | 184 | for t in range(max_timesteps): 185 | if callback is not None: 186 | if callback(locals(), globals()): 187 | break 188 | # Take action and update exploration to the newest value 189 | kwargs = {} 190 | if not param_noise: 191 | update_eps = exploration.value(t) 192 | update_param_noise_threshold = 0. 193 | else: 194 | update_eps = 0. 195 | if param_noise_threshold >= 0.: 196 | update_param_noise_threshold = param_noise_threshold 197 | else: 198 | update_param_noise_threshold = -np.log(1. - exploration.value(t) + exploration.value(t) / float(num_actions)) 199 | kwargs['reset'] = reset 200 | kwargs['update_param_noise_threshold'] = update_param_noise_threshold 201 | kwargs['update_param_noise_scale'] = True 202 | # np.array()[None] 是指多包一個維度在外面 e.g. [1] -> [[1]] 203 | action = act(np.array(screen)[None], update_eps=update_eps, **kwargs)[0] 204 | reset = False 205 | 206 | coord = [player[0], player[1]] 207 | rew = 0 208 | 209 | # 只有四個action,分別是上下左右,走過之後在路徑上留下一整排-3,目的是與水晶碎片的id(=3)相抵銷,代表有順利採集到。 210 | path_memory_ = np.array(path_memory, copy=True) 211 | if (action == 0): # UP 212 | 213 | if (player[1] >= 16): 214 | coord = [player[0], player[1] - 16] 215 | path_memory_[player[1] - 16: player[1], player[0]] = -3 216 | elif (player[1] > 0): 217 | coord = [player[0], 0] 218 | path_memory_[0 : player[1], player[0]] = -3 219 | 220 | elif (action == 1): # DOWN 221 | 222 | if (player[1] <= 47): 223 | coord = [player[0], player[1] + 16] 224 | path_memory_[player[1] : player[1] + 16, player[0]] = -3 225 | elif (player[1] > 47): 226 | coord = [player[0], 63] 227 | path_memory_[player[1] : 63, player[0]] = -3 228 | 229 | elif (action == 2): # LEFT 230 | 231 | if (player[0] >= 16): 232 | coord = [player[0] - 16, player[1]] 233 | path_memory_[player[1], player[0] - 16 : player[0]] = -3 234 | elif (player[0] < 16): 235 | coord = [0, player[1]] 236 | path_memory_[player[1], 0 : player[0]] = -3 237 | 238 | elif (action == 3): # RIGHT 239 | 240 | if (player[0] <= 47): 241 | coord = [player[0] + 16, player[1]] 242 | path_memory_[player[1], player[0] : player[0] + 16] = -3 243 | elif (player[0] > 47): 244 | coord = [63, player[1]] 245 | path_memory_[player[1], player[0] : 63] = -3 246 | 247 | # 更新path_memory 248 | path_memory = np.array(path_memory_) 249 | 250 | # 如果不能移動陸戰隊,想必是還沒圈選到陸戰隊,圈選他們 251 | if _MOVE_SCREEN not in obs[0].observation["available_actions"]: 252 | obs = env.step(actions=[sc2_actions.FunctionCall(_SELECT_ARMY, [_SELECT_ALL])]) 253 | 254 | # 移動陸戰隊 255 | new_action = [sc2_actions.FunctionCall(_MOVE_SCREEN, [_NOT_QUEUED, coord])] 256 | 257 | # 取得環境給的observation 258 | obs = env.step(actions=new_action) 259 | 260 | # 這裡要重新取得player_relative,因為上一行的obs是個有複數資訊的tuple 261 | # 但我們要存入replay_buffer的只有降維後的screen畫面 262 | player_relative = obs[0].observation["screen"][_PLAYER_RELATIVE] 263 | new_screen = player_relative + path_memory 264 | 265 | # 取得reward 266 | rew = obs[0].reward 267 | 268 | # StepType.LAST 代表done的意思 269 | done = obs[0].step_type == environment.StepType.LAST 270 | 271 | # Store transition in the replay buffer 272 | replay_buffer.add(screen, action, rew, new_screen, float(done)) 273 | 274 | # 確實存入之後就能以新screen取代舊screen 275 | screen = new_screen 276 | 277 | episode_rewards[-1] += rew 278 | 279 | if done: 280 | # 重新取得敵我中立關係位置圖 281 | obs = env.reset() 282 | # player_relative = obs[0].observation["screen"][_PLAYER_RELATIVE] 283 | 284 | # # 還是看不懂為何要加上path_memory 285 | # screen = player_relative + path_memory 286 | 287 | # player_y, player_x = (player_relative == _PLAYER_FRIENDLY).nonzero() 288 | # player = [int(player_x.mean()), int(player_y.mean())] 289 | 290 | # # 圈選全部的陸戰隊(為何要在done observation做這件事情?) 291 | # env.step(actions=[sc2_actions.FunctionCall(_SELECT_ARMY, [_SELECT_ALL])]) 292 | episode_rewards.append(0.0) 293 | 294 | # 清空path_memory 295 | path_memory = np.zeros((64, 64)) 296 | 297 | reset = True 298 | 299 | # 定期從replay buffer中抽experience來訓練,以及train target network 300 | if t > learning_starts and t % train_freq == 0: 301 | # Minimize the error in Bellman's equation on a batch sampled from replay buffer. 302 | if prioritized_replay: 303 | experience = replay_buffer.sample(batch_size, beta=beta_schedule.value(t)) 304 | (obses_t, actions, rewards, obses_tp1, dones, weights, batch_idxes) = experience 305 | else: 306 | obses_t, actions, rewards, obses_tp1, dones = replay_buffer.sample(batch_size) 307 | weights, batch_idxes = np.ones_like(rewards), None 308 | # 這裡的train來自deepq.build_train 309 | td_errors = train(obses_t, actions, rewards, obses_tp1, dones, weights) 310 | if prioritized_replay: 311 | new_priorities = np.abs(td_errors) + prioritized_replay_eps 312 | replay_buffer.update_priorities(batch_idxes, new_priorities) 313 | 314 | # target network 315 | if t > learning_starts and t % target_network_update_freq == 0: 316 | # 同樣來自deepq.build_train 317 | # Update target network periodically 318 | update_target() 319 | 320 | # 下LOG追蹤reward 321 | mean_100ep_reward = round(np.mean(episode_rewards[-101:-1]), 1) 322 | num_episodes = len(episode_rewards) 323 | if done and print_freq is not None and len(episode_rewards) % print_freq == 0: 324 | logger.record_tabular("steps", t) 325 | logger.record_tabular("episodes", num_episodes) 326 | logger.record_tabular("mean 100 episode", mean_100ep_reward) 327 | logger.record_tabular("% time spent exploring", int(100 * exploration.value(t))) 328 | logger.dump_tabular() 329 | 330 | # 當model進步時,就存檔下來 331 | if (checkpoint_freq is not None and t > learning_starts and 332 | num_episodes > 100 and t % checkpoint_freq == 0): 333 | if saved_mean_reward is None or mean_100ep_reward > saved_mean_reward: 334 | if print_freq is not None: 335 | logger.log("Saving model due to mean reward increase: {} -> {}".format( 336 | saved_mean_reward, mean_100ep_reward)) 337 | U.save_state(model_file) 338 | model_saved = True 339 | saved_mean_reward = mean_100ep_reward 340 | if model_saved: 341 | if print_freq is not None: 342 | logger.log("Restored model with mean reward: {}".format(saved_mean_reward)) 343 | U.load_state(model_file) 344 | 345 | return ActWrapper(act) --------------------------------------------------------------------------------