├── Melee └── FindAndDefeatDronesAA.SC2Map ├── mini_games └── FindAndDefeatDrones.SC2Map ├── random_agent1.py ├── sc2Processor.py ├── run_loop1.py ├── sc2Policy.py ├── run_loop_xun.py ├── TestScripted_V2.py ├── TestScripted_V1.py ├── README.md ├── prioReplayBuffer.py ├── exec_2agents.py ├── env.py ├── agent.py ├── scripted_agent1.py ├── sc2DqnAgent.py └── sc2_env_xun.py /Melee/FindAndDefeatDronesAA.SC2Map: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xunger99/SAAC-StarCraft-Adversary-Agent-Challenge/HEAD/Melee/FindAndDefeatDronesAA.SC2Map -------------------------------------------------------------------------------- /mini_games/FindAndDefeatDrones.SC2Map: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xunger99/SAAC-StarCraft-Adversary-Agent-Challenge/HEAD/mini_games/FindAndDefeatDrones.SC2Map -------------------------------------------------------------------------------- /random_agent1.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS-IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """A random agent for starcraft.""" 15 | 16 | # updated from pysc2 code. xun, June 17, 2021 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import numpy 23 | 24 | from pysc2.agents import base_agent 25 | from pysc2.lib import actions 26 | import pdb 27 | 28 | class RandomAgent1(base_agent.BaseAgent): 29 | """A random agent for starcraft.""" 30 | 31 | def step(self, obs): 32 | super(RandomAgent1, self).step(obs) 33 | pdb.set_trace() 34 | function_id = numpy.random.choice(obs.observation.available_actions) 35 | args = [[numpy.random.randint(0, size) for size in arg.sizes] 36 | for arg in self.action_spec.functions[function_id].args] 37 | return actions.FunctionCall(function_id, args) 38 | -------------------------------------------------------------------------------- /sc2Processor.py: -------------------------------------------------------------------------------- 1 | from rl.core import Processor 2 | import numpy as np 3 | import pdb 4 | 5 | class Sc2Processor(Processor): 6 | 7 | def __init__(self, screen=32): 8 | super(Processor, self).__init__() 9 | self._SCREEN = screen 10 | 11 | def process_state_batch(self, batch): 12 | # reshape cause batch is (bs, 1, 2, screen, screen) 13 | #pdb.set_trace() 14 | size_first_dim = len(batch) 15 | size_second_dim = len(batch[0,]) 16 | 17 | return np.reshape(batch, (size_first_dim, size_second_dim, self._SCREEN, self._SCREEN)) 18 | 19 | # observation, reward, done, info = env.step(action) 20 | def process_observation(self, observation): 21 | 22 | # small_observation = observation[0].observation["feature_screen"][5] 23 | 24 | # print(smallObservation, observation[0].reward, observation[0].last(), "lol") 25 | 26 | # small_observation = small_observation.reshape(1, small_observation.shape[0], small_observation.shape[0], 1) 27 | 28 | # print(smallObservation.shape) 29 | 30 | # print(smallObservation, observation[0].reward, observation[0].last(), "lol") 31 | 32 | # fix dim from 1 1 2 16 16 to 1 2 16 16 33 | # observation = observation[0] 34 | 35 | return observation 36 | 37 | # assert observation.ndim == 3 # (height, width, channel) 38 | # img = Image.fromarray(observation) 39 | # img = img.resize(INPUT_SHAPE).convert('L') # resize and convert to grayscale 40 | # processed_observation = np.array(img) 41 | # assert processed_observation.shape == INPUT_SHAPE 42 | # return processed_observation.astype('uint8') # saves storage in experience memory 43 | 44 | -------------------------------------------------------------------------------- /run_loop1.py: -------------------------------------------------------------------------------- 1 | # The code is developed based on run_loop from pysc2. 2 | # Xun, June 17, 2021. 3 | 4 | """A run loop for agent/environment interaction.""" 5 | 6 | from __future__ import absolute_import 7 | from __future__ import division 8 | from __future__ import print_function 9 | 10 | import time 11 | import pdb 12 | 13 | 14 | def run_loop1(agents, env, max_frames=0, max_episodes=0): 15 | """A run loop to have agents and an environment interact.""" 16 | total_frames = 0 17 | total_episodes = 0 18 | start_time = time.time() 19 | 20 | observation_spec = env.observation_spec() 21 | action_spec = env.action_spec() 22 | 23 | for agent, obs_spec, act_spec in zip(agents, observation_spec, action_spec): 24 | agent.setup(obs_spec, act_spec) 25 | 26 | try: 27 | while not max_episodes or total_episodes < max_episodes: 28 | total_episodes += 1 29 | timesteps = env.reset() 30 | for a in agents: 31 | a.reset() 32 | while True: 33 | total_frames += 1 34 | actions = [agent.step(timestep) 35 | for agent, timestep in zip(agents, timesteps)] 36 | # Disabled by Xun, cosz the code seems to contain a bug. 37 | #if max_frames and total_frames >= max_frames: 38 | #pdb.set_trace() 39 | # return 40 | #if 360 - total_frames <= 1: # by xun, try this code, here 360 = 2880/step_mul, the latter is 8 41 | #pdb.set_trace() 42 | #total_frames = 0 43 | #break 44 | 45 | if timesteps[0].last(): 46 | #pdb.set_trace() 47 | break 48 | timesteps = env.step(actions) 49 | except KeyboardInterrupt: 50 | pass 51 | finally: 52 | elapsed_time = time.time() - start_time 53 | print("Took %.3f seconds for %s steps: %.3f fps" % ( 54 | elapsed_time, total_frames, total_frames / elapsed_time)) -------------------------------------------------------------------------------- /sc2Policy.py: -------------------------------------------------------------------------------- 1 | # map the dqn output to actions understandable by pysc2 2 | 3 | from rl.policy import Policy 4 | import numpy as np 5 | from sc2DqnAgent import Sc2Action 6 | import pdb 7 | 8 | class Sc2Policy(Policy): 9 | 10 | def __init__(self, env, nb_actions=3, eps=0.1, testing=False): 11 | super(Sc2Policy, self).__init__() 12 | self.eps = eps 13 | self.nb_pixels = env._SCREEN 14 | #pdb.set_trace() 15 | self.nb_actions = nb_actions 16 | self.testing = testing 17 | 18 | def select_action(self, q_values): 19 | """Return the selected action 20 | 21 | # Arguments 22 | q_values (numpy array of shape (2, ?)): 23 | one List of q-estimates for action-selection 24 | one array of shape (screensize, screensize) for position selection 25 | 26 | # Returns 27 | Selection action 28 | """ 29 | 30 | action = Sc2Action() 31 | 32 | # Epsilon-Greedy 33 | # pdb.set_trace() 34 | egran=np.random.uniform() 35 | if egran < self.eps and not self.testing: 36 | action.action = np.random.random_integers(0, self.nb_actions-1) 37 | action.coords = (np.random.random_integers(0, self.nb_pixels-1), np.random.random_integers(0, self.nb_pixels-1)) 38 | if self.eps <0.05: 39 | print('eps:',self.eps) 40 | 41 | else: 42 | # greedy. 43 | action.action = np.argmax(q_values[0]) 44 | #pdb.set_trace() 45 | action.coords = np.unravel_index(q_values[1].argmax(), q_values[1].shape)[1:3] 46 | 47 | # action.coords = np.unravel_index(np.reshape(q_values[1][0][:][:], (16, 16)).argmax(), np.reshape( 48 | 49 | assert len(action.coords) == 2 50 | 51 | return action 52 | 53 | def get_config(self): 54 | """Return configurations of EpsGreedyPolicy 55 | 56 | # Returns 57 | Dict of config 58 | """ 59 | config = super(Sc2Policy, self).get_config() 60 | config['eps'] = self.eps 61 | config['testing'] = self.testing 62 | return config 63 | 64 | 65 | -------------------------------------------------------------------------------- /run_loop_xun.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS-IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """A run loop for agent/environment interaction.""" 15 | 16 | # The code has been modified to resolve the issue for 2-players. xun, 2021 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import time 23 | import pdb 24 | 25 | 26 | 27 | def run_loop_xun(agents, env, max_frames=0, max_episodes=0): 28 | """A run loop to have agents and an environment interact.""" 29 | 30 | total_frames = 0 31 | total_episodes = 0 32 | start_time = time.time() 33 | #pdb.set_trace() 34 | observation_spec = env.observation_spec() 35 | action_spec = env.action_spec() 36 | for agent, obs_spec, act_spec in zip(agents, observation_spec, action_spec): 37 | agent.setup(obs_spec, act_spec) 38 | 39 | try: 40 | while not max_episodes or total_episodes < max_episodes: 41 | total_episodes += 1 42 | #pdb.set_trace() 43 | timesteps = env.reset() 44 | for a in agents: 45 | a.reset() 46 | while True: 47 | total_frames += 1 48 | 49 | actions = [agent.step(timestep) for agent, timestep in zip(agents, timesteps)] 50 | #pdb.set_trace() 51 | if max_frames and total_frames >= max_frames: 52 | return 53 | 54 | if timesteps[0].last(): 55 | #pdb.set_trace() 56 | break 57 | 58 | #pdb.set_trace() 59 | # The key step 60 | timesteps = env.step(actions) 61 | 62 | except KeyboardInterrupt: 63 | pass 64 | finally: 65 | elapsed_time = time.time() - start_time 66 | print("Took %.3f seconds for %s steps: %.3f fps" % ( 67 | elapsed_time, total_frames, total_frames / elapsed_time)) 68 | -------------------------------------------------------------------------------- /TestScripted_V2.py: -------------------------------------------------------------------------------- 1 | # 2 | # Xun's fist script code, updated from pysc2/tests code. June 16, 2021. 3 | # 4 | # (1) Specifically for the map entitled FindAndDefeatZergling2 (single player case). 5 | # 6 | 7 | from __future__ import absolute_import 8 | from __future__ import division 9 | from __future__ import print_function 10 | 11 | from pysc2.agents import random_agent 12 | #from pysc2.env import run_loop 13 | from pysc2.env import sc2_env 14 | from pysc2.tests import utils 15 | from pysc2.lib import actions as sc2_actions 16 | from pysc2.lib import features, units 17 | 18 | from absl import flags 19 | from absl.testing import absltest 20 | import sys 21 | 22 | import pdb 23 | import time 24 | 25 | # Imported by Xun 26 | import run_loop1 27 | import random_agent1 28 | import scripted_agent1 29 | 30 | 31 | 32 | _NO_OP = sc2_actions.FUNCTIONS.no_op.id 33 | _PLAYER_RELATIVE = features.SCREEN_FEATURES.player_relative.index 34 | 35 | 36 | FLAGS = flags.FLAGS 37 | FLAGS(sys.argv) 38 | 39 | class TestScripted(utils.TestCase): 40 | steps = 2880 41 | step_mul = 8 42 | episodes = 10 43 | 44 | 45 | def test_defeat_zerglings(self): 46 | agent_format = sc2_env.AgentInterfaceFormat( 47 | feature_dimensions=sc2_env.Dimensions( 48 | screen=(32,32), 49 | minimap=(32,32), 50 | ), 51 | use_raw_units=True, 52 | use_feature_units=True 53 | ) 54 | #,sc2_env.Bot(sc2_env.Race.zerg,sc2_env.Difficulty.very_hard) 55 | #sc2_env.Bot(sc2_env.Race.zerg)], # 56 | with sc2_env.SC2Env( 57 | map_name="FindAndDefeatDrones", 58 | players=[sc2_env.Agent(sc2_env.Race.terran)], 59 | step_mul=self.step_mul, 60 | disable_fog=True, #False, # 61 | visualize=False, #True, 62 | agent_interface_format=[agent_format], 63 | game_steps_per_episode=2880)as env: # self.steps * self.step_mul) as env: 64 | 65 | 66 | obs = env.step(actions=[sc2_actions.FunctionCall(_NO_OP, [])]) 67 | player_relative = obs[0].observation["feature_screen"][_PLAYER_RELATIVE] 68 | 69 | # Break Point!! 70 | #pdb.set_trace() 71 | print(player_relative) 72 | 73 | #agent = random_agent1.RandomAgent1() Enable random agent 74 | # Instead, enable scripted agent 75 | # agent=scripted_agent1.FindAndDefeatZergling() 76 | agent=scripted_agent1.FindAndDefeatZergling_4() # the script for void_ray 77 | #agent =random_agent.RandomAgent 78 | agent2=random_agent.RandomAgent 79 | 80 | #pdb.set_trace() 81 | run_loop1.run_loop1([agent], env, self.steps, self.episodes) #agent,agent2] 82 | 83 | #pdb.set_trace() 84 | #self.tearDown() 85 | # self.assertEqual(agent.steps, self.steps) 86 | 87 | 88 | 89 | 90 | 91 | 92 | 93 | if __name__ == "__main__": 94 | absltest.main() 95 | 96 | 97 | 98 | 99 | -------------------------------------------------------------------------------- /TestScripted_V1.py: -------------------------------------------------------------------------------- 1 | # 2 | # Xun's fist script code, updated from pysc2/tests code. June 16, 2021. 3 | # 4 | # (1) Specifically for the map entitled FindAndDefeatZergling (single player case). 5 | # 6 | 7 | from __future__ import absolute_import 8 | from __future__ import division 9 | from __future__ import print_function 10 | 11 | from pysc2.agents import random_agent 12 | #from pysc2.env import run_loop 13 | from pysc2.env import sc2_env 14 | from pysc2.tests import utils 15 | from pysc2.lib import actions as sc2_actions 16 | from pysc2.lib import features, units 17 | 18 | from absl import flags 19 | from absl.testing import absltest 20 | import sys 21 | 22 | import pdb 23 | import time 24 | 25 | # Imported by Xun 26 | import run_loop1 27 | import random_agent1 28 | import scripted_agent1 29 | 30 | 31 | 32 | _NO_OP = sc2_actions.FUNCTIONS.no_op.id 33 | _PLAYER_RELATIVE = features.SCREEN_FEATURES.player_relative.index 34 | 35 | 36 | FLAGS = flags.FLAGS 37 | FLAGS(sys.argv) 38 | 39 | class TestScripted(utils.TestCase): 40 | steps = 2880 41 | step_mul = 8 42 | episodes = 10 43 | 44 | 45 | def test_defeat_zerglings(self): 46 | agent_format = sc2_env.AgentInterfaceFormat( 47 | feature_dimensions=sc2_env.Dimensions( 48 | screen=(32,32), 49 | minimap=(32,32), 50 | ), 51 | use_raw_units=True, 52 | use_feature_units=True 53 | ) 54 | #pdb.set_trace() #"FindAndDefeatZerglings", #"Empty_xun1" ,#"FindAndDefeatZerglings2", 55 | #,sc2_env.Bot(sc2_env.Race.zerg,sc2_env.Difficulty.very_hard) 56 | #sc2_env.Bot(sc2_env.Race.zerg)], # 57 | with sc2_env.SC2Env( 58 | map_name="FindAndDefeatZerglings", 59 | players=[sc2_env.Agent(sc2_env.Race.terran)], 60 | step_mul=self.step_mul, 61 | disable_fog=True, #False, #True, 62 | visualize=False, #True, 63 | agent_interface_format=[agent_format], 64 | game_steps_per_episode=2880)as env: # self.steps * self.step_mul) as env: 65 | 66 | 67 | obs = env.step(actions=[sc2_actions.FunctionCall(_NO_OP, [])]) 68 | player_relative = obs[0].observation["feature_screen"][_PLAYER_RELATIVE] 69 | 70 | # Break Point!! 71 | #pdb.set_trace() 72 | print(player_relative) 73 | 74 | #agent = random_agent1.RandomAgent1() Enable random agent 75 | 76 | # Instead, enable scripted agent 77 | # agent=scripted_agent1.FindAndDefeatZergling() 78 | 79 | agent=scripted_agent1.FindAndDefeatZergling_4() 80 | #agent =random_agent.RandomAgent 81 | agent2=random_agent.RandomAgent 82 | 83 | #pdb.set_trace() 84 | run_loop1.run_loop1([agent], env, self.steps, self.episodes) #agent,agent2] 85 | 86 | #pdb.set_trace() 87 | #self.tearDown() 88 | # self.assertEqual(agent.steps, self.steps) 89 | 90 | 91 | 92 | 93 | 94 | 95 | 96 | if __name__ == "__main__": 97 | absltest.main() 98 | 99 | 100 | 101 | 102 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SAAC-StarCraft-Adversary-Agent-Challenge 2 | 3 | # Paper 4 | A draft that explains the detailed design of the code is arXiv:submit/3894432 (https://arxiv.org/submit/3894432/view), "Adversary-Agent Reinforcement Learning for Pursuit–Evasion" by Prof. Xun Huang, Aug 2021. If you have used this code in your research work, please cite the paper. 5 | 6 | 7 | 8 | # Introduction 9 | A reinforcement learning environment with adversary agents is proposed in this work for pursuit--evasion game in the presence of fog of war, which is of both scientific significance and practical importance in aerospace applications. One of the most popular learning environments, StarCraft, is adopted here and the associated mini-games are analyzed to identify the current limitation for training adversary agents. The key contribution includes the analysis of the potential performance of an agent by incorporating control and differential game theory into the specific reinforcement learning environment, and the development of an adversary-agents challenge (SAAC) environment by extending the current StarCraft mini-games. The subsequent study showcases the use of this learning environment and the effectiveness of an adversary agent for evasion units. Overall, the proposed SAAC environment should benefit pursuit--evasion studies with rapidly-emerging reinforcement learning technologies. Last but not least, the corresponding tutorial code can be found at GitHub. 10 | 11 | 12 | 13 | # Installation steps: 14 | 1. conda, (python 2.8, 2.7, 2.6 all work) 15 | 2. pip tensorflow-gpu (2.5.0, without GPU is OK too) 16 | 3. keras (2.4.3) 17 | 4. PySC2 (3.0.0) 18 | 5. baselines (0.1.6) 19 | 6. battle.net + download maps 20 | 7. Download files from this folder 21 | * The above steps have been tested on Mac OSX/Windows 10/Ubuntu platforms. If you met installation problems, please google solutions. 22 | 23 | ## Tips 24 | 1. Currrntly, I provide three tests and two mini-game maps. 25 | Map 1: FindAndDefeatDrones.SC2Map, single-player game. Please copy this file to the mini_games folder of the StarCtaft installation folder on your computer. 26 | Map 2: FindAndDefeatDronesAA.SC2Map, double-player game. Please copy this file to the melee folder of the StarCtaft installation folder on your computer. 27 | 2. Then, open the files mini_games.py and melee.py on your pysc2/maps folder, include the names of these two games therein. Then, the SC2 env will be adble to load the new games. 28 | 3. You can edit these two games by map editor, and follow the above steps to run your own games. 29 | 30 | 31 | # Running cases: 32 | 1. python TestScripted_V1.py: 33 | Running the classical FindAndDefeatZergling mini-game, fog of war is deactivated (only for readers/users to gain a better understanding of what is going on), and the pursuit agent is from my script, where the traversal coordinates are obtained through a former learning. It can be seen that the mean score achieved is around 40, which already beats a couple of AI agents from other works. 34 | 35 | 2. python TestScripted_V2.py: 36 | Running the new FindAndDefeatDrone mini-game, to test the game/unit set-ups are OK. 37 | 38 | 3. python exec_2agents.py: 39 | This is to run the mini-game that supports two agents/interfaces. 40 | 41 | # Code issues: 42 | It is well known in the StarCraft programming community that the current PySC2 interface could produce websocket errors during the low-level message passing between multiple agent interfaces. To bypass this issue, a thorough programming debug has been conducted in this work to identify the corresponding code. Then, a temporary fix has been adopted to rectify the issue before any official fix is available from DeepMind in the near future. 43 | 44 | Details are: 45 | 1. The episode steps are 2880, but will end incorrectly when multi-interface is chosen. Hence, set 2880 to 2872 in the code to bypass the issue. 46 | 2. In agent.py, include the following fix, 47 | try: 48 | observation = deepcopy(env.reset()) 49 | except protocol.ConnectionError: 50 | # pdb.set_trace() 51 | env.close() 52 | observation = deepcopy(env.start()) 53 | 3. In env.py, include the following fix: 54 | try: 55 | observation = self.env.step(actions) 56 | except protocol.ConnectionError: 57 | #pdb.set_trace() 58 | self.close() 59 | #self.start() 60 | observation = self.start() 61 | 62 | 63 | -------------------------------------------------------------------------------- /prioReplayBuffer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import random 3 | from baselines.common.segment_tree import SumSegmentTree, MinSegmentTree 4 | 5 | 6 | class ReplayBuffer(object): 7 | def __init__(self, size): 8 | """Create Replay buffer. 9 | 10 | Parameters 11 | ---------- 12 | size: int 13 | Max number of transitions to store in the buffer. When the buffer 14 | overflows the old memories are dropped. 15 | """ 16 | self._storage = [] 17 | self._maxsize = size 18 | self._next_idx = 0 19 | 20 | def __len__(self): 21 | return len(self._storage) 22 | 23 | def add(self, obs_t, action, reward, obs_tp1, done): 24 | data = (obs_t, action, reward, obs_tp1, done) 25 | 26 | if self._next_idx >= len(self._storage): 27 | self._storage.append(data) 28 | else: 29 | self._storage[self._next_idx] = data 30 | self._next_idx = (self._next_idx + 1) % self._maxsize 31 | 32 | # modded: not converting anything to a numpy array before appending to output-arrays 33 | def _encode_sample(self, idxes): 34 | obses_t, actions, rewards, obses_tp1, dones = [], [], [], [], [] 35 | for i in idxes: 36 | data = self._storage[i] 37 | obs_t, action, reward, obs_tp1, done = data 38 | obses_t.append(obs_t) 39 | actions.append(action) 40 | rewards.append(reward) 41 | obses_tp1.append(obs_tp1) 42 | dones.append(done) 43 | return np.array(obses_t), np.array(actions), np.array(rewards), np.array(obses_tp1), np.array(dones) 44 | 45 | def sample(self, batch_size): 46 | """Sample a batch of experiences. 47 | 48 | Parameters 49 | ---------- 50 | batch_size: int 51 | How many transitions to sample. 52 | 53 | Returns 54 | ------- 55 | obs_batch: np.array 56 | batch of observations 57 | act_batch: np.array 58 | batch of actions executed given obs_batch 59 | rew_batch: np.array 60 | rewards received as results of executing act_batch 61 | next_obs_batch: np.array 62 | next set of observations seen after executing act_batch 63 | done_mask: np.array 64 | done_mask[i] = 1 if executing act_batch[i] resulted in 65 | the end of an episode and 0 otherwise. 66 | """ 67 | idxes = [random.randint(0, len(self._storage) - 1) for _ in range(batch_size)] 68 | return self._encode_sample(idxes) 69 | 70 | 71 | class PrioritizedReplayBuffer(ReplayBuffer): 72 | def __init__(self, size, alpha): 73 | """Create Prioritized Replay buffer. 74 | 75 | Parameters 76 | ---------- 77 | size: int 78 | Max number of transitions to store in the buffer. When the buffer 79 | overflows the old memories are dropped. 80 | alpha: float 81 | how much prioritization is used 82 | (0 - no prioritization, 1 - full prioritization) 83 | 84 | See Also 85 | -------- 86 | ReplayBuffer.__init__ 87 | """ 88 | super(PrioritizedReplayBuffer, self).__init__(size) 89 | assert alpha >= 0 90 | self._alpha = alpha 91 | 92 | it_capacity = 1 93 | while it_capacity < size: 94 | it_capacity *= 2 95 | 96 | self._it_sum = SumSegmentTree(it_capacity) 97 | self._it_min = MinSegmentTree(it_capacity) 98 | self._max_priority = 1.0 99 | 100 | def add(self, *args, **kwargs): 101 | """See ReplayBuffer.store_effect""" 102 | idx = self._next_idx 103 | super().add(*args, **kwargs) 104 | self._it_sum[idx] = self._max_priority ** self._alpha 105 | self._it_min[idx] = self._max_priority ** self._alpha 106 | 107 | def _sample_proportional(self, batch_size): 108 | res = [] 109 | p_total = self._it_sum.sum(0, len(self._storage) - 1) 110 | every_range_len = p_total / batch_size 111 | for i in range(batch_size): 112 | mass = random.random() * every_range_len + i * every_range_len 113 | idx = self._it_sum.find_prefixsum_idx(mass) 114 | res.append(idx) 115 | return res 116 | 117 | def sample(self, batch_size, beta): 118 | """Sample a batch of experiences. 119 | 120 | compared to ReplayBuffer.sample 121 | it also returns importance weights and idxes 122 | of sampled experiences. 123 | 124 | 125 | Parameters 126 | ---------- 127 | batch_size: int 128 | How many transitions to sample. 129 | beta: float 130 | To what degree to use importance weights 131 | (0 - no corrections, 1 - full correction) 132 | 133 | Returns 134 | ------- 135 | obs_batch: np.array 136 | batch of observations 137 | act_batch: np.array 138 | batch of actions executed given obs_batch 139 | rew_batch: np.array 140 | rewards received as results of executing act_batch 141 | next_obs_batch: np.array 142 | next set of observations seen after executing act_batch 143 | done_mask: np.array 144 | done_mask[i] = 1 if executing act_batch[i] resulted in 145 | the end of an episode and 0 otherwise. 146 | weights: np.array 147 | Array of shape (batch_size,) and dtype np.float32 148 | denoting importance weight of each sampled transition 149 | idxes: np.array 150 | Array of shape (batch_size,) and dtype np.int32 151 | idexes in buffer of sampled experiences 152 | """ 153 | assert beta > 0 154 | 155 | idxes = self._sample_proportional(batch_size) 156 | 157 | weights = [] 158 | p_min = self._it_min.min() / self._it_sum.sum() 159 | max_weight = (p_min * len(self._storage)) ** (-beta) 160 | 161 | for idx in idxes: 162 | p_sample = self._it_sum[idx] / self._it_sum.sum() 163 | weight = (p_sample * len(self._storage)) ** (-beta) 164 | weights.append(weight / max_weight) 165 | weights = np.array(weights) 166 | encoded_sample = self._encode_sample(idxes) 167 | return tuple(list(encoded_sample) + [weights, idxes]) 168 | 169 | def update_priorities(self, idxes, priorities): 170 | """Update priorities of sampled transitions. 171 | 172 | sets priority of transition at index idxes[i] in buffer 173 | to priorities[i]. 174 | 175 | Parameters 176 | ---------- 177 | idxes: [int] 178 | List of idxes of sampled transitions 179 | priorities: [float] 180 | List of updated priorities corresponding to 181 | transitions at the sampled idxes denoted by 182 | variable `idxes`. 183 | """ 184 | assert len(idxes) == len(priorities) 185 | for idx, priority in zip(idxes, priorities): 186 | if priority == 0: 187 | priority = .00001 188 | assert priority > 0 189 | assert 0 <= idx < len(self._storage) 190 | self._it_sum[idx] = priority ** self._alpha 191 | self._it_min[idx] = priority ** self._alpha 192 | 193 | self._max_priority = max(self._max_priority, priority) 194 | -------------------------------------------------------------------------------- /exec_2agents.py: -------------------------------------------------------------------------------- 1 | # Simlified DQN for pysc2 mini-game, by Xun, Oct 2020. 2 | # 1 vs 1. Both are controlled by agents. 3 | # 4 | 5 | import importlib 6 | 7 | import numpy 8 | import traceback 9 | import os 10 | import json 11 | import random 12 | from absl import app 13 | from absl import flags 14 | import pdb 15 | 16 | # own classes 17 | from env import SC2_Env_xun2 #Sc2Env2Outputs #Sc2Env1Output, SC2_Env_xun2 18 | from sc2Processor import Sc2Processor 19 | from sc2Policy import Sc2Policy #, Sc2PolicyD 20 | from sc2DqnAgent import Sc2DqnAgent_v5 21 | 22 | from prioReplayBuffer import PrioritizedReplayBuffer, ReplayBuffer #remove by xun 23 | 24 | # framework classes 25 | from pysc2.env import sc2_env 26 | # By xun, to use Plasdml and AMD GPU. 27 | #import plaidml 28 | # plaidml.keras.install_backend() 29 | #import os 30 | #os.environ["KERAS_BACKEND"] = "plaidml.keras.backend" 31 | 32 | import keras.backend as K 33 | from keras.models import Sequential, Model 34 | from keras.layers import Dense, Activation, Flatten, Convolution2D, Permute, Input, Conv2D 35 | from keras.optimizers import Adam 36 | from keras.layers import MaxPooling2D, Conv2DTranspose 37 | from keras.layers.merge import concatenate, add 38 | 39 | from rl.agents.dqn import DQNAgent 40 | from rl.policy import LinearAnnealedPolicy, EpsGreedyQPolicy 41 | from rl.memory import SequentialMemory 42 | from rl.callbacks import FileLogger, ModelIntervalCheckpoint 43 | 44 | # for the debug of all process costs. 45 | from pysc2.lib import stopwatch 46 | 47 | 48 | # To display the relations between functions, xun 2021. 49 | #from pycallgraph import PyCallGraph 50 | #from pycallgraph.output import GraphvizOutput 51 | 52 | # MiniGame 1: MoveToBeacon 53 | # MiniGame 2: CollectMineralShards 54 | _ENV_NAME = "FindAndDefeatDronesAA" #"PursuitEvasion1" #"FindAndDefeatZerglings" 55 | _SCREEN = 32 56 | _MINIMAP = 16 57 | 58 | _VISUALIZE = False# Simlified this code by Xun, Oct 2020. 59 | _TEST = False 60 | 61 | _profile=True 62 | 63 | 64 | FLAGS = flags.FLAGS 65 | flags.DEFINE_string("agent", "scripted_agent1.FindAndDefeatZergling_4", 66 | "Which agent to run, as a python path to an Agent class.") 67 | flags.DEFINE_string("agent2", "pysc2.agents.random_agent.RandomAgent", 68 | "Which agent to run, as a python path to an Agent class.") 69 | 70 | 71 | def __main__(unused_argv): 72 | agent_name = "Xun_test" 73 | run_number = 1 74 | 75 | 76 | # graphviz = GraphvizOutput() 77 | # graphviz.output_file = 'basic.png' 78 | 79 | 80 | results_dir = "weights/{}/{}/{}".format(_ENV_NAME, agent_name, run_number) 81 | 82 | agent_classes = [] 83 | 84 | agent_module, agent_name = FLAGS.agent.rsplit(".", 1) 85 | agent_cls = getattr(importlib.import_module(agent_module), agent_name) 86 | agent_classes.append(agent_cls) 87 | agent_module, agent_name = FLAGS.agent2.rsplit(".", 1) 88 | agent_cls = getattr(importlib.import_module(agent_module), agent_name) 89 | agent_classes.append(agent_cls) 90 | agents = [agent_cls() for agent_cls in agent_classes] 91 | 92 | #pdb.set_trace() 93 | 94 | 95 | # with PyCallGraph(output=graphviz): 96 | fully_conf_v_10(results_dir, agents) 97 | 98 | 99 | 100 | 101 | # Prepare the network 102 | def fully_conf_v_10(a_dir, agents): 103 | try: 104 | seed = random.randint(1, 324234) 105 | 106 | env = SC2_Env_xun2(screen=_SCREEN, visualize=_VISUALIZE, env_name=_ENV_NAME, 107 | training=not _TEST, agents = agents) 108 | 109 | env.seed(seed) 110 | numpy.random.seed(seed) 111 | 112 | nb_actions = 3 113 | 114 | prio_replay = True #False #True # modified by xun to avoid using new lib 115 | multi_step_size = 3 116 | 117 | # HyperParameter 118 | action_repetition = 1 119 | gamma = .99 120 | memory_size = 200000 121 | learning_rate = .0001 122 | warm_up_steps = 4000 123 | train_interval = 4 124 | 125 | bad_prio_replay = False #True # modified by xun to avoid using new lib 126 | prio_replay_alpha = 0.6 127 | prio_replay_beta = (0.5, 1.0, 200000) # (beta_start, beta_end, 128 | 129 | eps_start = 1 # modified by xun from 1 to the current value 130 | eps_end = 0 131 | eps_steps = 4000 132 | 133 | #Prepare the directory 134 | directory = a_dir 135 | if not os.path.exists(directory): 136 | os.makedirs(directory) 137 | 138 | weights_filename = directory + '/dqn_weights.h5f' 139 | checkpoint_weights_filename = directory + '/dqn_weights_{step}.h5f' 140 | log_filename = directory + '/dqn_log.json' 141 | log_filename_gpu = directory + '/dqn_log_gpu.json' 142 | log_interval = 8000 143 | 144 | #Prepare the network 145 | kernel_size=7 146 | n_filters=16 147 | main_input = Input(shape=(3, env.screen, env.screen), name='main_input') 148 | permuted_input = Permute((2, 3, 1))(main_input) 149 | 150 | # Normal deep network 151 | x = Conv2D(16, (5, 5), padding='same', activation='relu')(permuted_input) 152 | branch = Conv2D(32, (3, 3), padding='same', activation='relu')(x) 153 | coord_out = Conv2D(1, (1, 1), padding='same', activation='linear')(branch) 154 | act_out = Flatten()(branch) 155 | #act_out = Dense(256, activation='relu')(act_out) 156 | act_out = Dense(256, activation='relu')(act_out) 157 | act_out = Dense(nb_actions, activation='linear')(act_out) 158 | 159 | 160 | full_conv_sc2 = Model(main_input, [act_out, coord_out]) 161 | 162 | memory = PrioritizedReplayBuffer(memory_size, prio_replay_alpha) 163 | policy = LinearAnnealedPolicy(Sc2Policy(env=env,nb_actions=nb_actions), attr='eps', value_max=eps_start, value_min=eps_end, 164 | value_test=eps_end, nb_steps=eps_steps) 165 | 166 | test_policy = Sc2Policy(env=env, eps=eps_end) 167 | processor = Sc2Processor(screen=env._SCREEN) 168 | #pdb.set_trace() 169 | dqn = Sc2DqnAgent_v5(model=full_conv_sc2, nb_actions=nb_actions, screen_size=env._SCREEN, 170 | memory=memory, processor=processor,gamma=gamma, 171 | nb_steps_warmup=warm_up_steps,multi_step_size=multi_step_size, 172 | policy=policy, test_policy=test_policy, target_model_update=10000, 173 | train_interval=train_interval, delta_clip=1.) 174 | dqn.compile(Adam(lr=learning_rate), metrics=['mae']) 175 | 176 | # if _profile: 177 | # stopwatch.sw.enable() 178 | dqn.fit(env, nb_steps=300000, nb_max_start_steps=0, 179 | action_repetition=action_repetition) 180 | 181 | dqn.save_weights(weights_filename, overwrite=True) 182 | 183 | except KeyboardInterrupt: 184 | exit(0) 185 | pass 186 | 187 | except Exception as e: 188 | print(e) 189 | traceback.print_exc() 190 | pass 191 | 192 | 193 | 194 | def conv2d_block(input_tensor, n_filters, kernel_size=7): 195 | x = Conv2D(filters=n_filters, kernel_size=(kernel_size, kernel_size), kernel_initializer="he_normal", 196 | padding="same")(input_tensor) 197 | x = Activation("relu")(x) 198 | 199 | x = Conv2D(filters=n_filters*2, kernel_size=(kernel_size-2, kernel_size-2), kernel_initializer="he_normal", 200 | padding="same")(x) 201 | x = Activation("relu")(x) 202 | return x 203 | 204 | 205 | 206 | if __name__ == '__main__': 207 | app.run(__main__) 208 | 209 | 210 | 211 | -------------------------------------------------------------------------------- /env.py: -------------------------------------------------------------------------------- 1 | #---------------------------- 2 | # The environment for 1 vs 1 with 2 agents, one is scripted and the other is from DQN. 3 | # 4 | # This code is only for research purpose. 5 | # This code is modified from keras-rl and dqn-pysc2. The class action_to_sc2 requires further modifications 6 | # to improve the DQN performance. 7 | # 8 | # Xun Huang, Jul 29, 2021 9 | #---------------------------- 10 | 11 | from rl.core import Env 12 | #from pysc2.env import sc2_env 13 | import sc2_env_xun 14 | from pysc2.lib import features 15 | from pysc2.lib import actions 16 | import numpy as np 17 | import pdb 18 | 19 | from pysc2.lib import protocol # Xun, for protocol.ConnectionError 20 | 21 | FUNCTIONS = actions.FUNCTIONS 22 | 23 | 24 | class SC2_Env_xun2(Env): 25 | last_obs = None # observation, used for deep learning 26 | last_obs1= None # observation, used for deep learning 27 | env = None # used by close subroutine 28 | agents=[] 29 | _SCREEN = None 30 | _MINIMAP = None 31 | _ENV_NAME = None 32 | _TRAINING = None 33 | 34 | def __init__(self, screen=16, visualize=False, env_name="MoveToBeacon", training=False, agents=[]): 35 | print("init SC2") 36 | 37 | self._SCREEN = screen 38 | self._MINIMAP = screen 39 | self._VISUALIZE = visualize 40 | self._ENV_NAME = env_name 41 | self._TRAINING = training 42 | self.env = sc2_env_xun.SC2Env_xun( 43 | map_name=self._ENV_NAME, 44 | players=[sc2_env_xun.Agent(sc2_env_xun.Race.protoss),sc2_env_xun.Agent(sc2_env_xun.Race.zerg)], 45 | agent_interface_format=features.AgentInterfaceFormat( 46 | feature_dimensions=features.Dimensions( 47 | screen=self._SCREEN, 48 | minimap=self._MINIMAP 49 | ), 50 | use_feature_units=True 51 | ), 52 | step_mul=8, 53 | game_steps_per_episode=2880, 54 | visualize=self._VISUALIZE 55 | ) 56 | 57 | #pdb.set_trace() 58 | self.agents=agents 59 | 60 | 61 | # Actions assigned here. Xun, 2021 62 | def action_to_sc2(self, act): 63 | #pdb.set_trace() 64 | real_action = FUNCTIONS.no_op() 65 | real_action1= FUNCTIONS.no_op() 66 | 67 | if act.action == 1: #Move camera & move or attack 68 | # Must move camera first, then attack. Otherwise, the attacking coordinates would be incorrect. 69 | if FUNCTIONS.move_camera.id in self.last_obs.observation.available_actions: # By Xun 70 | real_action1 = FUNCTIONS.move_camera((act.coords[1], act.coords[0])) 71 | 72 | # if FUNCTIONS.Attack_screen.id in self.last_obs.observation.available_actions: # By Xun 73 | # real_action = FUNCTIONS.Attack_screen("now", (act.coords[1], act.coords[0])) 74 | if FUNCTIONS.Move_screen.id in self.last_obs.observation.available_actions: # By Xun 75 | real_action = FUNCTIONS.Move_screen("now", (act.coords[1], act.coords[0])) 76 | # else: 77 | # print('cannot attack') 78 | real_action=[real_action1,real_action] 79 | 80 | 81 | elif act.action == 2: #Select army or worker 82 | # if FUNCTIONS.select_army.id in self.last_obs.observation.available_actions: 83 | # real_action = FUNCTIONS.select_army("select") 84 | #pdb.set_trace() 85 | if FUNCTIONS.select_idle_worker.id in self.last_obs.observation.available_actions: 86 | real_action = FUNCTIONS.select_idle_worker("select_all") 87 | 88 | player_y, player_x = (self.last_obs.observation.feature_minimap.player_relative == features.PlayerRelative.SELF).nonzero() 89 | 90 | xmean=ymean=0 91 | if(len(player_y)==0): 92 | pass 93 | # pdb.set_trace() 94 | else: 95 | xmean=player_x.mean() 96 | ymean=player_y.mean() 97 | 98 | if FUNCTIONS.move_camera.id in self.last_obs.observation.available_actions: 99 | real_action1 = FUNCTIONS.move_camera((xmean, ymean)) 100 | 101 | real_action=[real_action,real_action1] 102 | 103 | 104 | elif act.action == 0: # 105 | pass # do nothing to continue the former action 106 | #real_action = FUNCTIONS.select_point("toggle", (act.coords[1], act.coords[0])) 107 | else: 108 | # pass 109 | assert False 110 | return real_action 111 | 112 | 113 | 114 | # User should edit this subroutine to obtain the requred observation and rewards. xun, Aug 2021 115 | def step(self, action): 116 | 117 | observation_spec = self.env.observation_spec() 118 | action_spec = self.env.action_spec() 119 | for agent, obs_spec, act_spec in zip(self.agents, observation_spec, action_spec): 120 | agent.setup(obs_spec, act_spec) 121 | 122 | real_action = self.action_to_sc2(action) # Action from the DQN 123 | obs=self.last_obs1 124 | 125 | #actions = [agent.step(obs0) for agent, obs0 in zip(self.agents,obs)] # enable this line will restore two two agents: 1 scripted vs 1 random. 126 | agent=self.agents[0] 127 | obs0=obs[0] 128 | action0 = agent.step(obs0) # Action from the agent[0] for the voidray 129 | actions = [action0, real_action] 130 | 131 | #pdb.set_trace() 132 | 133 | # fix from websocket timeout issue... Xun, Aug 2021 134 | try: 135 | observation = self.env.step(actions) 136 | except protocol.ConnectionError: 137 | #pdb.set_trace() 138 | self.close() 139 | #self.start() 140 | observation = self.start() 141 | 142 | 143 | # Observation[0:1], 0 for agent 0, 1 for agent 1, the latter is for evasion part here. xun 144 | self.last_obs = observation[1] 145 | self.last_obs1 = observation 146 | # small_observation = observation[0].observation.feature_screen.unit_density 147 | act_obs=np.zeros([32,32]) 148 | #pdb.set_trace() 149 | act_obs[(action.coords[1],action.coords[0])]=action.action 150 | small_observation = [observation[0].observation.feature_screen.player_relative, 151 | observation[0].observation.feature_screen.selected, 152 | act_obs] 153 | #observation[0].observation.feature_screen.visibility_map] 154 | #pdb.set_trace() 155 | # Modified by Xun 156 | #return small_observation, observation[0].reward, observation[0].last(), {} 157 | reward=observation[1][3].score_by_category[0][2] # the unit number of the evasion part, xun 158 | return small_observation, reward, observation[0].last(), {} 159 | 160 | # fix from websocket timeout issue... Xun, Aug 2021 161 | def start(self): 162 | self.env = sc2_env_xun.SC2Env_xun( 163 | map_name=self._ENV_NAME, 164 | players=[sc2_env_xun.Agent(sc2_env_xun.Race.protoss),sc2_env_xun.Agent(sc2_env_xun.Race.zerg)], 165 | agent_interface_format=features.AgentInterfaceFormat( 166 | feature_dimensions=features.Dimensions( 167 | screen=self._SCREEN, 168 | minimap=self._MINIMAP 169 | ), 170 | use_feature_units=True 171 | ), 172 | step_mul=8, 173 | game_steps_per_episode=2880, 174 | visualize=self._VISUALIZE 175 | ) 176 | return self.reset() 177 | 178 | 179 | def reset(self): 180 | #pdb.set_trace() 181 | observation = self.env.reset() 182 | 183 | if self._TRAINING and np.random.random_integers(0, 1) == 4: 184 | ys, xs = np.where(observation[0].observation.feature_screen.player_relative == 1) 185 | observation = self.env.step(actions=(FUNCTIONS.select_point("toggle", (xs[0], ys[0])),)) 186 | 187 | # observation = self.env.step(actions=(FUNCTIONS.select_army(0),)) 188 | # Select all for both sides, xun 189 | #pdb.set_trace() 190 | observation = self.env.step(actions=(FUNCTIONS.select_army(0),)) 191 | #observation = self.env.step(actions=(FUNCTIONS.select_army(0),FUNCTIONS.select_army(0))) # modified by xun 192 | 193 | self.last_obs = observation[1] 194 | self.last_obs1= observation 195 | # small_observation = [observation[0].observation.feature_screen.player_relative, 196 | # observation[0].observation.feature_screen.selected] 197 | small_observation = [observation[0].observation.feature_screen.player_relative, 198 | observation[0].observation.feature_screen.selected, 199 | observation[0].observation.feature_screen.visibility_map] 200 | 201 | 202 | return small_observation 203 | 204 | def render(self, mode: str = 'human', close: bool = False): 205 | pass 206 | 207 | 208 | def close(self): 209 | if self.env: 210 | self.env.close() 211 | 212 | def seed(self, seed=None): 213 | if seed: 214 | self.env._random_seed = seed 215 | 216 | def set_env_name(self, name: str): 217 | self._ENV_NAME = name 218 | 219 | def set_screen(self, screen: int): 220 | self._SCREEN = screen 221 | 222 | def set_visualize(self, visualize: bool): 223 | self._VISUALIZE = visualize 224 | 225 | def set_minimap(self, minimap: int): 226 | self._MINIMAP = minimap 227 | 228 | @property 229 | def screen(self): 230 | return self._SCREEN 231 | 232 | 233 | """ 234 | def configure(self, *args, **kwargs): 235 | 236 | switcher = { 237 | '_ENV_NAME': self.set_env_name, 238 | '_SCREEN': self.set_screen, 239 | '_MINIMAP': self.set_minimap, 240 | '_VISUALIZE': self.set_visualize, 241 | } 242 | 243 | if kwargs is not None: 244 | for key, value in kwargs: 245 | func = switcher.get(key, lambda: print) 246 | func(value) 247 | """ 248 | 249 | 250 | 251 | -------------------------------------------------------------------------------- /agent.py: -------------------------------------------------------------------------------- 1 | #---------------------------- 2 | # The DQN agent, simplified and modified for 1 vs 1 case. 3 | # 4 | # This code is modified from keras-rl and dqn-pysc2. 5 | # Most, if not all, modifications were explicitly pointed out in the code. 6 | # 7 | # Xun Huang, Jul 29, 2021 8 | #---------------------------- 9 | 10 | import warnings 11 | from copy import deepcopy 12 | import numpy as np 13 | import pdb 14 | 15 | from keras.callbacks import History 16 | from rl.callbacks import ( 17 | CallbackList, 18 | TestLogger, 19 | TrainEpisodeLogger, 20 | TrainIntervalLogger, 21 | Visualizer 22 | ) 23 | 24 | 25 | class Agent3(object): 26 | """Modified Version Keras-rl core/Agent 27 | """ 28 | def __init__(self, processor=None): 29 | self.processor = processor 30 | self.training = False 31 | self.step = 0 32 | 33 | def get_config(self): 34 | """Configuration of the agent for serialization. 35 | """ 36 | return {} 37 | 38 | def fit(self, env, nb_steps, action_repetition=1, callbacks=None, verbose=1, 39 | visualize=False, nb_max_start_steps=0, start_step_policy=None, log_interval=10000, 40 | nb_max_episode_steps=None): 41 | """Trains the agent on the given environment. 42 | """ 43 | self.training = True 44 | 45 | callbacks = [] if not callbacks else callbacks[:] 46 | 47 | history = History() 48 | callbacks += [history] 49 | callbacks = CallbackList(callbacks) 50 | if hasattr(callbacks, 'set_model'): 51 | callbacks.set_model(self) 52 | else: 53 | callbacks._set_model(self) 54 | callbacks._set_env(env) 55 | 56 | params = { 57 | 'nb_steps': nb_steps, 58 | } 59 | 60 | 61 | if hasattr(callbacks, 'set_params'): 62 | callbacks.set_params(params) 63 | else: 64 | callbacks._set_params(params) 65 | 66 | callbacks.on_train_begin() 67 | 68 | episode = np.int16(0) 69 | self.step = np.int16(0) 70 | observation = None 71 | episode_reward = None 72 | episode_step = None 73 | did_abort = False 74 | try: 75 | while self.step < nb_steps: 76 | #pdb.set_trace() 77 | if observation is None: # start of a new episode 78 | # print('if observation is None ...') # debuging code by xunger 79 | callbacks.on_episode_begin(episode) 80 | episode_step = np.int16(0) 81 | episode_reward = np.float32(0) 82 | 83 | # Obtain the initial observation by resetting the environment. 84 | #self.reset_states() 85 | #pdb.set_trace() 86 | try: 87 | observation = deepcopy(env.reset()) 88 | except protocol.ConnectionError: 89 | # pdb.set_trace() 90 | env.close() 91 | observation = deepcopy(env.start()) 92 | 93 | 94 | if self.processor is not None: 95 | observation = self.processor.process_observation(observation) 96 | assert observation is not None 97 | 98 | 99 | # At this point, we expect to be fully initialized. 100 | assert episode_reward is not None 101 | assert episode_step is not None 102 | assert observation is not None 103 | 104 | #pdb.set_trace() 105 | # Run a single step. 106 | #callbacks.on_step_begin(episode_step) 107 | # ******************************************************************************** 108 | # !!!!! 109 | # !!!!! 110 | # !!!!! 111 | # This is where all of the work happens. We first perceive and compute the action 112 | # (first step) and then use the reward to improve (backward step). 113 | # !!!!! 114 | # !!!!! 115 | # !!!!! 116 | # ******************************************************************************** 117 | # if self.step%5000==0: 118 | # pdb.set_trace() 119 | action = self.forward(observation) 120 | 121 | 122 | reward = np.float32(0) 123 | accumulated_info = {} 124 | done = False 125 | for _ in range(action_repetition): 126 | # print('agent step 2') 127 | # callbacks.on_action_begin(action) 128 | observation, r, done, info = env.step(action) 129 | observation = deepcopy(observation) 130 | 131 | #pdb.set_trace() #rewards_history.append(reward) 132 | reward += r 133 | if done: 134 | break 135 | 136 | if nb_max_episode_steps and episode_step >= nb_max_episode_steps - 1: 137 | # Force a terminal state. 138 | done = True 139 | 140 | metrics = self.backward(reward, terminal=done, observation_1=observation) 141 | episode_reward += reward 142 | 143 | episode_step += 1 144 | self.step += 1 145 | 146 | if done: 147 | # We are in a terminal state but the agent hasn't yet seen it. We therefore 148 | # perform one more q-backward call and simply ignore the action before 149 | # resetting the environment. We need to pass in `terminal=False` here since 150 | # the *next* state, that is the state of the newly reset environment, is 151 | # always non-terminal by convention. 152 | # BUT: I disagree, and don't wanna damage my backward call, cause memory is different 153 | # anyways now.... 154 | 155 | # Note: this part is different from the keras-rl.core. Xun 156 | 157 | 158 | self.forward(observation) 159 | 160 | self.backward(0., terminal=True, observation_1=observation) 161 | #pdb.set_trace() 162 | 163 | # This episode is finished, report and reset. 164 | episode_logs = { 165 | 'episode_reward': episode_reward, 166 | 'nb_episode_steps': episode_step, 167 | 'nb_steps': self.step, 168 | } 169 | callbacks.on_episode_end(episode, episode_logs) 170 | 171 | episode += 1 172 | observation = None 173 | 174 | for _ in range(self.recent.maxlen): 175 | self.recent.append(None) 176 | episode_step = None 177 | episode_reward = None 178 | except KeyboardInterrupt: 179 | # We catch keyboard interrupts here so that training can be be safely abortedself. 180 | did_abort = True 181 | callbacks.on_train_end(logs={'did_abort': did_abort}) 182 | self._on_train_end() 183 | 184 | return history 185 | 186 | 187 | 188 | 189 | def reset_states(self): 190 | """Resets all internally kept states after an episode is completed. 191 | """ 192 | pass 193 | 194 | def forward(self, observation): 195 | """Takes the an observation from the environment and returns the action to be taken next. 196 | If the policy is implemented by a neural network, this corresponds to a forward (inference) pass. 197 | 198 | # Argument 199 | observation (object): The current observation from the environment. 200 | 201 | # Returns 202 | The next action to be executed in the environment. 203 | """ 204 | raise NotImplementedError() 205 | 206 | # observation_1 was included here, different from the keras-rl.core code!!! xun 207 | def backward(self, reward, terminal, observation_1): 208 | """Updates the agent after having executed the action returned by `forward`. 209 | If the policy is implemented by a neural network, this corresponds to a weight update using back-prop. 210 | 211 | # Argument 212 | reward (float): The observed reward after executing the action returned by `forward`. 213 | terminal (boolean): `True` if the new state of the environment is terminal. 214 | 215 | # Returns 216 | List of metrics values 217 | """ 218 | raise NotImplementedError() 219 | 220 | def compile(self, optimizer, metrics=[]): 221 | """Compiles an agent and the underlaying models to be used for training and testing. 222 | 223 | # Arguments 224 | optimizer (`keras.optimizers.Optimizer` instance): The optimizer to be used during training. 225 | metrics (list of functions `lambda y_true, y_pred: metric`): The metrics to run during training. 226 | """ 227 | raise NotImplementedError() 228 | 229 | def load_weights(self, filepath): 230 | """Loads the weights of an agent from an HDF5 file. 231 | 232 | # Arguments 233 | filepath (str): The path to the HDF5 file. 234 | """ 235 | raise NotImplementedError() 236 | 237 | def save_weights(self, filepath, overwrite=False): 238 | """Saves the weights of an agent as an HDF5 file. 239 | 240 | # Arguments 241 | filepath (str): The path to where the weights should be saved. 242 | overwrite (boolean): If `False` and `filepath` already exists, raises an error. 243 | """ 244 | raise NotImplementedError() 245 | 246 | @property 247 | def layers(self): 248 | """Returns all layers of the underlying model(s). 249 | 250 | If the concrete implementation uses multiple internal models, 251 | this method returns them in a concatenated list. 252 | 253 | # Returns 254 | A list of the model's layers 255 | """ 256 | raise NotImplementedError() 257 | 258 | @property 259 | def metrics_names(self): 260 | """The human-readable names of the agent's metrics. Must return as many names as there 261 | are metrics (see also `compile`). 262 | 263 | # Returns 264 | A list of metric's names (string) 265 | """ 266 | return [] 267 | 268 | 269 | def _on_train_end(self): 270 | """Callback that is called after training ends." 271 | """ 272 | pass 273 | 274 | def _on_test_begin(self): 275 | """Callback that is called before testing begins." 276 | """ 277 | pass 278 | 279 | def _on_test_end(self): 280 | """Callback that is called after testing ends." 281 | """ 282 | pass 283 | 284 | 285 | 286 | -------------------------------------------------------------------------------- /scripted_agent1.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS-IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | # This is a working agent code that can achieve max = 40. xun, June 29, 2021. 17 | """Scripted agents.""" 18 | 19 | from __future__ import absolute_import 20 | from __future__ import division 21 | from __future__ import print_function 22 | 23 | import numpy 24 | 25 | from pysc2.agents import base_agent 26 | from pysc2.lib import actions 27 | from pysc2.lib import features 28 | 29 | _PLAYER_SELF = features.PlayerRelative.SELF 30 | _PLAYER_NEUTRAL = features.PlayerRelative.NEUTRAL # beacon/minerals 31 | _PLAYER_ENEMY = features.PlayerRelative.ENEMY 32 | 33 | FUNCTIONS = actions.FUNCTIONS 34 | RAW_FUNCTIONS = actions.RAW_FUNCTIONS 35 | 36 | import pdb 37 | 38 | 39 | # V1, actually adopt a random agent, for testing purpose only. Xun, Jun 18, 2021 40 | class FindAndDefeatZergling_1(base_agent.BaseAgent): 41 | """A random agent for starcraft.""" 42 | 43 | def step(self, obs): 44 | super(FindAndDefeatZergling_1, self).step(obs) 45 | #pdb.set_trace() 46 | function_id = numpy.random.choice(obs.observation.available_actions) 47 | args = [[numpy.random.randint(0, size) for size in arg.sizes] 48 | for arg in self.action_spec.functions[function_id].args] 49 | return actions.FunctionCall(function_id, args) 50 | 51 | 52 | 53 | 54 | 55 | # V2, scripted attacking. Xun, Jun 18, 2021 56 | # Only goes to (0,0) when no enemy. 57 | class FindAndDefeatZergling_2(base_agent.BaseAgent): 58 | """A random agent for starcraft.""" 59 | 60 | def step(self, obs): 61 | super(FindAndDefeatZergling_2, self).step(obs) 62 | #pdb.set_trace() 63 | coords=(0, 0) 64 | 65 | 66 | if FUNCTIONS.Attack_screen.id in obs.observation.available_actions: 67 | pdb.set_trace() 68 | 69 | player_relative = obs.observation.feature_screen.player_relative 70 | zergling = _xy_locs(player_relative == _PLAYER_ENEMY) 71 | # no visible enemy, explore by attacking (0, 0) 72 | if not zergling: 73 | return FUNCTIONS.Attack_screen("now", (0,0)) #FUNCTIONS.no_op() 74 | # Visible enemy, find the zergling with max y coord. 75 | target = zergling[numpy.argmax(numpy.array(zergling)[:, 1])] 76 | #coords= (target[1], target[0]) # (x, y) 77 | #target=coords 78 | return FUNCTIONS.Attack_screen("now", target) 79 | 80 | else: 81 | if FUNCTIONS.select_army.id in obs.observation.available_actions: 82 | return FUNCTIONS.select_army("select") 83 | else: 84 | pass 85 | 86 | return FUNCTIONS.no_op() 87 | 88 | 89 | # V3, scripted attacking with improved performance. Xun, Jun 18, 2021 90 | # Here the marines do a large Z route, for defeating zergling 91 | class FindAndDefeatZergling_3(base_agent.BaseAgent): 92 | """An agent is developed from CollectMineralShardsFeatureUnits class. 93 | """ 94 | id=0 95 | def setup(self, obs_spec, action_spec): 96 | super(FindAndDefeatZergling_3, self).setup(obs_spec, action_spec) 97 | #pdb.set_trace() 98 | if "feature_units" not in obs_spec: 99 | raise Exception("This agent requires the feature_units observation.") 100 | 101 | #self.coords=[(26,25),(32,0),(0,32),(32,32)] 102 | #self.coords=[(5,10),(26,10),(5,25),(26,25)] #self.coords=[(4.9,9.9),(15,9.9),(5,17),(15,17),(5,25),(15,25),(15,10),(26,10),(15,17),(26,17),(15,25),(26,25)] 103 | # Obtain the following coordinates by testing, xun 104 | self.coords=[(4.9,10.5),(26,14.2),(4.9,18.5),(26.1,19.5),(4.9,22.5),(25,25), (26.1,9.9)] 105 | # self.coords=[(4.8,10.5),(26,14.2),(4.9,18.5),(26.1,19.5),(4.8,22.5),(25,25), (26.1,9.9)] 106 | #self.coords=[(5,10),(26,10),(26,19),(5,19),(5,25),(26,25), (5,10)] 107 | #self.coords=[(14,16),(50,16),(50,28),(14,28),(14,40),(50,40), (14,16)] 108 | 109 | 110 | self.xmean_old = 0 111 | self.ymean_old = 0 112 | 113 | def reset(self): 114 | super(FindAndDefeatZergling_3, self).reset() 115 | self._marine_selected = False 116 | self._previous_mineral_xy = [-1, -1] 117 | self.id = 0 # trace the settled points 118 | self.xmean_old = 0 119 | self.ymean_old = 0 120 | 121 | def step(self, obs): 122 | super(FindAndDefeatZergling_3, self).step(obs) 123 | 124 | #pdb.set_trace() 125 | #obs.observation['raw_units'] 126 | real_action = FUNCTIONS.no_op() 127 | real_action1 = FUNCTIONS.no_op() 128 | 129 | player_y, player_x = (obs.observation.feature_minimap.player_relative == features.PlayerRelative.SELF).nonzero() 130 | xmean=player_x.mean() 131 | ymean=player_y.mean() 132 | #real_action1 = FUNCTIONS.move_camera((xmean, ymean)) 133 | 134 | #pdb.set_trace() 135 | real_action1 = FUNCTIONS.move_camera(self.coords[self.id]) 136 | if FUNCTIONS.Attack_screen.id in obs.observation.available_actions: 137 | real_action = FUNCTIONS.Attack_screen("now", self.coords[self.id]) 138 | else: 139 | if FUNCTIONS.Move_screen.id in obs.observation.available_actions: 140 | real_action = FUNCTIONS.Move_screen("now", self.coords[self.id]) 141 | 142 | 143 | # Find the distance to the taeget position 144 | distances = numpy.linalg.norm(numpy.array([xmean,ymean]) - numpy.array(self.coords[self.id])) 145 | # if self.id == 1: 146 | # print('distances: ', distances, ', xmean:', xmean, ', ymean:', ymean, ', coords:', self.coords[self.id]) 147 | #pdb.set_trace() 148 | 149 | if abs(distances) < 3: # 2 is a prescribed value, less than which we can say the target position is arrived 150 | self.id += 1 151 | self.id = self.id%7 # here I only define 4 corners. xun 152 | #pdb.set_trace() 153 | 154 | 155 | real_action = [real_action, real_action1] 156 | return real_action 157 | 158 | 159 | 160 | 161 | 162 | 163 | # V4, scripted attacking with improved performance. Xun, Jun 18, 2021 164 | # Here the void-ray conduct a simple route because its sight range is larger. 165 | class FindAndDefeatZergling_4(base_agent.BaseAgent): 166 | """An agent is developed from CollectMineralShardsFeatureUnits class. 167 | """ 168 | id=0 169 | def setup(self, obs_spec, action_spec): 170 | super(FindAndDefeatZergling_4, self).setup(obs_spec, action_spec) 171 | #pdb.set_trace() 172 | if "feature_units" not in obs_spec: 173 | raise Exception("This agent requires the feature_units observation.") 174 | 175 | #self.coords=[(26,25),(32,0),(0,32),(32,32)] 176 | #self.coords=[(5,10),(26,10),(5,25),(26,25)] #self.coords=[(4.9,9.9),(15,9.9),(5,17),(15,17),(5,25),(15,25),(15,10),(26,10),(15,17),(26,17),(15,25),(26,25)] 177 | # Obtain the following coordinates by testing, xun 178 | #self.coords=[(4.9,10.5),(26,14.2),(4.9,18.5),(26.1,19.5),(4.9,22.5),(25,25), (26.1,9.9)] 179 | #self.coords=[(5,12),(26,10),(26,19),(5,19),(5,25),(26,25), (5,10)] 180 | self.coords=[(5,12),(25,13),(26,19),(5,17),(5,24),(26,25),(26,19),(5,12)] 181 | #self.coords=[(14,16),(50,16),(50,28),(14,28),(14,40),(50,40), (14,16)] 182 | 183 | self.xmean_old = 0 184 | self.ymean_old = 0 185 | 186 | def reset(self): 187 | super(FindAndDefeatZergling_4, self).reset() 188 | self._marine_selected = False 189 | self._previous_mineral_xy = [-1, -1] 190 | self.id = 0 # trace the settled points 191 | self.xmean_old = 0 192 | self.ymean_old = 0 193 | 194 | def step(self, obs): 195 | super(FindAndDefeatZergling_4, self).step(obs) 196 | 197 | #pdb.set_trace() 198 | #obs.observation['raw_units'] 199 | real_action = FUNCTIONS.no_op() 200 | real_action1 = FUNCTIONS.no_op() 201 | 202 | player_y, player_x = (obs.observation.feature_minimap.player_relative == features.PlayerRelative.SELF).nonzero() 203 | xmean=player_x.mean() 204 | ymean=player_y.mean() 205 | #real_action1 = FUNCTIONS.move_camera((xmean, ymean)) 206 | 207 | #pdb.set_trace() 208 | real_action1 = FUNCTIONS.move_camera(self.coords[self.id]) 209 | if FUNCTIONS.Attack_screen.id in obs.observation.available_actions: 210 | real_action = FUNCTIONS.Attack_screen("now", self.coords[self.id]) 211 | else: 212 | if FUNCTIONS.Move_screen.id in obs.observation.available_actions: 213 | real_action = FUNCTIONS.Move_screen("now", self.coords[self.id]) 214 | 215 | 216 | # Find the distance to the taeget position 217 | distances = numpy.linalg.norm(numpy.array([xmean,ymean]) - numpy.array(self.coords[self.id])) 218 | # if self.id == 1: 219 | # print('distances: ', distances, ', xmean:', xmean, ', ymean:', ymean, ', coords:', self.coords[self.id]) 220 | #pdb.set_trace() 221 | 222 | if abs(distances) < 3: # 2 is a prescribed value, less than which we can say the target position is arrived 223 | self.id += 1 224 | self.id = self.id%8 # here I only define 4 corners. xun 225 | #pdb.set_trace() 226 | 227 | 228 | real_action = [real_action, real_action1] 229 | return real_action 230 | 231 | 232 | 233 | 234 | 235 | 236 | def _xy_locs(mask): 237 | """Mask should be a set of bools from comparison with a feature layer.""" 238 | y, x = mask.nonzero() 239 | #pdb.set_trace() 240 | return list(zip(x, y)) 241 | 242 | 243 | class MoveToBeacon(base_agent.BaseAgent): 244 | """An agent specifically for solving the MoveToBeacon map.""" 245 | 246 | def step(self, obs): 247 | super(MoveToBeacon, self).step(obs) 248 | if FUNCTIONS.Move_screen.id in obs.observation.available_actions: 249 | player_relative = obs.observation.feature_screen.player_relative 250 | beacon = _xy_locs(player_relative == _PLAYER_NEUTRAL) 251 | if not beacon: 252 | return FUNCTIONS.no_op() 253 | beacon_center = numpy.mean(beacon, axis=0).round() 254 | return FUNCTIONS.Move_screen("now", beacon_center) 255 | else: 256 | return FUNCTIONS.select_army("select") 257 | 258 | 259 | class CollectMineralShards(base_agent.BaseAgent): 260 | """An agent specifically for solving the CollectMineralShards map.""" 261 | 262 | def step(self, obs): 263 | super(CollectMineralShards, self).step(obs) 264 | if FUNCTIONS.Move_screen.id in obs.observation.available_actions: 265 | player_relative = obs.observation.feature_screen.player_relative 266 | minerals = _xy_locs(player_relative == _PLAYER_NEUTRAL) 267 | if not minerals: 268 | return FUNCTIONS.no_op() 269 | marines = _xy_locs(player_relative == _PLAYER_SELF) 270 | marine_xy = numpy.mean(marines, axis=0).round() # Average location. 271 | distances = numpy.linalg.norm(numpy.array(minerals) - marine_xy, axis=1) 272 | closest_mineral_xy = minerals[numpy.argmin(distances)] 273 | return FUNCTIONS.Move_screen("now", closest_mineral_xy) 274 | else: 275 | return FUNCTIONS.select_army("select") 276 | 277 | 278 | class CollectMineralShardsFeatureUnits(base_agent.BaseAgent): 279 | """An agent for solving the CollectMineralShards map with feature units. 280 | 281 | Controls the two marines independently: 282 | - select marine 283 | - move to nearest mineral shard that wasn't the previous target 284 | - swap marine and repeat 285 | """ 286 | 287 | def setup(self, obs_spec, action_spec): 288 | super(CollectMineralShardsFeatureUnits, self).setup(obs_spec, action_spec) 289 | if "feature_units" not in obs_spec: 290 | raise Exception("This agent requires the feature_units observation.") 291 | 292 | def reset(self): 293 | super(CollectMineralShardsFeatureUnits, self).reset() 294 | self._marine_selected = False 295 | self._previous_mineral_xy = [-1, -1] 296 | 297 | def step(self, obs): 298 | super(CollectMineralShardsFeatureUnits, self).step(obs) 299 | marines = [unit for unit in obs.observation.feature_units 300 | if unit.alliance == _PLAYER_SELF] 301 | if not marines: 302 | return FUNCTIONS.no_op() 303 | marine_unit = next((m for m in marines 304 | if m.is_selected == self._marine_selected), marines[0]) 305 | marine_xy = [marine_unit.x, marine_unit.y] 306 | 307 | if not marine_unit.is_selected: 308 | # Nothing selected or the wrong marine is selected. 309 | self._marine_selected = True 310 | return FUNCTIONS.select_point("select", marine_xy) 311 | 312 | if FUNCTIONS.Move_screen.id in obs.observation.available_actions: 313 | # Find and move to the nearest mineral. 314 | minerals = [[unit.x, unit.y] for unit in obs.observation.feature_units 315 | if unit.alliance == _PLAYER_NEUTRAL] 316 | 317 | if self._previous_mineral_xy in minerals: 318 | # Don't go for the same mineral shard as other marine. 319 | minerals.remove(self._previous_mineral_xy) 320 | 321 | if minerals: 322 | # Find the closest. 323 | distances = numpy.linalg.norm( 324 | numpy.array(minerals) - numpy.array(marine_xy), axis=1) 325 | closest_mineral_xy = minerals[numpy.argmin(distances)] 326 | 327 | # Swap to the other marine. 328 | self._marine_selected = False 329 | self._previous_mineral_xy = closest_mineral_xy 330 | return FUNCTIONS.Move_screen("now", closest_mineral_xy) 331 | 332 | return FUNCTIONS.no_op() 333 | 334 | 335 | class CollectMineralShardsRaw(base_agent.BaseAgent): 336 | """An agent for solving CollectMineralShards with raw units and actions. 337 | 338 | Controls the two marines independently: 339 | - move to nearest mineral shard that wasn't the previous target 340 | - swap marine and repeat 341 | """ 342 | 343 | def setup(self, obs_spec, action_spec): 344 | super(CollectMineralShardsRaw, self).setup(obs_spec, action_spec) 345 | if "raw_units" not in obs_spec: 346 | raise Exception("This agent requires the raw_units observation.") 347 | 348 | def reset(self): 349 | super(CollectMineralShardsRaw, self).reset() 350 | self._last_marine = None 351 | self._previous_mineral_xy = [-1, -1] 352 | 353 | def step(self, obs): 354 | super(CollectMineralShardsRaw, self).step(obs) 355 | marines = [unit for unit in obs.observation.raw_units 356 | if unit.alliance == _PLAYER_SELF] 357 | if not marines: 358 | return RAW_FUNCTIONS.no_op() 359 | marine_unit = next((m for m in marines if m.tag != self._last_marine)) 360 | marine_xy = [marine_unit.x, marine_unit.y] 361 | 362 | minerals = [[unit.x, unit.y] for unit in obs.observation.raw_units 363 | if unit.alliance == _PLAYER_NEUTRAL] 364 | 365 | if self._previous_mineral_xy in minerals: 366 | # Don't go for the same mineral shard as other marine. 367 | minerals.remove(self._previous_mineral_xy) 368 | 369 | if minerals: 370 | # Find the closest. 371 | distances = numpy.linalg.norm( 372 | numpy.array(minerals) - numpy.array(marine_xy), axis=1) 373 | closest_mineral_xy = minerals[numpy.argmin(distances)] 374 | 375 | self._last_marine = marine_unit.tag 376 | self._previous_mineral_xy = closest_mineral_xy 377 | return RAW_FUNCTIONS.Move_pt("now", marine_unit.tag, closest_mineral_xy) 378 | 379 | return RAW_FUNCTIONS.no_op() 380 | 381 | 382 | class DefeatRoaches(base_agent.BaseAgent): 383 | """An agent specifically for solving the DefeatRoaches map.""" 384 | 385 | def step(self, obs): 386 | super(DefeatRoaches, self).step(obs) 387 | if FUNCTIONS.Attack_screen.id in obs.observation.available_actions: 388 | player_relative = obs.observation.feature_screen.player_relative 389 | roaches = _xy_locs(player_relative == _PLAYER_ENEMY) 390 | if not roaches: 391 | return FUNCTIONS.no_op() 392 | 393 | # Find the roach with max y coord. 394 | target = roaches[numpy.argmax(numpy.array(roaches)[:, 1])] 395 | return FUNCTIONS.Attack_screen("now", target) 396 | 397 | if FUNCTIONS.select_army.id in obs.observation.available_actions: 398 | return FUNCTIONS.select_army("select") 399 | 400 | return FUNCTIONS.no_op() 401 | 402 | 403 | class DefeatRoachesRaw(base_agent.BaseAgent): 404 | """An agent specifically for solving DefeatRoaches using raw actions.""" 405 | 406 | def setup(self, obs_spec, action_spec): 407 | super(DefeatRoachesRaw, self).setup(obs_spec, action_spec) 408 | if "raw_units" not in obs_spec: 409 | raise Exception("This agent requires the raw_units observation.") 410 | 411 | def step(self, obs): 412 | super(DefeatRoachesRaw, self).step(obs) 413 | marines = [unit.tag for unit in obs.observation.raw_units 414 | if unit.alliance == _PLAYER_SELF] 415 | roaches = [unit for unit in obs.observation.raw_units 416 | if unit.alliance == _PLAYER_ENEMY] 417 | 418 | if marines and roaches: 419 | # Find the roach with max y coord. 420 | target = sorted(roaches, key=lambda r: r.y)[0].tag 421 | return RAW_FUNCTIONS.Attack_unit("now", marines, target) 422 | 423 | return FUNCTIONS.no_op() 424 | -------------------------------------------------------------------------------- /sc2DqnAgent.py: -------------------------------------------------------------------------------- 1 | #---------------------------- 2 | # The DQN agent, simplified and modified for 1 vs 1 case. 3 | # 4 | # This code is only for research purpose. 5 | # 6 | # This code is modified from keras-rl (https://github.com/keras-rl/keras-rl/blob/master/rl/agents/dqn.py) and dqn-pysc2. 7 | # Most, if not all, modifications were explicitly pointed out in the code. 8 | # 9 | # Xun Huang, Jul 29, 2021 10 | #---------------------------- 11 | 12 | 13 | from __future__ import division 14 | import warnings 15 | 16 | # framework imports 17 | from keras.layers import Lambda, Input, Dense, Conv2D, Flatten 18 | from rl.memory import RingBuffer 19 | #from rl.agents.dqn import Agent 20 | from rl.policy import EpsGreedyQPolicy, GreedyQPolicy 21 | from rl.util import * 22 | from baselines.common.schedules import LinearSchedule 23 | from agent import Agent3 24 | import pdb 25 | 26 | 27 | 28 | 29 | class Sc2Action: 30 | # default: noop 31 | def __init__(self, act=0, x=0, y=0): 32 | self.coords = (x, y) 33 | self.action = act 34 | 35 | 36 | 37 | class AbstractSc2DQNAgent3(Agent3): 38 | def __init__(self, nb_actions, screen_size, memory, gamma=.99, batch_size=32, nb_steps_warmup=1000, 39 | train_interval=1, memory_interval=1, target_model_update=10000, screen=32, 40 | delta_range=None, delta_clip=np.inf, custom_model_objects={}, **kwargs): 41 | super(AbstractSc2DQNAgent3, self).__init__(**kwargs) 42 | 43 | # Soft vs hard target model updates. 44 | if target_model_update < 0: 45 | raise ValueError('`target_model_update` must be >= 0.') 46 | elif target_model_update >= 1: 47 | # Hard update every `target_model_update` steps. 48 | target_model_update = int(target_model_update) 49 | else: 50 | # Soft update with `(1 - target_model_update) * old + target_model_update * new`. 51 | target_model_update = float(target_model_update) 52 | 53 | if delta_range is not None: 54 | warnings.warn( 55 | '`delta_range` is deprecated. Please use `delta_clip` instead, which takes a single scalar. For now we\'re falling back to `delta_range[1] = {}`'.format( 56 | delta_range[1])) 57 | delta_clip = delta_range[1] 58 | 59 | # Parameters. 60 | self._SCREEN = screen #included by xun 61 | 62 | self.nb_actions = nb_actions 63 | self.screen_size = screen_size 64 | self.gamma = gamma 65 | self.batch_size = batch_size 66 | self.nb_steps_warmup = nb_steps_warmup 67 | self.train_interval = train_interval 68 | self.memory_interval = memory_interval 69 | self.target_model_update = target_model_update 70 | self.delta_clip = delta_clip 71 | self.custom_model_objects = custom_model_objects 72 | 73 | # Related objects. 74 | self.memory = memory 75 | 76 | # State. 77 | self.compiled = False 78 | 79 | # This code looks ridiculous to me. xun 80 | def process_state_batch(self, batch): 81 | batch = np.array(batch) 82 | if self.processor is None: 83 | return batch 84 | return self.processor.process_state_batch(batch) 85 | 86 | def compute_batch_q_values(self, state_batch): 87 | batch = self.process_state_batch(state_batch) 88 | # print('debug step sc2agent 1') 89 | #pdb.set_trace() 90 | q_values = self.model.predict_on_batch(batch) 91 | # assert q_values.shape == (len(state_batch), self.nb_actions) (len(state_batch), 2) 92 | return q_values 93 | 94 | def compute_q_values(self, state): 95 | # pdb.set_trace() 96 | # q_values = self.compute_batch_q_values([state]) 97 | #Modify by Xun to avoid unnecessary function calls 98 | batch0=[state] 99 | batch0 = np.array(batch0) 100 | size_first_dim = len(batch0) 101 | size_second_dim = len(batch0[0,0]) 102 | batch=np.reshape(batch0, (size_first_dim, size_second_dim, self._SCREEN, self._SCREEN)) 103 | #pdb.set_trace() 104 | q_values = self.model.predict_on_batch(batch) 105 | return q_values 106 | 107 | 108 | 109 | def get_config(self): 110 | return { 111 | 'nb_actions': self.nb_actions, 112 | 'screen_size': self.screen_size, 113 | 'gamma': self.gamma, 114 | 'batch_size': self.batch_size, 115 | 'nb_steps_warmup': self.nb_steps_warmup, 116 | 'train_interval': self.train_interval, 117 | 'memory_interval': self.memory_interval, 118 | 'target_model_update': self.target_model_update, 119 | 'delta_clip': self.delta_clip, 120 | 'memory': get_object_config(self.memory), 121 | } 122 | 123 | 124 | 125 | class Sc2DqnAgent_v5(AbstractSc2DQNAgent3): 126 | def __init__(self, model, policy=None, test_policy=None, 127 | prio_replay=True, prio_replay_beta=(0.5, 1.0, 200000), 128 | bad_prio_replay=False, multi_step_size=3, *args, **kwargs): 129 | super(Sc2DqnAgent_v5, self).__init__(*args, **kwargs) 130 | 131 | # Validate (important) input. Falls man sein Model falsch definiert hat ( ^: 132 | if hasattr(model.output, '__len__') and len(model.output) != 2: 133 | raise ValueError( 134 | 'Model "{}" has more or less than two outputs. DQN expects a model that has exactly 2 outputs.'.format( 135 | model)) 136 | 137 | # Parameters. 138 | self.prio_replay = True #prio_replay Set to true by Xun but don't know why 139 | self.prio_replay_beta = prio_replay_beta 140 | self.bad_prio_replay = bad_prio_replay 141 | self.multi_step_size = multi_step_size 142 | 143 | # Related objects. 144 | self.model = model 145 | assert policy is not None 146 | if test_policy is None: 147 | test_policy = policy 148 | self.policy = policy 149 | self.test_policy = test_policy 150 | 151 | # if self.prio_replay: 152 | assert len(prio_replay_beta) == 3 153 | self.beta_schedule = LinearSchedule(prio_replay_beta[2], 154 | initial_p=prio_replay_beta[0], 155 | final_p=prio_replay_beta[1]) 156 | 157 | self.recent = RingBuffer(maxlen=multi_step_size) 158 | # RingBuffer für Rewards 159 | self.recent_r = RingBuffer(maxlen=multi_step_size) 160 | 161 | # State. 162 | self.reset_states() 163 | 164 | def get_config(self): 165 | config = super(Sc2DqnAgent_v5, self).get_config() 166 | config['model'] = get_object_config(self.model) 167 | config['policy'] = get_object_config(self.policy) 168 | config['test_policy'] = get_object_config(self.test_policy) 169 | if self.compiled: 170 | config['target_model'] = get_object_config(self.target_model) 171 | return config 172 | 173 | def compile(self, optimizer, metrics=[]): 174 | metrics += [mean_q] # register default metrics 175 | 176 | # We never train the target model, hence we can set the optimizer and loss arbitrarily. 177 | 178 | self.target_model = clone_model(self.model, self.custom_model_objects) 179 | print("custom_model_objects: ", self.custom_model_objects) 180 | self.target_model.compile(optimizer='sgd', loss='mse') 181 | self.model.compile(optimizer='sgd', loss='mse') 182 | 183 | 184 | # Lambda-Layer, welche den Loss des Netzwerks berechnet! 185 | def clipped_masked_error(args): 186 | y_true_a, y_true_b, y_pred_a, y_pred_b, mask_a, mask_b = args 187 | loss = [huber_loss(y_true_a, y_pred_a, self.delta_clip), 188 | huber_loss(y_true_b, y_pred_b, self.delta_clip)] 189 | loss[0] *= mask_a # apply element-wise mask 190 | loss[1] *= mask_b # apply element-wise mask 191 | sum_loss_a = K.sum(loss[0]) 192 | sum_loss_b = K.sum(loss[1]) 193 | return K.sum([sum_loss_a, sum_loss_b], axis=-1) 194 | 195 | 196 | y_pred = self.model.output 197 | 198 | y_true_a = Input(name='y_true_a', shape=(self.nb_actions,)) 199 | y_true_b = Input(name='y_true_b', shape=(self.screen_size, self.screen_size, 1)) 200 | mask_a = Input(name='mask_a', shape=(self.nb_actions,)) 201 | mask_b = Input(name='mask_b', shape=(self.screen_size, self.screen_size, 1)) 202 | 203 | loss_out = Lambda(clipped_masked_error, output_shape=(1,), name='loss')( 204 | [y_true_a, y_true_b, y_pred[0], y_pred[1], mask_a, mask_b]) 205 | ins = [self.model.input] if type(self.model.input) is not list else self.model.input 206 | 207 | 208 | trainable_model = Model(inputs=ins + [y_true_a, y_true_b, mask_a, mask_b], 209 | outputs=[loss_out, y_pred[0], y_pred[1]]) 210 | print(trainable_model.summary()) 211 | 212 | losses = [ 213 | lambda y_true, y_pred: y_pred, # loss is computed in Lambda layer 214 | lambda y_true, y_pred: K.zeros_like(y_pred), # we only include this for the metrics 215 | lambda y_true, y_pred: K.zeros_like(y_pred), # we only include this for the metrics 216 | ] 217 | trainable_model.compile(optimizer=optimizer, loss=losses) # metrics=combined_metrics 218 | self.trainable_model = trainable_model 219 | 220 | self.compiled = True 221 | 222 | def load_weights(self, filepath): 223 | self.model.load_weights(filepath) 224 | self.update_target_model_hard() 225 | 226 | def save_weights(self, filepath, overwrite=False): 227 | self.model.save_weights(filepath, overwrite=overwrite) 228 | 229 | def reset_states(self): 230 | self.recent_action = None 231 | self.recent_observation = None 232 | if self.compiled: 233 | self.model.reset_states() 234 | self.target_model.reset_states() 235 | 236 | def update_target_model_hard(self): 237 | self.target_model.set_weights(self.model.get_weights()) 238 | 239 | def forward(self, observation): 240 | # Select an action. 241 | state = [observation] 242 | # print('debug step sc2agent 0') 243 | #pdb.set_trace() 244 | q_values = self.compute_q_values(state) 245 | 246 | if self.training: 247 | action = self.policy.select_action(q_values=q_values) 248 | else: 249 | action = self.test_policy.select_action(q_values=q_values) 250 | 251 | # Book-keeping. 252 | self.recent.append((observation, action)) 253 | 254 | return action 255 | 256 | 257 | 258 | # Compared to keras-rl.core, here we have one new input, observation_1. xun 259 | def backward(self, reward, terminal, observation_1): 260 | # RingBuffer. 261 | self.recent_r.append(reward) 262 | 263 | # Store most recent experience in memory. (s_t, a_t, r_t1 + gamma*r_t2, s_t2, ter2) 264 | # ??? I don't get the meaning of the following code, differnt from keras-rl code. Xun 265 | if self.step % self.memory_interval == 0: 266 | # some resetting after terminal/done stuff to not save cross episodes. 267 | if self.recent.__len__() == self.recent.maxlen: 268 | if self.recent.__getitem__(0) is not None: 269 | acc_r = 0 270 | for i in range(self.recent_r.maxlen): 271 | acc_r += self.recent_r.__getitem__(i) * (self.gamma ** i) 272 | 273 | rec_0 = self.recent.__getitem__(0) 274 | obs_0 = rec_0[0] 275 | act_0 = rec_0[1] 276 | 277 | self.memory.add(obs_0, act_0, acc_r, observation_1, terminal) 278 | 279 | metrics = [np.nan for _ in self.metrics_names] 280 | if not self.training: 281 | # We're done here. No need to update the experience memory since we only use the working memory to obtain the state over the most recent observations. 282 | return metrics 283 | 284 | # Train the network on a single stochastic batch. 285 | if self.step > self.nb_steps_warmup and self.step % self.train_interval == 0: 286 | 287 | experiences = self.memory.sample(self.batch_size, self.beta_schedule.value(self.step)) 288 | assert len(experiences[0]) == self.batch_size 289 | 290 | # Start by extracting the necessary parameters (we use a vectorized implementation). 291 | state0_batch = [] 292 | action_batch = [] 293 | reward_batch = [] 294 | state2_batch = [] 295 | terminal2_batch = [] 296 | if self.prio_replay: 297 | prio_weights_batch = [] 298 | id_batch = [] 299 | 300 | if self.prio_replay: 301 | experiences = zip(experiences[0], experiences[1], experiences[2], experiences[3], experiences[4], 302 | experiences[5], experiences[6]) 303 | else: 304 | experiences = zip(experiences[0], experiences[1], experiences[2], experiences[3], experiences[4]) 305 | 306 | for e in experiences: 307 | state0_batch.append(e[0]) 308 | action_batch.append(e[1]) 309 | reward_batch.append(e[2]) 310 | state2_batch.append(e[3]) 311 | terminal2_batch.append(0. if e[4] else 1.) 312 | if self.prio_replay: 313 | prio_weights_batch.append(e[5]) 314 | id_batch.append(e[6]) 315 | 316 | # Prepare and validate parameters. 317 | state0_batch = self.process_state_batch(state0_batch) 318 | state2_batch = self.process_state_batch(state2_batch) 319 | terminal2_batch = np.array(terminal2_batch) 320 | reward_batch = np.array(reward_batch) 321 | if self.prio_replay: 322 | prio_weights_batch = np.array(prio_weights_batch) 323 | else: 324 | prio_weights_batch = np.ones(reward_batch.shape) 325 | assert reward_batch.shape == (self.batch_size,) 326 | assert terminal2_batch.shape == reward_batch.shape 327 | assert len(action_batch) == len(reward_batch) 328 | 329 | target_q2_values = self.target_model.predict_on_batch(state2_batch) 330 | q_batch_a = np.max(target_q2_values[0], axis=-1) 331 | q_batch_b = np.max(target_q2_values[1], axis=(1, 2))[:, 0] 332 | q_batch_a = np.array(q_batch_a) 333 | q_batch_b = np.array(q_batch_b) 334 | 335 | targets_a = np.zeros((self.batch_size, self.nb_actions,)) 336 | targets_b = np.zeros((self.batch_size, self.screen_size, self.screen_size, 1)) 337 | 338 | masks_a = np.zeros((self.batch_size, self.nb_actions,)) 339 | masks_b = np.zeros((self.batch_size, self.screen_size, self.screen_size, 1)) 340 | 341 | # Compute r_t+n (included discounting) + gamma^n * max_a Q(s_t+n, a) and update the targets accordingly, 342 | # but only for the affected output units (as given by action_batch). (Called Rs_a and Rs_b) 343 | discounted_reward_batch_a = (self.gamma ** self.multi_step_size) * q_batch_a 344 | discounted_reward_batch_b = (self.gamma ** self.multi_step_size) * q_batch_b 345 | # Set discounted reward to zero for all states that were terminal. 346 | discounted_reward_batch_a = discounted_reward_batch_a * terminal2_batch[:] 347 | discounted_reward_batch_b = discounted_reward_batch_b * terminal2_batch[:] 348 | Rs_a = reward_batch[:] + discounted_reward_batch_a 349 | Rs_b = reward_batch[:] + discounted_reward_batch_b 350 | 351 | for idx, (target_a, target_b, mask_a, mask_b, R_a, R_b, action, prio_weight) in \ 352 | enumerate(zip(targets_a, targets_b, masks_a, masks_b, Rs_a, Rs_b, action_batch, prio_weights_batch)): 353 | target_a[action.action] = R_a # update action with estimated accumulated reward 354 | target_b[action.coords] = R_b # update action with estimated accumulated reward 355 | 356 | mask_a[action.action] = prio_weight # enable loss for this specific action 357 | mask_b[action.coords] = prio_weight # enable loss for this specific action 358 | 359 | targets_a = np.array(targets_a).astype('float32') 360 | targets_b = np.array(targets_b).astype('float32') 361 | masks_a = np.array(masks_a).astype('float32') 362 | masks_b = np.array(masks_b).astype('float32') 363 | 364 | # Finally, perform a single update on the entire batch. We use a dummy target since 365 | # the actual loss is computed in a Lambda layer that needs more complex input. However, 366 | # it is still useful to know the actual target to compute metrics properly. 367 | ins = [state0_batch] if type(self.model.input) is not list else state0_batch 368 | 369 | metrics = self.trainable_model.train_on_batch(ins + [targets_a, targets_b, masks_a, masks_b], 370 | [np.zeros(self.batch_size), targets_a, targets_b]) 371 | 372 | metrics = [metric for idx, metric in enumerate(metrics) if 373 | idx not in (1, 2)] # throw away individual losses 374 | 375 | if self.prio_replay: 376 | pred = self.trainable_model.predict_on_batch(ins + [targets_a, targets_b, masks_a, masks_b]) 377 | 378 | # update priority batch 379 | if self.prio_replay: 380 | prios = [] 381 | 382 | # Richtige Implementierung. 383 | for (pre_a, pre_b, target_a, target_b, mask_a, mask_b, prio_weight) \ 384 | in zip(pred[1], pred[2], targets_a, targets_b, masks_a, masks_b, prio_weights_batch): 385 | # need to remove prio weight from masks 386 | mask_a = mask_a / prio_weight 387 | mask_b = mask_b / prio_weight 388 | loss = [pre_a - target_a, 389 | pre_b - target_b] 390 | loss[0] *= mask_a # apply element-wise mask 391 | loss[1] *= mask_b # apply element-wise mask 392 | sum_loss_a = np.sum(loss[0]) 393 | sum_loss_b = np.sum(loss[1]) 394 | prios.append(np.abs(np.sum([sum_loss_a, sum_loss_b]))) 395 | 396 | self.memory.update_priorities(id_batch, prios) 397 | 398 | metrics += self.policy.metrics 399 | if self.processor is not None: 400 | metrics += self.processor.metrics 401 | 402 | if self.target_model_update >= 1 and self.step % self.target_model_update == 0: 403 | self.update_target_model_hard() 404 | 405 | return metrics 406 | 407 | @property 408 | def layers(self): 409 | return self.model.layers[:] 410 | 411 | @property 412 | def metrics_names(self): 413 | # Throw away individual losses and replace output name since this is hidden from the user. 414 | assert len(self.trainable_model.output_names) == 3 415 | dummy_output_name = self.trainable_model.output_names[1] 416 | model_metrics = [name for idx, name in enumerate(self.trainable_model.metrics_names) if idx not in (1, 2)] 417 | model_metrics = [name.replace(dummy_output_name + '_', '') for name in model_metrics] 418 | 419 | names = model_metrics + self.policy.metrics_names[:] 420 | if self.processor is not None: 421 | names += self.processor.metrics_names[:] 422 | return names 423 | 424 | @property 425 | def policy(self): 426 | return self.__policy 427 | 428 | @policy.setter 429 | def policy(self, policy): 430 | self.__policy = policy 431 | self.__policy._set_agent(self) 432 | 433 | @property 434 | def test_policy(self): 435 | return self.__test_policy 436 | 437 | @test_policy.setter 438 | def test_policy(self, policy): 439 | self.__test_policy = policy 440 | self.__test_policy._set_agent(self) 441 | 442 | 443 | 444 | 445 | 446 | 447 | 448 | def mean_q(y_true, y_pred): 449 | mean_a = K.mean(K.max(y_pred[0], axis=-1)) 450 | mean_b = K.mean(K.max(y_pred[1], axis=(1, 2))) 451 | return K.mean(mean_a, mean_b) 452 | 453 | 454 | 455 | 456 | 457 | 458 | 459 | -------------------------------------------------------------------------------- /sc2_env_xun.py: -------------------------------------------------------------------------------- 1 | #---------------------------- 2 | # The environment for 1 vs 1 with 2 agents, one is scripted and the other is from DQN. 3 | # 4 | # This code is only for research purpose. 5 | # The code is developed based on the sc2_env from pysc2. 6 | # Most, if not all, modifications were explicitly pointed out in the code. 7 | # 8 | # Xun Huang, Jul 29, 2021 9 | #---------------------------- 10 | 11 | # Copyright 2017 Google Inc. All Rights Reserved. 12 | # 13 | # Licensed under the Apache License, Version 2.0 (the "License"); 14 | # you may not use this file except in compliance with the License. 15 | # You may obtain a copy of the License at 16 | # 17 | # http://www.apache.org/licenses/LICENSE-2.0 18 | # 19 | # Unless required by applicable law or agreed to in writing, software 20 | # distributed under the License is distributed on an "AS-IS" BASIS, 21 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 22 | # See the License for the specific language governing permissions and 23 | # limitations under the License. 24 | """A Starcraft II environment.""" 25 | # pylint: disable=g-complex-comprehension 26 | 27 | from __future__ import absolute_import 28 | from __future__ import division 29 | from __future__ import print_function 30 | 31 | import collections 32 | from absl import logging 33 | import random 34 | import time 35 | 36 | import pdb 37 | 38 | import enum 39 | 40 | from pysc2 import maps 41 | from pysc2 import run_configs 42 | from pysc2.env import environment 43 | from pysc2.lib import actions as actions_lib 44 | from pysc2.lib import features 45 | from pysc2.lib import metrics 46 | from pysc2.lib import portspicker 47 | from pysc2.lib import renderer_human 48 | from pysc2.lib import run_parallel 49 | from pysc2.lib import stopwatch 50 | 51 | from s2clientprotocol import common_pb2 as sc_common 52 | from s2clientprotocol import sc2api_pb2 as sc_pb 53 | 54 | import pdb 55 | 56 | sw = stopwatch.sw 57 | 58 | possible_results = { 59 | sc_pb.Victory: 1, 60 | sc_pb.Defeat: -1, 61 | sc_pb.Tie: 0, 62 | sc_pb.Undecided: 0, 63 | } 64 | 65 | 66 | class Race(enum.IntEnum): 67 | random = sc_common.Random 68 | protoss = sc_common.Protoss 69 | terran = sc_common.Terran 70 | zerg = sc_common.Zerg 71 | 72 | 73 | class Difficulty(enum.IntEnum): 74 | """Bot difficulties.""" 75 | very_easy = sc_pb.VeryEasy 76 | easy = sc_pb.Easy 77 | medium = sc_pb.Medium 78 | medium_hard = sc_pb.MediumHard 79 | hard = sc_pb.Hard 80 | harder = sc_pb.Harder 81 | very_hard = sc_pb.VeryHard 82 | cheat_vision = sc_pb.CheatVision 83 | cheat_money = sc_pb.CheatMoney 84 | cheat_insane = sc_pb.CheatInsane 85 | 86 | 87 | class BotBuild(enum.IntEnum): 88 | """Bot build strategies.""" 89 | random = sc_pb.RandomBuild 90 | rush = sc_pb.Rush 91 | timing = sc_pb.Timing 92 | power = sc_pb.Power 93 | macro = sc_pb.Macro 94 | air = sc_pb.Air 95 | 96 | 97 | # Re-export these names to make it easy to construct the environment. 98 | ActionSpace = actions_lib.ActionSpace # pylint: disable=invalid-name 99 | Dimensions = features.Dimensions # pylint: disable=invalid-name 100 | AgentInterfaceFormat = features.AgentInterfaceFormat # pylint: disable=invalid-name 101 | parse_agent_interface_format = features.parse_agent_interface_format 102 | 103 | 104 | def to_list(arg): 105 | return arg if isinstance(arg, list) else [arg] 106 | 107 | 108 | def get_default(a, b): 109 | return b if a is None else a 110 | 111 | 112 | class Agent(collections.namedtuple("Agent", ["race", "name"])): 113 | """Define an Agent. It can have a single race or a list of races.""" 114 | 115 | def __new__(cls, race, name=None): 116 | return super(Agent, cls).__new__(cls, to_list(race), name or "") 117 | 118 | 119 | class Bot(collections.namedtuple("Bot", ["race", "difficulty", "build"])): 120 | """Define a Bot. It can have a single or list of races or builds.""" 121 | 122 | def __new__(cls, race, difficulty, build=None): 123 | return super(Bot, cls).__new__( 124 | cls, to_list(race), difficulty, to_list(build or BotBuild.random)) 125 | 126 | 127 | _DelayedAction = collections.namedtuple( 128 | "DelayedAction", ["game_loop", "action"]) 129 | 130 | REALTIME_GAME_LOOP_SECONDS = 1 / 22.4 131 | MAX_STEP_COUNT = 524000 # The game fails above 2^19=524288 steps. 132 | NUM_ACTION_DELAY_BUCKETS = 10 133 | 134 | 135 | class SC2Env_xun(environment.Base): 136 | """A Starcraft II environment. 137 | 138 | The implementation details of the action and observation specs are in 139 | lib/features.py 140 | """ 141 | 142 | def __init__(self, # pylint: disable=invalid-name 143 | _only_use_kwargs=None, 144 | map_name=None, 145 | battle_net_map=False, 146 | players=None, 147 | agent_interface_format=None, 148 | discount=1., 149 | discount_zero_after_timeout=False, 150 | visualize=False, 151 | step_mul=None, 152 | realtime=False, 153 | save_replay_episodes=0, 154 | replay_dir='replay', #None, 155 | replay_prefix=None, 156 | game_steps_per_episode=None, 157 | score_index=None, 158 | score_multiplier=None, 159 | random_seed=None, 160 | disable_fog=False, 161 | ensure_available_actions=True, 162 | version=None): 163 | """Create a SC2 Env. 164 | 165 | You must pass a resolution that you want to play at. You can send either 166 | feature layer resolution or rgb resolution or both. If you send both you 167 | must also choose which to use as your action space. Regardless of which you 168 | choose you must send both the screen and minimap resolutions. 169 | 170 | For each of the 4 resolutions, either specify size or both width and 171 | height. If you specify size then both width and height will take that value. 172 | 173 | Args: 174 | _only_use_kwargs: Don't pass args, only kwargs. 175 | map_name: Name of a SC2 map. Run bin/map_list to get the full list of 176 | known maps. Alternatively, pass a Map instance. Take a look at the 177 | docs in maps/README.md for more information on available maps. Can 178 | also be a list of map names or instances, in which case one will be 179 | chosen at random per episode. 180 | battle_net_map: Whether to use the battle.net versions of the map(s). 181 | players: A list of Agent and Bot instances that specify who will play. 182 | agent_interface_format: A sequence containing one AgentInterfaceFormat 183 | per agent, matching the order of agents specified in the players list. 184 | Or a single AgentInterfaceFormat to be used for all agents. 185 | discount: Returned as part of the observation. 186 | discount_zero_after_timeout: If True, the discount will be zero 187 | after the `game_steps_per_episode` timeout. 188 | visualize: Whether to pop up a window showing the camera and feature 189 | layers. This won't work without access to a window manager. 190 | step_mul: How many game steps per agent step (action/observation). None 191 | means use the map default. 192 | realtime: Whether to use realtime mode. In this mode the game simulation 193 | automatically advances (at 22.4 gameloops per second) rather than 194 | being stepped manually. The number of game loops advanced with each 195 | call to step() won't necessarily match the step_mul specified. The 196 | environment will attempt to honour step_mul, returning observations 197 | with that spacing as closely as possible. Game loops will be skipped 198 | if they cannot be retrieved and processed quickly enough. 199 | save_replay_episodes: Save a replay after this many episodes. Default of 0 200 | means don't save replays. 201 | replay_dir: Directory to save replays. Required with save_replay_episodes. 202 | replay_prefix: An optional prefix to use when saving replays. 203 | game_steps_per_episode: Game steps per episode, independent of the 204 | step_mul. 0 means no limit. None means use the map default. 205 | score_index: -1 means use the win/loss reward, >=0 is the index into the 206 | score_cumulative with 0 being the curriculum score. None means use 207 | the map default. 208 | score_multiplier: How much to multiply the score by. Useful for negating. 209 | random_seed: Random number seed to use when initializing the game. This 210 | lets you run repeatable games/tests. 211 | disable_fog: Whether to disable fog of war. 212 | ensure_available_actions: Whether to throw an exception when an 213 | unavailable action is passed to step(). 214 | version: The version of SC2 to use, defaults to the latest. 215 | 216 | Raises: 217 | ValueError: if no map is specified. 218 | ValueError: if wrong number of players are requested for a map. 219 | ValueError: if the resolutions aren't specified correctly. 220 | """ 221 | if _only_use_kwargs: 222 | raise ValueError("All arguments must be passed as keyword arguments.") 223 | 224 | if not players: 225 | raise ValueError("You must specify the list of players.") 226 | 227 | 228 | for p in players: 229 | if not isinstance(p, (Agent, Bot)): 230 | raise ValueError( 231 | "Expected players to be of type Agent or Bot. Got: %s." % p) 232 | 233 | 234 | num_players = len(players) 235 | self._num_agents = sum(1 for p in players if isinstance(p, Agent)) 236 | self._players = players 237 | 238 | if not 1 <= num_players <= 2 or not self._num_agents: 239 | raise ValueError("Only 1 or 2 players with at least one agent is " 240 | "supported at the moment.") 241 | 242 | if not map_name: 243 | raise ValueError("Missing a map name.") 244 | 245 | self._battle_net_map = battle_net_map 246 | self._maps = [maps.get(name) for name in to_list(map_name)] 247 | min_players = min(m.players for m in self._maps) 248 | max_players = max(m.players for m in self._maps) 249 | if self._battle_net_map: 250 | for m in self._maps: 251 | if not m.battle_net: 252 | raise ValueError("%s isn't known on Battle.net" % m.name) 253 | 254 | #pdb.set_trace() 255 | if max_players == 1: 256 | if self._num_agents != 1: 257 | raise ValueError("Single player maps require exactly one Agent.") 258 | elif not 2 <= num_players <= min_players: 259 | raise ValueError( 260 | "Maps support 2 - %s players, but trying to join with %s" % ( 261 | min_players, num_players)) 262 | 263 | if save_replay_episodes and not replay_dir: 264 | raise ValueError("Missing replay_dir") 265 | 266 | self._realtime = realtime 267 | self._last_step_time = None 268 | self._save_replay_episodes = save_replay_episodes 269 | self._replay_dir = replay_dir 270 | self._replay_prefix = replay_prefix 271 | self._random_seed = random_seed 272 | self._disable_fog = disable_fog 273 | self._ensure_available_actions = ensure_available_actions 274 | self._discount = discount 275 | self._discount_zero_after_timeout = discount_zero_after_timeout 276 | self._default_step_mul = step_mul 277 | self._default_score_index = score_index 278 | self._default_score_multiplier = score_multiplier 279 | self._default_episode_length = game_steps_per_episode 280 | self._run_config = run_configs.get(version=version) 281 | self._parallel = run_parallel.RunParallel() # Needed for multiplayer. 282 | self._game_info = None 283 | 284 | if agent_interface_format is None: 285 | raise ValueError("Please specify agent_interface_format.") 286 | 287 | if isinstance(agent_interface_format, AgentInterfaceFormat): 288 | agent_interface_format = [agent_interface_format] * self._num_agents 289 | 290 | if len(agent_interface_format) != self._num_agents: 291 | raise ValueError( 292 | "The number of entries in agent_interface_format should " 293 | "correspond 1-1 with the number of agents.") 294 | 295 | self._action_delay_fns = [aif.action_delay_fn 296 | for aif in agent_interface_format] 297 | 298 | self._interface_formats = agent_interface_format 299 | self._interface_options = [ 300 | self._get_interface(interface_format, require_raw=visualize and i == 0) 301 | for i, interface_format in enumerate(agent_interface_format)] 302 | 303 | self._launch_game() 304 | self._create_join() 305 | 306 | self._finalize(visualize) 307 | 308 | def _finalize(self, visualize): 309 | self._delayed_actions = [collections.deque() 310 | for _ in self._action_delay_fns] 311 | 312 | if visualize: 313 | self._renderer_human = renderer_human.RendererHuman() 314 | self._renderer_human.init( 315 | self._controllers[0].game_info(), 316 | self._controllers[0].data()) 317 | else: 318 | self._renderer_human = None 319 | 320 | self._metrics = metrics.Metrics(self._map_name) 321 | self._metrics.increment_instance() 322 | 323 | self._last_score = None 324 | self._total_steps = 0 325 | self._episode_steps = 0 326 | self._episode_count = 0 327 | self._obs = [None] * self._num_agents 328 | self._agent_obs = [None] * self._num_agents 329 | self._state = environment.StepType.LAST # Want to jump to `reset`. 330 | logging.info("Environment is ready") 331 | 332 | @staticmethod 333 | def _get_interface(agent_interface_format, require_raw): 334 | aif = agent_interface_format 335 | interface = sc_pb.InterfaceOptions( 336 | raw=(aif.use_feature_units or 337 | aif.use_unit_counts or 338 | aif.use_raw_units or 339 | require_raw), 340 | show_cloaked=aif.show_cloaked, 341 | show_burrowed_shadows=aif.show_burrowed_shadows, 342 | show_placeholders=aif.show_placeholders, 343 | raw_affects_selection=True, 344 | raw_crop_to_playable_area=aif.raw_crop_to_playable_area, 345 | score=True) 346 | 347 | if aif.feature_dimensions: 348 | interface.feature_layer.width = aif.camera_width_world_units 349 | aif.feature_dimensions.screen.assign_to( 350 | interface.feature_layer.resolution) 351 | aif.feature_dimensions.minimap.assign_to( 352 | interface.feature_layer.minimap_resolution) 353 | interface.feature_layer.crop_to_playable_area = aif.crop_to_playable_area 354 | interface.feature_layer.allow_cheating_layers = aif.allow_cheating_layers 355 | 356 | if aif.rgb_dimensions: 357 | aif.rgb_dimensions.screen.assign_to(interface.render.resolution) 358 | aif.rgb_dimensions.minimap.assign_to(interface.render.minimap_resolution) 359 | 360 | return interface 361 | 362 | def _launch_game(self): 363 | # Reserve a whole bunch of ports for the weird multiplayer implementation. 364 | if self._num_agents > 1: 365 | self._ports = portspicker.pick_unused_ports(self._num_agents * 2) 366 | logging.info("Ports used for multiplayer: %s", self._ports) 367 | else: 368 | self._ports = [] 369 | 370 | # Actually launch the game processes. 371 | self._sc2_procs = [ 372 | self._run_config.start(extra_ports=self._ports, 373 | want_rgb=interface.HasField("render")) 374 | for interface in self._interface_options] 375 | self._controllers = [p.controller for p in self._sc2_procs] 376 | 377 | if self._battle_net_map: 378 | available_maps = self._controllers[0].available_maps() 379 | available_maps = set(available_maps.battlenet_map_names) 380 | unavailable = [m.name for m in self._maps 381 | if m.battle_net not in available_maps] 382 | if unavailable: 383 | raise ValueError("Requested map(s) not in the battle.net cache: %s" 384 | % ",".join(unavailable)) 385 | 386 | def _create_join(self): 387 | """Create the game, and join it.""" 388 | map_inst = random.choice(self._maps) 389 | self._map_name = map_inst.name 390 | 391 | self._step_mul = max(1, self._default_step_mul or map_inst.step_mul) 392 | self._score_index = get_default(self._default_score_index, 393 | map_inst.score_index) 394 | self._score_multiplier = get_default(self._default_score_multiplier, 395 | map_inst.score_multiplier) 396 | self._episode_length = get_default(self._default_episode_length, 397 | map_inst.game_steps_per_episode) 398 | 399 | if self._episode_length <= 0 or self._episode_length > MAX_STEP_COUNT: 400 | self._episode_length = MAX_STEP_COUNT 401 | 402 | # pdb.set_trace() 403 | # Create the game. Set the first instance as the host. 404 | create = sc_pb.RequestCreateGame( 405 | disable_fog=self._disable_fog, 406 | realtime=self._realtime) 407 | 408 | if self._battle_net_map: 409 | create.battlenet_map_name = map_inst.battle_net 410 | else: 411 | create.local_map.map_path = map_inst.path 412 | map_data = map_inst.data(self._run_config) 413 | if self._num_agents == 1: 414 | create.local_map.map_data = map_data 415 | else: 416 | # Save the maps so they can access it. Don't do it in parallel since SC2 417 | # doesn't respect tmpdir on windows, which leads to a race condition: 418 | # https://github.com/Blizzard/s2client-proto/issues/102 419 | for c in self._controllers: 420 | c.save_map(map_inst.path, map_data) 421 | if self._random_seed is not None: 422 | create.random_seed = self._random_seed 423 | for p in self._players: 424 | if isinstance(p, Agent): 425 | create.player_setup.add(type=sc_pb.Participant) 426 | else: 427 | create.player_setup.add( 428 | type=sc_pb.Computer, race=random.choice(p.race), 429 | difficulty=p.difficulty, ai_build=random.choice(p.build)) 430 | self._controllers[0].create_game(create) 431 | 432 | # Create the join requests. 433 | agent_players = [p for p in self._players if isinstance(p, Agent)] 434 | sanitized_names = crop_and_deduplicate_names(p.name for p in agent_players) 435 | join_reqs = [] 436 | for p, name, interface in zip(agent_players, sanitized_names, 437 | self._interface_options): 438 | join = sc_pb.RequestJoinGame(options=interface) 439 | join.race = random.choice(p.race) 440 | join.player_name = name 441 | if self._ports: 442 | join.shared_port = 0 # unused 443 | join.server_ports.game_port = self._ports[0] 444 | join.server_ports.base_port = self._ports[1] 445 | for i in range(self._num_agents - 1): 446 | join.client_ports.add(game_port=self._ports[i * 2 + 2], 447 | base_port=self._ports[i * 2 + 3]) 448 | join_reqs.append(join) 449 | 450 | # Join the game. This must be run in parallel because Join is a blocking 451 | # call to the game that waits until all clients have joined. 452 | #pdb.set_trace() 453 | self._parallel.run((c.join_game, join) 454 | for c, join in zip(self._controllers, join_reqs)) 455 | 456 | self._game_info = self._parallel.run(c.game_info for c in self._controllers) 457 | for g, interface in zip(self._game_info, self._interface_options): 458 | if g.options.render != interface.render: 459 | logging.warning( 460 | "Actual interface options don't match requested options:\n" 461 | "Requested:\n%s\n\nActual:\n%s", interface, g.options) 462 | 463 | self._features = [ 464 | features.features_from_game_info( 465 | game_info=g, agent_interface_format=aif, map_name=self._map_name) 466 | for g, aif in zip(self._game_info, self._interface_formats)] 467 | 468 | @property 469 | def map_name(self): 470 | return self._map_name 471 | 472 | @property 473 | def game_info(self): 474 | """A list of ResponseGameInfo, one per agent.""" 475 | return self._game_info 476 | 477 | def static_data(self): 478 | return self._controllers[0].data() 479 | 480 | def observation_spec(self): 481 | """Look at Features for full specs.""" 482 | return tuple(f.observation_spec() for f in self._features) 483 | 484 | def action_spec(self): 485 | """Look at Features for full specs.""" 486 | return tuple(f.action_spec() for f in self._features) 487 | 488 | def action_delays(self): 489 | """In realtime we track the delay observation -> action executed. 490 | 491 | Returns: 492 | A list per agent of action delays, where action delays are a list where 493 | the index in the list corresponds to the delay in game loops, the value 494 | at that index the count over the course of an episode. 495 | 496 | Raises: 497 | ValueError: If called when not in realtime mode. 498 | """ 499 | if not self._realtime: 500 | raise ValueError("This method is only supported in realtime mode") 501 | 502 | return self._action_delays 503 | 504 | def _restart(self): 505 | if (len(self._players) == 1 and len(self._players[0].race) == 1 and 506 | len(self._maps) == 1): 507 | # Need to support restart for fast-restart of mini-games. 508 | self._controllers[0].restart() 509 | else: 510 | if len(self._controllers) > 1: 511 | self._parallel.run(c.leave for c in self._controllers) 512 | self._create_join() 513 | 514 | @sw.decorate 515 | def reset(self): 516 | """Start a new episode.""" 517 | self._episode_steps = 0 518 | if self._episode_count: 519 | # No need to restart for the first episode. 520 | self._restart() 521 | 522 | self._episode_count += 1 523 | races = [Race(r).name 524 | for _, r in sorted(self._features[0].requested_races.items())] 525 | logging.info("Starting episode %s: [%s] on %s", 526 | self._episode_count, ", ".join(races), self._map_name) 527 | self._metrics.increment_episode() 528 | 529 | self._last_score = [0] * self._num_agents 530 | self._state = environment.StepType.FIRST 531 | if self._realtime: 532 | self._last_step_time = time.time() 533 | self._last_obs_game_loop = None 534 | self._action_delays = [[0] * NUM_ACTION_DELAY_BUCKETS] * self._num_agents 535 | 536 | return self._observe(target_game_loop=0) 537 | 538 | @sw.decorate("step_env") 539 | def step(self, actions, step_mul=None): 540 | """Apply actions, step the world forward, and return observations. 541 | 542 | Args: 543 | actions: A list of actions meeting the action spec, one per agent, or a 544 | list per agent. Using a list allows multiple actions per frame, but 545 | will still check that they're valid, so disabling 546 | ensure_available_actions is encouraged. 547 | step_mul: If specified, use this rather than the environment's default. 548 | 549 | Returns: 550 | A tuple of TimeStep namedtuples, one per agent. 551 | """ 552 | if self._state == environment.StepType.LAST: 553 | return self.reset() 554 | 555 | skip = not self._ensure_available_actions 556 | actions = [[f.transform_action(o.observation, a, skip_available=skip) 557 | for a in to_list(acts)] 558 | for f, o, acts in zip(self._features, self._obs, actions)] 559 | 560 | if not self._realtime: 561 | actions = self._apply_action_delays(actions) 562 | 563 | self._parallel.run((c.actions, sc_pb.RequestAction(actions=a)) 564 | for c, a in zip(self._controllers, actions)) 565 | 566 | 567 | # if self._episode_count >11: 568 | # pdb.set_trace() 569 | 570 | self._state = environment.StepType.MID 571 | return self._step(step_mul) 572 | 573 | def _step(self, step_mul=None): 574 | step_mul = step_mul or self._step_mul 575 | if step_mul <= 0: 576 | raise ValueError("step_mul should be positive, got {}".format(step_mul)) 577 | 578 | target_game_loop = self._episode_steps + step_mul 579 | if not self._realtime: 580 | # Send any delayed actions that were scheduled up to the target game loop. 581 | current_game_loop = self._send_delayed_actions( 582 | up_to_game_loop=target_game_loop, 583 | current_game_loop=self._episode_steps) 584 | 585 | self._step_to(game_loop=target_game_loop, 586 | current_game_loop=current_game_loop) 587 | 588 | return self._observe(target_game_loop=target_game_loop) 589 | 590 | def _apply_action_delays(self, actions): 591 | """Apply action delays to the requested actions, if configured to.""" 592 | assert not self._realtime 593 | actions_now = [] 594 | for actions_for_player, delay_fn, delayed_actions in zip( 595 | actions, self._action_delay_fns, self._delayed_actions): 596 | actions_now_for_player = [] 597 | 598 | for action in actions_for_player: 599 | delay = delay_fn() if delay_fn else 1 600 | if delay > 1 and action.ListFields(): # Skip no-ops. 601 | game_loop = self._episode_steps + delay - 1 602 | 603 | # Randomized delays mean that 2 delay actions can be reversed. 604 | # Make sure that doesn't happen. 605 | if delayed_actions: 606 | game_loop = max(game_loop, delayed_actions[-1].game_loop) 607 | 608 | # Don't send an action this frame. 609 | delayed_actions.append(_DelayedAction(game_loop, action)) 610 | else: 611 | actions_now_for_player.append(action) 612 | actions_now.append(actions_now_for_player) 613 | 614 | return actions_now 615 | 616 | def _send_delayed_actions(self, up_to_game_loop, current_game_loop): 617 | """Send any delayed actions scheduled for up to the specified game loop.""" 618 | assert not self._realtime 619 | while True: 620 | if not any(self._delayed_actions): # No queued actions 621 | return current_game_loop 622 | 623 | act_game_loop = min(d[0].game_loop for d in self._delayed_actions if d) 624 | if act_game_loop > up_to_game_loop: 625 | return current_game_loop 626 | 627 | self._step_to(act_game_loop, current_game_loop) 628 | current_game_loop = act_game_loop 629 | if self._controllers[0].status_ended: 630 | # We haven't observed and may have hit game end. 631 | return current_game_loop 632 | 633 | actions = [] 634 | for d in self._delayed_actions: 635 | if d and d[0].game_loop == current_game_loop: 636 | delayed_action = d.popleft() 637 | actions.append(delayed_action.action) 638 | else: 639 | actions.append(None) 640 | self._parallel.run((c.act, a) for c, a in zip(self._controllers, actions)) 641 | 642 | def _step_to(self, game_loop, current_game_loop): 643 | step_mul = game_loop - current_game_loop 644 | if step_mul < 0: 645 | raise ValueError("We should never need to step backwards") 646 | if step_mul > 0: 647 | with self._metrics.measure_step_time(step_mul): 648 | if not self._controllers[0].status_ended: # May already have ended. 649 | self._parallel.run((c.step, step_mul) for c in self._controllers) 650 | 651 | def _get_observations(self, target_game_loop): 652 | # Transform in the thread so it runs while waiting for other observations. 653 | def parallel_observe(c, f): 654 | obs = c.observe(target_game_loop=target_game_loop) 655 | agent_obs = f.transform_obs(obs) 656 | return obs, agent_obs 657 | 658 | with self._metrics.measure_observation_time(): 659 | self._obs, self._agent_obs = zip(*self._parallel.run( 660 | (parallel_observe, c, f) 661 | for c, f in zip(self._controllers, self._features))) 662 | 663 | game_loop = self._agent_obs[0].game_loop[0] 664 | if (game_loop < target_game_loop and 665 | not any(o.player_result for o in self._obs)): 666 | raise ValueError( 667 | ("The game didn't advance to the expected game loop. " 668 | "Expected: %s, got: %s") % (target_game_loop, game_loop)) 669 | elif game_loop > target_game_loop and target_game_loop > 0: 670 | logging.warn("Received observation %d step(s) late: %d rather than %d.", 671 | game_loop - target_game_loop, game_loop, target_game_loop) 672 | 673 | if self._realtime: 674 | # Track delays on executed actions. 675 | # Note that this will underestimate e.g. action sent, new observation 676 | # taken before action executes, action executes, observation taken 677 | # with action. This is difficult to avoid without changing the SC2 678 | # binary - e.g. send the observation game loop with each action, 679 | # return them in the observation action proto. 680 | if self._last_obs_game_loop is not None: 681 | for i, obs in enumerate(self._obs): 682 | for action in obs.actions: 683 | if action.HasField("game_loop"): 684 | delay = action.game_loop - self._last_obs_game_loop 685 | if delay > 0: 686 | num_slots = len(self._action_delays[i]) 687 | delay = min(delay, num_slots - 1) # Cap to num buckets. 688 | self._action_delays[i][delay] += 1 689 | break 690 | self._last_obs_game_loop = game_loop 691 | 692 | def _observe(self, target_game_loop): 693 | self._get_observations(target_game_loop) 694 | 695 | #pdb.set_trace() 696 | # TODO(tewalds): How should we handle more than 2 agents and the case where 697 | # the episode can end early for some agents? 698 | outcome = [0] * self._num_agents 699 | discount = self._discount 700 | episode_complete = any(o.player_result for o in self._obs) 701 | 702 | if episode_complete: 703 | self._state = environment.StepType.LAST 704 | discount = 0 705 | for i, o in enumerate(self._obs): 706 | player_id = o.observation.player_common.player_id 707 | for result in o.player_result: 708 | if result.player_id == player_id: 709 | outcome[i] = possible_results.get(result.result, 0) 710 | 711 | if self._score_index >= 0: # Game score, not win/loss reward. 712 | cur_score = [o["score_cumulative"][self._score_index] 713 | for o in self._agent_obs] 714 | if _episode_steps == 0: # First reward is always 0. 715 | reward = [0] * self._num_agents 716 | else: 717 | reward = [cur - last for cur, last in zip(cur_score, self._last_score)] 718 | self._last_score = cur_score 719 | else: 720 | reward = outcome 721 | 722 | if self._renderer_human: 723 | self._renderer_human.render(self._obs[0]) 724 | cmd = self._renderer_human.get_actions( 725 | self._run_config, self._controllers[0]) 726 | if cmd == renderer_human.ActionCmd.STEP: 727 | pass 728 | elif cmd == renderer_human.ActionCmd.RESTART: 729 | self._state = environment.StepType.LAST 730 | elif cmd == renderer_human.ActionCmd.QUIT: 731 | raise KeyboardInterrupt("Quit?") 732 | 733 | self._total_steps += self._agent_obs[0].game_loop[0] - self._episode_steps 734 | self._episode_steps = self._agent_obs[0].game_loop[0] 735 | 736 | if self._episode_steps >= self._episode_length-8: #modified by xun 737 | 738 | self._state = environment.StepType.LAST 739 | if self._discount_zero_after_timeout: 740 | discount = 0.0 741 | if self._episode_steps >= MAX_STEP_COUNT: 742 | logging.info("Cut short to avoid SC2's max step count of 2^19=524288.") 743 | 744 | 745 | if self._state == environment.StepType.LAST: 746 | 747 | if (self._save_replay_episodes > 0 and 748 | self._episode_count % self._save_replay_episodes == 0): 749 | self.save_replay(self._replay_dir, self._replay_prefix) 750 | 751 | # logging.info(("Episode %s finished after %s game steps. " 752 | # "Outcome: %s, reward: %s, score: %s"), 753 | # self._episode_count, self._episode_steps, outcome, reward, 754 | # [o["score_cumulative"][0] for o in self._agent_obs]) 755 | #pdb.set_trace() 756 | # Modified by xun, only works for the drone. 2021 757 | logging.info(("Episode %s finished after %s game steps. " 758 | "Outcome: %s, reward: %s, kill_number_units: %s"), 759 | self._episode_count, self._episode_steps, outcome, reward, 760 | [o["score_cumulative"][5]/50 for o in self._agent_obs]) 761 | 762 | 763 | def zero_on_first_step(value): 764 | return 0.0 if self._state == environment.StepType.FIRST else value 765 | return tuple(environment.TimeStep( 766 | step_type=self._state, 767 | reward=zero_on_first_step(r * self._score_multiplier), 768 | discount=zero_on_first_step(discount), 769 | observation=o) for r, o in zip(reward, self._agent_obs)) 770 | 771 | def send_chat_messages(self, messages, broadcast=True): 772 | """Useful for logging messages into the replay.""" 773 | self._parallel.run( 774 | (c.chat, 775 | message, 776 | sc_pb.ActionChat.Broadcast if broadcast else sc_pb.ActionChat.Team) 777 | for c, message in zip(self._controllers, messages)) 778 | 779 | def save_replay(self, replay_dir, prefix=None): 780 | pdb.set_trace() 781 | if prefix is None: 782 | prefix = self._map_name 783 | replay_path = self._run_config.save_replay( 784 | self._controllers[0].save_replay(), replay_dir, prefix) 785 | logging.info("Wrote replay to: %s", replay_path) 786 | return replay_path 787 | 788 | def close(self): 789 | logging.info("Environment Close") 790 | if hasattr(self, "_metrics") and self._metrics: 791 | self._metrics.close() 792 | self._metrics = None 793 | if hasattr(self, "_renderer_human") and self._renderer_human: 794 | self._renderer_human.close() 795 | self._renderer_human = None 796 | 797 | # Don't use parallel since it might be broken by an exception. 798 | if hasattr(self, "_controllers") and self._controllers: 799 | for c in self._controllers: 800 | c.quit() 801 | self._controllers = None 802 | if hasattr(self, "_sc2_procs") and self._sc2_procs: 803 | for p in self._sc2_procs: 804 | p.close() 805 | self._sc2_procs = None 806 | 807 | if hasattr(self, "_ports") and self._ports: 808 | portspicker.return_ports(self._ports) 809 | self._ports = None 810 | 811 | self._game_info = None 812 | 813 | 814 | def crop_and_deduplicate_names(names): 815 | """Crops and de-duplicates the passed names. 816 | 817 | SC2 gets confused in a multi-agent game when agents have the same 818 | name. We check for name duplication to avoid this, but - SC2 also 819 | crops player names to a hard character limit, which can again lead 820 | to duplicate names. To avoid this we unique-ify names if they are 821 | equivalent after cropping. Ideally SC2 would handle duplicate names, 822 | making this unnecessary. 823 | 824 | TODO(b/121092563): Fix this in the SC2 binary. 825 | 826 | Args: 827 | names: List of names. 828 | 829 | Returns: 830 | De-duplicated names cropped to 32 characters. 831 | """ 832 | max_name_length = 32 833 | 834 | # Crop. 835 | cropped = [n[:max_name_length] for n in names] 836 | 837 | # De-duplicate. 838 | deduplicated = [] 839 | name_counts = collections.Counter(n for n in cropped) 840 | name_index = collections.defaultdict(lambda: 1) 841 | for n in cropped: 842 | if name_counts[n] == 1: 843 | deduplicated.append(n) 844 | else: 845 | deduplicated.append("({}) {}".format(name_index[n], n)) 846 | name_index[n] += 1 847 | 848 | # Crop again. 849 | recropped = [n[:max_name_length] for n in deduplicated] 850 | if len(set(recropped)) != len(recropped): 851 | raise ValueError("Failed to de-duplicate names") 852 | 853 | return recropped 854 | --------------------------------------------------------------------------------