├── tstarbot ├── __init__.py ├── bin │ ├── .gitignore │ ├── __init__.py │ └── eval_agent.py ├── act │ ├── __init__.py │ └── act_mgr.py ├── agents │ ├── __init__.py │ ├── vs_builtin_ai_config.py │ ├── dft_config.py │ ├── dancing_drones_agent.py │ ├── micro_defeat_roaches_agent.py │ └── zerg_agent.py ├── data │ ├── __init__.py │ ├── queue │ │ ├── __init__.py │ │ ├── scout_command_queue.py │ │ ├── command_queue_base.py │ │ ├── combat_command_queue.py │ │ └── build_command_queue.py │ ├── pool │ │ ├── pool_base.py │ │ ├── __init__.py │ │ ├── opponent_pool.py │ │ ├── building_pool.py │ │ ├── combat_pool.py │ │ ├── map_tool.py │ │ ├── enemy_pool.py │ │ └── scout_pool.py │ ├── data_context.py │ └── demo_dc.py ├── sandbox │ ├── __init__.py │ ├── bin │ │ ├── __init__.py │ │ ├── eval_micro.py │ │ └── demo.py │ ├── agents │ │ ├── __init__.py │ │ ├── demo_bot.py │ │ └── rule_micro_agent.py │ ├── building-adhoc-plan.docx │ ├── .gitignore │ ├── bot_base.py │ ├── act_executor.py │ ├── building_mgr.py │ ├── py_multiplayer.py │ └── resource_mgr.py ├── scout │ ├── __init__.py │ ├── tasks │ │ ├── __init__.py │ │ ├── cruise_task.py │ │ ├── scout_task.py │ │ ├── force_scout.py │ │ └── explor_task_rl.py │ ├── oppo_monitor.py │ └── scout_mgr.py ├── combat │ ├── micro │ │ ├── __init__.py │ │ ├── ravager_micro.py │ │ ├── corruptor_micro.py │ │ ├── lurker_micro.py │ │ ├── mutalisk_micro.py │ │ ├── queen_micro.py │ │ ├── infestor_micro.py │ │ ├── roach_micro.py │ │ ├── micro_mgr.py │ │ └── viper_micro.py │ ├── __init__.py │ ├── defeat_roaches_mgr.py │ └── combat_mgr.py ├── resource │ ├── __init__.py │ └── dacing_drones_resource_mgr.py ├── combat_strategy │ ├── __init__.py │ ├── army.py │ ├── renderer.py │ └── squad.py ├── production_strategy │ ├── __init__.py │ ├── build_cmd.py │ ├── production_mgr.py │ ├── prod_rush.py │ ├── util.py │ ├── prod_advarms.py │ └── prod_defandadv.py ├── util │ ├── __init__.py │ ├── geom.py │ ├── unit.py │ └── act.py └── building │ ├── __init__.py │ └── dancing_drones_mgr.py ├── .gitignore ├── setup.py ├── docs ├── examples_howtorun.md └── examples_evaluate.md └── README.md /tstarbot/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tstarbot/bin/.gitignore: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tstarbot/act/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tstarbot/agents/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tstarbot/bin/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tstarbot/data/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tstarbot/sandbox/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tstarbot/scout/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tstarbot/combat/micro/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tstarbot/data/queue/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tstarbot/resource/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tstarbot/sandbox/bin/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tstarbot/scout/tasks/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tstarbot/combat_strategy/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tstarbot/sandbox/agents/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tstarbot/production_strategy/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tstarbot/util/__init__.py: -------------------------------------------------------------------------------- 1 | from .act import * 2 | from .unit import * 3 | from .geom import * 4 | -------------------------------------------------------------------------------- /tstarbot/data/pool/pool_base.py: -------------------------------------------------------------------------------- 1 | class PoolBase(object): 2 | def update(self, obs): 3 | raise NotImplementedError 4 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .idea 2 | TStarBot.egg-info/ 3 | tstarbot/__pycache__/ 4 | agents/__pycache__/ 5 | __pycache__ 6 | *.pyc 7 | .DS_Store -------------------------------------------------------------------------------- /tstarbot/sandbox/building-adhoc-plan.docx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Tencent/TStarBot2/HEAD/tstarbot/sandbox/building-adhoc-plan.docx -------------------------------------------------------------------------------- /tstarbot/combat/__init__.py: -------------------------------------------------------------------------------- 1 | from .combat_mgr import BaseCombatMgr, ZergCombatMgr 2 | from .defeat_roaches_mgr import DefeatRoachesCombatMgr 3 | -------------------------------------------------------------------------------- /tstarbot/building/__init__.py: -------------------------------------------------------------------------------- 1 | from .building_mgr import BaseBuildingMgr, ZergBuildingMgr 2 | from .dancing_drones_mgr import DancingDronesMgr 3 | -------------------------------------------------------------------------------- /tstarbot/sandbox/.gitignore: -------------------------------------------------------------------------------- 1 | cmd_tmpl 2 | my_config.py 3 | my_rl_config.py 4 | model-50000.data-00000-of-00001 5 | model-50000.index 6 | model-50000.meta 7 | -------------------------------------------------------------------------------- /tstarbot/data/pool/__init__.py: -------------------------------------------------------------------------------- 1 | from .macro_def import WorkerState 2 | from .macro_def import AllianceType 3 | from .worker_pool import Worker 4 | from .worker_pool import WorkerPool 5 | -------------------------------------------------------------------------------- /tstarbot/sandbox/bot_base.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | class PoolBase(object): 6 | def update(self, obs): 7 | raise NotImplementedError 8 | 9 | class ManagerBase(object): 10 | def execute(self): 11 | raise NotImplementedError 12 | 13 | -------------------------------------------------------------------------------- /tstarbot/data/queue/scout_command_queue.py: -------------------------------------------------------------------------------- 1 | from tstarbot.data.queue.command_queue_base import CommandQueueBase 2 | from enum import Enum, unique 3 | 4 | 5 | @unique 6 | class ScoutCommandType(Enum): 7 | MOVE = 0 8 | 9 | 10 | class ScoutCommandQueue(CommandQueueBase): 11 | def __init__(self): 12 | SCOUT_CMD_ID_BASE = 300000000 13 | super(ScoutCommandQueue, self).__init__(SCOUT_CMD_ID_BASE) 14 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | 3 | setup( 4 | name='TStarBot', 5 | version='0.1', 6 | description='TStartBot', 7 | keywords='TStartBot', 8 | packages=[ 9 | 'tstarbot', 10 | 'tstarbot.bin', 11 | 'tstarbot.data', 12 | 'tstarbot.data.pool', 13 | 'tstarbot.data.queue', 14 | 'tstarbot.act', 15 | ], 16 | 17 | install_requires=[ 18 | 'pillow' 19 | ], 20 | ) 21 | -------------------------------------------------------------------------------- /tstarbot/data/pool/opponent_pool.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | from enum import unique 3 | from tstarbot.data.pool.pool_base import PoolBase 4 | 5 | 6 | @unique 7 | class OppoOpeningTactics(Enum): 8 | UNKNOWN = 0 9 | BANELING_RUSH = 1 10 | ROACH_RUSH = 2 11 | ZERGROACH_RUSH = 3 12 | ROACH_SUPPRESS = 4 13 | 14 | 15 | class OppoPool(PoolBase): 16 | def __init__(self): 17 | super(OppoPool, self).__init__() 18 | self.opening_tactics = None 19 | 20 | def reset(self): 21 | self.opening_tactics = None 22 | -------------------------------------------------------------------------------- /tstarbot/act/act_mgr.py: -------------------------------------------------------------------------------- 1 | """Scripted Zerg agent.""" 2 | 3 | from __future__ import absolute_import 4 | from __future__ import division 5 | from __future__ import print_function 6 | from copy import deepcopy 7 | 8 | 9 | class ActMgr(object): 10 | def __init__(self): 11 | self.cur_actions = [] 12 | 13 | def push_actions(self, actions): 14 | if type(actions == list): 15 | self.cur_actions += actions 16 | else: 17 | self.cur_actions.append(actions) 18 | 19 | def pop_actions(self): 20 | a = deepcopy(self.cur_actions) 21 | self.cur_actions = [] 22 | return a 23 | -------------------------------------------------------------------------------- /tstarbot/agents/vs_builtin_ai_config.py: -------------------------------------------------------------------------------- 1 | """ config for playing against builtin AI 2 | """ 3 | sleep_per_step = 0.0 4 | building_verbose = 0 5 | building_placer = 'hybrid_v2' # 'naive_predef' | 'hybrid_v2' | 'hybrid_v3' 6 | building_placer_verbose = 0 7 | resource_verbose = 0 8 | production_verbose = 0 9 | combat_verbose = 0 10 | scout_explore_version = 2 11 | explore_rl_support = False 12 | max_forced_scout_count = 0 # num of drones used to scout 13 | combat_strategy = 'HARASS' # 'REFORM' | 'HARASS' 14 | production_strategy = 'DEF_AND_ADV' # 'RUSH' | 'ADV_ARMS' | 'DEF_AND_ADV' 15 | default_micro_version = 1 16 | game_version = '4.3' # '3.16.1' | '4.3' 17 | -------------------------------------------------------------------------------- /tstarbot/sandbox/act_executor.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | class ActExecutor(object): 6 | def __init__(self, env): 7 | self._env = env 8 | 9 | def exec_raw(self, pb_actions): 10 | return self.exec_inner(pb_actions, True) 11 | 12 | def exec_sc2(self, sc2_actions): 13 | return self.exec_inner(sc2_actions, False) 14 | 15 | def exec_inner(self, actions, raw_flag): 16 | if raw_flag: 17 | self._env.set_raw() 18 | timesteps = self._env.step(actions) 19 | self._env.reset_raw() 20 | else: 21 | timesteps = self._env.step(actions) 22 | 23 | is_end = False 24 | if timesteps[0].last(): 25 | is_end = True 26 | return (timesteps, is_end) 27 | 28 | -------------------------------------------------------------------------------- /tstarbot/sandbox/bin/eval_micro.py: -------------------------------------------------------------------------------- 1 | from pysc2.env import sc2_env 2 | from tstarbot.sandbox.agents.rule_micro_agent import MicroAgent 3 | from absl import app 4 | import os 5 | 6 | os.environ["SC2PATH"] = "/home/psun/StarCraftII" 7 | 8 | 9 | def demo(unused_argv): 10 | env = sc2_env.SC2Env( 11 | map_name='DefeatRoaches', 12 | screen_size_px=(64, 64), 13 | minimap_size_px=(64, 64), 14 | agent_race='T', 15 | bot_race='Z', 16 | difficulty=None, 17 | step_mul=1, 18 | game_steps_per_episode=0, 19 | visualize=True) # visualize must be True to return unit information 20 | 21 | my_agent = MicroAgent(env) 22 | my_agent.setup() 23 | my_agent.run(1000000) 24 | 25 | 26 | if __name__ == '__main__': 27 | app.run(demo) 28 | -------------------------------------------------------------------------------- /tstarbot/sandbox/bin/demo.py: -------------------------------------------------------------------------------- 1 | from pysc2.env import sc2_env 2 | from tstarbot.sandbox.demo_bot import DemoBot 3 | from absl import app 4 | from absl import flags 5 | import os 6 | 7 | os.environ["SC2PATH"] = "/home/psun/StarCraftII" 8 | 9 | def demo(unused_argv): 10 | env = sc2_env.SC2Env(map_name='Simple64', 11 | screen_size_px=(64, 64), 12 | minimap_size_px=(64, 64), 13 | agent_race='Z', 14 | bot_race='Z', 15 | difficulty=None, 16 | step_mul=8, 17 | game_steps_per_episode=0, 18 | visualize=True) 19 | 20 | bot = DemoBot(env) 21 | bot.setup() 22 | bot.run(100) 23 | 24 | 25 | if __name__ == '__main__': 26 | app.run(demo) 27 | -------------------------------------------------------------------------------- /tstarbot/production_strategy/build_cmd.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | 3 | 4 | BuildCmdBuilding = namedtuple('build_cmd_building', ['base_tag', 'unit_type']) 5 | BuildCmdUnit = namedtuple('build_cmd_unit', ['base_tag', 'unit_type']) 6 | BuildCmdUpgrade = namedtuple('build_cmd_upgrade', ['building_tag', 7 | 'ability_id']) 8 | BuildCmdMorph = namedtuple('build_cmd_morph', ['unit_tag', 9 | 'ability_id']) 10 | BuildCmdExpand = namedtuple('build_cmd_expand', ['base_tag', 'pos', 11 | 'builder_tag']) 12 | BuildCmdHarvest = namedtuple('build_cmd_harvest', ['gas_first']) 13 | BuildCmdSpawnLarva = namedtuple('build_cmd_spawn_larva', 14 | ['base_tag', 'queen_tag']) -------------------------------------------------------------------------------- /tstarbot/agents/dft_config.py: -------------------------------------------------------------------------------- 1 | """ A default template config file for an agent. 2 | 3 | Treat it as a plain py file. 4 | Define the required configurations in "flat structure", e.g., 5 | var1 = value1 6 | var2 = value2 7 | ... 8 | 9 | Do NOT ABUSE it, do NOT define nested or complex data structure. 10 | """ 11 | sleep_per_step = 0.0 12 | building_verbose = 0 13 | building_placer = 'hybrid_v2' # 'naive_predef' | 'hybrid_v2' | 'hybrid_v3' 14 | building_placer_verbose = 0 15 | resource_verbose = 0 16 | production_verbose = 0 17 | combat_verbose = 0 18 | scout_explore_version = 2 19 | explore_rl_support = False 20 | max_forced_scout_count = 0 # num of drones used to scout 21 | combat_strategy = 'HARASS' # 'REFORM' | 'HARASS' 22 | production_strategy = 'ADV_ARMS' # 'RUSH' | 'ADV_ARMS' | 'DEF_AND_ADV' 23 | default_micro_version = 1 24 | game_version = '4.3' # '3.16.1' | '4.3' 25 | -------------------------------------------------------------------------------- /tstarbot/util/geom.py: -------------------------------------------------------------------------------- 1 | """ geometry utilities """ 2 | from math import cos 3 | from math import sin 4 | from math import atan2 5 | from math import sqrt 6 | 7 | 8 | def polar_to_cart(rho, theta): 9 | return rho * cos(theta), rho * sin(theta) 10 | 11 | 12 | def cart_to_polar(x, y): 13 | return sqrt(x * x + y * x), atan2(y, x) 14 | 15 | 16 | def dist(unit1, unit2): 17 | """ return Euclidean distance ||unit1 - unit2|| """ 18 | return ((unit1.float_attr.pos_x - unit2.float_attr.pos_x) ** 2 + 19 | (unit1.float_attr.pos_y - unit2.float_attr.pos_y) ** 2) ** 0.5 20 | 21 | 22 | def dist_to_pos(unit, pos): 23 | """ return Euclidean distance ||unit - [x,y]|| """ 24 | return ((unit.float_attr.pos_x - pos[0]) ** 2 + 25 | (unit.float_attr.pos_y - pos[1]) ** 2) ** 0.5 26 | 27 | 28 | def list_mean(l): 29 | if not l: 30 | return None 31 | return sum(l) / float(len(l)) 32 | 33 | 34 | def mean_pos(units): 35 | if not units: 36 | return () 37 | xx = [u.float_attr.pos_x for u in units] 38 | yy = [u.float_attr.pos_y for u in units] 39 | return list_mean(xx), list_mean(yy) -------------------------------------------------------------------------------- /tstarbot/data/queue/command_queue_base.py: -------------------------------------------------------------------------------- 1 | class CommandBase(object): 2 | def __init__(self): 3 | self.cmd_id = 0 # command id 4 | self.cmd_type = 0 5 | self.idx = 0 6 | self.param = {} 7 | 8 | def __str__(self): 9 | return "id {}, type {}, idx {}, param {}".format( 10 | self.cmd_id, self.cmd_type, self.idx, self.param) 11 | 12 | 13 | class CommandQueueBase(object): 14 | def __init__(self, cmd_id_base): 15 | self._cmd_dict = {} 16 | self._cmd_id = cmd_id_base 17 | 18 | def put(self, idx, cmd_type, param): 19 | cmd = CommandBase() 20 | cmd.cmd_id = self._cmd_id 21 | cmd.idx = idx 22 | cmd.cmd_type = cmd_type 23 | cmd.param = param 24 | 25 | if idx in self._cmd_dict: 26 | self._cmd_dict[idx].append(cmd) 27 | else: 28 | self._cmd_dict[idx] = [cmd] 29 | 30 | self._cmd_id += 1 31 | 32 | def get(self, idx): 33 | cmds = [] 34 | 35 | if idx in self._cmd_dict: 36 | cmds = self._cmd_dict[idx] 37 | del self._cmd_dict[idx] 38 | 39 | return cmds 40 | 41 | def clear_all(self): 42 | self._cmd_dict.clear() 43 | -------------------------------------------------------------------------------- /tstarbot/agents/dancing_drones_agent.py: -------------------------------------------------------------------------------- 1 | """Adopted from 2 | 3 | Demonstrate the per-unit-control via raw interface of s2client-proto. 4 | Adopted from demo_bot.py written by zhengyang. 5 | """ 6 | from __future__ import absolute_import 7 | from __future__ import division 8 | from __future__ import print_function 9 | 10 | from pysc2.agents import base_agent 11 | 12 | from tstarbot.building.dancing_drones_mgr import DancingDronesMgr 13 | from tstarbot.data.demo_dc import DancingDrones 14 | from tstarbot.act.act_mgr import ActMgr 15 | 16 | 17 | class DancingDronesAgent(base_agent.BaseAgent): 18 | """An agent that makes drones dancing. 19 | 20 | Show how to send per-unit-control actions.""" 21 | 22 | def __init__(self): 23 | super(DancingDronesAgent, self).__init__() 24 | # self.dc = DataContext() 25 | self.dc = DancingDrones() 26 | self.am = ActMgr() 27 | 28 | self._mgr = DancingDronesMgr(self.dc) 29 | 30 | def step(self, timestep): 31 | super(DancingDronesAgent, self).step(timestep) 32 | # print(timestep) 33 | self.dc.update(timestep) 34 | self._mgr.update(self.dc, self.am) 35 | return self.am.pop_actions() 36 | -------------------------------------------------------------------------------- /tstarbot/production_strategy/production_mgr.py: -------------------------------------------------------------------------------- 1 | """Production Manager""" 2 | from __future__ import absolute_import 3 | from __future__ import division 4 | from __future__ import print_function 5 | 6 | from tstarbot.production_strategy.prod_advarms import ZergProdAdvArms 7 | from tstarbot.production_strategy.prod_rush import ZergProdRush 8 | from tstarbot.production_strategy.prod_defandadv import ZergProdDefAndAdv 9 | 10 | 11 | class ZergProductionMgr(object): 12 | def __init__(self, dc): 13 | self.strategy = 'DEF_AND_ADV' 14 | if hasattr(dc, 'config'): 15 | if hasattr(dc.config, 'production_strategy'): 16 | self.strategy = dc.config.production_strategy 17 | if self.strategy == 'RUSH': 18 | self.c = ZergProdRush(dc) 19 | elif self.strategy == 'ADV_ARMS': 20 | self.c = ZergProdAdvArms(dc) 21 | elif self.strategy == 'DEF_AND_ADV': 22 | self.c = ZergProdDefAndAdv(dc) 23 | else: 24 | raise Exception('Unknow production_strategy combat_strategy: "%s"' % str(self.strategy)) 25 | 26 | def reset(self): 27 | return self.c.reset() 28 | 29 | def update(self, dc, am): 30 | return self.c.update(dc, am) 31 | -------------------------------------------------------------------------------- /tstarbot/agents/micro_defeat_roaches_agent.py: -------------------------------------------------------------------------------- 1 | """ 2 | A rule based multi-agent micro-management bot for the mini-game DefeatRoaches. 3 | Adopted from the code originally writen by lxhan(slivermoda). 4 | """ 5 | from __future__ import absolute_import 6 | from __future__ import division 7 | from __future__ import print_function 8 | 9 | from pysc2.agents import base_agent 10 | from pysc2.lib import stopwatch 11 | 12 | from tstarbot.combat import DefeatRoachesCombatMgr 13 | from tstarbot.data.demo_dc import DefeatRoaches 14 | from tstarbot.act.act_mgr import ActMgr 15 | 16 | sw = stopwatch.sw 17 | 18 | 19 | class MicroDefeatRoachesAgent(base_agent.BaseAgent): 20 | """An agent for the DefeatRoaches map.""" 21 | 22 | def __init__(self): 23 | super(MicroDefeatRoachesAgent, self).__init__() 24 | self.dc = DefeatRoaches() 25 | self.am = ActMgr() 26 | self._mgr = DefeatRoachesCombatMgr() 27 | 28 | def step(self, timestep): 29 | super(MicroDefeatRoachesAgent, self).step(timestep) 30 | return self.mystep(timestep) 31 | 32 | @sw.decorate 33 | def mystep(self, timestep): 34 | self.dc.update(timestep) 35 | self._mgr.update(self.dc, self.am) 36 | 37 | actions = self.am.pop_actions() 38 | return actions 39 | -------------------------------------------------------------------------------- /tstarbot/combat/defeat_roaches_mgr.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | from pysc2.agents import base_agent 6 | 7 | from tstarbot.combat.micro.micro_mgr import MicroBase 8 | 9 | 10 | ROACH_ATTACK_RANGE = 5.0 11 | 12 | 13 | class DefeatRoachesCombatMgr(base_agent.BaseAgent): 14 | """ Combat Manager for the DefeatRoaches minimap""" 15 | 16 | def __init__(self): 17 | super(DefeatRoachesCombatMgr, self).__init__() 18 | self.marines = None 19 | self.roaches = None 20 | self.micro_base = MicroBase() 21 | 22 | def update(self, dc, am): 23 | self.marines = dc.marines 24 | self.roaches = dc.roaches 25 | 26 | actions = list() 27 | for m in self.marines: 28 | closest_roach = self.micro_base.find_closest_enemy(m, self.roaches) 29 | closest_enemy_dist = self.micro_base.dist_between_units(m, closest_roach) 30 | if closest_enemy_dist < ROACH_ATTACK_RANGE and \ 31 | (m.float_attr.health / m.float_attr.health_max) < 0.3 and \ 32 | self.micro_base.find_strongest_unit_hp(self.marines) > 0.9: 33 | action = self.micro_base.run_away_from_closest_enemy(m, closest_roach, 34 | 0.4) 35 | else: 36 | action = self.micro_base.attack_weakest_enemy(m, self.roaches) 37 | actions.append(action) 38 | 39 | am.push_actions(actions) 40 | -------------------------------------------------------------------------------- /tstarbot/building/dancing_drones_mgr.py: -------------------------------------------------------------------------------- 1 | """Resource Manager""" 2 | from __future__ import absolute_import 3 | from __future__ import division 4 | from __future__ import print_function 5 | import random 6 | 7 | from s2clientprotocol import sc2api_pb2 as sc_pb 8 | 9 | from tstarbot.building import BaseBuildingMgr 10 | 11 | 12 | class DancingDronesMgr(BaseBuildingMgr): 13 | def __init__(self, dc): 14 | super(DancingDronesMgr, self).__init__(dc) 15 | self._range_high = 5 16 | self._range_low = -5 17 | self._move_ability = 1 18 | 19 | def update(self, dc, am): 20 | super(DancingDronesMgr, self).update(dc, am) 21 | 22 | drone_ids = dc.get_drones() 23 | pos = dc.get_hatcherys() 24 | 25 | print('pos=', pos) 26 | actions = self.move_drone_random_round_hatchery(drone_ids, pos[0]) 27 | 28 | am.push_actions(actions) 29 | 30 | def move_drone_random_round_hatchery(self, drone_ids, pos): 31 | actions = [] 32 | for drone in drone_ids: 33 | action = sc_pb.Action() 34 | action.action_raw.unit_command.ability_id = self._move_ability 35 | x = pos[0] + random.randint(self._range_low, self._range_high) 36 | y = pos[1] + random.randint(self._range_low, self._range_high) 37 | action.action_raw.unit_command.target_world_space_pos.x = x 38 | action.action_raw.unit_command.target_world_space_pos.y = y 39 | action.action_raw.unit_command.unit_tags.append(drone) 40 | actions.append(action) 41 | return actions 42 | -------------------------------------------------------------------------------- /tstarbot/resource/dacing_drones_resource_mgr.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | import random 5 | from s2clientprotocol import sc2api_pb2 as sc_pb 6 | from tstarbot.resource.resource_mgr import BaseResourceMgr 7 | 8 | 9 | class DancingDronesResourceMgr(BaseResourceMgr): 10 | def __init__(self): 11 | super(DancingDronesResourceMgr, self).__init__() 12 | self._range_high = 5 13 | self._range_low = -5 14 | self._move_ability = 1 15 | 16 | def update(self, dc, am): 17 | super(DancingDronesResourceMgr, self).update(dc, am) 18 | 19 | drone_ids = dc.get_drones() 20 | pos = dc.get_hatcherys() 21 | 22 | # print('pos=', pos) 23 | actions = self.move_drone_random_round_hatchery(drone_ids, pos[0]) 24 | 25 | am.push_actions(actions) 26 | 27 | def move_drone_random_round_hatchery(self, drone_ids, pos): 28 | actions = [] 29 | for drone in drone_ids: 30 | action = sc_pb.Action() 31 | action.action_raw.unit_command.ability_id = self._move_ability 32 | x = pos[0] + random.randint(self._range_low, self._range_high) 33 | y = pos[1] + random.randint(self._range_low, self._range_high) 34 | action.action_raw.unit_command.target_world_space_pos.x = x 35 | action.action_raw.unit_command.target_world_space_pos.y = y 36 | action.action_raw.unit_command.unit_tags.append(drone) 37 | actions.append(action) 38 | return actions 39 | -------------------------------------------------------------------------------- /tstarbot/data/queue/combat_command_queue.py: -------------------------------------------------------------------------------- 1 | """Combat Command Queue.""" 2 | from __future__ import absolute_import 3 | from __future__ import division 4 | from __future__ import print_function 5 | import collections 6 | from enum import Enum 7 | 8 | from tstarbot.combat_strategy.squad import Squad 9 | 10 | 11 | CombatCmdType = Enum('CombatCmdType', 12 | ('MOVE', 'ATTACK', 'DEFEND', 'RALLY', 'ROCK')) 13 | 14 | 15 | class CombatCommand(object): 16 | 17 | def __init__(self, type, squad, position): 18 | assert isinstance(type, CombatCmdType) 19 | assert isinstance(squad, Squad) 20 | assert isinstance(position, dict) 21 | assert 'x' in position and 'y' in position 22 | 23 | self._type = type 24 | self._squad = squad 25 | self._position = position 26 | 27 | def __repr__(self): 28 | return ('CombatCmd(Type(%s), Squad(%s), Position(%s))' % 29 | (self._type, self._squad, self._position)) 30 | 31 | @property 32 | def type(self): 33 | return self._type 34 | 35 | @property 36 | def squad(self): 37 | return self._squad 38 | 39 | @property 40 | def position(self): 41 | return self._position 42 | 43 | 44 | # TODO: update CommandQueueBase if hope to inherit it. 45 | class CombatCommandQueue(object): 46 | 47 | def __init__(self): 48 | self._queue = collections.deque() 49 | 50 | def push(self, cmd): 51 | assert isinstance(cmd, CombatCommand) 52 | self._queue.append(cmd) 53 | 54 | def pull(self): 55 | return [] if len(self._queue) == 0 else self._queue.pop() 56 | 57 | def clear(self): 58 | self._queue.clear() 59 | -------------------------------------------------------------------------------- /tstarbot/combat_strategy/army.py: -------------------------------------------------------------------------------- 1 | """Army Class.""" 2 | from __future__ import absolute_import 3 | from __future__ import division 4 | from __future__ import print_function 5 | 6 | from tstarbot.combat_strategy.squad import Squad 7 | 8 | 9 | class Army(object): 10 | def __init__(self): 11 | self._squads = list() 12 | self._unsquaded_units = set() 13 | 14 | def update(self, combat_pool): 15 | for squad in self._squads: 16 | squad.update(combat_pool) 17 | squaded_units = set.union( 18 | set(), *[set(squad.units) for squad in self._squads]) 19 | self._unsquaded_units = set(u for u in combat_pool.units 20 | if u not in squaded_units) 21 | 22 | def create_squad(self, units, uniform=None): 23 | for u in units: 24 | self._unsquaded_units.remove(u) 25 | squad = Squad(units) 26 | if uniform is not None: 27 | squad._uniform = uniform 28 | self._squads.append(squad) 29 | return squad 30 | 31 | def delete_squad(self, squad): 32 | self._unsquaded_units.union(set(squad.units)) 33 | self._squads.remove(squad) 34 | 35 | @property 36 | def squads(self): 37 | return self._squads 38 | 39 | @property 40 | def unsquaded_units(self): 41 | return self._unsquaded_units 42 | 43 | @property 44 | def num_units(self): 45 | return sum([squad.num_units for squad in self._squads]) 46 | 47 | @property 48 | def num_hydralisk_units(self): 49 | return sum([squad.num_hydralisk_units for squad in self._squads]) 50 | 51 | @property 52 | def num_roach_units(self): 53 | return sum([squad.num_roach_units for squad in self._squads]) 54 | 55 | @property 56 | def num_zergling_units(self): 57 | return sum([squad.num_zergling_units for squad in self._squads]) 58 | -------------------------------------------------------------------------------- /tstarbot/data/queue/build_command_queue.py: -------------------------------------------------------------------------------- 1 | from collections import deque 2 | from tstarbot.data.queue.command_queue_base import CommandQueueBase 3 | from enum import Enum, unique 4 | 5 | 6 | @unique 7 | class BuildCommandType(Enum): 8 | BUILD = 0 9 | EXPAND = 1 10 | CANCEL = 2 11 | 12 | 13 | class BuildCommandQueue(CommandQueueBase): 14 | def __init__(self): 15 | BUILD_CMD_ID_BASE = 100000000 16 | super(BuildCommandQueue, self).__init__(BUILD_CMD_ID_BASE) 17 | 18 | 19 | class BuildCommandQueueV2(object): 20 | """ A left-in-right-out queue (first in first out) """ 21 | 22 | def __init__(self): 23 | self._q = deque() 24 | 25 | def put(self, item): 26 | """ push an item at left """ 27 | self._q.appendleft(item) 28 | 29 | def get(self): 30 | """ pop an item from right """ 31 | return self._q.pop() 32 | 33 | def empty(self): 34 | return False if self._q else True 35 | 36 | def size(self): 37 | return len(self._q) 38 | 39 | 40 | if __name__ == "__main__": 41 | # usage of BuildCommandQueue 42 | q = BuildCommandQueue() 43 | q.put(0, BuildCommandType.BUILD, {'unit_id': 100, 'count': 1}) 44 | q.put(1, BuildCommandType.BUILD, {'unit_id': 200, 'count': 2}) 45 | # q.clear_all() 46 | 47 | cmds = q.get(1) 48 | for cmd in cmds: 49 | print(cmd) 50 | 51 | # usage of BuildCommandQueueV2 52 | from collections import namedtuple 53 | 54 | BuildCmdBuilding = namedtuple('build_cmd_building', ['base_tag']) 55 | BuildCmdUnit = namedtuple('build_cmd_unit', ['base_tag']) 56 | 57 | qq = BuildCommandQueueV2() 58 | qq.put(BuildCmdBuilding(base_tag=34567)) 59 | qq.put(BuildCmdUnit(base_tag=8879)) 60 | c1 = qq.get() 61 | print(c1) 62 | qq.put(BuildCmdUnit(base_tag=445891)) 63 | c2 = qq.get() 64 | print(c2) 65 | c3 = qq.get() 66 | print(c3) 67 | -------------------------------------------------------------------------------- /tstarbot/combat/micro/ravager_micro.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from pysc2.lib.typeenums import UNIT_TYPEID, ABILITY_ID, UPGRADE_ID 3 | from s2clientprotocol import sc2api_pb2 as sc_pb 4 | 5 | from tstarbot.combat.micro.micro_base import MicroBase 6 | from tstarbot.data.queue.combat_command_queue import CombatCmdType 7 | from tstarbot.data.pool.macro_def import COMBAT_ANTI_AIR_UNITS 8 | 9 | 10 | class RavagerMgr(MicroBase): 11 | """ A zvz Zerg combat manager """ 12 | 13 | def __init__(self): 14 | super(RavagerMgr, self).__init__() 15 | self.corrosive_range = 15 16 | self.corrosive_harm_range = 3 17 | 18 | @staticmethod 19 | def corrosive_attack_pos(u, pos): 20 | action = sc_pb.Action() 21 | action.action_raw.unit_command.ability_id = \ 22 | ABILITY_ID.EFFECT_CORROSIVEBILE.value 23 | action.action_raw.unit_command.target_world_space_pos.x = pos['x'] 24 | action.action_raw.unit_command.target_world_space_pos.y = pos['y'] 25 | action.action_raw.unit_command.unit_tags.append(u.tag) 26 | return action 27 | 28 | def find_densest_enemy_pos_in_range(self, u): 29 | targets = self.find_units_wihtin_range(u, self.enemy_combat_units, 30 | r=self.corrosive_range) 31 | if len(targets) == 0: 32 | return None 33 | target_density = list() 34 | for e in targets: 35 | target_density.append(len( 36 | self.find_units_wihtin_range(e, targets, r=self.corrosive_harm_range))) 37 | target_id = np.argmax(target_density) 38 | target = targets[target_id] 39 | target_pos = {'x': target.float_attr.pos_x, 40 | 'y': target.float_attr.pos_y} 41 | return target_pos 42 | 43 | def act(self, u, pos, mode): 44 | if u.float_attr.weapon_cooldown > 0: 45 | target_pos = self.find_densest_enemy_pos_in_range(u) 46 | if target_pos is None: 47 | action = self.attack_pos(u, pos) 48 | else: 49 | action = self.corrosive_attack_pos(u, target_pos) 50 | else: 51 | action = self.attack_pos(u, pos) 52 | return action 53 | -------------------------------------------------------------------------------- /tstarbot/combat/micro/corruptor_micro.py: -------------------------------------------------------------------------------- 1 | from pysc2.lib.typeenums import ABILITY_ID 2 | from s2clientprotocol import sc2api_pb2 as sc_pb 3 | 4 | from tstarbot.combat.micro.micro_base import MicroBase 5 | from tstarbot.data.pool.macro_def import COMBAT_FLYING_UNITS 6 | 7 | 8 | class CorruptorMgr(MicroBase): 9 | """ A zvz Zerg combat manager """ 10 | 11 | def __init__(self): 12 | super(CorruptorMgr, self).__init__() 13 | self.corruptor_range = 20 14 | 15 | @staticmethod 16 | def parasitic_bomb_attack_target(u, target): 17 | action = sc_pb.Action() 18 | action.action_raw.unit_command.ability_id = \ 19 | ABILITY_ID.EFFECT_PARASITICBOMB.value 20 | action.action_raw.unit_command.target_unit_tag = target.tag 21 | action.action_raw.unit_command.unit_tags.append(u.tag) 22 | return action 23 | 24 | def act(self, u, pos, mode): 25 | air_targets = [e for e in self.enemy_combat_units 26 | if e.int_attr.unit_type in COMBAT_FLYING_UNITS] 27 | if len(air_targets) > 0: 28 | closest_target = self.find_closest_enemy(u, air_targets) 29 | if self.dist_between_units(u, closest_target) > self.corruptor_range: 30 | # follow the ground unit 31 | self_ground_units = [u for u in self.self_combat_units 32 | if u.int_attr.unit_type not in COMBAT_FLYING_UNITS] 33 | if len(self_ground_units) == 0: 34 | print('no ground units') 35 | action = self.hold_fire(u) 36 | return action 37 | self_most_dangerous_ground_unit = self.find_closest_units_in_battle( 38 | self_ground_units, closest_target) 39 | move_pos = {'x': self_most_dangerous_ground_unit.float_attr.pos_x, 40 | 'y': self_most_dangerous_ground_unit.float_attr.pos_y} 41 | action = self.move_pos(u, move_pos) 42 | else: 43 | # attack 44 | action = self.attack_pos(u, pos) 45 | else: 46 | move_pos = self.get_center_of_units(self.self_combat_units) 47 | action = self.move_pos(u, move_pos) 48 | return action 49 | -------------------------------------------------------------------------------- /docs/examples_howtorun.md: -------------------------------------------------------------------------------- 1 | # Examples: How-to-Run 2 | 3 | ## AI vs Builtin Bot 4 | Mini-games: 5 | ``` 6 | python -m pysc2.bin.agent \ 7 | --map DefeatRoaches \ 8 | --feature_screen_size 64 \ 9 | --agent tstarbot.agents.micro_defeat_roaches_agent.MicroDefeatRoachesAgent \ 10 | --agent_race terran \ 11 | --agent2 Bot \ 12 | --agent2_race zerg 13 | ``` 14 | 15 | ``` 16 | python -m pysc2.bin.agent \ 17 | --map Simple64 \ 18 | --feature_screen_size 64 \ 19 | --agent tstarbot.agents.dancing_drones_agent.DancingDronesAgent \ 20 | --agent_race zerg \ 21 | --agent2 Bot \ 22 | --agent2_race zerg 23 | ``` 24 | 25 | Full 1v 1 game: 26 | ``` 27 | python -m pysc2.bin.agent \ 28 | --map Simple64 \ 29 | --feature_screen_size 64 \ 30 | --agent tstarbot.agents.zerg_agent.ZergAgent \ 31 | --agent_race zerg \ 32 | --agent2 Bot \ 33 | --agent2_race zerg 34 | ``` 35 | 36 | ## AI vs AI 37 | See how two AIs play against each other: 38 | ``` 39 | python -m pysc2.bin.agent \ 40 | --map AbyssalReef \ 41 | --agent tstarbot.agents.zerg_agent.ZergAgent \ 42 | --agent_race zerg \ 43 | --agent2 pysc2.agents.random_agent.RandomAgent \ 44 | --agent2_race zerg 45 | ``` 46 | 47 | See how the AI performs "self-play": 48 | ``` 49 | python -m pysc2.bin.agent \ 50 | --map AbyssalReef \ 51 | --agent tstarbot.agents.zerg_agent.ZergAgent \ 52 | --agent_race zerg \ 53 | --agent2 tstarbot.agents.zerg_agent.ZergAgent \ 54 | --agent2_race zerg 55 | ``` 56 | 57 | ## Human vs AI 58 | Here is a simple example of playing against AI in a single machine 59 | (See more details in the doc of `pysc2.bin.play_vs_agent`). 60 | 61 | First, run: 62 | ``` 63 | python -m pysc2.bin.play_vs_agent \ 64 | --human \ 65 | --map AbyssalReef \ 66 | --user_race zerg 67 | ``` 68 | to host a game. 69 | 70 | Then run the following command in another process (terminal) 71 | ``` 72 | python -m pysc2.bin.play_vs_agent \ 73 | --agent tstarbot.agents.zerg_agent.ZergAgent \ 74 | --agent_race zerg 75 | ``` 76 | to let the AI join the game. 77 | 78 | -------------------------------------------------------------------------------- /tstarbot/util/unit.py: -------------------------------------------------------------------------------- 1 | """ units collecting/finding utilities """ 2 | import numpy as np 3 | from tstarbot.util.geom import dist_to_pos 4 | 5 | from .geom import dist 6 | 7 | 8 | def collect_units_by_type_alliance(units, unit_type, alliance=1): 9 | """ return units with the specified unit type and alliance """ 10 | return [u for u in units 11 | if u.unit_type == unit_type and u.int_attr.alliance == alliance] 12 | 13 | 14 | def collect_units_by_tags(units, tags): 15 | uu = [] 16 | for tag in tags: 17 | u = find_by_tag(units, tag) 18 | if u: 19 | uu.append(u) 20 | return uu 21 | 22 | 23 | def find_by_tag(units, tag): 24 | for u in units: 25 | if u.tag == tag: 26 | return u 27 | return None 28 | 29 | 30 | def find_nearest_l1(units, unit): 31 | """ find the nearest one (in l1-norm) to 'unit' within the list 'units' """ 32 | if not units: 33 | return None 34 | x, y = unit.float_attr.pos_x, unit.float_attr.pos_y 35 | dd = np.asarray([ 36 | abs(u.float_attr.pos_x - x) + abs(u.float_attr.pos_y - y) for 37 | u in units]) 38 | return units[dd.argmin()] 39 | 40 | 41 | def find_nearest_to_pos(units, pos): 42 | """ find the nearest one to pos within the list 'units' """ 43 | if not units: 44 | return None 45 | dd = np.asarray([dist_to_pos(u, pos) for u in units]) 46 | return units[dd.argmin()] 47 | 48 | 49 | def find_nearest(units, unit): 50 | """ find the nearest one (in l2-norm) to 'unit' within the list 'units' """ 51 | x, y = unit.float_attr.pos_x, unit.float_attr.pos_y 52 | return find_nearest_to_pos(units, [x, y]) 53 | 54 | 55 | def find_first_if(units, f=lambda x: True): 56 | for u in units: 57 | if f(u): 58 | return u 59 | return None 60 | 61 | 62 | def find_n_if(units, n, f=lambda x: True): 63 | ru = [] 64 | for u in units: 65 | if len(ru) >= n: 66 | break 67 | if f(u): 68 | ru.append(u) 69 | return ru 70 | 71 | 72 | def sort_units_by_distance(units, unit): 73 | def my_dist(x_u): 74 | return dist(x_u, unit) 75 | 76 | return sorted(units, key=my_dist) 77 | 78 | -------------------------------------------------------------------------------- /tstarbot/data/pool/building_pool.py: -------------------------------------------------------------------------------- 1 | from tstarbot.data.pool.pool_base import PoolBase 2 | from tstarbot.data.pool import macro_def as tm 3 | 4 | 5 | class Building(object): 6 | def __init__(self, unit): 7 | self._unit = unit 8 | self._lost = False # is building lost 9 | 10 | def unit(self): 11 | return self._unit 12 | 13 | def set_lost(self, lost): 14 | self._lost = lost 15 | 16 | def is_lost(self): 17 | return self._lost 18 | 19 | def __str__(self): 20 | u = self._unit 21 | return "tag {}, type {}, alliance {}".format(u.int_attr.tag, 22 | u.int_attr.unit_type, 23 | u.int_attr.alliance) 24 | 25 | 26 | class BuildingPool(PoolBase): 27 | def __init__(self): 28 | super(PoolBase, self).__init__() 29 | self._buildings = {} # unit_tag -> Building 30 | 31 | def update(self, timestep): 32 | units = timestep.observation['units'] 33 | self._update_building(units) 34 | 35 | def exist_any(self, unit_type): 36 | for k, b in self._buildings.items(): 37 | if b.unit().int_attr.unit_type == unit_type: 38 | return True 39 | 40 | return False 41 | 42 | def list_buildings(self, unit_type): 43 | buildings = [] 44 | for k, b in self._buildings.items(): 45 | if b.unit().int_attr.unit_type == unit_type: 46 | buildings.append(b.unit()) 47 | 48 | return buildings 49 | 50 | def _update_building(self, units): 51 | # set all building 'lost' state 52 | for k, b in self._buildings.items(): 53 | b.set_lost(True) 54 | 55 | # update / insert building 56 | for u in units: 57 | if u.int_attr.unit_type in tm.BUILDING_UNITS \ 58 | and u.int_attr.alliance == tm.AllianceType.SELF.value: 59 | tag = u.int_attr.tag 60 | self._buildings[tag] = Building(u) 61 | 62 | # delete lost buildings 63 | del_keys = [] 64 | for k, b in self._buildings.items(): 65 | if b.is_lost(): 66 | u = b.unit() 67 | del_keys.append(k) 68 | 69 | for k in del_keys: 70 | del self._buildings[k] 71 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # TStarBot 2 | 3 | A rule-based Star Craft II bot. Compatible with `pysc2.agents`. 4 | 5 | ## Install 6 | cd to the folder and run the command: 7 | ``` 8 | pip install -e . 9 | ``` 10 | 11 | ## Dependencies 12 | ``` 13 | pysc2 (Use Tencent AI Lab fork, required!) 14 | pillow 15 | ``` 16 | We recommend `pip install` each Python package. 17 | 18 | ## How to Run 19 | Run the agent using the scripts from `pysc2.bin`. 20 | Example: 21 | 22 | ``` 23 | python -m pysc2.bin.agent \ 24 | --map AbyssalReef \ 25 | --feature_screen_size 64 \ 26 | --agent tstarbot.agents.zerg_agent.ZergAgent \ 27 | --agent_race zerg \ 28 | --agent2 Bot \ 29 | --agent2_race zerg 30 | ``` 31 | See more examples [here](docs/examples_howtorun.md). 32 | 33 | ## Evaluate 34 | Evaluate the agent (e.g., winning rate) using `tstarbot.bin.eval_agent`. 35 | Example: 36 | ``` 37 | python -m tstarbot.bin.eval_agent \ 38 | --max_agent_episodes 5 \ 39 | --map AbyssalReef \ 40 | --norender \ 41 | --agent1 tstarbot.agents.zerg_agent.ZergAgent \ 42 | --screen_resolution 64 \ 43 | --agent1_race Z \ 44 | --agent2 Bot \ 45 | --agent2_race Z \ 46 | --difficulty 3 47 | ``` 48 | See more examples [here](docs/examples_evaluate.md). 49 | In particular, see how a well configured agent plays against 50 | difficulty-A (cheat_insane) builtin bot [here](docs/examples_evaluate.md#against-difficulty-a-builtin-bot). 51 | 52 | ## Profiling 53 | Use `pysc2.lib.stopwatch` to profile the code. 54 | As an example, see `tstarbot/agents/micro_defeat_roaches_agent.py` and run the following command: 55 | ``` 56 | python -m pysc2.bin.agent \ 57 | --map DefeatRoaches \ 58 | --feature_screen_size 64 \ 59 | --max_episodes 2 \ 60 | --agent tstarbot.agents.micro_defeat_roaches_agent.MicroDefeatRoachesAgent \ 61 | --agent_race terran \ 62 | --agent2 Bot \ 63 | --agent2_race zerg \ 64 | --profile 65 | ``` 66 | 67 | ## AI-vs-AI and Human-vs-AI 68 | See examples [here](docs/examples_howtorun.md#ai-vs-ai) for AI-vs-AI and 69 | examples [here](docs/examples_howtorun.md#human-vs-ai) for Human-vs-AI. 70 | 71 | ## Coding Style 72 | Be consistent with that of `pysc2`. 73 | -------------------------------------------------------------------------------- /tstarbot/combat/micro/lurker_micro.py: -------------------------------------------------------------------------------- 1 | from pysc2.lib.typeenums import UNIT_TYPEID 2 | from pysc2.lib.typeenums import ABILITY_ID 3 | from s2clientprotocol import sc2api_pb2 as sc_pb 4 | 5 | from tstarbot.combat.micro.micro_base import MicroBase 6 | from tstarbot.data.pool.macro_def import AIR_UNITS 7 | 8 | 9 | class LurkerMgr(MicroBase): 10 | """ A zvz Zerg combat manager """ 11 | 12 | def __init__(self): 13 | super(LurkerMgr, self).__init__() 14 | self.lurker_range = 9 15 | 16 | @staticmethod 17 | def burrow_down(u): 18 | action = sc_pb.Action() 19 | action.action_raw.unit_command.ability_id = \ 20 | ABILITY_ID.BURROWDOWN_LURKER.value 21 | action.action_raw.unit_command.unit_tags.append(u.tag) 22 | return action 23 | 24 | @staticmethod 25 | def burrow_up(u): 26 | action = sc_pb.Action() 27 | action.action_raw.unit_command.ability_id = ABILITY_ID.BURROWUP_LURKER.value 28 | action.action_raw.unit_command.unit_tags.append(u.tag) 29 | return action 30 | 31 | def act(self, u, pos, mode): 32 | lurker_can_atk_units = [e for e in self.enemy_units 33 | if e.int_attr.unit_type not in AIR_UNITS] 34 | if u.int_attr.unit_type == UNIT_TYPEID.ZERG_LURKERMP.value: 35 | if len(lurker_can_atk_units) > 0: 36 | closest_enemy = self.find_closest_enemy(u, lurker_can_atk_units) 37 | 38 | if self.dist_between_units(u, closest_enemy) < self.lurker_range: 39 | action = self.burrow_down(u) 40 | else: 41 | action = self.move_pos(u, pos) 42 | else: 43 | action = self.move_pos(u, pos) 44 | elif u.int_attr.unit_type == UNIT_TYPEID.ZERG_LURKERMPBURROWED.value: 45 | if len(lurker_can_atk_units) > 0: 46 | closest_enemy = self.find_closest_enemy(u, lurker_can_atk_units) 47 | if self.dist_between_units(u, closest_enemy) < self.lurker_range: 48 | action = self.attack_pos(u, pos) 49 | else: 50 | action = self.burrow_up(u) 51 | else: 52 | action = self.burrow_up(u) 53 | else: 54 | print("Unrecognized lurker type: %s" % str(u.int_attr.unit_type)) 55 | raise NotImplementedError 56 | return action 57 | -------------------------------------------------------------------------------- /tstarbot/combat_strategy/renderer.py: -------------------------------------------------------------------------------- 1 | """Strategy Renderer.""" 2 | from __future__ import absolute_import 3 | from __future__ import division 4 | from __future__ import print_function 5 | 6 | import pygame 7 | 8 | 9 | class Color(object): 10 | BLACK = (0, 0, 0) 11 | WHITE = (255, 255, 255) 12 | RED = (255, 0, 0) 13 | GREEN = (0, 255, 0) 14 | BLUE = (0, 0, 255) 15 | 16 | 17 | class Renderer(object): 18 | def __init__(self, window_size, world_size, caption): 19 | pygame.init() 20 | pygame.display.set_caption(caption) 21 | self._surface = pygame.display.set_mode(window_size) 22 | self._window_size = window_size 23 | self._world_size = world_size 24 | self.clear() 25 | 26 | def __del__(object): 27 | pygame.quit() 28 | 29 | def draw_circle(self, color, world_pos, radias): 30 | pygame.draw.circle(self._surface, color, self._transform(world_pos), 31 | int(radias)) 32 | 33 | def draw_line(self, color, start_world_pos, end_world_pos, width=1): 34 | pygame.draw.line(self._surface, color, self._transform(start_world_pos), 35 | self._transform(end_world_pos), width) 36 | 37 | def render(self): 38 | pygame.event.get() 39 | pygame.display.flip() 40 | 41 | def clear(self): 42 | self._surface.fill(Color.GREEN) 43 | 44 | def _transform(self, pos): 45 | x = pos['x'] / float(self._world_size['x']) * self._window_size[0] 46 | y = (1 - pos['y'] / float(self._world_size['y'])) * self._window_size[1] 47 | return (int(x), int(y)) 48 | 49 | 50 | class StrategyRenderer(Renderer): 51 | def draw(self, squads, enemy_clusters, commands): 52 | self.clear() 53 | for squad in squads: 54 | self._draw_squad(squad) 55 | for cluster in enemy_clusters: 56 | self._draw_enemy_cluster(cluster) 57 | for command in commands: 58 | self._draw_command(command) 59 | 60 | def _draw_squad(self, squad): 61 | self.draw_circle(Color.BLUE, squad.centroid, squad.num_units / 1.5 + 1) 62 | 63 | def _draw_enemy_cluster(self, cluster): 64 | self.draw_circle(Color.RED, cluster.centroid, cluster.num_units / 1.5 + 1) 65 | 66 | def _draw_command(self, command): 67 | self.draw_line(Color.BLACK, command.squad.centroid, command.position, 2) 68 | -------------------------------------------------------------------------------- /tstarbot/combat/micro/mutalisk_micro.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from pysc2.lib.typeenums import UNIT_TYPEID, ABILITY_ID, UPGRADE_ID 3 | from s2clientprotocol import sc2api_pb2 as sc_pb 4 | 5 | from tstarbot.combat.micro.micro_base import MicroBase 6 | from tstarbot.data.queue.combat_command_queue import CombatCmdType 7 | from tstarbot.data.pool.macro_def import COMBAT_ANTI_AIR_UNITS 8 | 9 | 10 | class MutaliskMgr(MicroBase): 11 | """ A zvz Zerg combat manager """ 12 | 13 | def __init__(self): 14 | super(MutaliskMgr, self).__init__() 15 | self.safe_harass_dist = 11 16 | self.safe_positions = [ 17 | {'x': 20, 'y': 4}, 18 | {'x': 20, 'y': 145}, 19 | {'x': 180, 'y': 145}, 20 | {'x': 180, 'y': 4}, 21 | 22 | {'x': 20, 'y': 75}, 23 | {'x': 90, 'y': 145}, 24 | {'x': 180, 'y': 75}, 25 | {'x': 110, 'y': 4} 26 | ] 27 | 28 | def act(self, u, pos, mode): 29 | action = [] 30 | if mode == CombatCmdType.MOVE: 31 | action = self.move_pos(u, pos) 32 | elif mode == CombatCmdType.ATTACK: 33 | closest_combat_enemy_dist = 100000 34 | enemy_anti_air_units = [e for e in self.enemy_units 35 | if e.int_attr.unit_type in COMBAT_ANTI_AIR_UNITS 36 | and (e.int_attr.unit_type != 37 | UNIT_TYPEID.ZERG_SPORECRAWLER.value or 38 | e.float_attr.build_progress == 1)] 39 | if UNIT_TYPEID.ZERG_RAVAGER.value in [e.int_attr.unit_type for e in 40 | enemy_anti_air_units]: 41 | self.safe_harass_dist = 20 42 | else: 43 | self.safe_harass_dist = 11 44 | if len(enemy_anti_air_units) > 0: 45 | closest_combat_enemy = self.find_closest_enemy(u, enemy_anti_air_units) 46 | closest_combat_enemy_dist = self.dist_between_units(u, 47 | closest_combat_enemy) 48 | if closest_combat_enemy_dist < self.safe_harass_dist or \ 49 | u.float_attr.health / u.float_attr.health_max < 0.5: 50 | pos_id = np.argmin( 51 | [self.dist_between_coordinates({'x': u.float_attr.pos_x, 52 | 'y': u.float_attr.pos_y}, 53 | pos) for pos in 54 | self.safe_positions]) 55 | action = self.move_pos(u, self.safe_positions[pos_id]) 56 | else: 57 | action = self.attack_pos(u, pos) 58 | return action 59 | -------------------------------------------------------------------------------- /tstarbot/util/act.py: -------------------------------------------------------------------------------- 1 | """ sc_pb action utilities """ 2 | from s2clientprotocol import sc2api_pb2 as sc_pb 3 | from pysc2.lib.typeenums import UNIT_TYPEID 4 | from pysc2.lib.typeenums import ABILITY_ID 5 | 6 | 7 | def act_build_by_self(builder_tag, ability_id): 8 | action = sc_pb.Action() 9 | action.action_raw.unit_command.ability_id = ability_id 10 | action.action_raw.unit_command.unit_tags.append(builder_tag) 11 | return action 12 | 13 | 14 | def act_build_by_tag(builder_tag, target_tag, ability_id): 15 | action = sc_pb.Action() 16 | action.action_raw.unit_command.ability_id = ability_id 17 | action.action_raw.unit_command.target_unit_tag = target_tag 18 | action.action_raw.unit_command.unit_tags.append(builder_tag) 19 | return action 20 | 21 | 22 | def act_build_by_pos(builder_tag, target_pos, ability_id): 23 | action = sc_pb.Action() 24 | action.action_raw.unit_command.ability_id = ability_id 25 | action.action_raw.unit_command.target_world_space_pos.x = target_pos[0] 26 | action.action_raw.unit_command.target_world_space_pos.y = target_pos[1] 27 | action.action_raw.unit_command.unit_tags.append(builder_tag) 28 | return action 29 | 30 | 31 | def act_move_to_pos(unit_tag, target_pos): 32 | action = sc_pb.Action() 33 | action.action_raw.unit_command.ability_id = ABILITY_ID.MOVE.value 34 | action.action_raw.unit_command.target_world_space_pos.x = target_pos[0] 35 | action.action_raw.unit_command.target_world_space_pos.y = target_pos[1] 36 | action.action_raw.unit_command.unit_tags.append(unit_tag) 37 | return action 38 | 39 | 40 | def act_worker_harvests_on_target(target_tag, worker_tag): 41 | # The CALLER should assure the target-worker is a reasonable pair 42 | # e.g., the target is an extractor, a mineral, 43 | # the worker should not be too far away to the target 44 | action = sc_pb.Action() 45 | action.action_raw.unit_command.ability_id = \ 46 | ABILITY_ID.HARVEST_GATHER_DRONE.value 47 | action.action_raw.unit_command.target_unit_tag = target_tag 48 | action.action_raw.unit_command.unit_tags.append(worker_tag) 49 | return action 50 | 51 | 52 | def act_rally_worker(target_tag, base_tag): 53 | action = sc_pb.Action() 54 | action.action_raw.unit_command.ability_id = \ 55 | ABILITY_ID.RALLY_HATCHERY_WORKERS.value 56 | action.action_raw.unit_command.target_unit_tag = target_tag 57 | action.action_raw.unit_command.unit_tags.append(base_tag) 58 | return action 59 | 60 | 61 | def act_stop(unit_tag): 62 | action = sc_pb.Action() 63 | action.action_raw.unit_command.ability_id = ABILITY_ID.STOP.value 64 | action.action_raw.unit_command.unit_tags.append(unit_tag) 65 | return action 66 | -------------------------------------------------------------------------------- /tstarbot/agents/zerg_agent.py: -------------------------------------------------------------------------------- 1 | """ Scripted Zerg agent.""" 2 | from __future__ import absolute_import 3 | from __future__ import division 4 | from __future__ import print_function 5 | import importlib 6 | from time import sleep 7 | 8 | from pysc2.agents import base_agent 9 | 10 | from tstarbot.combat_strategy.combat_strategy_mgr import ZergStrategyMgr 11 | from tstarbot.production_strategy.production_mgr import ZergProductionMgr 12 | from tstarbot.building.building_mgr import ZergBuildingMgr 13 | from tstarbot.resource.resource_mgr import ZergResourceMgr 14 | from tstarbot.combat.combat_mgr import ZergCombatMgr 15 | from tstarbot.scout.scout_mgr import ZergScoutMgr 16 | from tstarbot.data.data_context import DataContext 17 | from tstarbot.act.act_mgr import ActMgr 18 | 19 | 20 | DFT_CONFIG_PATH = 'tstarbot.agents.dft_config' 21 | 22 | 23 | class ZergAgent(base_agent.BaseAgent): 24 | """A ZvZ Zerg agent for full game map.""" 25 | 26 | def __init__(self, **kwargs): 27 | super(ZergAgent, self).__init__() 28 | self._sleep_per_step = None 29 | 30 | config_path = DFT_CONFIG_PATH 31 | if kwargs.get('config_path'): # use the config file 32 | config_path = kwargs['config_path'] 33 | config = importlib.import_module(config_path) 34 | self._init_config(config) 35 | 36 | self.dc = DataContext(config) 37 | self.am = ActMgr() 38 | 39 | self.strategy_mgr = ZergStrategyMgr(self.dc) 40 | self.production_mgr = ZergProductionMgr(self.dc) 41 | self.building_mgr = ZergBuildingMgr(self.dc) 42 | self.resource_mgr = ZergResourceMgr(self.dc) 43 | self.combat_mgr = ZergCombatMgr(self.dc) 44 | self.scout_mgr = ZergScoutMgr(self.dc) 45 | 46 | def _init_config(self, cfg): 47 | if hasattr(cfg, 'sleep_per_step'): 48 | self._sleep_per_step = cfg.sleep_per_step 49 | 50 | def step(self, timestep): 51 | super(ZergAgent, self).step(timestep) 52 | 53 | if self._sleep_per_step: 54 | sleep(self._sleep_per_step) 55 | 56 | self.dc.update(timestep) # update data context 57 | 58 | # Brain 59 | self.strategy_mgr.update(self.dc, self.am) 60 | self.production_mgr.update(self.dc, self.am) 61 | 62 | # Battle 63 | self.combat_mgr.update(self.dc, self.am) 64 | self.scout_mgr.update(self.dc, self.am) 65 | 66 | # Construct 67 | self.building_mgr.update(self.dc, self.am) 68 | self.resource_mgr.update(self.dc, self.am) 69 | 70 | return self.am.pop_actions() 71 | 72 | def reset(self): 73 | super(ZergAgent, self).reset() 74 | self.dc.reset() 75 | self.strategy_mgr.reset() 76 | self.production_mgr.reset() 77 | self.building_mgr.reset() 78 | self.resource_mgr.reset() 79 | self.combat_mgr.reset() 80 | self.scout_mgr.reset() 81 | -------------------------------------------------------------------------------- /tstarbot/combat_strategy/squad.py: -------------------------------------------------------------------------------- 1 | """Squad Class.""" 2 | from __future__ import absolute_import 3 | from __future__ import division 4 | from __future__ import print_function 5 | 6 | from enum import Enum 7 | from pysc2.lib.typeenums import UNIT_TYPEID 8 | from tstarbot.data.pool import macro_def as tm 9 | from tstarbot.data.pool.combat_pool import CombatUnitStatus 10 | 11 | SquadStatus = Enum('SquadStatus', 12 | ('IDLE', 'MOVE', 'ATTACK', 'DEFEND', 'SCOUT', 'ROCK')) 13 | MutaliskSquadStatus = Enum('SquadStatus', ('IDLE', 'PHASE1', 'PHASE2')) 14 | 15 | 16 | class Squad(object): 17 | def __init__(self, units, uniform=None): 18 | self._units = units # combat_unit 19 | self._status = SquadStatus.IDLE 20 | self.combat_status = MutaliskSquadStatus.IDLE 21 | self._uniform = uniform 22 | for u in units: 23 | assert u.unit.int_attr.alliance == tm.AllianceType.SELF.value 24 | 25 | def __repr__(self): 26 | return ('Squad(Units(%d), Roaches(%d), Zerglings(%d))' % 27 | (self.num_units, self.num_roach_units, self.num_zergling_units)) 28 | 29 | def update(self, combat_unit_pool): 30 | tags = [u.tag for u in self._units] 31 | self._units = list() 32 | for tag in tags: 33 | combat_unit = combat_unit_pool.get_by_tag(tag) 34 | if combat_unit is not None and combat_unit.status != \ 35 | CombatUnitStatus.SCOUT: 36 | self._units.append(combat_unit) 37 | 38 | @property 39 | def num_units(self): 40 | return len(self._units) 41 | 42 | @property 43 | def num_hydralisk_units(self): 44 | return len(self.hydralisk_units) 45 | 46 | @property 47 | def num_roach_units(self): 48 | return len(self.roach_units) 49 | 50 | @property 51 | def num_zergling_units(self): 52 | return len(self.zergling_units) 53 | 54 | @property 55 | def units(self): 56 | return self._units 57 | 58 | @property 59 | def hydralisk_units(self): 60 | return [u for u in self._units 61 | if u.type == UNIT_TYPEID.ZERG_HYDRALISK.value] 62 | 63 | @property 64 | def roach_units(self): 65 | return [u for u in self._units 66 | if u.type == UNIT_TYPEID.ZERG_ROACH.value] 67 | 68 | @property 69 | def zergling_units(self): 70 | return [u for u in self._units 71 | if u.type == UNIT_TYPEID.ZERG_ZERGLING.value] 72 | 73 | @property 74 | def uniform(self): 75 | return self._uniform 76 | 77 | @property 78 | def status(self): 79 | return self._status 80 | 81 | @status.setter 82 | def status(self, status): 83 | self._status = status 84 | 85 | @property 86 | def centroid(self): 87 | x, y = 0, 0 88 | if len(self._units) > 0: 89 | x = sum(u.position['x'] for u in self._units) / len(self._units) 90 | y = sum(u.position['y'] for u in self._units) / len(self._units) 91 | return {'x': x, 'y': y} 92 | -------------------------------------------------------------------------------- /tstarbot/production_strategy/prod_rush.py: -------------------------------------------------------------------------------- 1 | """Production Manager""" 2 | from __future__ import absolute_import 3 | from __future__ import division 4 | from __future__ import print_function 5 | 6 | from pysc2.lib.typeenums import UNIT_TYPEID 7 | from pysc2.lib.typeenums import UPGRADE_ID 8 | from tstarbot.production_strategy.base_zerg_production_mgr import ZergBaseProductionMgr 9 | 10 | 11 | class ZergProdRush(ZergBaseProductionMgr): 12 | def __init__(self, dc): 13 | super(ZergProdRush, self).__init__(dc) 14 | 15 | def get_opening_build_order(self): 16 | return [UNIT_TYPEID.ZERG_DRONE, 17 | UNIT_TYPEID.ZERG_DRONE, 18 | UNIT_TYPEID.ZERG_OVERLORD, 19 | UNIT_TYPEID.ZERG_EXTRACTOR, 20 | UNIT_TYPEID.ZERG_DRONE, 21 | UNIT_TYPEID.ZERG_DRONE, 22 | UNIT_TYPEID.ZERG_DRONE, 23 | UNIT_TYPEID.ZERG_SPAWNINGPOOL] + \ 24 | [UNIT_TYPEID.ZERG_DRONE] * 4 + \ 25 | [UNIT_TYPEID.ZERG_HATCHERY, 26 | UNIT_TYPEID.ZERG_DRONE, 27 | UNIT_TYPEID.ZERG_DRONE, 28 | UNIT_TYPEID.ZERG_ROACHWARREN, 29 | UNIT_TYPEID.ZERG_DRONE, 30 | UNIT_TYPEID.ZERG_DRONE, 31 | UNIT_TYPEID.ZERG_QUEEN] + \ 32 | [UNIT_TYPEID.ZERG_ZERGLING] * 2 + \ 33 | [UNIT_TYPEID.ZERG_ROACH] * 5 34 | 35 | def get_goal(self, dc): 36 | if not self.has_building_built([UNIT_TYPEID.ZERG_LAIR.value, 37 | UNIT_TYPEID.ZERG_HIVE.value]): 38 | goal = [UNIT_TYPEID.ZERG_LAIR] + \ 39 | [UNIT_TYPEID.ZERG_DRONE, 40 | UNIT_TYPEID.ZERG_ROACH] * 5 + \ 41 | [UNIT_TYPEID.ZERG_EVOLUTIONCHAMBER] + \ 42 | [UNIT_TYPEID.ZERG_ROACH, 43 | UNIT_TYPEID.ZERG_DRONE] * 2 + \ 44 | [UPGRADE_ID.BURROW, 45 | UNIT_TYPEID.ZERG_HYDRALISKDEN] + \ 46 | [UNIT_TYPEID.ZERG_ROACH, 47 | UNIT_TYPEID.ZERG_DRONE] * 3 + \ 48 | [UPGRADE_ID.TUNNELINGCLAWS] 49 | else: 50 | num_worker_needed = 0 51 | num_worker = 0 52 | bases = dc.dd.base_pool.bases 53 | for base_tag in bases: 54 | base = bases[base_tag] 55 | num_worker += self.assigned_harvesters(base) 56 | num_worker_needed += self.ideal_harvesters(base) 57 | num_worker_needed -= num_worker 58 | if num_worker_needed > 0 and num_worker < 66: 59 | goal = [UNIT_TYPEID.ZERG_DRONE] * 2 + \ 60 | [UNIT_TYPEID.ZERG_ROACH] * 3 + \ 61 | [UNIT_TYPEID.ZERG_HYDRALISK] * 2 62 | else: 63 | goal = [UNIT_TYPEID.ZERG_ROACH] * 3 + \ 64 | [UNIT_TYPEID.ZERG_HYDRALISK] * 2 65 | # add some ravager 66 | game_loop = self.obs['game_loop'][0] 67 | if game_loop > 6 * 60 * 16: # 6 min 68 | goal += [UNIT_TYPEID.ZERG_RAVAGER] * 2 69 | return goal 70 | -------------------------------------------------------------------------------- /tstarbot/combat/micro/queen_micro.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import numpy as np 4 | from pysc2.lib.typeenums import UNIT_TYPEID, ABILITY_ID, UPGRADE_ID 5 | from s2clientprotocol import sc2api_pb2 as sc_pb 6 | 7 | from tstarbot.combat.micro.micro_base import MicroBase 8 | from tstarbot.data.queue.combat_command_queue import CombatCmdType 9 | from tstarbot.data.pool.macro_def import COMBAT_ANTI_AIR_UNITS 10 | 11 | 12 | class QueenMgr(MicroBase): 13 | """ A zvz Zerg combat manager """ 14 | 15 | def __init__(self): 16 | super(QueenMgr, self).__init__() 17 | self.cure_range = 15 18 | 19 | @staticmethod 20 | def cure_target(u, target): 21 | action = sc_pb.Action() 22 | action.action_raw.unit_command.ability_id = \ 23 | ABILITY_ID.EFFECT_TRANSFUSION.value 24 | action.action_raw.unit_command.target_unit_tag = target.tag 25 | action.action_raw.unit_command.unit_tags.append(u.tag) 26 | return action 27 | 28 | def find_weakest_unit(self, u, units, dist): 29 | min_a = None 30 | min_hp = 10000 31 | for a in units: 32 | if self.dist_between_units_with_radius(u, 33 | a) < dist and 0 < a.float_attr.health < min_hp: 34 | min_a = a 35 | min_hp = a.float_attr.health 36 | return min_a 37 | 38 | def is_queen_run_away(self, u, closest_enemy): 39 | closest_enemy_dist = self.dist_between_units(u, closest_enemy) 40 | if closest_enemy_dist < self.roach_attack_range: 41 | return True 42 | return False 43 | 44 | def hit_and_run(self, u, pos): 45 | action = self.attack_pos(u, pos) 46 | if len(self.enemy_combat_units) > 0: 47 | closest_enemy = self.find_closest_enemy(u, self.enemy_combat_units) 48 | if self.is_queen_run_away(u, closest_enemy): 49 | action = self.run_away_from_closest_enemy(u, closest_enemy) 50 | else: 51 | action = self.attack_pos(u, pos) 52 | return action 53 | 54 | def act(self, u, pos, mode): 55 | action = self.hit_and_run(u, pos) 56 | 57 | if u.float_attr.energy > 50: 58 | 59 | spine_crawlers = [a for a in self.self_combat_units 60 | if a.int_attr.unit_type == 61 | UNIT_TYPEID.ZERG_SPINECRAWLER.value] 62 | roaches = [a for a in self.self_combat_units 63 | if a.int_attr.unit_type == UNIT_TYPEID.ZERG_ROACH.value] 64 | 65 | weakest_spine_crawler = self.find_weakest_unit(u, spine_crawlers, 66 | self.cure_range) 67 | weakest_roach = self.find_weakest_unit(u, roaches, self.cure_range) 68 | 69 | if weakest_spine_crawler is not None and \ 70 | weakest_spine_crawler.float_attr.health_max - \ 71 | weakest_spine_crawler.float_attr.health >= 125: 72 | action = self.cure_target(u, weakest_spine_crawler) 73 | return action 74 | if weakest_roach is not None and \ 75 | weakest_roach.float_attr.health_max - \ 76 | weakest_roach.float_attr.health >= 125: 77 | action = self.cure_target(u, weakest_roach) 78 | return action 79 | 80 | return action 81 | -------------------------------------------------------------------------------- /tstarbot/combat/micro/infestor_micro.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from pysc2.lib.typeenums import UNIT_TYPEID, ABILITY_ID 3 | from s2clientprotocol import sc2api_pb2 as sc_pb 4 | 5 | from tstarbot.combat.micro.micro_base import MicroBase 6 | from tstarbot.data.pool.macro_def import COMBAT_FLYING_UNITS 7 | 8 | 9 | class InfestorMgr(MicroBase): 10 | """ A zvz Zerg combat manager """ 11 | 12 | def __init__(self): 13 | super(InfestorMgr, self).__init__() 14 | self.infestor_range = 20 15 | self.infestor_harm_range = 4 16 | 17 | @staticmethod 18 | def fungal_growth_attack_pos(u, pos): 19 | action = sc_pb.Action() 20 | action.action_raw.unit_command.ability_id = \ 21 | ABILITY_ID.EFFECT_FUNGALGROWTH.value 22 | action.action_raw.unit_command.target_world_space_pos.x = pos['x'] 23 | action.action_raw.unit_command.target_world_space_pos.y = pos['y'] 24 | action.action_raw.unit_command.unit_tags.append(u.tag) 25 | return action 26 | 27 | def find_densest_enemy_pos_in_range(self, u): 28 | enemy_ground_units = [e for e in self.enemy_combat_units 29 | if e.int_attr.unit_type not in COMBAT_FLYING_UNITS and 30 | e.int_attr.unit_type not in [ 31 | UNIT_TYPEID.ZERG_SPINECRAWLER.value, 32 | UNIT_TYPEID.ZERG_SPORECRAWLER.value]] 33 | targets = self.find_units_wihtin_range(u, enemy_ground_units, 34 | r=self.infestor_range) 35 | if len(targets) == 0: 36 | return None 37 | target_density = list() 38 | for e in targets: 39 | target_density.append(len( 40 | self.find_units_wihtin_range(e, targets, r=self.infestor_harm_range))) 41 | target_id = np.argmax(target_density) 42 | target = targets[target_id] 43 | target_pos = {'x': target.float_attr.pos_x, 44 | 'y': target.float_attr.pos_y} 45 | return target_pos 46 | 47 | def act(self, u, pos, mode): 48 | ground_targets = [e for e in self.enemy_combat_units 49 | if e.int_attr.unit_type not in COMBAT_FLYING_UNITS] 50 | if len(ground_targets) > 0: 51 | closest_target = self.find_closest_enemy(u, ground_targets) 52 | if self.dist_between_units(u, closest_target) < self.infestor_range: 53 | if u.float_attr.energy > 75: 54 | target_pos = self.find_densest_enemy_pos_in_range(u) 55 | if target_pos is None: 56 | action = self.attack_pos(u, pos) 57 | return action 58 | action = self.fungal_growth_attack_pos(u, target_pos) 59 | else: 60 | bases = self.dc.dd.base_pool.bases 61 | base_units = [bases[tag].unit for tag in bases] 62 | if len(base_units) == 0: 63 | action = self.hold_fire(u) 64 | return action 65 | closest_base = self.find_closest_enemy(u, base_units) 66 | base_pos = {'x': closest_base.float_attr.pos_x, 67 | 'y': closest_base.float_attr.pos_y} 68 | action = self.move_pos(u, base_pos) 69 | else: 70 | action = self.move_pos(u, pos) 71 | else: 72 | action = self.move_pos(u, pos) 73 | return action 74 | -------------------------------------------------------------------------------- /tstarbot/data/pool/combat_pool.py: -------------------------------------------------------------------------------- 1 | """CombatPool Class.""" 2 | from __future__ import absolute_import 3 | from __future__ import division 4 | from __future__ import print_function 5 | from enum import Enum 6 | 7 | from tstarbot.data.pool import macro_def as tm 8 | from tstarbot.data.pool.pool_base import PoolBase 9 | 10 | 11 | CombatUnitStatus = Enum('CombatUnitStatus', ('IDLE', 'COMBAT', 'SCOUT')) 12 | 13 | 14 | class CombatUnit(object): 15 | 16 | def __init__(self, unit): 17 | self._unit = unit 18 | self._status = CombatUnitStatus.IDLE 19 | self._lost = False 20 | 21 | @property 22 | def unit(self): 23 | return self._unit 24 | 25 | @property 26 | def tag(self): 27 | return self._unit.tag 28 | 29 | @property 30 | def type(self): 31 | return self._unit.unit_type 32 | 33 | @property 34 | def status(self): 35 | return self._status 36 | 37 | @property 38 | def position(self): 39 | return {'x': self._unit.float_attr.pos_x, 40 | 'y': self._unit.float_attr.pos_y} 41 | 42 | def set_lost(self, lost): 43 | self._lost = lost 44 | 45 | def is_lost(self): 46 | return self._lost 47 | 48 | def set_status(self, status): 49 | self._status = status 50 | 51 | def update(self, u): 52 | if u.int_attr.tag == self._unit.int_attr.tag: # is the same unit 53 | self._unit = u 54 | 55 | 56 | class CombatUnitPool(PoolBase): 57 | def __init__(self): 58 | super(PoolBase, self).__init__() 59 | self._units = dict() 60 | 61 | def update(self, timestep): 62 | units = timestep.observation['units'] 63 | 64 | # set all combat unit 'lost' state 65 | for k, b in self._units.items(): 66 | b.set_lost(True) 67 | 68 | # update unit 69 | for u in units: 70 | if self._is_combat_unit(u): 71 | self._add_unit(u) 72 | 73 | # delete lost unit 74 | del_keys = [] 75 | for k, b in self._units.items(): 76 | if b.is_lost(): 77 | del_keys.append(k) 78 | 79 | for k in del_keys: 80 | del self._units[k] 81 | 82 | def get_by_tag(self, tag): 83 | return self._units.get(tag, None) 84 | 85 | def employ_combat_unit(self, employ_status, unit_type): 86 | idles = [u for u in self.units 87 | if 88 | u.unit.int_attr.unit_type == unit_type and u.status == CombatUnitStatus.IDLE] 89 | 90 | if len(idles) > 0: 91 | u = idles[0] 92 | self._units[u.unit.int_attr.tag].set_status(employ_status) 93 | return u.unit 94 | return None 95 | 96 | @property 97 | def num_units(self): 98 | return len(self._units) 99 | 100 | @property 101 | def units(self): 102 | return self._units.values() 103 | 104 | @staticmethod 105 | def _is_combat_unit(u): 106 | if (u.unit_type in tm.COMBAT_UNITS and 107 | u.int_attr.alliance == tm.AllianceType.SELF.value): 108 | return True 109 | else: 110 | return False 111 | 112 | def _add_unit(self, u): 113 | tag = u.int_attr.tag 114 | 115 | if tag in self._units: 116 | self._units[u.tag].update(u) 117 | else: 118 | self._units[u.tag] = CombatUnit(u) 119 | 120 | self._units[u.tag].set_lost(False) 121 | -------------------------------------------------------------------------------- /docs/examples_evaluate.md: -------------------------------------------------------------------------------- 1 | # Examples: Evaluate 2 | 3 | ## Full Games 4 | Evaluate an agent on full game: 5 | ``` 6 | python -m tstarbot.bin.eval_agent \ 7 | --max_agent_episodes 5 \ 8 | --map AbyssalReef \ 9 | --render \ 10 | --agent1 tstarbot.agents.zerg_agent.ZergAgent \ 11 | --screen_resolution 64 \ 12 | --agent1_race Z \ 13 | --agent2 Bot \ 14 | --agent2_race Z \ 15 | --difficulty 8 16 | ``` 17 | 18 | One can pass in agent config file (when supported by the agent): 19 | ``` 20 | python -m tstarbot.bin.eval_agent \ 21 | --max_agent_episodes 5 \ 22 | --map Simple64 \ 23 | --render \ 24 | --agent1 tstarbot.agents.zerg_agent.ZergAgent \ 25 | --agent1_config tstarbot.agents.dft_config \ 26 | --screen_resolution 64 \ 27 | --agent1_race Z \ 28 | --agent2 Bot \ 29 | --agent2_race Z \ 30 | --difficulty 5 \ 31 | --disable_fog 32 | ``` 33 | 34 | One can turn-off rendering/visualization (on a headless server): 35 | ``` 36 | python -m tstarbot.bin.eval_agent \ 37 | --max_agent_episodes 5 \ 38 | --map AbyssalReef \ 39 | --difficulty 3 \ 40 | --norender \ 41 | --agent tstarbot.agents.zerg_agent.ZergAgent \ 42 | --screen_resolution 64 \ 43 | --agent_race Z \ 44 | --bot_race Z 45 | ``` 46 | 47 | ## Against Difficulty Level 10 (a.k.a Level A) Builtin Bot 48 | A well configured agent plays against Difficulty-A (cheat_insane) builtin bot: 49 | ``` 50 | python -m tstarbot.bin.eval_agent \ 51 | --max_agent_episodes 1 \ 52 | --step_mul 4 \ 53 | --map AbyssalReef \ 54 | --norender \ 55 | --agent1 tstarbot.agents.zerg_agent.ZergAgent \ 56 | --agent1_config tstarbot.agents.vs_builtin_ai_config \ 57 | --screen_resolution 64 \ 58 | --agent1_race Z \ 59 | --agent2 Bot \ 60 | --agent2_race Z \ 61 | --nodisable_fog \ 62 | --difficulty A 63 | ``` 64 | 65 | ## AI against AI 66 | Two well configured agents plays against each other: 67 | ``` 68 | python -m tstarbot.bin.eval_agent \ 69 | --max_agent_episodes 1 \ 70 | --step_mul 4 \ 71 | --map AbyssalReef \ 72 | --norender \ 73 | --agent1 tstarbot.agents.zerg_agent.ZergAgent \ 74 | --agent1_config tstarbot.agents.dft_config \ 75 | --agent2 tstarbot.agents.zerg_agent.ZergAgent \ 76 | --agent2_config tstarbot.agents.dft_config \ 77 | --screen_resolution 64 \ 78 | --agent1_race Z \ 79 | --agent2_race Z \ 80 | --nodisable_fog 81 | ``` 82 | 83 | ## Mini Games 84 | One can also evaluate an agent over Mini Game. Example: 85 | ``` 86 | python -m tstarbot.bin.eval_agent \ 87 | --max_agent_episodes 15 \ 88 | --map DefeatRoaches \ 89 | --norender \ 90 | --agent1 tstarbot.agents.micro_defeat_roaches_agent.MicroDefeatRoachesAgent \ 91 | --screen_resolution 64 \ 92 | --agent1_race T 93 | ``` 94 | which should achieve an almost 100% winning rate. 95 | 96 | As a comparison, the Deempmind baseline agent rarely wins: 97 | ``` 98 | python -m tstarbot.bin.eval_agent \ 99 | --max_agent_episodes 15 \ 100 | --map DefeatRoaches \ 101 | --norender \ 102 | --agent1 pysc2.agents.scripted_agent.DefeatRoaches \ 103 | --screen_resolution 64 \ 104 | --agent1_race T 105 | ``` -------------------------------------------------------------------------------- /tstarbot/data/data_context.py: -------------------------------------------------------------------------------- 1 | """data context""" 2 | 3 | from __future__ import absolute_import 4 | from __future__ import division 5 | from __future__ import print_function 6 | 7 | from s2clientprotocol import sc2api_pb2 as sc_pb 8 | from pysc2.lib.data_raw import data_raw_3_16 9 | from pysc2.lib.data_raw import data_raw_4_0 10 | from pysc2.lib import TechTree 11 | 12 | from tstarbot.data.queue.build_command_queue import BuildCommandQueue 13 | from tstarbot.data.queue.build_command_queue import BuildCommandQueueV2 14 | from tstarbot.data.queue.combat_command_queue import CombatCommandQueue 15 | from tstarbot.data.queue.scout_command_queue import ScoutCommandQueue 16 | from tstarbot.data.pool.base_pool import BasePool 17 | from tstarbot.data.pool.building_pool import BuildingPool 18 | from tstarbot.data.pool.worker_pool import WorkerPool 19 | from tstarbot.data.pool.combat_pool import CombatUnitPool 20 | from tstarbot.data.pool.enemy_pool import EnemyPool 21 | from tstarbot.data.pool.scout_pool import ScoutPool 22 | from tstarbot.data.pool.opponent_pool import OppoPool 23 | 24 | 25 | class StaticData(object): 26 | def __init__(self, config): 27 | self._obs = None 28 | self._timestep = None 29 | self._data_raw = data_raw_4_0 30 | 31 | self.game_version = '3.16.1' 32 | if hasattr(config, 'game_version'): 33 | self.game_version = config.game_version 34 | self.TT = TechTree() 35 | self.TT.update_version(self.game_version) 36 | 37 | def update(self, timestep): 38 | self._obs = timestep.observation 39 | self._timestep = timestep 40 | 41 | @property 42 | def obs(self): 43 | return self._obs 44 | 45 | @property 46 | def timestep(self): 47 | return self._timestep 48 | 49 | @property 50 | def data_raw(self): 51 | return self._data_raw 52 | 53 | 54 | class DynamicData(object): 55 | def __init__(self, config): 56 | self.build_command_queue = BuildCommandQueueV2() 57 | self.combat_command_queue = CombatCommandQueue() 58 | self.scout_command_queue = ScoutCommandQueue() 59 | 60 | self.building_pool = BuildingPool() 61 | self.worker_pool = WorkerPool() 62 | self.combat_pool = CombatUnitPool() 63 | self.base_pool = BasePool(self) 64 | self.enemy_pool = EnemyPool(self) 65 | self.scout_pool = ScoutPool(self) 66 | self.oppo_pool = OppoPool() 67 | 68 | def update(self, timestep): 69 | # update command queues 70 | 71 | # update pools 72 | self.building_pool.update(timestep) 73 | self.worker_pool.update(timestep) 74 | self.combat_pool.update(timestep) 75 | self.base_pool.update(timestep) 76 | self.enemy_pool.update(timestep) 77 | self.scout_pool.update(timestep) 78 | 79 | # update statistic 80 | 81 | def reset(self): 82 | self.base_pool.reset() 83 | self.scout_pool.reset() 84 | self.oppo_pool.reset() 85 | self.enemy_pool.reset() 86 | 87 | 88 | class DataContext: 89 | def __init__(self, config): 90 | self.config = config 91 | self._dynamic = DynamicData(config) 92 | self._static = StaticData(config) 93 | 94 | def update(self, timestep): 95 | # self._obs = timestep.observation 96 | self._dynamic.update(timestep) 97 | self._static.update(timestep) 98 | 99 | def reset(self): 100 | # print('***DataContext reset***') 101 | self._dynamic.reset() 102 | 103 | @property 104 | def dd(self): 105 | return self._dynamic 106 | 107 | @property 108 | def sd(self): 109 | return self._static 110 | -------------------------------------------------------------------------------- /tstarbot/combat/micro/roach_micro.py: -------------------------------------------------------------------------------- 1 | from pysc2.lib.typeenums import UNIT_TYPEID, ABILITY_ID, UPGRADE_ID 2 | 3 | from s2clientprotocol import sc2api_pb2 as sc_pb 4 | from tstarbot.combat.micro.micro_base import MicroBase 5 | from tstarbot.data.pool.macro_def import UNITS_CAN_DETECT 6 | 7 | 8 | class RoachMgr(MicroBase): 9 | """ A zvz Zerg combat manager """ 10 | 11 | def __init__(self): 12 | super(RoachMgr, self).__init__() 13 | self.overseer_sight = 15 14 | 15 | @staticmethod 16 | def burrow_down(u): 17 | action = sc_pb.Action() 18 | action.action_raw.unit_command.ability_id = ABILITY_ID.BURROWDOWN.value 19 | action.action_raw.unit_command.unit_tags.append(u.tag) 20 | return action 21 | 22 | @staticmethod 23 | def burrow_up(u): 24 | action = sc_pb.Action() 25 | action.action_raw.unit_command.ability_id = ABILITY_ID.BURROWUP.value 26 | action.action_raw.unit_command.unit_tags.append(u.tag) 27 | return action 28 | 29 | def act(self, u, pos, mode): 30 | if u.int_attr.unit_type == UNIT_TYPEID.ZERG_ROACH.value: 31 | # hit and burrow/run 32 | closest_enemy_dist = 100000 33 | if len(self.enemy_combat_units) > 0: 34 | closest_enemy = self.find_closest_enemy(u, self.enemy_combat_units) 35 | closest_enemy_dist = self.dist_between_units(closest_enemy, u) 36 | 37 | if self.is_run_away(u, closest_enemy, self.self_combat_units): 38 | if UPGRADE_ID.BURROW.value in self.dc.sd.obs[ 39 | 'raw_data'].player.upgrade_ids: 40 | action = self.burrow_down(u) 41 | else: 42 | action = self.run_away_from_closest_enemy(u, closest_enemy) 43 | else: 44 | action = self.attack_pos(u, pos) 45 | else: 46 | action = self.attack_pos(u, pos) 47 | # burrow to recover when idle 48 | if (u.float_attr.health / u.float_attr.health_max < 1 and 49 | closest_enemy_dist > 1.2 * self.roach_attack_range): 50 | if UPGRADE_ID.BURROW.value in self.dc.sd.obs[ 51 | 'raw_data'].player.upgrade_ids: 52 | action = self.burrow_down(u) 53 | elif u.int_attr.unit_type == UNIT_TYPEID.ZERG_ROACHBURROWED.value: 54 | if u.float_attr.health / u.float_attr.health_max == 1: 55 | action = self.burrow_up(u) 56 | else: 57 | action = self.hold_fire(u) 58 | if len(self.enemy_combat_units) > 0: 59 | closest_enemy = self.find_closest_enemy(u, self.enemy_combat_units) 60 | if (self.is_run_away(u, closest_enemy, self.self_combat_units) and 61 | UPGRADE_ID.TUNNELINGCLAWS.value in self.dc.sd.obs[ 62 | 'raw_data'].player.upgrade_ids): 63 | action = self.run_away_from_closest_enemy(u, closest_enemy) 64 | 65 | # if detected 66 | # print(u.int_attr.cloak) 67 | enemy_detect_units = [u for u in self.enemy_units 68 | if u.int_attr.unit_type in UNITS_CAN_DETECT] 69 | if len(enemy_detect_units) > 0 and len(self.enemy_combat_units) > 0: 70 | closest_enemy_detect_unit = \ 71 | self.find_closest_enemy(u, enemy_detect_units) 72 | closest_enemy_combat_unit = \ 73 | self.find_closest_enemy(u, self.enemy_combat_units) 74 | if self.dist_between_units(closest_enemy_detect_unit, 75 | u) < self.overseer_sight: 76 | action = self.run_away_from_closest_enemy(u, 77 | closest_enemy_combat_unit) 78 | else: 79 | print("Unrecognized roach type: %s" % str(u.int_attr.unit_type)) 80 | raise NotImplementedError 81 | return action 82 | -------------------------------------------------------------------------------- /tstarbot/sandbox/agents/demo_bot.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import numpy 6 | import random 7 | 8 | from pysc2.lib import actions 9 | from s2clientprotocol import sc2api_pb2 as sc_pb 10 | import tstarbot as ts 11 | 12 | UNIT_TYPE_HATCHERY = 86 13 | UNIT_TYPE_DRONE= 104 14 | 15 | class DemoPool(ts.PoolBase): 16 | def __init__(self): 17 | self._drone_ids = [] 18 | self._hatcherys = [] 19 | 20 | def update(self, obs): 21 | units = obs['units'] 22 | self._locate_hatcherys(units) 23 | self._update_drone(units) 24 | 25 | def _locate_hatcherys(self, units): 26 | for u in units: 27 | if u.unit_type == UNIT_TYPE_HATCHERY: 28 | self._hatcherys.append((u.float_attr.pos_x, u.float_attr.pos_y, u.float_attr.pos_z)) 29 | 30 | def _update_drone(self, units): 31 | drone_ids = [] 32 | for u in units: 33 | if u.unit_type == UNIT_TYPE_DRONE: 34 | drone_ids.append(u.tag) 35 | 36 | self._drone_ids = drone_ids 37 | 38 | def get_drones(self): 39 | return self._drone_ids 40 | 41 | def get_hatcherys(self): 42 | return self._hatcherys 43 | 44 | class DemoManager(ts.ManagerBase): 45 | def __init__(self, pool): 46 | self._pool = pool 47 | self._range_high = 5 48 | self._range_low = -5 49 | self._move_ability = 1 50 | 51 | def execute(self): 52 | drone_ids = self._pool.get_drones() 53 | pos = self._pool.get_hatcherys() 54 | print('pos=', pos) 55 | actions = self.move_drone_random_round_hatchery(drone_ids, pos[0]) 56 | return actions 57 | 58 | def move_drone_random_round_hatchery(self, drone_ids, pos): 59 | length = len(drone_ids) 60 | actions = [] 61 | for drone in drone_ids: 62 | action = sc_pb.Action() 63 | action.action_raw.unit_command.ability_id = self._move_ability 64 | x = pos[0] + random.randint(self._range_low, self._range_high) 65 | y = pos[1] + random.randint(self._range_low, self._range_high) 66 | action.action_raw.unit_command.target_world_space_pos.x = x 67 | action.action_raw.unit_command.target_world_space_pos.y = y 68 | action.action_raw.unit_command.unit_tags.append(drone) 69 | actions.append(action) 70 | return actions 71 | 72 | class DemoBot: 73 | """A random agent for starcraft.""" 74 | def __init__(self, env): 75 | self._pools = [] 76 | self._managers = [] 77 | self._env = env 78 | 79 | def setup(self): 80 | demo_pool = DemoPool() 81 | demo_manager = DemoManager(demo_pool) 82 | self._pools.append(demo_pool) 83 | self._managers.append(demo_manager) 84 | self._executor = ts.ActExecutor(self._env) 85 | 86 | def reset(self): 87 | timesteps = self._env.reset() 88 | return timesteps 89 | 90 | def run(self, n): 91 | return self._run_inner(n) 92 | 93 | def _run_inner(self, n): 94 | try: 95 | """episode loop """ 96 | step_num = 0 97 | timesteps = self.reset() 98 | while True: 99 | obs = timesteps[0].observation 100 | for pool in self._pools: 101 | pool.update(obs) 102 | 103 | actions = [] 104 | for manager in self._managers: 105 | part_actions = manager.execute() 106 | actions.extend(part_actions) 107 | 108 | result = self._executor.exec_raw(actions) 109 | if result[1]: 110 | break 111 | timesteps = result[0] 112 | 113 | if step_num > n: 114 | break 115 | step_num += 1 116 | except KeyboardInterrupt: 117 | print("SC2Imp exception") 118 | 119 | -------------------------------------------------------------------------------- /tstarbot/combat/combat_mgr.py: -------------------------------------------------------------------------------- 1 | """Combat Manager""" 2 | from __future__ import absolute_import 3 | from __future__ import division 4 | from __future__ import print_function 5 | 6 | from tstarbot.data.pool.macro_def import UNIT_TYPEID 7 | from tstarbot.data.queue.combat_command_queue import CombatCmdType 8 | from tstarbot.combat.micro.micro_mgr import MicroMgr 9 | from tstarbot.combat.micro.lurker_micro import LurkerMgr 10 | import tstarbot.util.geom as geom 11 | 12 | 13 | class BaseCombatMgr(object): 14 | """ Basic Combat Manager 15 | 16 | Common Utilites for combat are implemented here. """ 17 | 18 | def __init__(self, dc): 19 | pass 20 | 21 | def reset(self): 22 | pass 23 | 24 | def update(self, dc, am): 25 | pass 26 | 27 | 28 | class ZergCombatMgr(BaseCombatMgr): 29 | """ A zvz Zerg combat manager """ 30 | 31 | def __init__(self, dc): 32 | super(ZergCombatMgr, self).__init__(dc) 33 | self.dc = dc 34 | self.micro_mgr = MicroMgr(dc) 35 | 36 | def reset(self): 37 | self.micro_mgr = MicroMgr(self.dc) 38 | self.dc = None 39 | 40 | def update(self, dc, am): 41 | super(ZergCombatMgr, self).update(dc, am) 42 | self.dc = dc 43 | 44 | actions = list() 45 | while True: 46 | cmd = dc.dd.combat_command_queue.pull() 47 | if not cmd: 48 | break 49 | else: 50 | actions.extend(self.exe_cmd(cmd.squad, cmd.position, cmd.type)) 51 | am.push_actions(actions) 52 | 53 | def exe_cmd(self, squad, pos, mode): 54 | actions = [] 55 | if mode == CombatCmdType.ATTACK: 56 | actions = self.exe_attack(squad, pos) 57 | elif mode == CombatCmdType.MOVE: 58 | actions = self.exe_move(squad, pos) 59 | elif mode == CombatCmdType.DEFEND: 60 | actions = self.exe_defend(squad, pos) 61 | elif mode == CombatCmdType.RALLY: 62 | actions = self.exe_rally(squad, pos) 63 | elif mode == CombatCmdType.ROCK: 64 | actions = self.exe_rock(squad, pos) 65 | return actions 66 | 67 | def exe_attack(self, squad, pos): 68 | actions = list() 69 | squad_units = [] 70 | for combat_unit in squad.units: 71 | squad_units.append(combat_unit.unit) 72 | for u in squad_units: 73 | action = self.exe_micro(u, pos, mode=CombatCmdType.ATTACK) 74 | actions.append(action) 75 | return actions 76 | 77 | def exe_defend(self, squad, pos): 78 | actions = list() 79 | squad_units = [] 80 | for combat_unit in squad.units: 81 | squad_units.append(combat_unit.unit) 82 | for u in squad_units: 83 | action = self.exe_micro(u, pos, mode=CombatCmdType.DEFEND) 84 | actions.append(action) 85 | return actions 86 | 87 | def exe_move(self, squad, pos): 88 | actions = [] 89 | for u in squad.units: 90 | u = u.unit 91 | if u.int_attr.unit_type == UNIT_TYPEID.ZERG_LURKERMPBURROWED.value: 92 | actions.append(LurkerMgr().burrow_up(u)) 93 | else: 94 | actions.append(self.micro_mgr.move_pos(u, pos)) 95 | return actions 96 | 97 | def exe_rally(self, squad, pos): 98 | actions = [] 99 | for u in squad.units: 100 | actions.append(self.micro_mgr.attack_pos(u.unit, pos)) 101 | return actions 102 | 103 | def exe_rock(self, squad, pos): 104 | actions = [] 105 | rocks = [u for u in self.dc.sd.obs['units'] 106 | if u.int_attr.unit_type == 107 | UNIT_TYPEID.NEUTRAL_DESTRUCTIBLEROCKEX1DIAGONALHUGEBLUR.value] 108 | target_rock = None 109 | for r in rocks: 110 | d = geom.dist_to_pos(r, (pos['x'], pos['y'])) 111 | if d < 0.1: 112 | target_rock = r 113 | break 114 | for u in squad.units: 115 | actions.append(self.micro_mgr.attack_target(u, target_rock)) 116 | return actions 117 | 118 | def exe_micro(self, u, pos, mode): 119 | action = self.micro_mgr.exe(self.dc, u, pos, mode) 120 | return action 121 | -------------------------------------------------------------------------------- /tstarbot/data/pool/map_tool.py: -------------------------------------------------------------------------------- 1 | from queue import Queue 2 | import copy 3 | 4 | import numpy as np 5 | from PIL import Image 6 | 7 | 8 | def bitmap2array(image): 9 | array = np.frombuffer(image.data, dtype=np.uint8) 10 | array = np.reshape(array, (image.size.y, image.size.x)) 11 | array = copy.copy(array[::-1].transpose()) 12 | return array 13 | 14 | 15 | def save_image(image, figure_name): 16 | array = np.frombuffer(image.data, dtype=np.uint8) 17 | array = np.reshape(array, (image.size.y, image.size.x)) 18 | im = Image.fromarray(array, mode='L') 19 | im = im.convert('RGB') 20 | im.save(figure_name) 21 | 22 | 23 | def compute_dist(x, y, array): 24 | q = Queue() 25 | q.put((x, y)) 26 | nx, ny = array.shape 27 | dist = -np.ones(array.shape, dtype=np.int16) 28 | dist[x, y] = 0 29 | dx = [-1, 1, 0, 0] 30 | dy = [0, 0, -1, 1] 31 | while not q.empty(): 32 | x_now, y_now = q.get() 33 | for i in range(4): 34 | x_next = x_now + dx[i] 35 | y_next = y_now + dy[i] 36 | if x_next >= 0 and x_next < nx and y_next >= 0 and y_next < ny: 37 | if dist[x_next, y_next] == -1 and array[x_next, y_next] == 0: 38 | dist[x_next, y_next] = dist[x_now, y_now] + 1 39 | q.put((x_next, y_next)) 40 | return dist 41 | 42 | 43 | def compute_area_dist(areas, timestep, pos): 44 | """ distance from area.ideal_base_pos to pos """ 45 | pathing_grid = timestep.game_info.start_raw.pathing_grid 46 | array = bitmap2array(pathing_grid) 47 | dist = {} 48 | # erase base in pathing_grid 49 | for area in areas: 50 | pos_area = area.ideal_base_pos 51 | if array[int(pos_area[0]), int(pos_area[1])] != 0: 52 | for dx in range(-2, 3): 53 | for dy in range(-2, 3): 54 | array[int(pos_area[0]) + dx, 55 | int(pos_area[1]) + dy] = 0 56 | # compute map distance from area.ideal_base_pos to pos 57 | d = compute_dist(int(pos[0]), 58 | int(pos[1]), array) 59 | for area in areas: 60 | pos_area = area.ideal_base_pos 61 | dist[area] = d[int(pos_area[0]), int(pos_area[1])] 62 | return dist 63 | 64 | 65 | class Slope(object): 66 | def __init__(self, mean_x, mean_y, size, min_height, max_height, pos, h): 67 | self.x = mean_x 68 | self.y = mean_y 69 | self.size = size 70 | self.min_h = min_height 71 | self.max_h = max_height 72 | self.pos = pos # [(x1, y1), (x2, y2), ... , (xn, yn)] 73 | self.height = h # [H1, H2, ... , Hn] 74 | 75 | 76 | def get_slopes(timestep): 77 | """ get all the slopes in map """ 78 | pathing_grid = timestep.game_info.start_raw.pathing_grid 79 | placement_grid = timestep.game_info.start_raw.placement_grid 80 | terrain_height = timestep.game_info.start_raw.terrain_height 81 | pathing = bitmap2array(pathing_grid) 82 | placement = bitmap2array(placement_grid) 83 | height = bitmap2array(terrain_height) 84 | 85 | slopes = [] 86 | for i in range(pathing_grid.size.x): 87 | for j in range(pathing_grid.size.y): 88 | if pathing[i, j] == 0 and placement[i, j] == 0: 89 | slope_item = extract_slope(i, j, pathing, placement, height) 90 | if slope_item.min_h != slope_item.max_h: 91 | slopes.append(slope_item) 92 | return slopes 93 | 94 | 95 | def extract_slope(x, y, pathing, placement, height): 96 | q = Queue() 97 | q.put((x, y)) 98 | pathing[x, y] = 255 99 | nx, ny = pathing.shape 100 | pos = [(x, y)] 101 | h = height[x, y] 102 | heights = [h] 103 | sum_x = x 104 | sum_y = y 105 | num = 1 106 | max_h = h 107 | min_h = h 108 | dx = [-1, 1, 0, 0] 109 | dy = [0, 0, -1, 1] 110 | while not q.empty(): 111 | x_now, y_now = q.get() 112 | for i in range(4): 113 | x_next = x_now + dx[i] 114 | y_next = y_now + dy[i] 115 | if x_next >= 0 and x_next < nx and y_next >= 0 and y_next < ny: 116 | if pathing[x_next, y_next] == 0 and placement[x_next, y_next] == 0: 117 | pathing[x_next, y_next] = 255 118 | q.put((x_next, y_next)) 119 | pos.append((x_next, y_next)) 120 | sum_x += x_next 121 | sum_y += y_next 122 | num += 1 123 | h = height[x_next, y_next] 124 | heights.append(h) 125 | max_h = max(h, max_h) 126 | min_h = min(h, min_h) 127 | return Slope(sum_x / num, sum_y / num, num, min_h, max_h, pos, heights) 128 | -------------------------------------------------------------------------------- /tstarbot/production_strategy/util.py: -------------------------------------------------------------------------------- 1 | from pysc2.lib.typeenums import UNIT_TYPEID 2 | from pysc2.lib.typeenums import UPGRADE_ID 3 | from collections import deque 4 | 5 | 6 | def unit_count(units, tech_tree, alliance=1): 7 | count = {} 8 | for unit_id in UNIT_TYPEID: # init 9 | count[unit_id.value] = 0 10 | 11 | for u in units: # add unit 12 | if u.int_attr.alliance == alliance: 13 | if u.unit_type in count: 14 | count[u.unit_type] += 1 15 | else: 16 | count[u.unit_type] = 1 17 | 18 | eggs = [] 19 | for u in units: # get all the eggs 20 | if u.unit_type == UNIT_TYPEID.ZERG_EGG.value: 21 | eggs.append(u) 22 | 23 | for unit_type, data in tech_tree.m_unitTypeData.items(): # add unit in egg 24 | if data.isUnit: 25 | count[unit_type] += sum( 26 | [(len(egg.orders) > 0 and 27 | egg.orders[0].ability_id == data.buildAbility) 28 | for egg in eggs]) 29 | return count 30 | 31 | 32 | def unique_unit_count(units, tech_tree, alliance=1): 33 | count = unit_count(units, tech_tree, alliance) 34 | unit_alias = {UNIT_TYPEID.ZERG_BANELING.value: 35 | [UNIT_TYPEID.ZERG_BANELINGBURROWED.value, 36 | UNIT_TYPEID.ZERG_BANELINGCOCOON.value], 37 | UNIT_TYPEID.ZERG_BROODLORD.value: 38 | [UNIT_TYPEID.ZERG_BROODLORDCOCOON.value], 39 | UNIT_TYPEID.ZERG_DRONE.value: 40 | [UNIT_TYPEID.ZERG_DRONEBURROWED.value], 41 | UNIT_TYPEID.ZERG_HYDRALISK.value: 42 | [UNIT_TYPEID.ZERG_HYDRALISKBURROWED.value], 43 | UNIT_TYPEID.ZERG_INFESTOR.value: 44 | [UNIT_TYPEID.ZERG_INFESTORBURROWED.value], 45 | UNIT_TYPEID.ZERG_LURKERMP.value: 46 | [UNIT_TYPEID.ZERG_LURKERMPBURROWED.value, 47 | UNIT_TYPEID.ZERG_LURKERMPEGG.value], 48 | UNIT_TYPEID.ZERG_OVERSEER.value: 49 | [UNIT_TYPEID.ZERG_OVERLORDCOCOON.value], 50 | UNIT_TYPEID.ZERG_QUEEN.value: 51 | [UNIT_TYPEID.ZERG_QUEENBURROWED.value], 52 | UNIT_TYPEID.ZERG_RAVAGER.value: 53 | [UNIT_TYPEID.ZERG_RAVAGERCOCOON.value], 54 | UNIT_TYPEID.ZERG_ROACH.value: 55 | [UNIT_TYPEID.ZERG_ROACHBURROWED.value], 56 | UNIT_TYPEID.ZERG_SPORECRAWLER.value: 57 | [UNIT_TYPEID.ZERG_SPORECRAWLERUPROOTED.value], 58 | UNIT_TYPEID.ZERG_SWARMHOSTMP.value: 59 | [UNIT_TYPEID.ZERG_SWARMHOSTBURROWEDMP.value], 60 | UNIT_TYPEID.ZERG_ZERGLING.value: 61 | [UNIT_TYPEID.ZERG_ZERGLINGBURROWED.value]} 62 | for unit_type, alias in unit_alias.items(): 63 | for a in alias: 64 | count[unit_type] += count[a] 65 | return count 66 | 67 | 68 | class BuildOrderQueue(object): 69 | def __init__(self, tech_tree): 70 | self.queue = deque() 71 | self.TT = tech_tree 72 | 73 | def set_build_order(self, unit_list): 74 | for unit_id in unit_list: 75 | if type(unit_id) == UNIT_TYPEID: 76 | build_item = self.TT.getUnitData(unit_id.value) 77 | elif type(unit_id) == UPGRADE_ID: 78 | build_item = self.TT.getUpgradeData(unit_id.value) 79 | else: 80 | raise Exception('Unknown unit_id {}'.format(unit_id)) 81 | build_item.unit_id = unit_id 82 | self.queue.append(build_item) 83 | 84 | def size(self): 85 | return len(self.queue) 86 | 87 | def is_empty(self): 88 | return len(self.queue) == 0 89 | 90 | def current_item(self): 91 | if len(self.queue) > 0: 92 | return self.queue[0] 93 | else: 94 | return None 95 | 96 | def remove_current_item(self): 97 | if len(self.queue) > 0: 98 | self.queue.popleft() 99 | 100 | def queue_as_highest(self, unit_id): 101 | if type(unit_id) == UNIT_TYPEID: 102 | build_item = self.TT.getUnitData(unit_id.value) 103 | elif type(unit_id) == UPGRADE_ID: 104 | build_item = self.TT.getUpgradeData(unit_id.value) 105 | else: 106 | raise Exception('Unknown unit_id {}'.format(unit_id)) 107 | build_item.unit_id = unit_id 108 | self.queue.appendleft(build_item) 109 | 110 | def queue(self, unit_id): 111 | if type(unit_id) == UNIT_TYPEID: 112 | build_item = self.TT.getUnitData(unit_id.value) 113 | elif type(unit_id) == UPGRADE_ID: 114 | build_item = self.TT.getUpgradeData(unit_id.value) 115 | else: 116 | raise Exception('Unknown unit_id {}'.format(unit_id)) 117 | build_item.unit_id = unit_id 118 | self.queue.append(build_item) 119 | 120 | def clear_all(self): 121 | self.queue.clear() 122 | 123 | def reset(self): 124 | self.queue.clear() 125 | -------------------------------------------------------------------------------- /tstarbot/scout/tasks/cruise_task.py: -------------------------------------------------------------------------------- 1 | from tstarbot.scout.tasks.scout_task import ScoutTask, ScoutAttackEscape 2 | import tstarbot.scout.tasks.scout_task as st 3 | from tstarbot.data.pool import macro_def as md 4 | from tstarbot.data.pool import scout_pool as sp 5 | 6 | 7 | class ScoutCruiseTask(ScoutTask): 8 | def __init__(self, scout, home, target): 9 | super(ScoutCruiseTask, self).__init__(scout, home) 10 | self._target = target 11 | self._paths = [] 12 | self._curr_pos = 0 13 | self._generate_path() 14 | self._status = md.ScoutTaskStatus.DOING 15 | self._attack_escape = None 16 | 17 | def type(self): 18 | return md.ScoutTaskType.CRUISE 19 | 20 | def _do_task_inner(self, view_enemys, dc): 21 | if self._check_scout_lost(): 22 | self._status = md.ScoutTaskStatus.SCOUT_DESTROY 23 | return None 24 | 25 | if self._status == md.ScoutTaskStatus.UNDER_ATTACK or self._check_attack(): 26 | return self._exec_under_attack(view_enemys) 27 | 28 | self._detect_enemy(view_enemys, dc) 29 | return self._exec_by_status() 30 | 31 | def post_process(self): 32 | self._target.has_cruise = False 33 | self._scout.is_doing_task = False 34 | 35 | def _exec_by_status(self): 36 | if self._status == md.ScoutTaskStatus.DOING: 37 | return self._exec_cruise() 38 | elif self._status == md.ScoutTaskStatus.DONE: 39 | return self._move_to_home() 40 | else: 41 | #print('SCOUT cruise exec noop, scout status=', self._status) 42 | return self._noop() 43 | 44 | def _exec_under_attack(self, view_enemys): 45 | me = self._scout.unit() 46 | if self._attack_escape is None: 47 | self._attack_escape = ScoutAttackEscape() 48 | self._attack_escape.generate_path(view_enemys, me, self._home) 49 | 50 | me_pos = (me.float_attr.pos_x, me.float_attr.pos_y) 51 | if not self._attack_escape.curr_arrived(me_pos): 52 | #print('Scout exec_under_attack move to curr=', 53 | # self._attack_escape.curr_pos()) 54 | act = self._move_to_target(self._attack_escape.curr_pos()) 55 | else: 56 | act = self._move_to_target(self._attack_escape.next_pos()) 57 | #print('Scout exec_under_attack move to next=', 58 | # self._attack_escape.curr_pos()) 59 | 60 | if self._attack_escape.is_last_pos(): 61 | self._status = md.ScoutTaskStatus.DONE 62 | 63 | return act 64 | 65 | def _exec_cruise(self): 66 | if self._curr_pos < 0 or self._curr_pos >= len(self._paths): 67 | self._curr_pos = 0 68 | pos = self._paths[self._curr_pos] 69 | if self._check_arrived_pos(pos): 70 | self._curr_pos += 1 71 | return self._move_to_target(pos) 72 | 73 | def _check_arrived_pos(self, pos): 74 | dist = md.calculate_distance(self._scout.unit().float_attr.pos_x, 75 | self._scout.unit().float_attr.pos_y, 76 | pos[0], pos[1]) 77 | if dist < st.SCOUT_CRUISE_ARRIVAED_RANGE: 78 | return True 79 | else: 80 | return False 81 | 82 | def _generate_path(self): 83 | home_x = self._home[0] 84 | home_y = self._home[1] 85 | target_x = self._target.pos[0] 86 | target_y = self._target.pos[1] 87 | 88 | pos1 = ((home_x * 2) / 3 + target_x / 3, (home_y * 2)/3 + target_y / 3) 89 | pos2 = ((home_x + target_x)/2, (home_y + target_y)/2) 90 | pos3 = (pos2[0] - st.SCOUT_CRUISE_RANGE, pos2[1]) 91 | pos4 = (home_x /3 + (target_x * 2) / 3, home_y / 3 + (target_y * 2) / 3) 92 | pos5 = (pos2[0] + st.SCOUT_CRUISE_RANGE, pos2[1]) 93 | 94 | self._paths.append(pos1) 95 | self._paths.append(pos3) 96 | self._paths.append(pos2) 97 | self._paths.append(pos4) 98 | self._paths.append(pos5) 99 | 100 | def _check_attack(self): 101 | attack = self._detect_attack() 102 | if attack: 103 | #print('SCOUT task turn DOING to UNDER_ATTACK, target=', str(self._target)) 104 | self._status = md.ScoutTaskStatus.UNDER_ATTACK 105 | return True 106 | else: 107 | return False 108 | 109 | def _detect_enemy(self, view_enemys, dc): 110 | spool = dc.dd.scout_pool 111 | armys = [] 112 | for enemy in view_enemys: 113 | if enemy.unit_type in md.COMBAT_UNITS: 114 | armys.append(enemy) 115 | 116 | me = (self._scout.unit().float_attr.pos_x, 117 | self._scout.unit().float_attr.pos_y) 118 | 119 | scout_armys = [] 120 | for unit in armys: 121 | dist = md.calculate_distance(me[0], me[1], 122 | unit.float_attr.pos_x, 123 | unit.float_attr.pos_y) 124 | if dist < st.SCOUT_CRUISE_RANGE: 125 | scout_armys.append(unit) 126 | break 127 | if len(scout_armys) > 0: 128 | alarm = sp.ScoutAlarm() 129 | alarm.enmey_armys = scout_armys 130 | if not spool.alarms.full(): 131 | spool.alarms.put(alarm) 132 | -------------------------------------------------------------------------------- /tstarbot/scout/oppo_monitor.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import pysc2.lib.typeenums as tp 3 | 4 | import tstarbot.data.pool.opponent_pool as op 5 | 6 | OPENING_STAGE = 5000 7 | BANELING_RUSH_THRESHOLD = 20 8 | ROACH_RUSH_THRESHOLD = 4 9 | ZERG_ROACH_RUSH_THRESHOLD = 15 10 | ROACH_SUPPRESS_THRESHOLD = 10 11 | 12 | class OppoMonitor(object): 13 | def __init__(self): 14 | #{type: number, .....} 15 | self._scout_max = {} 16 | self._scout_curr = {} 17 | self._scout_change = {} 18 | self._step_count = 0 19 | 20 | def analysis(self, dc): 21 | self._step_count = dc.sd.obs['game_loop'][0] 22 | spool = dc.dd.scout_pool 23 | scouts = spool.get_view_scouts() 24 | #print("Scount monitor, scouts={},step={}".format( 25 | # len(scouts), self._step_count)) 26 | if (self._step_count < OPENING_STAGE and 27 | dc.dd.oppo_pool.opening_tactics is None): 28 | self.analysis_opening_tactis(scouts) 29 | self.judge_opening_tactis(dc) 30 | elif (self._step_count == OPENING_STAGE and 31 | dc.dd.oppo_pool.opening_tactics is None): 32 | #print('SCOUT oppo monitor, un-known, step=', self._step_count) 33 | dc.dd.oppo_pool.opening_tactics = op.OppoOpeningTactics.UNKNOWN 34 | else: 35 | pass 36 | 37 | def judge_opening_tactis(self, dc): 38 | if self.is_roach_rush(): 39 | #print('SCOUT oppo monitor, ROACH_RUSH, step=', self._step_count) 40 | dc.dd.oppo_pool.opening_tactics = op.OppoOpeningTactics.ROACH_RUSH 41 | elif self.is_baneling_rush(): 42 | #print('SCOUT oppo monitor, BANELING_RUSH, step=', self._step_count) 43 | dc.dd.oppo_pool.opening_tactics = op.OppoOpeningTactics.BANELING_RUSH 44 | elif self.is_zerg_roach_rush(): 45 | #print('SCOUT oppo monitor, ZERG_ROACH_RUSH, step=', self._step_count) 46 | dc.dd.oppo_pool.opening_tactics = op.OppoOpeningTactics.ZERGROACH_RUSH 47 | elif self.is_roach_suppress(): 48 | #print('SCOUT oppo monitor, ROACH_SUPPRESS, step=', self._step_count) 49 | dc.dd.oppo_pool.opening_tactics = op.OppoOpeningTactics.ROACH_SUPPRESS 50 | else: 51 | pass 52 | 53 | def analysis_opening_tactis(self, scouts): 54 | self._scout_curr = {} 55 | self._scout_change = {} 56 | for scout in scouts: 57 | for unit in scout.snapshot_armys: 58 | if unit.unit_type in self._scout_curr: 59 | self._scout_curr[unit.unit_type] += 1 60 | else: 61 | self._scout_curr[unit.unit_type] = 1 62 | 63 | max_keys = set(self._scout_curr.keys()) 64 | curr_keys = set(self._scout_max.keys()) 65 | 66 | exist_keys = max_keys.intersection(curr_keys) 67 | add_keys = max_keys.difference(curr_keys) 68 | del_keys = curr_keys.difference(max_keys) 69 | 70 | for key in exist_keys: 71 | v_max = self._scout_max[key] 72 | v_curr = self._scout_curr[key] 73 | if v_curr > v_max: 74 | self._scout_max[key] = v_curr 75 | self._scout_change[key] = v_curr - v_max 76 | 77 | for key in add_keys: 78 | self._scout_max[key] = self._scout_curr[key] 79 | self._scout_change[key] = self._scout_curr[key] 80 | 81 | #print("Scout max={}, curr={}, change={}".format( 82 | # self._scout_max, self._scout_curr, self._scout_change)) 83 | 84 | def is_baneling_rush(self): 85 | baneling_t = tp.UNIT_TYPEID.ZERG_BANELING.value 86 | zergling_t = tp.UNIT_TYPEID.ZERG_ZERGLING.value 87 | if baneling_t not in self._scout_max or zergling_t not in self._scout_max: 88 | return False 89 | max_num = self._scout_max[baneling_t] + self._scout_max[zergling_t] 90 | if max_num >= BANELING_RUSH_THRESHOLD: 91 | return True 92 | else: 93 | return False 94 | 95 | def is_zerg_roach_rush(self): 96 | roach_t = tp.UNIT_TYPEID.ZERG_ROACH.value 97 | zergling_t = tp.UNIT_TYPEID.ZERG_ZERGLING.value 98 | if roach_t not in self._scout_max or zergling_t not in self._scout_max: 99 | return False 100 | 101 | max_num = self._scout_max[roach_t] + self._scout_max[zergling_t] 102 | if max_num >= ZERG_ROACH_RUSH_THRESHOLD: 103 | return True 104 | else: 105 | return False 106 | 107 | def is_roach_rush(self): 108 | roach_t = tp.UNIT_TYPEID.ZERG_ROACH.value 109 | if roach_t not in self._scout_max or len(self._scout_max) > 2: 110 | return False 111 | 112 | max_num = self._scout_max[roach_t] 113 | if max_num >= ROACH_RUSH_THRESHOLD: 114 | return True 115 | else: 116 | return False 117 | 118 | def is_roach_suppress(self): 119 | roach_t = tp.UNIT_TYPEID.ZERG_ROACH.value 120 | ravager_t = tp.UNIT_TYPEID.ZERG_RAVAGER.value 121 | if roach_t not in self._scout_max or ravager_t not in self._scout_max: 122 | return False 123 | 124 | max_num = self._scout_max[roach_t] + self._scout_max[ravager_t] 125 | if max_num >= ROACH_SUPPRESS_THRESHOLD: 126 | return True 127 | else: 128 | return False 129 | 130 | 131 | -------------------------------------------------------------------------------- /tstarbot/combat/micro/micro_mgr.py: -------------------------------------------------------------------------------- 1 | from pysc2.lib.typeenums import UNIT_TYPEID 2 | 3 | from tstarbot.combat.micro.micro_base import MicroBase 4 | from tstarbot.combat.micro.roach_micro import RoachMgr 5 | from tstarbot.combat.micro.lurker_micro import LurkerMgr 6 | from tstarbot.combat.micro.mutalisk_micro import MutaliskMgr 7 | from tstarbot.combat.micro.ravager_micro import RavagerMgr 8 | from tstarbot.combat.micro.viper_micro import ViperMgr 9 | from tstarbot.combat.micro.corruptor_micro import CorruptorMgr 10 | from tstarbot.combat.micro.infestor_micro import InfestorMgr 11 | from tstarbot.combat.micro.queen_micro import QueenMgr 12 | 13 | 14 | class MicroMgr(MicroBase): 15 | """ A zvz Zerg combat manager """ 16 | 17 | def __init__(self, dc): 18 | super(MicroMgr, self).__init__() 19 | self.roach_mgr = RoachMgr() 20 | self.lurker_mgr = LurkerMgr() 21 | self.mutalisk_mgr = MutaliskMgr() 22 | self.ravager_mgr = RavagerMgr() 23 | self.viper_mgr = ViperMgr() 24 | self.corruptor_mgr = CorruptorMgr() 25 | self.infestor_mgr = InfestorMgr() 26 | self.queen_mgr = QueenMgr() 27 | 28 | self.default_micro_version = 1 29 | self.init_config(dc) 30 | 31 | def init_config(self, dc): 32 | if hasattr(dc, 'config'): 33 | if hasattr(dc.config, 'default_micro_version'): 34 | self.default_micro_verion = int(dc.config.default_micro_version) 35 | 36 | def exe(self, dc, u, pos, mode): 37 | if u.int_attr.unit_type in [ 38 | UNIT_TYPEID.ZERG_ROACH.value, 39 | UNIT_TYPEID.ZERG_ROACHBURROWED.value]: 40 | self.roach_mgr.update(dc) 41 | action = self.roach_mgr.act(u, pos, mode) 42 | elif u.int_attr.unit_type in [ 43 | UNIT_TYPEID.ZERG_LURKERMP.value, 44 | UNIT_TYPEID.ZERG_LURKERMPBURROWED.value]: 45 | self.lurker_mgr.update(dc) 46 | action = self.lurker_mgr.act(u, pos, mode) 47 | elif u.int_attr.unit_type in [ 48 | UNIT_TYPEID.ZERG_MUTALISK.value]: 49 | self.mutalisk_mgr.update(dc) 50 | action = self.mutalisk_mgr.act(u, pos, mode) 51 | elif u.int_attr.unit_type in [ 52 | UNIT_TYPEID.ZERG_RAVAGER.value]: 53 | self.ravager_mgr.update(dc) 54 | action = self.ravager_mgr.act(u, pos, mode) 55 | elif u.int_attr.unit_type in [ 56 | UNIT_TYPEID.ZERG_VIPER.value]: 57 | self.viper_mgr.update(dc) 58 | action = self.viper_mgr.act(u, pos, mode) 59 | elif u.int_attr.unit_type in [ 60 | UNIT_TYPEID.ZERG_CORRUPTOR.value]: 61 | self.corruptor_mgr.update(dc) 62 | action = self.corruptor_mgr.act(u, pos, mode) 63 | elif u.int_attr.unit_type in [ 64 | UNIT_TYPEID.ZERG_INFESTOR.value]: 65 | self.infestor_mgr.update(dc) 66 | action = self.infestor_mgr.act(u, pos, mode) 67 | elif u.int_attr.unit_type in [ 68 | UNIT_TYPEID.ZERG_ULTRALISK.value]: 69 | self.update(dc) 70 | action = self.attack_pos(u, pos) 71 | elif u.int_attr.unit_type in [ 72 | UNIT_TYPEID.ZERG_QUEEN.value]: 73 | self.queen_mgr.update(dc) 74 | action = self.queen_mgr.act(u, pos, mode) 75 | else: 76 | self.update(dc) 77 | if self.default_micro_version == 1: 78 | action = self.default_act(u, pos, mode) 79 | elif self.default_micro_version == 2: 80 | action = self.default_act_v2(u, pos, mode) 81 | else: 82 | raise NotImplementedError 83 | return action 84 | 85 | def default_act(self, u, pos, mode): 86 | if len(self.enemy_combat_units) > 0: 87 | closest_enemy = self.find_closest_enemy(u, self.enemy_combat_units) 88 | if self.is_run_away(u, closest_enemy, self.self_combat_units): 89 | action = self.run_away_from_closest_enemy(u, closest_enemy) 90 | else: 91 | action = self.attack_pos(u, pos) 92 | else: 93 | action = self.attack_pos(u, pos) 94 | return action 95 | 96 | def default_act_v2(self, u, pos, mode): 97 | def POSX(u): 98 | return u.float_attr.pos_x 99 | 100 | def POSY(u): 101 | return u.float_attr.pos_y 102 | 103 | atk_range = self.get_atk_range(u.int_attr.unit_type) 104 | atk_type = self.get_atk_type(u.int_attr.unit_type) 105 | if not atk_range or not atk_type: 106 | return self.default_act(u, pos, mode) 107 | if len(self.enemy_combat_units) > 0: 108 | if self.ready_to_atk(u): 109 | weakest = self.find_weakest_nearby(u, self.enemy_combat_units, 110 | atk_range) 111 | if weakest: 112 | return self.attack_target(u, weakest) 113 | else: 114 | return self.attack_pos(u, pos) 115 | else: 116 | weakest = self.find_weakest_nearby(u, self.enemy_combat_units, 10) 117 | closest_enemy = self.find_closest_enemy(u, self.enemy_combat_units) 118 | if not weakest: 119 | return self.attack_pos(u, pos) 120 | enemy_range = self.get_atk_range(weakest.int_attr.unit_type) 121 | if self.is_run_away(u, closest_enemy, self.self_combat_units): 122 | return self.run_away_from_closest_enemy(u, closest_enemy) 123 | cur_dist = self.dist_between_units_with_radius(u, weakest) 124 | if enemy_range and atk_range >= enemy_range: 125 | if cur_dist < atk_range: 126 | return self.move_dir(u, ( 127 | POSX(u) - POSX(weakest), POSY(u) - POSY(weakest))) 128 | else: 129 | return self.move_dir(u, ( 130 | POSX(weakest) - POSX(u), POSY(weakest) - POSY(u))) 131 | else: 132 | return self.move_dir(u, ( 133 | POSX(weakest) - POSX(u), POSY(weakest) - POSY(u))) 134 | else: 135 | action = self.attack_pos(u, pos) 136 | return action 137 | -------------------------------------------------------------------------------- /tstarbot/scout/tasks/scout_task.py: -------------------------------------------------------------------------------- 1 | from s2clientprotocol import sc2api_pb2 as sc_pb 2 | import pysc2.lib.typeenums as tp 3 | from tstarbot.data.pool import macro_def as md 4 | 5 | SCOUT_BASE_RANGE = 10 6 | SCOUT_SAFE_RANGE = 12 7 | SCOUT_VIEW_RANGE = 10 8 | SCOUT_CRUISE_RANGE = 5 9 | SCOUT_CRUISE_ARRIVAED_RANGE = 1 10 | BUILD_PROGRESS_FINISH = 1.0 11 | 12 | EXPLORE_V1 = 0 13 | EXPLORE_V2 = 1 14 | EXPLORE_V3 = 2 15 | 16 | 17 | class ScoutTask(object): 18 | def __init__(self, scout, home): 19 | self._scout = scout 20 | self._home = home 21 | self._status = md.ScoutTaskStatus.INIT 22 | self._last_health = None 23 | 24 | def scout(self): 25 | return self._scout 26 | 27 | def type(self): 28 | raise NotImplementedError 29 | 30 | def status(self): 31 | return self._status 32 | 33 | def do_task(self, view_enemys, dc): 34 | return self._do_task_inner(view_enemys, dc) 35 | 36 | def post_process(self): 37 | raise NotImplementedError 38 | 39 | def _do_task_inner(self, view_enemys, dc): 40 | raise NotImplementedError 41 | 42 | def _move_to_target(self, pos): 43 | action = sc_pb.Action() 44 | action.action_raw.unit_command.ability_id = tp.ABILITY_ID.SMART.value 45 | action.action_raw.unit_command.target_world_space_pos.x = pos[0] 46 | action.action_raw.unit_command.target_world_space_pos.y = pos[1] 47 | action.action_raw.unit_command.unit_tags.append(self._scout.unit().tag) 48 | return action 49 | 50 | def _move_to_home(self): 51 | action = sc_pb.Action() 52 | action.action_raw.unit_command.ability_id = tp.ABILITY_ID.SMART.value 53 | action.action_raw.unit_command.target_world_space_pos.x = self._home[0] 54 | action.action_raw.unit_command.target_world_space_pos.y = self._home[1] 55 | action.action_raw.unit_command.unit_tags.append(self._scout.unit().tag) 56 | return action 57 | 58 | def _noop(self): 59 | action = sc_pb.Action() 60 | action.action_raw.unit_command.ability_id = tp.ABILITY_ID.INVALID.value 61 | return action 62 | 63 | def _detect_attack(self): 64 | attack = False 65 | current_health = self._scout.unit().float_attr.health 66 | if self._last_health is None: 67 | self._last_health = current_health 68 | return attack 69 | 70 | if self._last_health > current_health: 71 | attack = True 72 | 73 | self._last_health = current_health 74 | return attack 75 | 76 | def _detect_recovery(self): 77 | curr_health = self._scout.unit().float_attr.health 78 | max_health = self._scout.unit().float_attr.health_max 79 | return curr_health == max_health 80 | 81 | def _check_scout_lost(self): 82 | return self._scout.is_lost() 83 | 84 | def check_end(self, view_enemys, dc): 85 | find_base = False 86 | find_queue = False 87 | for enemy in view_enemys: 88 | if enemy.unit_type == md.UNIT_TYPEID.ZERG_QUEEN.value: 89 | dist = md.calculate_distance(self._target.pos[0], 90 | self._target.pos[1], 91 | enemy.float_attr.pos_x, 92 | enemy.float_attr.pos_y) 93 | if dist < SCOUT_BASE_RANGE: 94 | find_queue = True 95 | break 96 | elif enemy.unit_type in md.BASE_UNITS: 97 | dist = md.calculate_distance(self._target.pos[0], self._target.pos[1], 98 | enemy.float_attr.pos_x, enemy.float_attr.pos_y) 99 | if dist < SCOUT_BASE_RANGE: 100 | find_base = True 101 | break 102 | else: 103 | continue 104 | 105 | return find_base or find_queue 106 | 107 | 108 | class ScoutAttackEscape(object): 109 | def __init__(self): 110 | self._paths = [] 111 | self._curr = 0 112 | 113 | def curr_arrived(self, me_pos): 114 | curr_pos = self._paths[self._curr] 115 | dist = md.calculate_distance(curr_pos[0], curr_pos[1], 116 | me_pos[0], me_pos[1]) 117 | if dist <= SCOUT_CRUISE_ARRIVAED_RANGE: 118 | return True 119 | else: 120 | return False 121 | 122 | def curr_pos(self): 123 | return self._paths[self._curr] 124 | 125 | def next_pos(self): 126 | self._curr += 1 127 | return self._paths[self._curr] 128 | 129 | def is_last_pos(self): 130 | last = len(self._paths) - 1 131 | return self._curr == last 132 | 133 | def generate_path(self, view_enemys, me, home_pos): 134 | airs = [] 135 | for enemy in view_enemys: 136 | if enemy.unit_type in md.COMBAT_AIR_UNITS: 137 | dist = md.calculate_distance(me.float_attr.pos_x, 138 | me.float_attr.pos_y, 139 | enemy.float_attr.pos_x, 140 | enemy.float_attr.pos_y) 141 | if dist < SCOUT_SAFE_RANGE: 142 | airs.append(enemy) 143 | else: 144 | continue 145 | 146 | enemy_pos = None 147 | total_x = 0.0 148 | total_y = 0.0 149 | if len(airs) > 0: 150 | for unit in airs: 151 | total_x += unit.float_attr.pos_x 152 | total_y += unit.float_attr.pos_y 153 | enemy_pos = (total_x / len(airs), total_y / len(airs)) 154 | else: 155 | total_x = me.float_attr.pos_x + 1 156 | total_y = me.float_attr.pos_y + 1 157 | enemy_pos = (total_x, total_y) 158 | self._generate_path(home_pos, (me.float_attr.pos_x, 159 | me.float_attr.pos_y), enemy_pos) 160 | #print("SCOUT escape attack, path=", self._paths) 161 | 162 | def _generate_path(self, home_pos, me_pos, enemy_pos): 163 | diff_x = me_pos[0] - enemy_pos[0] 164 | diff_y = me_pos[1] - enemy_pos[1] 165 | pos_1 = (me_pos[0] + diff_x, me_pos[1] + diff_y) 166 | pos_2 = (pos_1[0], home_pos[1]) 167 | self._paths.append(pos_1) 168 | self._paths.append(pos_2) 169 | self._paths.append(home_pos) 170 | 171 | 172 | -------------------------------------------------------------------------------- /tstarbot/combat/micro/viper_micro.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from pysc2.lib.typeenums import UNIT_TYPEID, ABILITY_ID, UPGRADE_ID 3 | from s2clientprotocol import sc2api_pb2 as sc_pb 4 | 5 | from tstarbot.combat.micro.micro_base import MicroBase 6 | from tstarbot.data.queue.combat_command_queue import CombatCmdType 7 | from tstarbot.data.pool.macro_def import COMBAT_ANTI_AIR_UNITS 8 | from tstarbot.data.pool.macro_def import COMBAT_FLYING_UNITS 9 | 10 | 11 | class ViperMgr(MicroBase): 12 | """ A zvz Zerg combat manager """ 13 | 14 | def __init__(self): 15 | super(ViperMgr, self).__init__() 16 | self.viper_range = 20 17 | self.viper_harm_range = 3 18 | self.viper_consume_range = 5 19 | 20 | @staticmethod 21 | def blinding_cloud_attack_pos(u, pos): 22 | action = sc_pb.Action() 23 | action.action_raw.unit_command.ability_id = \ 24 | ABILITY_ID.EFFECT_BLINDINGCLOUD.value 25 | action.action_raw.unit_command.target_world_space_pos.x = pos['x'] 26 | action.action_raw.unit_command.target_world_space_pos.y = pos['y'] 27 | action.action_raw.unit_command.unit_tags.append(u.tag) 28 | return action 29 | 30 | @staticmethod 31 | def parasitic_bomb_attack_target(u, target): 32 | action = sc_pb.Action() 33 | action.action_raw.unit_command.ability_id = \ 34 | ABILITY_ID.EFFECT_PARASITICBOMB.value 35 | action.action_raw.unit_command.target_unit_tag = target.tag 36 | action.action_raw.unit_command.unit_tags.append(u.tag) 37 | return action 38 | 39 | @staticmethod 40 | def consume_target(u, target): 41 | action = sc_pb.Action() 42 | action.action_raw.unit_command.ability_id = \ 43 | ABILITY_ID.EFFECT_VIPERCONSUME.value 44 | action.action_raw.unit_command.target_unit_tag = target.tag 45 | action.action_raw.unit_command.unit_tags.append(u.tag) 46 | return action 47 | 48 | def find_densest_enemy_pos_in_range(self, u): 49 | enemy_ground_units = [e for e in self.enemy_combat_units 50 | if e.int_attr.unit_type not in COMBAT_FLYING_UNITS and 51 | e.int_attr.unit_type not in [ 52 | UNIT_TYPEID.ZERG_SPINECRAWLER.value, 53 | UNIT_TYPEID.ZERG_SPORECRAWLER.value]] 54 | targets = self.find_units_wihtin_range(u, enemy_ground_units, 55 | r=self.viper_range) 56 | if len(targets) == 0: 57 | return None 58 | target_density = list() 59 | for e in targets: 60 | target_density.append( 61 | len(self.find_units_wihtin_range(e, targets, r=self.viper_harm_range))) 62 | target_id = np.argmax(target_density) 63 | target = targets[target_id] 64 | target_pos = {'x': target.float_attr.pos_x, 65 | 'y': target.float_attr.pos_y} 66 | return target_pos 67 | 68 | def find_densest_air_enemy_unit_in_range(self, u): 69 | enemy_combat_flying_units = [e for e in self.enemy_combat_units 70 | if e.int_attr.unit_type in COMBAT_FLYING_UNITS] 71 | if len(enemy_combat_flying_units) == 0: 72 | return None 73 | targets = self.find_units_wihtin_range(u, enemy_combat_flying_units, 74 | r=self.viper_range) 75 | if len(targets) == 0: 76 | return None 77 | target_density = list() 78 | for e in targets: 79 | target_density.append( 80 | len(self.find_units_wihtin_range(e, targets, r=self.viper_harm_range))) 81 | target_id = np.argmax(target_density) 82 | target = targets[target_id] 83 | return target 84 | 85 | def act(self, u, pos, mode): 86 | if len(self.enemy_combat_units) > 0: 87 | if ((len(u.orders) == 0 or u.orders[ 88 | 0].ability_id != ABILITY_ID.EFFECT_VIPERCONSUME.value) and 89 | u.float_attr.energy > 100): 90 | # enough energy 91 | closest_enemy = self.find_closest_enemy(u, self.enemy_combat_units) 92 | # follow the ground unit 93 | self_ground_units = [a for a in self.self_combat_units 94 | if a.int_attr.unit_type not in COMBAT_FLYING_UNITS] 95 | if len(self_ground_units) == 0: 96 | action = self.hold_fire(u) 97 | return action 98 | self_most_dangerous_ground_unit = self.find_closest_units_in_battle( 99 | self_ground_units, closest_enemy) 100 | pos = {'x': self_most_dangerous_ground_unit.float_attr.pos_x, 101 | 'y': self_most_dangerous_ground_unit.float_attr.pos_y} 102 | action = self.move_pos(u, pos) 103 | if self.dist_between_units(u, closest_enemy) <= self.viper_range: 104 | # use bomb 105 | air_target = self.find_densest_air_enemy_unit_in_range(u) 106 | if air_target is None: 107 | # print('use blind') 108 | ground_pos = self.find_densest_enemy_pos_in_range(u) 109 | if ground_pos is not None: 110 | action = self.blinding_cloud_attack_pos(u, ground_pos) 111 | else: 112 | # print('use bomb') 113 | action = self.parasitic_bomb_attack_target(u, air_target) 114 | else: 115 | # not enough energy 116 | # print('not enough energy') 117 | bases = self.dc.dd.base_pool.bases 118 | base_units = [bases[tag].unit for tag in bases 119 | if bases[tag].unit.float_attr.health > 500 120 | and bases[tag].unit.float_attr.build_progress == 1] 121 | if len(base_units) == 0: 122 | action = self.hold_fire(u) 123 | return action 124 | closest_base = self.find_closest_enemy(u, base_units) 125 | if self.dist_between_units(u, closest_base) < self.viper_consume_range: 126 | action = self.consume_target(u, closest_base) 127 | else: 128 | pos = {'x': closest_base.float_attr.pos_x, 129 | 'y': closest_base.float_attr.pos_y} 130 | action = self.move_pos(u, pos) 131 | else: 132 | pos = self.get_center_of_units(self.self_combat_units) 133 | action = self.move_pos(u, pos) 134 | return action 135 | -------------------------------------------------------------------------------- /tstarbot/scout/scout_mgr.py: -------------------------------------------------------------------------------- 1 | """Scout Manager""" 2 | from __future__ import absolute_import 3 | from __future__ import division 4 | from __future__ import print_function 5 | 6 | import tstarbot.scout.tasks.scout_task as st 7 | from tstarbot.scout.tasks.explor_task import ScoutExploreTask 8 | from tstarbot.scout.tasks.cruise_task import ScoutCruiseTask 9 | from tstarbot.scout.tasks.force_scout import ScoutForcedTask 10 | 11 | import tstarbot.scout.oppo_monitor as om 12 | import tstarbot.data.pool.macro_def as md 13 | 14 | 15 | class BaseScoutMgr(object): 16 | def __init__(self): 17 | pass 18 | 19 | def update(self, dc, am): 20 | pass 21 | 22 | def reset(self): 23 | pass 24 | 25 | DEF_EXPLORE_VER = 0 26 | 27 | class ZergScoutMgr(BaseScoutMgr): 28 | def __init__(self, dc): 29 | super(ZergScoutMgr, self).__init__() 30 | self._tasks = [] 31 | self._oppo_monitor = om.OppoMonitor() 32 | self._explore_ver = DEF_EXPLORE_VER 33 | self._forced_scout_count = 0 34 | self._assigned_forced_scout_count = 0 35 | 36 | # explore task rl 37 | self._rl_support = False 38 | self._explore_task_model = None 39 | self._map_max_x = 0 40 | self._map_max_y = 0 41 | 42 | self._init_config(dc) 43 | 44 | def _init_config(self, dc): 45 | if not hasattr(dc, 'config'): 46 | return 47 | 48 | if hasattr(dc.config, 'scout_explore_version'): 49 | self._explore_ver = dc.config.scout_explore_version 50 | #print('Scout explore version=', self._explore_ver) 51 | 52 | if hasattr(dc.config, 'max_forced_scout_count'): 53 | self._forced_scout_count = dc.config.max_forced_scout_count 54 | 55 | if hasattr(dc.config, 'scout_explore_task_model'): 56 | self._explore_task_model = dc.config.scout_explore_task_model 57 | 58 | if hasattr(dc.config, 'scout_map_max_x'): 59 | self._map_max_x = dc.config.scout_map_max_x 60 | 61 | if hasattr(dc.config, 'scout_map_max_y'): 62 | self._map_max_y = dc.config.scout_map_max_y 63 | 64 | if hasattr(dc.config, 'explore_rl_support'): 65 | self._rl_support = dc.config.explore_rl_support 66 | 67 | def reset(self): 68 | self._tasks = [] 69 | self._assigned_forced_scout_count = 0 70 | self._oppo_monitor = om.OppoMonitor() 71 | 72 | def update(self, dc, am): 73 | super(ZergScoutMgr, self).update(dc, am) 74 | #print('SCOUT scout_mgr update, task_num=', len(self._tasks)) 75 | self._dispatch_task(dc) 76 | self._check_task(dc) 77 | 78 | actions = [] 79 | # observe the enemy 80 | units = dc.sd.obs['units'] 81 | view_enemys = [] 82 | for u in units: 83 | if u.int_attr.alliance == md.AllianceType.ENEMY.value: 84 | view_enemys.append(u) 85 | 86 | #print('SCOUT view enemy number:', len(view_enemys)) 87 | for task in self._tasks: 88 | act = task.do_task(view_enemys, dc) 89 | if act is not None: 90 | actions.append(act) 91 | 92 | self._oppo_monitor.analysis(dc) 93 | if len(actions) > 0: 94 | am.push_actions(actions) 95 | 96 | def _check_task(self, dc): 97 | keep_tasks = [] 98 | done_tasks = [] 99 | for task in self._tasks: 100 | if task.status() == md.ScoutTaskStatus.DONE: 101 | done_tasks.append(task) 102 | elif task.status() == md.ScoutTaskStatus.SCOUT_DESTROY: 103 | done_tasks.append(task) 104 | else: 105 | keep_tasks.append(task) 106 | 107 | for task in done_tasks: 108 | task.post_process() 109 | 110 | if task.type() == md.ScoutTaskType.FORCED \ 111 | and task.status() == md.ScoutTaskStatus.DONE: 112 | dc.dd.scout_pool.remove_scout(task.scout().unit().int_attr.tag) 113 | 114 | self._tasks = keep_tasks 115 | 116 | def _dispatch_task(self, dc): 117 | if self._forced_scout_count > self._assigned_forced_scout_count: 118 | ret = self._dispatch_forced_scout_task(dc) 119 | if ret: 120 | self._assigned_forced_scout_count += 1 121 | 122 | if self._explore_ver < st.EXPLORE_V3: 123 | self._dispatch_cruise_task(dc) 124 | 125 | self._dispatch_explore_task(dc) 126 | 127 | def _dispatch_explore_task(self, dc): 128 | sp = dc.dd.scout_pool 129 | scout = sp.select_scout() 130 | if self._explore_ver >= st.EXPLORE_V3: 131 | target = sp.find_enemy_subbase_target() 132 | else: 133 | target = sp.find_furthest_idle_target() 134 | if scout is None or target is None: 135 | # not need dispatch task 136 | return 137 | 138 | if self._rl_support: 139 | # rl-based explore task 140 | if self._explore_task_model is None: 141 | raise ValueError('no valid explore_task_model provided!') 142 | 143 | from tstarbot.scout.tasks.explor_task_rl import ScoutExploreTaskRL 144 | 145 | task = ScoutExploreTaskRL(scout, target, sp.home_pos, 146 | self._explore_task_model, 147 | self._map_max_x, self._map_max_y) 148 | self._rl_support = False # currently only start one 149 | else: 150 | # rule base explore task 151 | task = ScoutExploreTask(scout, target, sp.home_pos, self._explore_ver) 152 | 153 | scout.is_doing_task = True 154 | target.has_scout = True 155 | self._tasks.append(task) 156 | 157 | def _dispatch_cruise_task(self, dc): 158 | sp = dc.dd.scout_pool 159 | scout = sp.select_scout() 160 | target = sp.find_cruise_target() 161 | if scout is None or target is None: 162 | return 163 | 164 | task = ScoutCruiseTask(scout, sp.home_pos, target) 165 | scout.is_doing_task = True 166 | target.has_cruise = True 167 | self._tasks.append(task) 168 | 169 | def _dispatch_forced_scout_task(self, dc): 170 | sp = dc.dd.scout_pool 171 | 172 | target = sp.find_forced_scout_target() 173 | #target = sp.find_furthest_idle_target() 174 | if target is None: 175 | return False 176 | 177 | scout = sp.select_drone_scout() 178 | if scout is None: 179 | return False 180 | 181 | task = ScoutForcedTask(scout, target, sp.home_pos) 182 | scout.is_doing_task = True 183 | target.has_scout = True 184 | self._tasks.append(task) 185 | 186 | return True 187 | 188 | -------------------------------------------------------------------------------- /tstarbot/sandbox/building_mgr.py: -------------------------------------------------------------------------------- 1 | """Building Manager""" 2 | from __future__ import absolute_import 3 | from __future__ import division 4 | from __future__ import print_function 5 | import random 6 | 7 | from s2clientprotocol import sc2api_pb2 as sc_pb 8 | from pysc2.lib.typeenums import UNIT_TYPEID, ABILITY_ID, RACE 9 | from pysc2.lib import TechTree 10 | 11 | TT = TechTree() 12 | 13 | 14 | def dist(unit1, unit2): 15 | return ((unit1.float_attr.pos_x - unit2.float_attr.pos_x)**2 + 16 | (unit1.float_attr.pos_y - unit2.float_attr.pos_y)**2)**0.5 17 | 18 | 19 | def dist_to_pos(unit, x, y): 20 | return ((unit.float_attr.pos_x - x)**2 + (unit.float_attr.pos_y - y)**2)**0.5 21 | 22 | 23 | def collect_units(units, unit_type, owner=1): 24 | unit_list = [] 25 | for u in units: 26 | if u.unit_type == unit_type and u.int_attr.owner == owner: 27 | unit_list.append(u) 28 | return unit_list 29 | 30 | 31 | class BaseBuildingMgr(object): 32 | def __init__(self): 33 | pass 34 | 35 | def update(self, dc, am): 36 | pass 37 | 38 | def reset(self): 39 | pass 40 | 41 | 42 | class ZergBuildingMgr(BaseBuildingMgr): 43 | def __init__(self): 44 | super(ZergBuildingMgr, self).__init__() 45 | self.vespen_status = False 46 | 47 | def reset(self): 48 | self.vespen_status = False 49 | 50 | def update(self, dc, am): 51 | super(ZergBuildingMgr, self).update(dc, am) 52 | self.obs = dc.sd.obs 53 | units = self.obs['units'] 54 | self.hatcheries = collect_units(units, UNIT_TYPEID.ZERG_LAIR.value) +\ 55 | collect_units(units, UNIT_TYPEID.ZERG_HATCHERY.value) 56 | drones = collect_units(units, UNIT_TYPEID.ZERG_DRONE.value) 57 | self.extractors = collect_units(units, UNIT_TYPEID.ZERG_EXTRACTOR.value) 58 | self.larvas = collect_units(units, UNIT_TYPEID.ZERG_LARVA.value) 59 | vespens = collect_units(units, UNIT_TYPEID.NEUTRAL_VESPENEGEYSER.value, 16) 60 | minerals = collect_units(units, UNIT_TYPEID.NEUTRAL_MINERALFIELD.value, 16) + \ 61 | collect_units(units, UNIT_TYPEID.NEUTRAL_MINERALFIELD750.value, 16) 62 | actions = [] 63 | if len(self.hatcheries) > 0: 64 | self.vespens = [g for g in vespens if dist(g, self.hatcheries[0]) < 15] 65 | self.minerals = [g for g in minerals if dist(g, self.hatcheries[0]) < 15] 66 | # TODO: impl here 67 | for hatchery in self.hatcheries: 68 | self.hatchery = hatchery 69 | cmds = dc.dd.build_command_queue.get(self.hatchery.tag) 70 | for cmd in cmds: 71 | if cmd.cmd_type == 0: # build 72 | unit_type = cmd.param['unit_id'] 73 | unit_data = TT.getUnitData(unit_type) 74 | if unit_data.isBuilding: 75 | action = self.produce_building(drones, unit_type) 76 | else: 77 | action = self.produce_unit(self.larvas, unit_type) 78 | actions.extend(action) 79 | elif cmd.cmd_type == 1: # expand 80 | pos = self.find_base_pos(self.hatcheries, dc) 81 | action = self.expand(drones, pos) 82 | actions.append(action) 83 | am.push_actions(actions) 84 | 85 | def produce_unit(self, larvas, unit_type): 86 | if len(larvas) == 0: 87 | return [] 88 | larva = random.choice(larvas) 89 | unit_data = TT.getUnitData(unit_type) 90 | action = sc_pb.Action() 91 | action.action_raw.unit_command.ability_id = unit_data.buildAbility 92 | action.action_raw.unit_command.unit_tags.append(larva.tag) 93 | return [action] 94 | 95 | def produce_building(self, drones, unit_type, pos_x=None, pos_y=None): 96 | base_x = self.hatchery.float_attr.pos_x 97 | base_y = self.hatchery.float_attr.pos_y 98 | drone = random.choice(drones) 99 | unit_data = TT.getUnitData(unit_type) 100 | pos = self.build_place(base_x, base_y, unit_type) 101 | action = sc_pb.Action() 102 | action.action_raw.unit_command.ability_id = unit_data.buildAbility 103 | if unit_type == UNIT_TYPEID.ZERG_EXTRACTOR.value: 104 | action.action_raw.unit_command.target_unit_tag = self.vespens[0].tag 105 | for extractor in self.extractors: 106 | if dist(extractor, self.vespens[0]) < 1: 107 | action.action_raw.unit_command.target_unit_tag = self.vespens[1].tag 108 | elif unit_type != UNIT_TYPEID.ZERG_LAIR.value: 109 | action.action_raw.unit_command.target_world_space_pos.x = pos[0] 110 | action.action_raw.unit_command.target_world_space_pos.y = pos[1] 111 | if unit_type != UNIT_TYPEID.ZERG_LAIR.value: 112 | action.action_raw.unit_command.unit_tags.append(drone.tag) 113 | else: 114 | action.action_raw.unit_command.unit_tags.append(self.hatchery.tag) 115 | return [action] 116 | 117 | def build_place(self, base_x, base_y, unit_type): 118 | delta_pos = {UNIT_TYPEID.ZERG_SPAWNINGPOOL.value: [6, 0], 119 | UNIT_TYPEID.ZERG_ROACHWARREN.value: [0, -6], 120 | UNIT_TYPEID.ZERG_HYDRALISKDEN.value: [6, -3]} 121 | if unit_type not in delta_pos: 122 | return [] 123 | if base_x < base_y: 124 | pos = [base_x + delta_pos[unit_type][0], 125 | base_y + delta_pos[unit_type][1]] 126 | else: 127 | pos = [base_x - delta_pos[unit_type][0], 128 | base_y - delta_pos[unit_type][1]] 129 | return pos 130 | 131 | def expand(self, drones, pos): 132 | d_min = 10000 133 | drone_tag = None 134 | for drone in drones: 135 | d = dist_to_pos(drone, pos[0], pos[1]) 136 | if d < d_min: 137 | drone_tag = drone.tag 138 | d_min = d 139 | action = sc_pb.Action() 140 | action.action_raw.unit_command.ability_id = ABILITY_ID.BUILD_HATCHERY.value 141 | action.action_raw.unit_command.target_world_space_pos.x = pos[0] 142 | action.action_raw.unit_command.target_world_space_pos.y = pos[1] 143 | action.action_raw.unit_command.unit_tags.append(drone_tag) 144 | return action 145 | 146 | def find_base_pos(self, hatcheries, dc): 147 | areas = dc.dd.base_pool.resource_cluster 148 | d_min = 10000 149 | pos = None 150 | for area in areas: 151 | d = dist_to_pos(hatcheries[0], area.ideal_base_pos[0], area.ideal_base_pos[1]) 152 | if d < d_min and d > 5: 153 | pos = area.ideal_base_pos 154 | d_min = d 155 | return pos 156 | -------------------------------------------------------------------------------- /tstarbot/sandbox/agents/rule_micro_agent.py: -------------------------------------------------------------------------------- 1 | """ lxhan 2 | A rule based multi-agent micro-management bot. seems working well that sometimes can reach grandmaster human player's performance. 3 | Run bin/eval_micro.py to see the game. 4 | """ 5 | from __future__ import absolute_import 6 | from __future__ import division 7 | from __future__ import print_function 8 | 9 | import random 10 | import math 11 | import numpy as np 12 | from s2clientprotocol import sc2api_pb2 as sc_pb 13 | from tstarbot.sandbox.bot_base import PoolBase, ManagerBase 14 | from tstarbot.sandbox.act_executor import ActExecutor 15 | 16 | 17 | UNIT_TYPE_MARINE = 48 18 | UNIT_TYPE_ROACH = 110 19 | 20 | MOVE = 1 21 | ATTACK = 23 22 | ATTACK_TOWARDS = 24 23 | 24 | ROACH_ATTACK_RANGE = 5.0 25 | 26 | 27 | # A pool containing all units that you want to operate 28 | class UnitPool(PoolBase): 29 | def __init__(self): 30 | self.marines = [] 31 | self.roaches = [] 32 | 33 | def update(self, obs): 34 | units = obs['units'] 35 | # print(units) 36 | self.collect_marine(units) 37 | self.collect_roach(units) 38 | 39 | def collect_marine(self, units): 40 | marines = [] 41 | for u in units: 42 | if u.unit_type == UNIT_TYPE_MARINE and u.int_attr.owner == 1: 43 | marines.append(u) 44 | # print("marine assigned_harvesters: {}".format(u.int_attr.assigned_harvesters)) 45 | self.marines = marines 46 | 47 | def collect_roach(self, units): 48 | roaches = [] 49 | for u in units: 50 | if u.unit_type == UNIT_TYPE_ROACH and u.int_attr.owner == 2: 51 | roaches.append(u) 52 | # print("roach target: {}".format(u.int_attr.engaged_target_tag)) 53 | self.roaches = roaches 54 | 55 | def get_marines(self): 56 | return self.marines 57 | 58 | def get_roaches(self): 59 | return self.roaches 60 | 61 | 62 | class MicroManager(ManagerBase): 63 | def __init__(self, pool): 64 | self._pool = pool 65 | self._range_high = 5 66 | self._range_low = -5 67 | self.marines = None 68 | self.roaches = None 69 | 70 | def execute(self): 71 | self.marines = self._pool.get_marines() 72 | self.roaches = self._pool.get_roaches() 73 | actions = self.operate() 74 | return actions 75 | 76 | def operate(self): 77 | actions = list() 78 | for m in self.marines: 79 | closest_enemy_dist = math.sqrt(self.cal_square_dist(m, self.find_closest_enemy(m, self.roaches))) 80 | if closest_enemy_dist < ROACH_ATTACK_RANGE and (m.float_attr.health / m.float_attr.health_max) < 0.3 and self.find_strongest_unit() > 0.9: 81 | action = self.run_away_from_closest_enemy(m) 82 | else: 83 | action = self.attack_weakest_enemy(m) 84 | actions.append(action) 85 | return actions 86 | 87 | def attack_closest_enemy(self, u): 88 | action = sc_pb.Action() 89 | action.action_raw.unit_command.ability_id = ATTACK 90 | target = self.find_closest_enemy(u, enemies=self.roaches) 91 | 92 | action.action_raw.unit_command.target_unit_tag = u.tag 93 | action.action_raw.unit_command.unit_tags.append(u.tag) 94 | return action 95 | 96 | def attack_weakest_enemy(self, u): 97 | action = sc_pb.Action() 98 | action.action_raw.unit_command.ability_id = ATTACK 99 | target = self.find_weakest_enemy(enemies=self.roaches) 100 | 101 | action.action_raw.unit_command.target_unit_tag = target.tag 102 | action.action_raw.unit_command.unit_tags.append(u.tag) 103 | return action 104 | 105 | def run_away_from_closest_enemy(self, u): 106 | action = sc_pb.Action() 107 | action.action_raw.unit_command.ability_id = MOVE 108 | target = self.find_closest_enemy(u, enemies=self.roaches) 109 | 110 | action.action_raw.unit_command.target_world_space_pos.x = u.float_attr.pos_x + (u.float_attr.pos_x - target.float_attr.pos_x) * 0.2 111 | action.action_raw.unit_command.target_world_space_pos.y = u.float_attr.pos_y + (u.float_attr.pos_y - target.float_attr.pos_y) * 0.2 112 | action.action_raw.unit_command.unit_tags.append(u.tag) 113 | return action 114 | 115 | def find_closest_enemy(self, u, enemies): 116 | dist = [] 117 | for e in enemies: 118 | dist.append(self.cal_square_dist(u, e)) 119 | idx = np.argmin(dist) 120 | # print('closest dist: {}'.format(math.sqrt(dist[idx]))) 121 | return enemies[idx] 122 | 123 | def find_weakest_enemy(self, enemies): 124 | hp = [] 125 | for e in enemies: 126 | hp.append(e.float_attr.health) 127 | idx = np.argmin(hp) 128 | if hp[idx] == np.max(hp): 129 | idx = 0 130 | return enemies[idx] 131 | 132 | def find_strongest_unit(self): 133 | hp = [] 134 | for m in self.marines: 135 | hp = m.float_attr.health / m.float_attr.health_max 136 | max_hp = np.max(hp) 137 | return max_hp 138 | 139 | @staticmethod 140 | def cal_square_dist(u1, u2): 141 | return pow(u1.float_attr.pos_x - u2.float_attr.pos_x, 2) + pow(u1.float_attr.pos_y - u2.float_attr.pos_y, 2) 142 | 143 | 144 | class MicroAgent: 145 | """A random agent for starcraft.""" 146 | def __init__(self, env): 147 | self._pools = [] 148 | self._managers = [] 149 | self._env = env 150 | self._executor = [] 151 | 152 | def setup(self): 153 | pool = UnitPool() 154 | task_manager = MicroManager(pool) 155 | self._pools.append(pool) 156 | self._managers.append(task_manager) 157 | self._executor = ActExecutor(self._env) 158 | 159 | def reset(self): 160 | timesteps = self._env.reset() 161 | return timesteps 162 | 163 | def run(self, n): 164 | return self._run_inner(n) 165 | 166 | def _run_inner(self, n): 167 | try: 168 | """episode loop """ 169 | step_num = 0 170 | timesteps = self.reset() 171 | while True: 172 | obs = timesteps[0].observation 173 | for pool in self._pools: 174 | pool.update(obs) 175 | 176 | actions = [] 177 | for manager in self._managers: 178 | part_actions = manager.execute() 179 | actions.extend(part_actions) 180 | 181 | result = self._executor.exec_raw(actions) 182 | if result[1]: 183 | timesteps = self.reset() 184 | continue 185 | # break 186 | timesteps = result[0] 187 | 188 | if step_num > n: 189 | break 190 | step_num += 1 191 | except KeyboardInterrupt: 192 | print("SC2Imp exception") 193 | -------------------------------------------------------------------------------- /tstarbot/sandbox/py_multiplayer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import portpicker 3 | import sys 4 | import time 5 | import importlib 6 | 7 | from pysc2 import maps 8 | from pysc2.env import sc2_env 9 | from pysc2 import run_configs 10 | from pysc2.env import environment 11 | from pysc2.lib import features 12 | from pysc2.lib import point 13 | from pysc2.lib import run_parallel 14 | from pysc2.tests import utils 15 | import copy 16 | 17 | from s2clientprotocol import common_pb2 as sc_common 18 | from s2clientprotocol import sc2api_pb2 as sc_pb 19 | from s2clientprotocol import debug_pb2 20 | 21 | from tstarbot.agents.zerg_agent import ZergAgent 22 | 23 | from absl import flags 24 | from absl import app 25 | 26 | races = { 27 | "R": sc_common.Random, 28 | "P": sc_common.Protoss, 29 | "T": sc_common.Terran, 30 | "Z": sc_common.Zerg, 31 | } 32 | 33 | FLAGS = flags.FLAGS 34 | flags.DEFINE_bool("realtime", True, "Whether to run in real time.") 35 | 36 | flags.DEFINE_integer("step_mul", 8, "Game steps per agent step.") 37 | flags.DEFINE_float("sleep_time", 0.2, "Sleep time between agent steps.") 38 | 39 | flags.DEFINE_string("agent1", "pysc2.agents.random_agent.RandomAgent", 40 | "Which agent to run") 41 | flags.DEFINE_string("agent1_config", "", 42 | "Agent's config in py file. Pass it as python module." 43 | "E.g., tstarbot.agents.dft_config") 44 | flags.DEFINE_string("agent2", None, 45 | "Which agent to run") 46 | flags.DEFINE_string("agent2_config", "", 47 | "Agent's config in py file. Pass it as python module." 48 | "E.g., tstarbot.agents.dft_config") 49 | flags.DEFINE_enum("agent1_race", None, sc2_env.races.keys(), "Agent1's race.") 50 | flags.DEFINE_enum("agent2_race", None, sc2_env.races.keys(), "Agent2's race.") 51 | 52 | flags.DEFINE_bool("disable_fog_1", False, "Turn off the Fog of War for agent 1.") 53 | flags.DEFINE_bool("disable_fog_2", False, "Turn off the Fog of War for agent 2.") 54 | 55 | flags.DEFINE_string("map", None, "Name of a map to use.") 56 | flags.mark_flag_as_required("map") 57 | 58 | 59 | def test_multi_player(agents, disable_fog): 60 | players = 2 61 | if len(agents) == 2: 62 | agent1, agent2 = agents 63 | run_config = run_configs.get() 64 | parallel = run_parallel.RunParallel() 65 | map_inst = maps.get(FLAGS.map) 66 | 67 | screen_size_px = point.Point(64, 64) 68 | minimap_size_px = point.Point(32, 32) 69 | interface = sc_pb.InterfaceOptions( 70 | raw=True, score=True) 71 | screen_size_px.assign_to(interface.feature_layer.resolution) 72 | minimap_size_px.assign_to(interface.feature_layer.minimap_resolution) 73 | 74 | # Reserve a whole bunch of ports for the weird multiplayer implementation. 75 | ports = [portpicker.pick_unused_port() for _ in range(1 + players * 2)] 76 | print("Valid Ports: %s", ports) 77 | 78 | # Actually launch the game processes. 79 | print("start") 80 | sc2_procs = [run_config.start(extra_ports=ports) for _ in range(players)] 81 | controllers = [p.controller for p in sc2_procs] 82 | 83 | try: 84 | # Save the maps so they can access it. 85 | map_path = os.path.basename(map_inst.path) 86 | print("save_map") 87 | parallel.run((c.save_map, map_path, run_config.map_data(map_inst.path)) 88 | for c in controllers) 89 | 90 | # Create the create request. 91 | real_time = True 92 | create = sc_pb.RequestCreateGame( 93 | local_map=sc_pb.LocalMap(map_path=map_path), realtime=real_time) 94 | for _ in range(players): 95 | create.player_setup.add(type=sc_pb.Participant) 96 | 97 | # Create the join request. 98 | join1 = sc_pb.RequestJoinGame(race=races[FLAGS.agent1_race], options=interface) 99 | join1.shared_port = ports.pop() 100 | join1.server_ports.game_port = ports.pop() 101 | join1.server_ports.base_port = ports.pop() 102 | join1.client_ports.add(game_port=ports.pop(), base_port=ports.pop()) 103 | 104 | join2 = copy.copy(join1) 105 | join2.race = races[FLAGS.agent2_race] 106 | 107 | # This is where actually game plays 108 | # Create and Join 109 | print("create") 110 | controllers[0].create_game(create) 111 | print("join") 112 | parallel.run((c.join_game, join) for c, join in zip(controllers, [join1, join2])) 113 | 114 | controllers[0]._client.send(debug=sc_pb.RequestDebug( 115 | debug=[debug_pb2.DebugCommand(game_state=1)])) 116 | if disable_fog[0]: 117 | controllers[0].disable_fog() 118 | if disable_fog[1]: 119 | controllers[1].disable_fog() 120 | 121 | print("run") 122 | game_info = controllers[0].game_info() 123 | extractors = features.Features(game_info) 124 | for game_loop in range(1, 100000): # steps per episode 125 | # Step the game 126 | step_mul = FLAGS.step_mul 127 | if not real_time: 128 | parallel.run((c.step, step_mul) for c in controllers) 129 | else: 130 | time.sleep(FLAGS.sleep_time) 131 | 132 | # Observe 133 | obs = parallel.run(c.observe for c in controllers) 134 | agent_obs = [extractors.transform_obs(o.observation) for o in obs] 135 | game_info = [None for c in controllers] 136 | 137 | if not any(o.player_result for o in obs): # Episode over. 138 | game_info = parallel.run(c.game_info for c in controllers) 139 | timesteps = tuple(environment.TimeStep(step_type=0, 140 | reward=0, 141 | discount=0, observation=o, 142 | game_info=info) 143 | for o, info in zip(agent_obs, game_info)) 144 | 145 | # Act 146 | if agent1 is not None: 147 | actions1 = agent1.step(timesteps[0]) 148 | else: 149 | actions1 = [] 150 | actions2 = agent2.step(timesteps[1]) 151 | actions = [actions1, actions2] 152 | funcs_with_args = [(c.acts, a) for c, a in zip(controllers, actions)] 153 | parallel.run(funcs_with_args) 154 | 155 | # Done with the game. 156 | print("leave") 157 | parallel.run(c.leave for c in controllers) 158 | finally: 159 | print("quit") 160 | # Done, shut down. Don't depend on parallel since it might be broken. 161 | for c in controllers: 162 | c.quit() 163 | for p in sc2_procs: 164 | p.close() 165 | 166 | 167 | def main(unused_argv): 168 | """Run an agent.""" 169 | maps.get(FLAGS.map) # Assert the map exists. 170 | 171 | agent_module, agent_name = FLAGS.agent1.rsplit(".", 1) 172 | agent_cls = getattr(importlib.import_module(agent_module), agent_name) 173 | agent1_kwargs = {} 174 | if FLAGS.agent1_config: 175 | agent1_kwargs['config_path'] = FLAGS.agent1_config 176 | agent1 = agent_cls(**agent1_kwargs) 177 | 178 | if FLAGS.agent2: 179 | agent_module, agent_name = FLAGS.agent2.rsplit(".", 1) 180 | agent_cls = getattr(importlib.import_module(agent_module), agent_name) 181 | agent2_kwargs = {} 182 | if FLAGS.agent1_config: 183 | agent2_kwargs['config_path'] = FLAGS.agent2_config 184 | agent2 = agent_cls(**agent2_kwargs) 185 | test_multi_player([agent2, agent1], [FLAGS.disable_fog_2, FLAGS.disable_fog_1]) 186 | else: 187 | test_multi_player([None, agent1], [FLAGS.disable_fog_2, FLAGS.disable_fog_1]) 188 | 189 | 190 | if __name__ == "__main__": 191 | app.run(main) 192 | -------------------------------------------------------------------------------- /tstarbot/scout/tasks/force_scout.py: -------------------------------------------------------------------------------- 1 | import pysc2.lib.typeenums as tp 2 | from enum import Enum 3 | import numpy as np 4 | 5 | from tstarbot.scout.tasks.scout_task import ScoutTask 6 | import tstarbot.scout.tasks.scout_task as st 7 | from tstarbot.data.pool import macro_def as md 8 | 9 | class ForcedScoutStep(Enum): 10 | STEP_INIT = 0 11 | STEP_MOVE_TO_BASE = 1 12 | STEP_CIRCLE_MINERAL = 2 13 | STEP_RETREAT = 3 14 | 15 | 16 | class ScoutForcedTask(ScoutTask): 17 | def __init__(self, scout, target, home): 18 | super(ScoutForcedTask, self).__init__(scout, home) 19 | self._target = target 20 | self._circle_path = [] 21 | self._cur_circle_target = 0 # index of _circle_path 22 | self._cur_step = ForcedScoutStep.STEP_INIT 23 | 24 | def type(self): 25 | return md.ScoutTaskType.FORCED 26 | 27 | def post_process(self): 28 | self._target.has_scout = False 29 | self._scout.is_doing_task = False 30 | 31 | if self._status == md.ScoutTaskStatus.SCOUT_DESTROY: 32 | self._target.has_enemy_base = True 33 | self._target.has_army = True 34 | 35 | def _do_task_inner(self, view_enemys, dc): 36 | if self._check_scout_lost(): 37 | self._status = md.ScoutTaskStatus.SCOUT_DESTROY 38 | return None 39 | 40 | if self._detect_enemy(view_enemys, dc): 41 | self._status = md.ScoutTaskStatus.DONE 42 | self._cur_step = ForcedScoutStep.STEP_RETREAT 43 | return self._move_to_home() 44 | 45 | if self._cur_step == ForcedScoutStep.STEP_INIT: 46 | # step one, move to target base 47 | action = self._move_to_target(self._target.pos) 48 | self._cur_step = ForcedScoutStep.STEP_MOVE_TO_BASE 49 | return action 50 | elif self._cur_step == ForcedScoutStep.STEP_MOVE_TO_BASE: 51 | # step two, circle target base 52 | if self._arrive_xy(self.scout().unit(), 53 | self._target.pos[0], self._target.pos[1], 10): 54 | #self._generate_circle_path(self._target.area.m_pos) 55 | self._generate_base_around_path() 56 | self._cur_step = ForcedScoutStep.STEP_CIRCLE_MINERAL 57 | return self._move_to_target(self._circle_path[0]) 58 | else: 59 | return None 60 | elif self._cur_step == ForcedScoutStep.STEP_CIRCLE_MINERAL: 61 | cur_target = self._circle_path[self._cur_circle_target] 62 | if self._arrive_xy(self.scout().unit(), 63 | cur_target[0], cur_target[1], 1): 64 | self._cur_circle_target += 1 65 | if self._cur_circle_target < len(self._circle_path): 66 | return self._move_to_target(self._circle_path[self._cur_circle_target]) 67 | else: 68 | self._cur_circle_target = 0 69 | return self._move_to_target(self._circle_path[self._cur_circle_target]) 70 | #self._cur_step = ForcedScoutStep.STEP_RETREAT 71 | #return self._move_to_home() 72 | else: 73 | return None 74 | elif self._cur_step == ForcedScoutStep.STEP_RETREAT: 75 | # step three, retreat 76 | if self._arrive_xy(self.scout().unit(), 77 | self._home[0], self._home[1], 10): 78 | self._status = md.ScoutTaskStatus.DONE 79 | 80 | return None 81 | 82 | def _generate_circle_path(self, m_pos): 83 | m_dim = len(m_pos) 84 | dis_mat = np.zeros((m_dim, m_dim)) 85 | 86 | for i in range(m_dim): 87 | for j in range(m_dim): 88 | dis_mat[i][j] = self.distance(m_pos[i], m_pos[j]) 89 | 90 | max_i = 0 91 | max_dist = 0 92 | 93 | for i in range(m_dim): 94 | for j in range(m_dim): 95 | if dis_mat[i][j] > max_dist: 96 | max_dist = dis_mat[i][j] 97 | max_i = i 98 | 99 | dist_list = [] 100 | for i in range(m_dim): 101 | d = {'idx': i, 'distance': dis_mat[max_i][i]} 102 | dist_list.append(d) 103 | 104 | dist_list.sort(key=lambda x: x['distance']) 105 | 106 | for i in range(m_dim): 107 | self._circle_path.append(m_pos[dist_list[i]['idx']]) 108 | 109 | def _generate_base_around_path(self): 110 | pos_arr_1 = [] 111 | pos_arr_2 = [] 112 | unit_pos = self._sort_target_unit_by_pos() 113 | centor_pos = self._target.area.ideal_base_pos 114 | for pos in unit_pos: 115 | diff_x = centor_pos[0] - pos[0] 116 | diff_y = centor_pos[1] - pos[1] 117 | pos_arr_1.append((centor_pos[0] + 1.2 * diff_x, 118 | centor_pos[1] + 1.2 * diff_y)) 119 | pos_arr_2.append((centor_pos[0] - 1.4 * diff_x, 120 | centor_pos[1] - 1.4 * diff_y)) 121 | 122 | pos_arr_2_first_half = pos_arr_2[:len(pos_arr_2)] 123 | pos_arr_2_sec_half = pos_arr_2[len(pos_arr_2):] 124 | 125 | for pos_0 in reversed(pos_arr_2_first_half): 126 | self._circle_path.append(pos_0) 127 | 128 | for pos_1 in pos_arr_1: 129 | self._circle_path.append(pos_1) 130 | 131 | for pos_2 in reversed(pos_arr_2_sec_half): 132 | self._circle_path.append(pos_2) 133 | 134 | #print("Scout path:", self._circle_path) 135 | 136 | def _sort_target_unit_by_pos(self): 137 | total = self._target.area.m_pos + self._target.area.g_pos 138 | #print('SCOUT before sort:', total) 139 | m_x = [item[0] for item in self._target.area.m_pos] 140 | m_y = [item[1] for item in self._target.area.m_pos] 141 | x_diff = abs(max(m_x) - min(m_x)) 142 | y_diff = abs(max(m_y) - min(m_y)) 143 | if x_diff > y_diff: 144 | '''sort by x axis''' 145 | #print('SCOUT sort by x axis') 146 | count = len(total) 147 | for i in range(1, count): 148 | item = total[i] 149 | j = i - 1 150 | while j >= 0: 151 | tmp = total[j] 152 | if tmp[0] > item[0]: 153 | total[j + 1] = total[j] 154 | total[j] = item 155 | j -= 1 156 | else: 157 | #print('SCOUT sort by y axis') 158 | '''sort by y axis ''' 159 | count = len(total) 160 | for i in range(1, count): 161 | item = total[i] 162 | j = i - 1 163 | while j >= 0: 164 | tmp = total[j] 165 | if tmp[1] > item[1]: 166 | total[j + 1] = total[j] 167 | total[j] = item 168 | j -= 1 169 | #print('SCOUT after sort:', total) 170 | return total 171 | 172 | 173 | def _arrive_xy(self, u, target_x, target_y, error): 174 | x = u.float_attr.pos_x - target_x 175 | y = u.float_attr.pos_y - target_y 176 | distance = (x * x + y * y) ** 0.5 177 | 178 | return distance < error 179 | 180 | def distance(self, pos1, pos2): 181 | x = pos1[0] - pos2[0] 182 | y = pos1[1] - pos2[1] 183 | 184 | return (x * x + y * y) ** 0.5 185 | 186 | def _detect_enemy(self, view_enemys, dc): 187 | spool = dc.dd.scout_pool 188 | armys = [] 189 | eggs = [] 190 | spawning_on = False 191 | for enemy in view_enemys: 192 | if enemy.unit_type in md.COMBAT_UNITS: 193 | armys.append(enemy) 194 | elif enemy.unit_type == tp.UNIT_TYPEID.ZERG_EGG.value: 195 | eggs.append(enemy) 196 | elif enemy.unit_type == tp.UNIT_TYPEID.ZERG_SPAWNINGPOOL.value: 197 | if enemy.float_attr.build_progress >= st.BUILD_PROGRESS_FINISH: 198 | spawning_on = True 199 | else: 200 | pass 201 | 202 | me = (self._scout.unit().float_attr.pos_x, 203 | self._scout.unit().float_attr.pos_y) 204 | 205 | scout_armys = [] 206 | for unit in armys: 207 | dist = md.calculate_distance(me[0], me[1], 208 | unit.float_attr.pos_x, 209 | unit.float_attr.pos_y) 210 | if dist < st.SCOUT_CRUISE_RANGE: 211 | scout_armys.append(unit) 212 | if len(eggs) > 0 and spawning_on: 213 | #print("SCOUT escape, may be zergling is building") 214 | return True 215 | 216 | return len(scout_armys) > 0 217 | -------------------------------------------------------------------------------- /tstarbot/bin/eval_agent.py: -------------------------------------------------------------------------------- 1 | """evaluate an agent. Adopted from pysc2.bin.agent""" 2 | 3 | from __future__ import absolute_import 4 | from __future__ import division 5 | from __future__ import print_function 6 | import importlib 7 | import time 8 | 9 | from pysc2 import maps 10 | from pysc2.env import sc2_env 11 | from pysc2.lib import stopwatch 12 | from absl import app 13 | from absl import flags 14 | 15 | 16 | races = { 17 | "R": sc2_env.Race.random, 18 | "P": sc2_env.Race.protoss, 19 | "T": sc2_env.Race.terran, 20 | "Z": sc2_env.Race.zerg, 21 | } 22 | 23 | difficulties = { 24 | "1": sc2_env.Difficulty.very_easy, 25 | "2": sc2_env.Difficulty.easy, 26 | "3": sc2_env.Difficulty.medium, 27 | "4": sc2_env.Difficulty.medium_hard, 28 | "5": sc2_env.Difficulty.hard, 29 | "6": sc2_env.Difficulty.hard, 30 | "7": sc2_env.Difficulty.very_hard, 31 | "8": sc2_env.Difficulty.cheat_vision, 32 | "9": sc2_env.Difficulty.cheat_money, 33 | "A": sc2_env.Difficulty.cheat_insane, 34 | } 35 | 36 | FLAGS = flags.FLAGS 37 | flags.DEFINE_bool("render", True, "Whether to render with pygame.") 38 | flags.DEFINE_string("agent1", "Bot", 39 | "Agent for player 1 ('Bot' for internal AI)") 40 | flags.DEFINE_string("agent1_config", "", 41 | "Agent's config in py file. Pass it as python module." 42 | "E.g., tstarbot.agents.dft_config") 43 | flags.DEFINE_string("agent2", None, 44 | "Agent for player 2 ('Bot' for internal AI, None for one player map.)") 45 | flags.DEFINE_string("agent2_config", "", 46 | "Agent's config in py file. Pass it as python module." 47 | "E.g., tstarbot.agents.dft_config") 48 | 49 | flags.DEFINE_enum("agent1_race", 'Z', races.keys(), "Agent1's race.") 50 | flags.DEFINE_enum("agent2_race", 'Z', races.keys(), "Agent2's race.") 51 | flags.DEFINE_string("difficulty", "A", 52 | "Bot difficulty (from '1' to 'A')") 53 | 54 | flags.DEFINE_integer("screen_resolution", 84, 55 | "Resolution for screen feature layers.") 56 | flags.DEFINE_integer("minimap_resolution", 64, 57 | "Resolution for minimap feature layers.") 58 | flags.DEFINE_float("screen_ratio", "1.33", 59 | "Screen ratio of width / height") 60 | flags.DEFINE_string("agent_interface_format", "feature", 61 | "Agent Interface Format: [feature|rgb]") 62 | 63 | flags.DEFINE_integer("max_agent_episodes", 3, "Total agent episodes.") 64 | flags.DEFINE_integer("game_steps_per_episode", 0, "Game steps per episode.") 65 | flags.DEFINE_integer("step_mul", 8, "Game steps per agent step.") 66 | flags.DEFINE_integer("random_seed", None, "Random_seed used in game_core.") 67 | 68 | flags.DEFINE_bool("disable_fog", False, "Turn off the Fog of War.") 69 | flags.DEFINE_bool("profile", False, "Whether to turn on code profiling.") 70 | flags.DEFINE_bool("trace", False, "Whether to trace the code execution.") 71 | flags.DEFINE_integer("parallel", 2, "How many instances to run in parallel.") 72 | 73 | flags.DEFINE_bool("save_replay", True, "Whether to save a replay at the end.") 74 | 75 | flags.DEFINE_string("map", None, "Name of a map to use.") 76 | flags.mark_flag_as_required("map") 77 | 78 | 79 | def run_loop(agents, env, max_episodes=1): 80 | """A run loop to have agents and an environment interact.""" 81 | me_id = 0 82 | total_frames = 0 83 | n_episode = 0 84 | n_win = 0 85 | result_stat = [0] * 3 # n_draw, n_win, n_loss 86 | start_time = time.time() 87 | 88 | action_spec = env.action_spec() 89 | observation_spec = env.observation_spec() 90 | for agent, obs_spec, act_spec in zip(agents, observation_spec, action_spec): 91 | agent.setup(obs_spec, act_spec) 92 | 93 | try: 94 | while True: 95 | timesteps = env.reset() 96 | for a in agents: 97 | a.reset() 98 | 99 | # run this episode 100 | while True: 101 | total_frames += 1 102 | actions = [agent.step(timestep) for agent, timestep in 103 | zip(agents, timesteps)] 104 | timesteps = env.step(actions) 105 | if timesteps[me_id].last(): 106 | result_stat[timesteps[0].reward] += 1 107 | break 108 | 109 | # update 110 | n_episode += 1 111 | 112 | # print info 113 | outcome = timesteps[me_id].reward 114 | if outcome > 0: 115 | n_win += 1 116 | elif outcome == 0: 117 | n_win += 0.5 118 | 119 | win_rate = n_win / n_episode 120 | print( 121 | 'episode = {}, outcome = {}, n_win = {}, ' 122 | 'current winning rate = {}'.format(n_episode, outcome, n_win, win_rate) 123 | ) 124 | 125 | # done? 126 | if n_episode >= max_episodes: 127 | break 128 | except KeyboardInterrupt: 129 | pass 130 | finally: 131 | elapsed_time = time.time() - start_time 132 | print("Took %.3f seconds for %s steps: %.3f fps" % ( 133 | elapsed_time, total_frames, total_frames / elapsed_time)) 134 | 135 | 136 | def run_thread(players, agents, map_name, visualize): 137 | rs = FLAGS.random_seed 138 | if FLAGS.random_seed is None: 139 | rs = int((time.time() % 1) * 1000000) 140 | print("Random seed: {}.".format(rs)) 141 | screen_res = (int(FLAGS.screen_ratio * FLAGS.screen_resolution) // 4 * 4, 142 | FLAGS.screen_resolution) 143 | if FLAGS.agent_interface_format == 'feature': 144 | agent_interface_format = sc2_env.AgentInterfaceFormat( 145 | feature_dimensions=sc2_env.Dimensions( 146 | screen=screen_res, 147 | minimap=FLAGS.minimap_resolution)) 148 | elif FLAGS.agent_interface_format == 'rgb': 149 | agent_interface_format = sc2_env.AgentInterfaceFormat( 150 | rgb_dimensions=sc2_env.Dimensions( 151 | screen=screen_res, 152 | minimap=FLAGS.minimap_resolution)) 153 | else: 154 | raise NotImplementedError 155 | with sc2_env.SC2Env( 156 | map_name=map_name, 157 | players=players, 158 | step_mul=FLAGS.step_mul, 159 | random_seed=rs, 160 | game_steps_per_episode=FLAGS.game_steps_per_episode, 161 | agent_interface_format=agent_interface_format, 162 | score_index=-1, # this indicates the outcome is reward 163 | disable_fog=FLAGS.disable_fog, 164 | visualize=visualize) as env: 165 | 166 | run_loop(agents, env, max_episodes=FLAGS.max_agent_episodes) 167 | if FLAGS.save_replay: 168 | env.save_replay("%s vs. %s" % (FLAGS.agent1, FLAGS.agent2)) 169 | 170 | 171 | def get_agent(agt_path, config_path=""): 172 | agent_module, name = agt_path.rsplit('.', 1) 173 | agt_cls = getattr(importlib.import_module(agent_module), name) 174 | agent_kwargs = {} 175 | if config_path: 176 | agent_kwargs['config_path'] = config_path 177 | agent = agt_cls(**agent_kwargs) 178 | return agent 179 | 180 | 181 | def main(unused_argv): 182 | """Run an agent.""" 183 | stopwatch.sw.enabled = FLAGS.profile or FLAGS.trace 184 | stopwatch.sw.trace = FLAGS.trace 185 | 186 | maps.get(FLAGS.map) # Assert the map exists. 187 | players = [] 188 | agents = [] 189 | bot_difficulty = difficulties[FLAGS.difficulty] 190 | if FLAGS.agent1 == 'Bot': 191 | players.append(sc2_env.Bot(races['Z'], bot_difficulty)) 192 | else: 193 | players.append(sc2_env.Agent(races[FLAGS.agent1_race])) 194 | agents.append(get_agent(FLAGS.agent1, FLAGS.agent1_config)) 195 | if FLAGS.agent2 is None: 196 | pass 197 | elif FLAGS.agent2 == 'Bot': 198 | players.append(sc2_env.Bot(races['Z'], bot_difficulty)) 199 | else: 200 | players.append(sc2_env.Agent(races[FLAGS.agent2_race])) 201 | agents.append(get_agent(FLAGS.agent2, FLAGS.agent2_config)) 202 | 203 | run_thread(players, agents, FLAGS.map, FLAGS.render) 204 | 205 | if FLAGS.profile: 206 | print(stopwatch.sw) 207 | 208 | 209 | def entry_point(): # Needed so setup.py scripts work. 210 | app.run(main) 211 | 212 | 213 | if __name__ == "__main__": 214 | app.run(main) 215 | -------------------------------------------------------------------------------- /tstarbot/data/pool/enemy_pool.py: -------------------------------------------------------------------------------- 1 | """EnemyPool Class.""" 2 | from __future__ import absolute_import 3 | from __future__ import division 4 | from __future__ import print_function 5 | import collections 6 | import operator 7 | 8 | from pysc2.lib.typeenums import UNIT_TYPEID 9 | import numpy as np 10 | 11 | from tstarbot.data.pool.pool_base import PoolBase 12 | from tstarbot.data.pool import macro_def as tm 13 | from tstarbot.data.pool.macro_def import BUILDING_UNITS 14 | 15 | 16 | class EnemyCluster(object): 17 | 18 | def __init__(self, units): 19 | self._units = units 20 | 21 | def __repr__(self): 22 | return ('EnemyCluster(CombatUnits(%d), WorkerUnits(%d))' % 23 | (self.num_combat_units, self.num_worker_units)) 24 | 25 | @property 26 | def num_units(self): 27 | return len(self._units) 28 | 29 | @property 30 | def num_worker_units(self): 31 | return len(self.worker_units) 32 | 33 | @property 34 | def num_combat_units(self): 35 | return len(self.combat_units) 36 | 37 | @property 38 | def units(self): 39 | return self._units 40 | 41 | @property 42 | def worker_units(self): 43 | return [u for u in self._units 44 | if u.int_attr.unit_type in tm.WORKER_UNITS] 45 | 46 | @property 47 | def combat_units(self): 48 | return [u for u in self._units 49 | if u.int_attr.unit_type in tm.COMBAT_UNITS] 50 | 51 | @property 52 | def centroid(self): 53 | x = sum(u.float_attr.pos_x for u in self._units) / len(self._units) 54 | y = sum(u.float_attr.pos_y for u in self._units) / len(self._units) 55 | return {'x': x, 'y': y} 56 | 57 | 58 | class EnemyPool(PoolBase): 59 | 60 | def __init__(self, dd): 61 | super(PoolBase, self).__init__() 62 | self._dd = dd 63 | self._enemy_units = list() 64 | self._enemy_clusters = list() 65 | self._self_bases = list() 66 | self._is_set_main_base = False 67 | self._main_base_pos = None 68 | 69 | def reset(self): 70 | self._enemy_units = list() 71 | self._enemy_clusters = list() 72 | self._self_bases = list() 73 | self._is_set_main_base = False 74 | self._main_base_pos = None 75 | 76 | def update(self, timestep): 77 | self._enemy_units = list() 78 | self._self_bases = list() 79 | units = timestep.observation['units'] 80 | if not self._is_set_main_base: 81 | for u in units: 82 | if u.int_attr.unit_type in [UNIT_TYPEID.ZERG_HATCHERY.value]: 83 | self._main_base_pos = {'x': u.float_attr.pos_x, 84 | 'y': u.float_attr.pos_y} 85 | self._is_set_main_base = True 86 | break 87 | 88 | for u in units: 89 | if self._is_enemy_unit(u): 90 | self._enemy_units.append(u) 91 | else: 92 | if (u.int_attr.unit_type == UNIT_TYPEID.ZERG_HATCHERY.value or 93 | u.int_attr.unit_type == UNIT_TYPEID.ZERG_LAIR.value or 94 | u.int_attr.unit_type == UNIT_TYPEID.ZERG_HIVE): 95 | self._self_bases.append(u) 96 | 97 | self._enemy_clusters = list() 98 | for units in self._agglomerative_cluster(self._enemy_units): 99 | self.enemy_clusters.append(EnemyCluster(units)) 100 | 101 | @property 102 | def units(self): 103 | return self._enemy_units 104 | 105 | @property 106 | def num_worker_units(self): 107 | return sum(cluster.num_worker_units for cluster in self._enemy_clusters) 108 | 109 | @property 110 | def num_combat_units(self): 111 | return sum(cluster.num_combat_units for cluster in self._enemy_clusters) 112 | 113 | @property 114 | def main_base_pos(self): 115 | return self._main_base_pos 116 | 117 | @property 118 | def enemy_clusters(self): 119 | return self._enemy_clusters 120 | 121 | @property 122 | def weakest_cluster(self): 123 | if len(self._enemy_clusters) == 0: 124 | return None 125 | return min(self._enemy_clusters, 126 | key=lambda c: c.num_combat_units if c.num_combat_units >= 3 127 | else float('inf')) 128 | 129 | @property 130 | def strongest_cluster(self): 131 | if len(self._enemy_clusters) == 0: 132 | return None 133 | return max(self._enemy_clusters, key=lambda c: c.num_combat_units) 134 | 135 | @property 136 | def closest_cluster(self): 137 | if len(self._enemy_clusters) == 0 or len(self._self_bases) == 0: 138 | return None 139 | 140 | c_targets = [c for c in self._enemy_clusters 141 | if (c.num_units > 1 or 142 | (c.num_units == 1 and 143 | c.units[0].int_attr.unit_type in BUILDING_UNITS))] 144 | if len(c_targets) == 0: 145 | return None 146 | target_c = min(c_targets, 147 | key=lambda c: self._distance(c.centroid, 148 | self._main_base_pos)) 149 | return target_c 150 | 151 | @property 152 | def priority_pos(self): 153 | sorted_x = sorted(self._dd.base_pool.enemy_home_dist.items(), 154 | key=operator.itemgetter(1), reverse=True) 155 | # sorted_x = sorted(self._dd.base_pool.home_dist.items(), 156 | # key=operator.itemgetter(1)) 157 | sorted_pos_list = [x[0].ideal_base_pos for x in sorted_x] 158 | for pos in sorted_pos_list: 159 | pos = {'x': pos[0], 160 | 'y': pos[1]} 161 | detected = [u for u in self.units 162 | if u.int_attr.unit_type in BUILDING_UNITS and 163 | self._distance({'x': u.float_attr.pos_x, 164 | 'y': u.float_attr.pos_y}, pos) < 10] 165 | if len(detected) > 0: 166 | return pos 167 | return None 168 | 169 | def _is_enemy_unit(self, u): 170 | if u.int_attr.alliance != tm.AllianceType.ENEMY.value: 171 | return False 172 | else: 173 | return True 174 | 175 | def _agglomerative_cluster(self, units, merge_distance=20, grid_size=8): 176 | 177 | def get_centroid(units): 178 | x = sum(u.float_attr.pos_x for u in units) / len(units) 179 | y = sum(u.float_attr.pos_y for u in units) / len(units) 180 | return (x, y) 181 | 182 | def agglomerative_step(cluster_map): 183 | merge_threshold = merge_distance ** 2 184 | min_distance, min_pair = float('inf'), None 185 | centroids = list(cluster_map.keys()) 186 | for i in range(len(centroids)): 187 | xi, yi = centroids[i] 188 | for j in range(i + 1, len(centroids)): 189 | xj, yj = centroids[j] 190 | distance = (xi - xj) ** 2 + (yi - yj) ** 2 191 | if distance < min_distance: 192 | min_distance = distance 193 | min_pair = (centroids[i], centroids[j]) 194 | if min_pair is not None and min_distance < merge_threshold: 195 | cluster = cluster_map[min_pair[0]] + cluster_map[min_pair[1]] 196 | cluster_map.pop(min_pair[0]) 197 | cluster_map.pop(min_pair[1]) 198 | cluster_map[get_centroid(cluster)] = cluster 199 | return True 200 | else: 201 | return False 202 | 203 | def initial_grid_cluster(units): 204 | grid_map = collections.defaultdict(list) 205 | for u in units: 206 | x_grid = u.float_attr.pos_x // grid_size 207 | y_grid = u.float_attr.pos_y // grid_size 208 | grid_map[(x_grid, y_grid)].append(u) 209 | cluster_map = collections.defaultdict(list) 210 | for cluster in grid_map.values(): 211 | cluster_map[get_centroid(cluster)] = cluster 212 | return cluster_map 213 | 214 | cluster_map = initial_grid_cluster(units) 215 | while agglomerative_step(cluster_map): pass 216 | return list(cluster_map.values()) 217 | 218 | def _distance(self, pos_a, pos_b): 219 | return ((pos_a['x'] - pos_b['x']) ** 2 + 220 | (pos_a['y'] - pos_b['y']) ** 2) ** 0.5 221 | 222 | def _cal_dist(self, u1, u2): 223 | pos_a = {'x': u1.float_attr.pos_x, 224 | 'y': u1.float_attr.pos_y} 225 | pos_b = {'x': u2.float_attr.pos_x, 226 | 'y': u2.float_attr.pos_y} 227 | return self._distance(pos_a, pos_b) 228 | -------------------------------------------------------------------------------- /tstarbot/production_strategy/prod_advarms.py: -------------------------------------------------------------------------------- 1 | """Production Manager""" 2 | from __future__ import absolute_import 3 | from __future__ import division 4 | from __future__ import print_function 5 | 6 | from pysc2.lib.typeenums import UNIT_TYPEID 7 | from pysc2.lib.typeenums import UPGRADE_ID 8 | from tstarbot.production_strategy.base_zerg_production_mgr import ZergBaseProductionMgr 9 | from tstarbot.production_strategy.util import unique_unit_count 10 | 11 | 12 | class ZergProdAdvArms(ZergBaseProductionMgr): 13 | def __init__(self, dc): 14 | super(ZergProdAdvArms, self).__init__(dc) 15 | self.ultra_goal = self.get_ultra_goal() 16 | 17 | @staticmethod 18 | def get_ultra_goal(): 19 | return{UNIT_TYPEID.ZERG_ROACH: 13, 20 | UNIT_TYPEID.ZERG_HYDRALISK: 23, 21 | UNIT_TYPEID.ZERG_INFESTOR: 3, 22 | UNIT_TYPEID.ZERG_CORRUPTOR: 0, 23 | UNIT_TYPEID.ZERG_LURKERMP: 6, 24 | UNIT_TYPEID.ZERG_VIPER: 2, 25 | UNIT_TYPEID.ZERG_RAVAGER: 4, 26 | UNIT_TYPEID.ZERG_ULTRALISK: 4, 27 | UNIT_TYPEID.ZERG_MUTALISK: 0, 28 | UNIT_TYPEID.ZERG_BROODLORD: 0, 29 | UNIT_TYPEID.ZERG_QUEEN: 3, 30 | UNIT_TYPEID.ZERG_OVERSEER: 20, 31 | UNIT_TYPEID.ZERG_DRONE: 66} 32 | 33 | def get_opening_build_order(self): 34 | return [UNIT_TYPEID.ZERG_DRONE, 35 | UNIT_TYPEID.ZERG_DRONE, 36 | UNIT_TYPEID.ZERG_OVERLORD, 37 | UNIT_TYPEID.ZERG_DRONE, 38 | UNIT_TYPEID.ZERG_DRONE, 39 | UNIT_TYPEID.ZERG_DRONE, 40 | UNIT_TYPEID.ZERG_HATCHERY, 41 | UNIT_TYPEID.ZERG_DRONE, 42 | UNIT_TYPEID.ZERG_EXTRACTOR] + \ 43 | [UNIT_TYPEID.ZERG_DRONE] * 2 + \ 44 | [UNIT_TYPEID.ZERG_SPAWNINGPOOL, 45 | UNIT_TYPEID.ZERG_DRONE, 46 | UNIT_TYPEID.ZERG_DRONE, 47 | UNIT_TYPEID.ZERG_DRONE, 48 | UNIT_TYPEID.ZERG_DRONE, 49 | UNIT_TYPEID.ZERG_DRONE, 50 | UNIT_TYPEID.ZERG_ROACHWARREN, 51 | UNIT_TYPEID.ZERG_DRONE, 52 | UNIT_TYPEID.ZERG_DRONE, 53 | UNIT_TYPEID.ZERG_DRONE, 54 | UNIT_TYPEID.ZERG_DRONE, 55 | UNIT_TYPEID.ZERG_QUEEN, 56 | UNIT_TYPEID.ZERG_DRONE, 57 | UNIT_TYPEID.ZERG_ROACH, 58 | UNIT_TYPEID.ZERG_SPINECRAWLER] + \ 59 | [UNIT_TYPEID.ZERG_DRONE, 60 | UNIT_TYPEID.ZERG_ROACH] * 2 + \ 61 | [UNIT_TYPEID.ZERG_SPINECRAWLER, 62 | UNIT_TYPEID.ZERG_ROACH, 63 | UNIT_TYPEID.ZERG_ROACH, 64 | UNIT_TYPEID.ZERG_SPINECRAWLER] 65 | 66 | def get_goal(self, dc): 67 | if not self.has_building_built([UNIT_TYPEID.ZERG_LAIR.value, 68 | UNIT_TYPEID.ZERG_HIVE.value]): 69 | goal = [UNIT_TYPEID.ZERG_LAIR] + \ 70 | [UNIT_TYPEID.ZERG_DRONE] * 6 + \ 71 | [UNIT_TYPEID.ZERG_ROACH] * 5 + \ 72 | [UNIT_TYPEID.ZERG_DRONE, UNIT_TYPEID.ZERG_ROACH] * 5 + \ 73 | [UNIT_TYPEID.ZERG_EVOLUTIONCHAMBER] + \ 74 | [UNIT_TYPEID.ZERG_ROACH, 75 | UNIT_TYPEID.ZERG_DRONE] * 3 + \ 76 | [UNIT_TYPEID.ZERG_EVOLUTIONCHAMBER] + \ 77 | [UPGRADE_ID.BURROW, 78 | UPGRADE_ID.TUNNELINGCLAWS, 79 | UNIT_TYPEID.ZERG_HYDRALISKDEN] + \ 80 | [UNIT_TYPEID.ZERG_ROACH, 81 | UNIT_TYPEID.ZERG_DRONE] * 7 82 | else: 83 | num_worker_needed = 0 84 | num_worker = 0 85 | bases = dc.dd.base_pool.bases 86 | for base_tag in bases: 87 | base = bases[base_tag] 88 | num_worker += self.assigned_harvesters(base) 89 | num_worker_needed += self.ideal_harvesters(base) 90 | num_worker_needed -= num_worker 91 | game_loop = self.obs['game_loop'][0] 92 | 93 | count = unique_unit_count(self.obs['units'], self.TT) 94 | if game_loop < 6 * 60 * 16: # 8 min 95 | goal = [UNIT_TYPEID.ZERG_ROACH] * 2 + \ 96 | [UNIT_TYPEID.ZERG_HYDRALISK] * 2 + \ 97 | [UNIT_TYPEID.ZERG_RAVAGER] * 1 98 | elif game_loop < 12 * 60 * 16: # 12 min 99 | goal = [UNIT_TYPEID.ZERG_ROACH] * 1 + \ 100 | [UNIT_TYPEID.ZERG_HYDRALISK] * 2 + \ 101 | [UNIT_TYPEID.ZERG_RAVAGER] * 1 102 | if (not self.unit_in_progress(UNIT_TYPEID.ZERG_LURKERDENMP.value) 103 | and not self.has_unit(UNIT_TYPEID.ZERG_LURKERDENMP.value)): 104 | goal += [UNIT_TYPEID.ZERG_LURKERDENMP] 105 | 106 | if self.has_building_built([UNIT_TYPEID.ZERG_LURKERDENMP.value]): 107 | diff = self.ultra_goal[UNIT_TYPEID.ZERG_LURKERMP] - count[ 108 | UNIT_TYPEID.ZERG_LURKERMP.value] 109 | if diff > 0: 110 | goal += [UNIT_TYPEID.ZERG_LURKERMP] * 2 111 | else: 112 | goal = [] 113 | diff = self.ultra_goal[UNIT_TYPEID.ZERG_ROACH] - count[ 114 | UNIT_TYPEID.ZERG_ROACH.value] 115 | if diff > 0: 116 | goal += [UNIT_TYPEID.ZERG_ROACH] * min(3, diff) 117 | 118 | diff = self.ultra_goal[UNIT_TYPEID.ZERG_RAVAGER] - count[ 119 | UNIT_TYPEID.ZERG_RAVAGER.value] 120 | if diff > 0: 121 | goal += [UNIT_TYPEID.ZERG_RAVAGER] * min(1, diff) 122 | 123 | diff = self.ultra_goal[UNIT_TYPEID.ZERG_HYDRALISK] - count[ 124 | UNIT_TYPEID.ZERG_HYDRALISK.value] 125 | if diff > 0: 126 | goal += [UNIT_TYPEID.ZERG_HYDRALISK] * min(3, diff) 127 | 128 | if (not self.unit_in_progress(UNIT_TYPEID.ZERG_SPIRE.value) 129 | and not self.has_unit(UNIT_TYPEID.ZERG_SPIRE.value)): 130 | goal += [UNIT_TYPEID.ZERG_SPIRE] 131 | 132 | if (not self.unit_in_progress(UNIT_TYPEID.ZERG_LURKERDENMP.value) 133 | and not self.has_unit(UNIT_TYPEID.ZERG_LURKERDENMP.value)): 134 | goal += [UNIT_TYPEID.ZERG_LURKERDENMP] 135 | 136 | if self.has_building_built([UNIT_TYPEID.ZERG_LURKERDENMP.value]): 137 | diff = self.ultra_goal[UNIT_TYPEID.ZERG_LURKERMP] - count[ 138 | UNIT_TYPEID.ZERG_LURKERMP.value] 139 | if diff > 0: 140 | goal += [UNIT_TYPEID.ZERG_LURKERMP] * min(2, diff) 141 | 142 | if self.has_building_built([UNIT_TYPEID.ZERG_SPIRE.value]) and \ 143 | self.has_building_built([UNIT_TYPEID.ZERG_HIVE.value]): 144 | diff = self.ultra_goal[UNIT_TYPEID.ZERG_VIPER] - count[ 145 | UNIT_TYPEID.ZERG_VIPER.value] 146 | if diff > 0: 147 | goal += [UNIT_TYPEID.ZERG_VIPER] * min(1, diff) 148 | 149 | if (not self.unit_in_progress(UNIT_TYPEID.ZERG_INFESTATIONPIT.value) 150 | and not self.has_unit(UNIT_TYPEID.ZERG_INFESTATIONPIT.value)): 151 | goal += [UNIT_TYPEID.ZERG_INFESTATIONPIT] 152 | 153 | if self.has_building_built([UNIT_TYPEID.ZERG_INFESTATIONPIT.value]): 154 | diff = self.ultra_goal[UNIT_TYPEID.ZERG_INFESTOR] - count[ 155 | UNIT_TYPEID.ZERG_INFESTOR.value] 156 | if diff > 0: 157 | goal += [UNIT_TYPEID.ZERG_INFESTOR] * min(1, diff) 158 | 159 | if (self.has_building_built([UNIT_TYPEID.ZERG_INFESTATIONPIT.value]) 160 | and not self.unit_in_progress(UNIT_TYPEID.ZERG_HIVE.value) 161 | and not self.has_unit(UNIT_TYPEID.ZERG_HIVE.value)): 162 | goal += [UNIT_TYPEID.ZERG_HIVE] 163 | 164 | # ULTRALISK 165 | if (self.has_building_built([UNIT_TYPEID.ZERG_HIVE.value]) 166 | and not self.unit_in_progress(UNIT_TYPEID.ZERG_ULTRALISKCAVERN.value) 167 | and not self.has_unit(UNIT_TYPEID.ZERG_ULTRALISKCAVERN.value)): 168 | goal += [UNIT_TYPEID.ZERG_ULTRALISKCAVERN] 169 | 170 | if self.has_building_built([UNIT_TYPEID.ZERG_ULTRALISKCAVERN.value]): 171 | diff = self.ultra_goal[UNIT_TYPEID.ZERG_ULTRALISK] - count[ 172 | UNIT_TYPEID.ZERG_ULTRALISK.value] 173 | if diff > 0: 174 | goal += [UNIT_TYPEID.ZERG_ULTRALISK] * min(2, diff) 175 | if num_worker_needed > 0 and num_worker < 66: 176 | diff = self.ultra_goal[UNIT_TYPEID.ZERG_DRONE] - count[UNIT_TYPEID.ZERG_DRONE.value] 177 | if diff > 0: 178 | goal = [UNIT_TYPEID.ZERG_DRONE] * min(5, diff) + goal 179 | return goal 180 | -------------------------------------------------------------------------------- /tstarbot/production_strategy/prod_defandadv.py: -------------------------------------------------------------------------------- 1 | """Production Manager""" 2 | from __future__ import absolute_import 3 | from __future__ import division 4 | from __future__ import print_function 5 | 6 | from pysc2.lib.typeenums import UNIT_TYPEID 7 | from pysc2.lib.typeenums import UPGRADE_ID 8 | from tstarbot.production_strategy.base_zerg_production_mgr import ZergBaseProductionMgr 9 | from tstarbot.production_strategy.util import unique_unit_count 10 | 11 | 12 | class ZergProdDefAndAdv(ZergBaseProductionMgr): 13 | def __init__(self, dc): 14 | super(ZergProdDefAndAdv, self).__init__(dc) 15 | self.ultra_goal = self.get_ultra_goal() 16 | 17 | @staticmethod 18 | def get_ultra_goal(): 19 | return{UNIT_TYPEID.ZERG_ROACH: 13, 20 | UNIT_TYPEID.ZERG_HYDRALISK: 23, 21 | UNIT_TYPEID.ZERG_INFESTOR: 3, 22 | UNIT_TYPEID.ZERG_CORRUPTOR: 0, 23 | UNIT_TYPEID.ZERG_LURKERMP: 6, 24 | UNIT_TYPEID.ZERG_VIPER: 2, 25 | UNIT_TYPEID.ZERG_RAVAGER: 4, 26 | UNIT_TYPEID.ZERG_ULTRALISK: 4, 27 | UNIT_TYPEID.ZERG_MUTALISK: 0, 28 | UNIT_TYPEID.ZERG_BROODLORD: 0, 29 | UNIT_TYPEID.ZERG_QUEEN: 3, 30 | UNIT_TYPEID.ZERG_OVERSEER: 20, 31 | UNIT_TYPEID.ZERG_DRONE: 66} 32 | 33 | def get_opening_build_order(self): 34 | return [UNIT_TYPEID.ZERG_DRONE, 35 | UNIT_TYPEID.ZERG_DRONE, 36 | UNIT_TYPEID.ZERG_OVERLORD, 37 | UNIT_TYPEID.ZERG_DRONE, 38 | UNIT_TYPEID.ZERG_DRONE, 39 | UNIT_TYPEID.ZERG_DRONE, 40 | UNIT_TYPEID.ZERG_HATCHERY, 41 | UNIT_TYPEID.ZERG_DRONE, 42 | UNIT_TYPEID.ZERG_EXTRACTOR] + \ 43 | [UNIT_TYPEID.ZERG_DRONE] * 4 + \ 44 | [UNIT_TYPEID.ZERG_SPAWNINGPOOL, 45 | UNIT_TYPEID.ZERG_DRONE, 46 | UNIT_TYPEID.ZERG_DRONE, 47 | UNIT_TYPEID.ZERG_DRONE, 48 | UNIT_TYPEID.ZERG_ZERGLING, 49 | UNIT_TYPEID.ZERG_ROACHWARREN, 50 | UNIT_TYPEID.ZERG_DRONE, 51 | UNIT_TYPEID.ZERG_DRONE, 52 | UNIT_TYPEID.ZERG_DRONE, 53 | UNIT_TYPEID.ZERG_QUEEN, 54 | UNIT_TYPEID.ZERG_DRONE, 55 | UNIT_TYPEID.ZERG_ROACH] + \ 56 | [UNIT_TYPEID.ZERG_SPINECRAWLER] * 4 + \ 57 | [UNIT_TYPEID.ZERG_DRONE, 58 | UNIT_TYPEID.ZERG_ROACH] * 2 + \ 59 | [UNIT_TYPEID.ZERG_SPINECRAWLER, 60 | UNIT_TYPEID.ZERG_ROACH, 61 | UNIT_TYPEID.ZERG_ROACH, 62 | UNIT_TYPEID.ZERG_SPINECRAWLER] 63 | 64 | def get_goal(self, dc): 65 | if not self.has_building_built([UNIT_TYPEID.ZERG_LAIR.value, 66 | UNIT_TYPEID.ZERG_HIVE.value]): 67 | goal = [UNIT_TYPEID.ZERG_LAIR] + \ 68 | [UNIT_TYPEID.ZERG_DRONE] * 6 + \ 69 | [UNIT_TYPEID.ZERG_ROACH] * 5 + \ 70 | [UNIT_TYPEID.ZERG_SPIRE] * 1 + \ 71 | [UNIT_TYPEID.ZERG_DRONE, 72 | UNIT_TYPEID.ZERG_ROACH] * 5 + \ 73 | [UNIT_TYPEID.ZERG_MUTALISK] * 6 + \ 74 | [UNIT_TYPEID.ZERG_EVOLUTIONCHAMBER] + \ 75 | [UNIT_TYPEID.ZERG_ROACH, 76 | UNIT_TYPEID.ZERG_DRONE] * 3 + \ 77 | [UNIT_TYPEID.ZERG_EVOLUTIONCHAMBER, 78 | UPGRADE_ID.BURROW, 79 | UPGRADE_ID.TUNNELINGCLAWS, 80 | UNIT_TYPEID.ZERG_HYDRALISKDEN] + \ 81 | [UNIT_TYPEID.ZERG_ROACH, 82 | UNIT_TYPEID.ZERG_DRONE] * 7 83 | else: 84 | num_worker_needed = 0 85 | num_worker = 0 86 | bases = dc.dd.base_pool.bases 87 | for base_tag in bases: 88 | base = bases[base_tag] 89 | num_worker += self.assigned_harvesters(base) 90 | num_worker_needed += self.ideal_harvesters(base) 91 | num_worker_needed -= num_worker 92 | game_loop = self.obs['game_loop'][0] 93 | 94 | count = unique_unit_count(self.obs['units'], self.TT) 95 | if game_loop < 6 * 60 * 16: # 8 min 96 | goal = [UNIT_TYPEID.ZERG_ROACH] * 2 + \ 97 | [UNIT_TYPEID.ZERG_HYDRALISK] * 2 + \ 98 | [UNIT_TYPEID.ZERG_RAVAGER] * 1 99 | elif game_loop < 12 * 60 * 16: # 12 min 100 | goal = [UNIT_TYPEID.ZERG_ROACH] * 1 + \ 101 | [UNIT_TYPEID.ZERG_HYDRALISK] * 2 + \ 102 | [UNIT_TYPEID.ZERG_RAVAGER] * 1 103 | if (not self.unit_in_progress(UNIT_TYPEID.ZERG_LURKERDENMP.value) 104 | and not self.has_unit(UNIT_TYPEID.ZERG_LURKERDENMP.value)): 105 | goal += [UNIT_TYPEID.ZERG_LURKERDENMP] 106 | 107 | if self.has_building_built([UNIT_TYPEID.ZERG_LURKERDENMP.value]): 108 | diff = self.ultra_goal[UNIT_TYPEID.ZERG_LURKERMP] - count[ 109 | UNIT_TYPEID.ZERG_LURKERMP.value] 110 | if diff > 0: 111 | goal += [UNIT_TYPEID.ZERG_LURKERMP] * 2 112 | else: 113 | goal = [] 114 | diff = self.ultra_goal[UNIT_TYPEID.ZERG_ROACH] - count[ 115 | UNIT_TYPEID.ZERG_ROACH.value] 116 | if diff > 0: 117 | goal += [UNIT_TYPEID.ZERG_ROACH] * min(3, diff) 118 | 119 | diff = self.ultra_goal[UNIT_TYPEID.ZERG_RAVAGER] - count[ 120 | UNIT_TYPEID.ZERG_RAVAGER.value] 121 | if diff > 0: 122 | goal += [UNIT_TYPEID.ZERG_RAVAGER] * min(1, diff) 123 | 124 | diff = self.ultra_goal[UNIT_TYPEID.ZERG_HYDRALISK] - count[ 125 | UNIT_TYPEID.ZERG_HYDRALISK.value] 126 | if diff > 0: 127 | goal += [UNIT_TYPEID.ZERG_HYDRALISK] * min(3, diff) 128 | 129 | if (not self.unit_in_progress(UNIT_TYPEID.ZERG_SPIRE.value) 130 | and not self.has_unit(UNIT_TYPEID.ZERG_SPIRE.value)): 131 | goal += [UNIT_TYPEID.ZERG_SPIRE] 132 | 133 | if (not self.unit_in_progress(UNIT_TYPEID.ZERG_LURKERDENMP.value) 134 | and not self.has_unit(UNIT_TYPEID.ZERG_LURKERDENMP.value)): 135 | goal += [UNIT_TYPEID.ZERG_LURKERDENMP] 136 | 137 | if self.has_building_built([UNIT_TYPEID.ZERG_LURKERDENMP.value]): 138 | diff = self.ultra_goal[UNIT_TYPEID.ZERG_LURKERMP] - count[ 139 | UNIT_TYPEID.ZERG_LURKERMP.value] 140 | if diff > 0: 141 | goal += [UNIT_TYPEID.ZERG_LURKERMP] * min(2, diff) 142 | 143 | if self.has_building_built([UNIT_TYPEID.ZERG_SPIRE.value]) and \ 144 | self.has_building_built([UNIT_TYPEID.ZERG_HIVE.value]): 145 | diff = self.ultra_goal[UNIT_TYPEID.ZERG_VIPER] - count[ 146 | UNIT_TYPEID.ZERG_VIPER.value] 147 | if diff > 0: 148 | goal += [UNIT_TYPEID.ZERG_VIPER] * min(1, diff) 149 | 150 | if (not self.unit_in_progress(UNIT_TYPEID.ZERG_INFESTATIONPIT.value) 151 | and not self.has_unit(UNIT_TYPEID.ZERG_INFESTATIONPIT.value)): 152 | goal += [UNIT_TYPEID.ZERG_INFESTATIONPIT] 153 | 154 | if self.has_building_built([UNIT_TYPEID.ZERG_INFESTATIONPIT.value]): 155 | diff = self.ultra_goal[UNIT_TYPEID.ZERG_INFESTOR] - count[ 156 | UNIT_TYPEID.ZERG_INFESTOR.value] 157 | if diff > 0: 158 | goal += [UNIT_TYPEID.ZERG_INFESTOR] * min(1, diff) 159 | 160 | if (self.has_building_built([UNIT_TYPEID.ZERG_INFESTATIONPIT.value]) 161 | and not self.unit_in_progress(UNIT_TYPEID.ZERG_HIVE.value) 162 | and not self.has_unit(UNIT_TYPEID.ZERG_HIVE.value)): 163 | goal += [UNIT_TYPEID.ZERG_HIVE] 164 | 165 | # ULTRALISK 166 | if (self.has_building_built([UNIT_TYPEID.ZERG_HIVE.value]) 167 | and not self.unit_in_progress(UNIT_TYPEID.ZERG_ULTRALISKCAVERN.value) 168 | and not self.has_unit(UNIT_TYPEID.ZERG_ULTRALISKCAVERN.value)): 169 | goal += [UNIT_TYPEID.ZERG_ULTRALISKCAVERN] 170 | 171 | if self.has_building_built([UNIT_TYPEID.ZERG_ULTRALISKCAVERN.value]): 172 | diff = self.ultra_goal[UNIT_TYPEID.ZERG_ULTRALISK] - count[ 173 | UNIT_TYPEID.ZERG_ULTRALISK.value] 174 | if diff > 0: 175 | goal += [UNIT_TYPEID.ZERG_ULTRALISK] * min(2, diff) 176 | if num_worker_needed > 0 and num_worker < 66: 177 | diff = self.ultra_goal[UNIT_TYPEID.ZERG_DRONE] - count[UNIT_TYPEID.ZERG_DRONE.value] 178 | if diff > 0: 179 | goal = [UNIT_TYPEID.ZERG_DRONE] * min(5, diff) + goal 180 | return goal 181 | -------------------------------------------------------------------------------- /tstarbot/data/demo_dc.py: -------------------------------------------------------------------------------- 1 | import pysc2.lib.typeenums as tp 2 | 3 | from tstarbot.data.pool.pool_base import PoolBase 4 | 5 | 6 | class DancingDrones(PoolBase): 7 | """ Let drones dance around their base. 8 | adopted from Zheng Yang's code.""" 9 | 10 | def __init__(self): 11 | super(DancingDrones, self).__init__() 12 | self._drone_ids = [] 13 | self._hatcherys = [] 14 | 15 | def update(self, timestamp): 16 | units = timestamp.observation['units'] 17 | self._locate_hatcherys(units) 18 | self._update_drone(units) 19 | 20 | def _locate_hatcherys(self, units): 21 | tmp_hatcherys = [] 22 | for u in units: 23 | if u.unit_type == tp.UNIT_TYPEID.ZERG_HATCHERY.value: 24 | tmp_hatcherys.append( 25 | (u.float_attr.pos_x, u.float_attr.pos_y, u.float_attr.pos_z)) 26 | self._hatcherys = tmp_hatcherys 27 | 28 | def _update_drone(self, units): 29 | drone_ids = [] 30 | for u in units: 31 | if u.unit_type == tp.UNIT_TYPEID.ZERG_DRONE.value: 32 | drone_ids.append(u.tag) 33 | 34 | self._drone_ids = drone_ids 35 | 36 | def get_drones(self): 37 | return self._drone_ids 38 | 39 | def get_hatcherys(self): 40 | return self._hatcherys 41 | 42 | def key(self): 43 | return 'dancing_drones' 44 | 45 | 46 | class DefeatRoaches(PoolBase): 47 | """ for DefeatRoaches Minimap. 48 | Adopted from lxhan's code 49 | """ 50 | 51 | def __init__(self): 52 | self.marines = [] # fro self 53 | self.roaches = [] # for enemy 54 | 55 | def update(self, timestep): 56 | units = timestep.observation['units'] 57 | self.collect_marine(units) 58 | self.collect_roach(units) 59 | 60 | def collect_marine(self, units): 61 | marines = [] 62 | for u in units: 63 | if u.unit_type == tp.UNIT_TYPEID.TERRAN_MARINE.value and u.int_attr.owner == 1: 64 | marines.append(u) 65 | # print("marine assigned_harvesters: {}".format(u.int_attr.assigned_harvesters)) 66 | self.marines = marines 67 | 68 | def collect_roach(self, units): 69 | roaches = [] 70 | for u in units: 71 | if u.unit_type == tp.UNIT_TYPEID.ZERG_ROACH.value and u.int_attr.owner == 2: 72 | roaches.append(u) 73 | # print("roach target: {}".format(u.int_attr.engaged_target_tag)) 74 | self.roaches = roaches 75 | 76 | def get_marines(self): 77 | return self.marines 78 | 79 | def get_roaches(self): 80 | return self.roaches 81 | 82 | def key(self): 83 | return 'defeat_roaches' 84 | 85 | 86 | class ZergLxHanDcMgr(PoolBase): 87 | """ full game with Simple64 by producing roaches + hydralisk 88 | (for testing combat module) """ 89 | 90 | def __init__(self): 91 | super(ZergLxHanDcMgr, self).__init__() 92 | self.reset() 93 | 94 | def reset(self): 95 | self.units = [] 96 | self.screen = [] 97 | self.player_info = [] 98 | 99 | self.drones = [] 100 | self.hatcheries = [] 101 | self.minerals = [] 102 | self.larvas = [] 103 | self.queen = [] 104 | self.spawningpool = [] 105 | self.extractors = [] 106 | self.vespens = [] 107 | self.roachwarren = [] 108 | self.roaches = [] 109 | self.hydraliskden = [] 110 | self.hydralisk = [] 111 | 112 | self.enemy_units = [] 113 | 114 | self.base_pos = [] 115 | self.mini_map = [] 116 | 117 | def update(self, timestep): 118 | units = timestep.observation['units'] 119 | screen = timestep.observation['screen'] 120 | player_info = timestep.observation['player'] 121 | mini_map = timestep.observation['minimap'] 122 | 123 | self.units = units 124 | self.screen = screen 125 | self.player_info = player_info 126 | self.mini_map = mini_map 127 | 128 | self.collect_drones(units) 129 | self.collect_hatcheries(units) 130 | self.collect_minerals(units) 131 | self.collect_larvas(units) 132 | self.collect_queen(units) 133 | self.collect_spawningpool(units) 134 | self.collect_extractor(units) 135 | self.collect_vespen(units) 136 | self.collect_roachwarren(units) 137 | self.collect_roaches(units) 138 | self.collect_hydraliskden(units) 139 | self.collect_hydralisk(units) 140 | 141 | self.collect_enemy_units(units) 142 | 143 | if len(self.hatcheries) != 0 and len(self.base_pos) == 0: 144 | self.base_pos = [self.hatcheries[0].float_attr.pos_x, 145 | self.hatcheries[0].float_attr.pos_y] 146 | print('base_pos: ', self.base_pos) 147 | 148 | def collect_drones(self, units): 149 | drones = [] 150 | for u in units: 151 | if (u.unit_type == tp.UNIT_TYPEID.ZERG_DRONE.value and 152 | u.int_attr.owner == 1): 153 | drones.append(u) 154 | self.drones = drones 155 | 156 | def collect_hatcheries(self, units): 157 | hatcheries = [] 158 | for u in units: 159 | if (u.unit_type == tp.UNIT_TYPEID.ZERG_HATCHERY.value and 160 | u.int_attr.owner == 1): 161 | hatcheries.append(u) 162 | self.hatcheries = hatcheries 163 | 164 | def collect_minerals(self, units): 165 | minerals = [] 166 | for u in units: 167 | if u.unit_type == tp.UNIT_TYPEID.NEUTRAL_MINERALFIELD.value: 168 | minerals.append(u) 169 | self.minerals = minerals 170 | 171 | def collect_larvas(self, units): 172 | larvas = [] 173 | for u in units: 174 | if (u.unit_type == tp.UNIT_TYPEID.ZERG_LARVA.value and 175 | u.int_attr.owner == 1): 176 | larvas.append(u) 177 | self.larvas = larvas 178 | 179 | def collect_queen(self, units): 180 | queen = [] 181 | for u in units: 182 | if (u.unit_type == tp.UNIT_TYPEID.ZERG_QUEEN.value and 183 | u.int_attr.owner == 1): 184 | queen.append(u) 185 | self.queen = queen 186 | 187 | def collect_spawningpool(self, units): 188 | spawningpool = [] 189 | for u in units: 190 | if (u.unit_type == tp.UNIT_TYPEID.ZERG_SPAWNINGPOOL.value and 191 | u.int_attr.owner == 1): 192 | spawningpool.append(u) 193 | self.spawningpool = spawningpool 194 | 195 | def collect_extractor(self, units): 196 | extractors = [] 197 | for u in units: 198 | if (u.unit_type == tp.UNIT_TYPEID.ZERG_EXTRACTOR.value and 199 | u.int_attr.owner == 1): 200 | extractors.append(u) 201 | self.extractors = extractors 202 | 203 | def collect_vespen(self, units): 204 | vespens = [] 205 | for u in units: 206 | if u.unit_type == tp.UNIT_TYPEID.NEUTRAL_VESPENEGEYSER.value: 207 | vespens.append(u) 208 | self.vespens = vespens 209 | 210 | def collect_roachwarren(self, units): 211 | roachwarren = [] 212 | for u in units: 213 | if (u.unit_type == tp.UNIT_TYPEID.ZERG_ROACHWARREN.value and 214 | u.int_attr.owner == 1): 215 | roachwarren.append(u) 216 | self.roachwarren = roachwarren 217 | 218 | def collect_roaches(self, units): 219 | roaches = [] 220 | for u in units: 221 | if (u.unit_type == tp.UNIT_TYPEID.ZERG_ROACH.value and 222 | u.int_attr.owner == 1): 223 | roaches.append(u) 224 | self.roaches = roaches 225 | 226 | def collect_hydraliskden(self, units): 227 | hydraliskden = [] 228 | for u in units: 229 | if (u.unit_type == tp.UNIT_TYPEID.ZERG_HYDRALISKDEN.value and 230 | u.int_attr.owner == 1): 231 | hydraliskden.append(u) 232 | self.hydraliskden = hydraliskden 233 | 234 | def collect_hydralisk(self, units): 235 | hydralisk = [] 236 | for u in units: 237 | if (u.unit_type == tp.UNIT_TYPEID.ZERG_HYDRALISK.value and 238 | u.int_attr.owner == 1): 239 | hydralisk.append(u) 240 | self.hydralisk = hydralisk 241 | 242 | def collect_enemy_units(self, units): 243 | enemy_units = [] 244 | for u in units: 245 | if u.int_attr.owner == 2: 246 | enemy_units.append(u) 247 | self.enemy_units = enemy_units 248 | 249 | def get_drones(self): 250 | return self.drones 251 | 252 | def get_hatcheries(self): 253 | return self.hatcheries 254 | 255 | def get_minerals(self): 256 | return self.minerals 257 | 258 | def get_larvas(self): 259 | return self.larvas 260 | 261 | def get_queen(self): 262 | return self.queen 263 | 264 | def get_spawningpool(self): 265 | return self.spawningpool 266 | 267 | def get_extractor(self): 268 | return self.extractors 269 | 270 | def get_vespens(self): 271 | return self.vespens 272 | 273 | def get_roachwarren(self): 274 | return self.roachwarren 275 | 276 | def get_roaches(self): 277 | return self.roaches 278 | 279 | def get_hydraliskden(self): 280 | return self.hydraliskden 281 | 282 | def get_hydralisk(self): 283 | return self.hydralisk 284 | 285 | def get_enemy_units(self): 286 | return self.enemy_units 287 | -------------------------------------------------------------------------------- /tstarbot/scout/tasks/explor_task_rl.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from gym.spaces import Box, Discrete 3 | from baselines import deepq 4 | from baselines.common.tf_util import load_state 5 | from baselines.deepq.utils import ObservationInput 6 | from baselines.deepq.simple import ActWrapper 7 | 8 | from enum import Enum, unique 9 | import numpy as np 10 | 11 | import tstarbot.scout.tasks.scout_task as st 12 | from tstarbot.scout.tasks.scout_task import ScoutTask 13 | from tstarbot.data.pool import macro_def as md 14 | 15 | MOVE_RANGE = 1.0 16 | 17 | @unique 18 | class ScoutMove(Enum): 19 | UPPER = 0 20 | LEFT = 1 21 | DOWN = 2 22 | RIGHT = 3 23 | UPPER_LEFT = 4 24 | LOWER_LEFT = 5 25 | LOWER_RIGHT = 6 26 | UPPER_RIGHT = 7 27 | NOOP = 8 28 | HOME = 9 29 | 30 | class ScoutExploreTaskRL(ScoutTask): 31 | act = None 32 | def __init__(self, scout, target, home, model_dir, map_max_x, map_max_y): 33 | super(ScoutExploreTaskRL, self).__init__(scout, home) 34 | self._target = target 35 | self._status = md.ScoutTaskStatus.DOING 36 | ''' 37 | scout explore task 38 | Simple64 = 88 * 96 39 | AbyssalReef = 200 * 176 40 | Acolyte = 168 * 200 41 | AscensiontoAiur = 176 * 152 42 | Frost = 184 * 184 43 | Interloper = 152 * 176 44 | MechDepot = 184 * 176 45 | Odyssey = 168 * 184 46 | ''' 47 | self._map_max_x = map_max_x 48 | self._map_max_y = map_max_y 49 | self._reverse = self.judge_reverse(scout) 50 | self.load_model(model_dir) 51 | 52 | def type(self): 53 | return md.ScoutTaskType.EXPORE 54 | 55 | def target(self): 56 | return self._target 57 | 58 | def post_process(self): 59 | self._target.has_scout = False 60 | self._scout.is_doing_task = False 61 | if self._status == md.ScoutTaskStatus.SCOUT_DESTROY: 62 | if self._check_in_base_range() and not self._judge_task_done(): 63 | self._target.has_enemy_base = True 64 | self._target.has_army = True 65 | else: 66 | self._target.has_scout = False 67 | #print('SCOUT explore_task post destory; target=', str(self._target)) 68 | else: 69 | #print('SCOUT task post_process, status=', self._status, ';target=', str(self._target)) 70 | pass 71 | 72 | @staticmethod 73 | def load_model(model_path): 74 | if ScoutExploreTaskRL.act is not None: 75 | return 76 | 77 | class FakeEnv(object): 78 | def __init__(self): 79 | low = np.zeros(6) 80 | high = np.ones(6) 81 | self.observation_space = Box(low, high) 82 | self.action_space = Discrete(8) 83 | 84 | def make_obs_ph(name): 85 | return ObservationInput(env.observation_space, name=name) 86 | 87 | env = FakeEnv() 88 | network = deepq.models.mlp([64, 32]) 89 | act_params = { 90 | 'make_obs_ph': make_obs_ph, 91 | 'q_func': network, 92 | 'num_actions': env.action_space.n, 93 | } 94 | 95 | act = deepq.build_act(**act_params) 96 | sess = tf.Session() 97 | sess.__enter__() 98 | print("load_model path=", model_path) 99 | load_state(model_path) 100 | ScoutExploreTaskRL.act = ActWrapper(act, act_params) 101 | print("load_model ok") 102 | 103 | def _do_task_inner(self, view_enemys, dc): 104 | if self._check_scout_lost(): 105 | self._status = md.ScoutTaskStatus.SCOUT_DESTROY 106 | return None 107 | 108 | if self.check_end(view_enemys, dc): 109 | action = ScoutMove.HOME.value 110 | self._status = md.ScoutTaskStatus.DONE 111 | else: 112 | obs = self._get_obs() 113 | action = ScoutExploreTaskRL.act(obs[None])[0] 114 | next_pos = self._calcuate_pos_by_action(action) 115 | return self._move_to_target(next_pos) 116 | 117 | def _get_obs(self): 118 | scout = self.scout().unit() 119 | if self._reverse: 120 | scout_pos = self.pos_transfer(scout.float_attr.pos_x, scout.float_attr.pos_y) 121 | else: 122 | scout_pos = (scout.float_attr.pos_x, scout.float_attr.pos_y) 123 | return np.array([float(scout_pos[0]) / self._map_max_x, 124 | float(scout_pos[1]) / self._map_max_y, 125 | float(self._home[0]) / self._map_max_x, 126 | float(self._home[1]) / self._map_max_y, 127 | float(self._target.pos[0]) / self._map_max_x, 128 | float(self._target.pos[1]) / self._map_max_y]) 129 | 130 | def _calcuate_pos_by_action(self, action): 131 | scout = self.scout().unit() 132 | if self._reverse: 133 | action = self.action_transfer(action) 134 | 135 | if action == ScoutMove.UPPER.value: 136 | pos = (scout.float_attr.pos_x, 137 | scout.float_attr.pos_y + MOVE_RANGE) 138 | #print('action upper,scout:{} pos:{}'.format( 139 | # (scout.float_attr.pos_x, scout.float_attr.pos_y), pos)) 140 | elif action == ScoutMove.LEFT.value: 141 | pos = (scout.float_attr.pos_x + MOVE_RANGE, 142 | scout.float_attr.pos_y) 143 | #print('action left,scout:{} pos:{}'.format( 144 | # (scout.float_attr.pos_x, scout.float_attr.pos_y), pos)) 145 | elif action == ScoutMove.DOWN.value: 146 | pos = (scout.float_attr.pos_x, 147 | scout.float_attr.pos_y - MOVE_RANGE) 148 | #print('action down,scout:{} pos:{}'.format( 149 | # (scout.float_attr.pos_x, scout.float_attr.pos_y), pos)) 150 | elif action == ScoutMove.RIGHT.value: 151 | pos = (scout.float_attr.pos_x - MOVE_RANGE, 152 | scout.float_attr.pos_y) 153 | #print('action right,scout:{} pos:{}'.format( 154 | # (scout.float_attr.pos_x, scout.float_attr.pos_y), pos)) 155 | elif action == ScoutMove.UPPER_LEFT.value: 156 | pos = (scout.float_attr.pos_x + MOVE_RANGE, 157 | scout.float_attr.pos_y + MOVE_RANGE) 158 | #print('action upper_left,scout:{} pos:{}'.format( 159 | # (scout.float_attr.pos_x, scout.float_attr.pos_y), pos)) 160 | elif action == ScoutMove.LOWER_LEFT.value: 161 | pos = (scout.float_attr.pos_x + MOVE_RANGE, 162 | scout.float_attr.pos_y - MOVE_RANGE) 163 | #print('action lower_left,scout:{} pos:{}'.format( 164 | # (scout.float_attr.pos_x, scout.float_attr.pos_y), pos)) 165 | elif action == ScoutMove.LOWER_RIGHT.value: 166 | pos = (scout.float_attr.pos_x - MOVE_RANGE, 167 | scout.float_attr.pos_y - MOVE_RANGE) 168 | #print('action lower_right,scout:{} pos:{}'.format( 169 | # (scout.float_attr.pos_x, scout.float_attr.pos_y), pos)) 170 | elif action == ScoutMove.UPPER_RIGHT.value: 171 | pos = (scout.float_attr.pos_x - MOVE_RANGE, 172 | scout.float_attr.pos_y + MOVE_RANGE) 173 | #print('action upper_right,scout:{} pos:{}'.format( 174 | # (scout.float_attr.pos_x, scout.float_attr.pos_y), pos)) 175 | elif action == ScoutMove.HOME.value: 176 | print('*** return to home ***, home=', self._home) 177 | pos = self._home 178 | else: 179 | #print('action upper_right,scout:{} pos:None, action={}'.format( 180 | # (scout.float_attr.pos_x, scout.float_attr.pos_y), action)) 181 | pos = None 182 | return pos 183 | 184 | def _check_in_base_range(self): 185 | dist = md.calculate_distance(self._scout.unit().float_attr.pos_x, 186 | self._scout.unit().float_attr.pos_y, 187 | self._target.pos[0], 188 | self._target.pos[1]) 189 | if dist < st.SCOUT_CRUISE_ARRIVAED_RANGE: 190 | return True 191 | else: 192 | return False 193 | 194 | def _judge_task_done(self): 195 | if self._target.has_enemy_base: 196 | return True 197 | elif self._target.has_army: 198 | return True 199 | else: 200 | return False 201 | 202 | def judge_reverse(self, scout): 203 | if scout.unit().float_attr.pos_x < scout.unit().float_attr.pos_y: 204 | return False 205 | else: 206 | return True 207 | 208 | def action_transfer(self, action): 209 | if action == ScoutMove.UPPER.value: 210 | return ScoutMove.DOWN.value 211 | elif action == ScoutMove.LEFT.value: 212 | return ScoutMove.RIGHT.value 213 | elif action == ScoutMove.DOWN.value: 214 | return ScoutMove.UPPER.value 215 | elif action == ScoutMove.RIGHT.value: 216 | return ScoutMove.LEFT.value 217 | elif action == ScoutMove.UPPER_LEFT.value: 218 | return ScoutMove.LOWER_RIGHT.value 219 | elif action == ScoutMove.LOWER_LEFT.value: 220 | return ScoutMove.UPPER_RIGHT.value 221 | elif action == ScoutMove.LOWER_RIGHT.value: 222 | return ScoutMove.UPPER_LEFT.value 223 | elif action == ScoutMove.UPPER_RIGHT.value: 224 | return ScoutMove.LOWER_LEFT.value 225 | elif action == ScoutMove.HOME.value: 226 | return action 227 | else: 228 | pos = None 229 | return pos 230 | 231 | def pos_transfer(self, x, y): 232 | cx = self._map_max_x / 2 233 | cy = self._map_max_y / 2 234 | pos_x = 0.0 235 | pos_y = 0.0 236 | if x > cx: 237 | pos_x = cx - abs(x - cx) 238 | else: 239 | pos_x = cx + abs(x - cx) 240 | 241 | if y > cy: 242 | pos_y = cy - abs(y - cy) 243 | else: 244 | pos_y = cy + abs(y - cy) 245 | 246 | return (pos_x, pos_y) 247 | 248 | 249 | -------------------------------------------------------------------------------- /tstarbot/sandbox/resource_mgr.py: -------------------------------------------------------------------------------- 1 | """Resource Manager""" 2 | from __future__ import absolute_import 3 | from __future__ import division 4 | from __future__ import print_function 5 | import random 6 | from copy import deepcopy 7 | 8 | from s2clientprotocol import sc2api_pb2 as sc_pb 9 | from pysc2.lib.typeenums import UNIT_TYPEID, ABILITY_ID, RACE 10 | 11 | 12 | def collect_units(units, unit_type, owner=1): 13 | unit_list = [] 14 | for u in units: 15 | if u.unit_type == unit_type and u.int_attr.owner == owner: 16 | unit_list.append(u) 17 | return unit_list 18 | 19 | 20 | def collect_tags(units): 21 | return [u.tag for u in units] 22 | 23 | 24 | def find_by_tag(units, tag): 25 | for u in units: 26 | if u.tag == tag: 27 | return u 28 | return None 29 | 30 | 31 | def get_target_tag(unit, idx=0): 32 | if len(unit.orders)>0: 33 | return unit.orders[idx].target_tag 34 | return None 35 | 36 | 37 | class BaseResourceMgr(object): 38 | def __init__(self): 39 | pass 40 | 41 | def update(self, obs_mgr, act_mgr): 42 | pass 43 | 44 | def reset(self): 45 | pass 46 | 47 | 48 | class DancingDronesResourceMgr(BaseResourceMgr): 49 | def __init__(self): 50 | super(DancingDronesResourceMgr, self).__init__() 51 | self._range_high = 5 52 | self._range_low = -5 53 | self._move_ability = 1 54 | 55 | def update(self, dc, am): 56 | super(DancingDronesResourceMgr, self).update(dc, am) 57 | 58 | drone_ids = dc.get_drones() 59 | pos = dc.get_hatcherys() 60 | 61 | print('pos=', pos) 62 | actions = self.move_drone_random_round_hatchery(drone_ids, pos[0]) 63 | 64 | am.push_actions(actions) 65 | 66 | def move_drone_random_round_hatchery(self, drone_ids, pos): 67 | length = len(drone_ids) 68 | actions = [] 69 | for drone in drone_ids: 70 | action = sc_pb.Action() 71 | action.action_raw.unit_command.ability_id = self._move_ability 72 | x = pos[0] + random.randint(self._range_low, self._range_high) 73 | y = pos[1] + random.randint(self._range_low, self._range_high) 74 | action.action_raw.unit_command.target_world_space_pos.x = x 75 | action.action_raw.unit_command.target_world_space_pos.y = y 76 | action.action_raw.unit_command.unit_tags.append(drone) 77 | actions.append(action) 78 | return actions 79 | 80 | 81 | class ExtractorWorkerTagMgrOld(object): 82 | """ Maintain the Extractor-Worker bi-directional mapping/contract using tags """ 83 | MAX_WORKERS_PER_EXTRACTOR = 3 84 | 85 | def __init__(self): 86 | self.extractor_to_workers = {} 87 | 88 | def reset(self): 89 | self.extractor_to_workers = {} 90 | 91 | def update(self, all_extractor_tags, all_worker_tags): 92 | self._add_extractor_if_has_new(all_extractor_tags) 93 | self._remove_extractor_if_not_exist(all_extractor_tags) 94 | self._remove_workers_if_not_exist(all_worker_tags) 95 | 96 | def _add_extractor_if_has_new(self, all_extractor_tags): 97 | for e in all_extractor_tags: 98 | if not self.extractor_to_workers.get(e, False): 99 | self.extractor_to_workers[e] = [] 100 | 101 | def _remove_extractor_if_not_exist(self, all_extractor_tags): 102 | for e in self.extractor_to_workers: 103 | if e not in all_extractor_tags: 104 | self.extractor_to_workers.pop(e, None) 105 | 106 | def _remove_workers_if_not_exist(self, all_worker_tags): 107 | for _, worker_tags in self.extractor_to_workers.items(): 108 | for w_tag in worker_tags: 109 | if w_tag not in all_worker_tags: 110 | worker_tags.remove(w_tag) # this should affect self.extractor_to_workers[e] due to the by-ref semantics 111 | 112 | def act_worker_harvests_on_extractor(self, extractor_tag, worker_tag): 113 | # REMEMBER to update the mapping! 114 | # The CALLER should assure the extractor-worker is a reasonable pair 115 | self.extractor_to_workers[extractor_tag].append(worker_tag) 116 | 117 | # make the real action 118 | action = sc_pb.Action() 119 | action.action_raw.unit_command.ability_id = ABILITY_ID.HARVEST_GATHER_DRONE.value 120 | action.action_raw.unit_command.target_unit_tag = extractor_tag 121 | action.action_raw.unit_command.unit_tags.append(worker_tag) 122 | return [action] 123 | 124 | def contain_worker(self, worker_tag): 125 | for _, workers in self.extractor_to_workers.items(): 126 | if worker_tag in workers: 127 | return True 128 | return False 129 | 130 | 131 | class ExtractorWorkerTagMgr(object): 132 | """ Maintain the Extractor-Worker bi-directional mapping/contract using tags """ 133 | MAX_WORKERS_PER_EXTRACTOR = 3 134 | 135 | def __init__(self): 136 | self.extractor_tag_to_workers_num = {} 137 | 138 | def reset(self): 139 | self.extractor_tag_to_workers_num = {} 140 | 141 | def update(self, all_extractor, all_worker): 142 | self._add_extractor_if_has_new(all_extractor) 143 | self._remove_extractor_if_not_exist(all_extractor) 144 | self._update_num_workers(all_worker) 145 | 146 | def _add_extractor_if_has_new(self, all_extractor): 147 | for e in all_extractor: 148 | if not self.extractor_tag_to_workers_num.get(e.tag, False): 149 | self.extractor_tag_to_workers_num[e.tag] = 0 150 | 151 | def _remove_extractor_if_not_exist(self, all_extractor): 152 | for e_tag in self.extractor_tag_to_workers_num: 153 | if not find_by_tag(all_extractor, e_tag): 154 | self.extractor_tag_to_workers_num.pop(e_tag, None) 155 | 156 | def _update_num_workers(self, all_workers): 157 | for w in all_workers: 158 | target_tag = get_target_tag(w) 159 | if target_tag and target_tag in self.extractor_tag_to_workers_num: 160 | self.extractor_tag_to_workers_num[target_tag] += 1 161 | 162 | def act_worker_harvests_on_extractor(self, extractor_tag, worker_tag): 163 | # REMEMBER to update the mapping! 164 | # The CALLER should assure the extractor-worker is a reasonable pair 165 | self.extractor_tag_to_workers_num[extractor_tag] += 1 166 | 167 | # make the real action 168 | action = sc_pb.Action() 169 | action.action_raw.unit_command.ability_id = ABILITY_ID.HARVEST_GATHER_DRONE.value 170 | action.action_raw.unit_command.target_unit_tag = extractor_tag 171 | action.action_raw.unit_command.unit_tags.append(worker_tag) 172 | return [action] 173 | 174 | 175 | class ZergResourceMgr(BaseResourceMgr): 176 | def __init__(self): 177 | super(ZergResourceMgr, self).__init__() 178 | 179 | self.update_harvest_gas_freq = 6 180 | self.update_harvest_mineral_freq = 6 181 | 182 | self.step = 0 183 | self.ew_mgr = ExtractorWorkerTagMgr() 184 | self.tmp = set([]) 185 | 186 | def reset(self): 187 | self.step = 0 188 | self.ew_mgr.reset() 189 | 190 | def update(self, dc, am): 191 | super(ZergResourceMgr, self).update(dc, am) 192 | 193 | units = dc.sd.obs['units'] 194 | 195 | all_extractors = collect_units(units, UNIT_TYPEID.ZERG_EXTRACTOR.value) 196 | #all_extractor_tags = collect_tags(all_extractors) 197 | all_workers = collect_units(units, UNIT_TYPEID.ZERG_DRONE.value) 198 | #all_worker_tags = collect_tags(all_workers) 199 | if len(all_extractors) > 0: 200 | a = 3 201 | b = 4 202 | self.ew_mgr.update(all_extractors, all_workers) 203 | 204 | for w in all_workers: 205 | if w.tag not in self.tmp: 206 | self.tmp.add(w.tag) 207 | 208 | print('len workers = ', len(all_workers)) 209 | print('len tmp = ', len(self.tmp)) 210 | 211 | actions = [] 212 | actions += self._update_harvest_gas(all_extractors, all_workers) 213 | actions += self._update_harvest_mineral() 214 | 215 | am.push_actions(actions) 216 | self.step += 1 217 | 218 | def _update_harvest_gas(self, cur_all_extractors, cur_all_workers): 219 | actions = [] 220 | 221 | for e_tag, num_workers in self.ew_mgr.extractor_tag_to_workers_num.items(): 222 | e = find_by_tag(cur_all_extractors, e_tag) 223 | if not e: # not a valid extractor tag due to unknown reason... 224 | continue 225 | if e.float_attr.build_progress < 1.0: # extractor not yet built 226 | continue 227 | # if len(cur_all_workers) < 16: 228 | # continue 229 | 230 | n_remain = ExtractorWorkerTagMgr.MAX_WORKERS_PER_EXTRACTOR - num_workers 231 | if n_remain <= 0: # full on this extractor 232 | continue 233 | 234 | for w in cur_all_workers: 235 | target_tag = get_target_tag(w) 236 | if target_tag not in self.ew_mgr.extractor_tag_to_workers_num: # this worker is not harvesting gas 237 | actions += self.ew_mgr.act_worker_harvests_on_extractor(e.tag, w.tag) 238 | break # send only ONE worker to harvest the gas on this step 239 | 240 | return actions 241 | 242 | def _update_harvest_mineral(self): 243 | return [] 244 | -------------------------------------------------------------------------------- /tstarbot/data/pool/scout_pool.py: -------------------------------------------------------------------------------- 1 | from tstarbot.data.pool.pool_base import PoolBase 2 | from tstarbot.data.pool import macro_def as md 3 | from pysc2.lib.typeenums import UNIT_TYPEID 4 | from tstarbot.data.pool.worker_pool import EmployStatus 5 | from tstarbot.data.pool.combat_pool import CombatUnitStatus 6 | import queue 7 | 8 | MAX_AREA_DISTANCE = 12 9 | MAX_ALARM_QUEUE = 20 10 | 11 | 12 | class Scout(object): 13 | def __init__(self, unit, team_id=0): 14 | self._unit = unit 15 | self._lost = False # is building lost 16 | self.is_doing_task = False 17 | self.snapshot_armys = None 18 | 19 | def unit(self): 20 | return self._unit 21 | 22 | def set_lost(self, lost): 23 | self._lost = lost 24 | 25 | def is_lost(self): 26 | return self._lost 27 | 28 | def is_health(self): 29 | curr_health = self._unit.float_attr.health 30 | max_health = self._unit.float_attr.health_max 31 | return curr_health == max_health 32 | 33 | def update(self, u): 34 | if u.int_attr.tag == self._unit.int_attr.tag: # is the same unit 35 | self._unit = u 36 | return True 37 | 38 | return False 39 | 40 | def __str__(self): 41 | u = self._unit 42 | return "tag {}, type {}, alliance {}".format(u.int_attr.tag, 43 | u.int_attr.unit_type, 44 | u.int_attr.alliance) 45 | 46 | 47 | class ScoutBaseTarget(object): 48 | def __init__(self): 49 | self.area = None 50 | self.enemy_unit = None 51 | self.has_enemy_base = False 52 | self.is_main = False 53 | self.pos = None 54 | self.has_scout = False 55 | self.has_army = False 56 | self.has_cruise = False 57 | 58 | def __str__(self): 59 | return 'pos:{} base:{} main_base:{} scout:{} army:{}'.format( 60 | self.pos, self.has_enemy_base, 61 | self.is_main, self.has_scout, self.has_army) 62 | 63 | 64 | class ScoutAlarm(object): 65 | def __init__(self): 66 | self.enmey_armys = [] 67 | 68 | 69 | class ScoutPool(PoolBase): 70 | def __init__(self, dd): 71 | super(PoolBase, self).__init__() 72 | self._scouts = {} # unit_tag -> Scout 73 | '''{base_tag: ScoutEnemyBase, ....} ''' 74 | self._scout_base_target = [] 75 | self._dd = dd 76 | self._init = False 77 | self.home_pos = None 78 | self.alarms = queue.Queue(maxsize=MAX_ALARM_QUEUE) 79 | 80 | def reset(self): 81 | self._scouts = {} 82 | self._scout_base_target = [] 83 | self._init = False 84 | self.home_pos = None 85 | self.alarms = queue.Queue(maxsize=MAX_ALARM_QUEUE) 86 | 87 | def enemy_bases(self): 88 | bases = [] 89 | for base in self._scout_base_target: 90 | if base.has_enemy_base: 91 | bases.append(base) 92 | return bases 93 | 94 | def main_enemy_base(self): 95 | bases = self.enemy_bases() 96 | if 0 == len(bases): 97 | return None 98 | 99 | for base in bases: 100 | if base.is_main: 101 | return base 102 | return bases[0] 103 | 104 | def has_enemy_main_base(self): 105 | for base in self._scout_base_target: 106 | if base.is_main: 107 | return True 108 | return False 109 | 110 | def update(self, timestep): 111 | if not self._init: 112 | self._init_home_pos() 113 | self._init_scout_base_target() 114 | self._init = True 115 | 116 | units = timestep.observation['units'] 117 | self._update_all_scouts(units) 118 | 119 | def add_scout(self, u): 120 | tag = u.int_attr.tag 121 | 122 | if tag in self._scouts: 123 | # print("update overlord {}".format(u)) 124 | self._scouts[tag].update(u) 125 | else: 126 | # print("add overlord {}".format(u)) 127 | self._scouts[tag] = Scout(u) 128 | 129 | self._scouts[tag].set_lost(False) 130 | 131 | def remove_scout(self, tag): 132 | del self._scouts[tag] 133 | 134 | def list_scout(self): 135 | scouts = [] 136 | for k, b in self._scouts.items(): 137 | scouts.append(b.unit()) 138 | 139 | return scouts 140 | 141 | def _update_all_scouts(self, units): 142 | # set all scouts 'lost' state 143 | for k, b in self._scouts.items(): 144 | b.set_lost(True) 145 | # print('scout=', str(b)) 146 | 147 | # update scout 148 | for u in units: 149 | if u.int_attr.unit_type == UNIT_TYPEID.ZERG_OVERLORD.value \ 150 | and u.int_attr.alliance == md.AllianceType.SELF.value: 151 | self.add_scout(u) 152 | elif u.int_attr.unit_type == UNIT_TYPEID.ZERG_DRONE.value \ 153 | and u.int_attr.tag in self._scouts: 154 | self.add_scout(u) 155 | elif u.int_attr.unit_type == UNIT_TYPEID.ZERG_ZERGLING.value \ 156 | and u.int_attr.tag in self._scouts: 157 | self.add_scout(u) 158 | 159 | # delete lost scouts 160 | del_keys = [] 161 | for k, b in self._scouts.items(): 162 | if b.is_lost(): 163 | # print('SCOUT overload is over, tag=', b.unit().tag) 164 | del_keys.append(k) 165 | 166 | for k in del_keys: 167 | del self._scouts[k] 168 | 169 | def select_scout(self): 170 | for scout in self._scouts.values(): 171 | if not scout.is_doing_task and scout.is_health(): 172 | return scout 173 | return None 174 | 175 | def select_drone_scout(self): 176 | worker = self._dd.worker_pool.employ_worker(EmployStatus.EMPLOY_SCOUT) 177 | if worker is None: 178 | return None 179 | 180 | self.add_scout(worker.unit) 181 | 182 | return self._scouts[worker.unit.int_attr.tag] 183 | 184 | def select_zergling_scout(self): 185 | zergling = self._dd.combat_pool.employ_combat_unit( 186 | CombatUnitStatus.SCOUT, UNIT_TYPEID.ZERG_ZERGLING.value) 187 | if zergling is None: 188 | return None 189 | 190 | self.add_scout(zergling) 191 | 192 | return self._scouts[zergling.int_attr.tag] 193 | 194 | def find_cruise_target(self): 195 | for target in self._scout_base_target: 196 | if target.has_enemy_base and not target.has_cruise: 197 | return target 198 | return None 199 | 200 | def find_enemy_subbase_target(self): 201 | candidates = [] 202 | for target in self._scout_base_target: 203 | if target.has_enemy_base: 204 | continue 205 | 206 | # if target.has_army: 207 | # continue 208 | 209 | if target.has_scout: 210 | continue 211 | candidates.append(target) 212 | 213 | min_dist = 1000 214 | target = None 215 | for candidate in candidates: 216 | dist = self._dd.base_pool.enemy_home_dist[candidate.area] 217 | if min_dist > dist > 0: 218 | min_dist = dist 219 | target = candidate 220 | return target 221 | 222 | def find_furthest_idle_target(self): 223 | candidates = [] 224 | for target in self._scout_base_target: 225 | if target.has_enemy_base: 226 | continue 227 | 228 | # if target.has_army: 229 | # continue 230 | 231 | if target.has_scout: 232 | continue 233 | candidates.append(target) 234 | 235 | # print('candidate_idle_targe=', len(candidates)) 236 | furthest_dist = 0.0 237 | furthest_candidate = None 238 | for candidate in candidates: 239 | dist = md.calculate_distance(self.home_pos[0], 240 | self.home_pos[1], 241 | candidate.pos[0], 242 | candidate.pos[1]) 243 | if furthest_dist < dist: 244 | furthest_dist = dist 245 | furthest_candidate = candidate 246 | 247 | return furthest_candidate 248 | 249 | def find_forced_scout_target(self): 250 | candidates = [] 251 | for target in self._scout_base_target: 252 | if target.has_scout: 253 | continue 254 | 255 | if target.has_enemy_base or target.has_army: 256 | candidates.append(target) 257 | 258 | furthest_dist = 0.0 259 | furthest_candidate = None 260 | for candidate in candidates: 261 | dist = md.calculate_distance(self.home_pos[0], 262 | self.home_pos[1], 263 | candidate.pos[0], 264 | candidate.pos[1]) 265 | if furthest_dist < dist: 266 | furthest_dist = dist 267 | furthest_candidate = candidate 268 | 269 | return furthest_candidate 270 | 271 | def _init_home_pos(self): 272 | bases = self._dd.base_pool.bases 273 | if len(bases) != 1: 274 | raise Exception('only one base in the game begin') 275 | for base in bases.values(): 276 | self.home_pos = (base.unit.float_attr.pos_x, 277 | base.unit.float_attr.pos_y) 278 | 279 | def _init_scout_base_target(self): 280 | areas = self._dd.base_pool.resource_cluster 281 | if 0 == len(areas): 282 | raise Exception('resource areas is none') 283 | for area in areas: 284 | scout_target = ScoutBaseTarget() 285 | scout_target.area = area 286 | scout_target.pos = area.ideal_base_pos 287 | dist = md.calculate_distance(self.home_pos[0], 288 | self.home_pos[1], 289 | scout_target.pos[0], 290 | scout_target.pos[1]) 291 | if dist > MAX_AREA_DISTANCE: 292 | self._scout_base_target.append(scout_target) 293 | # print('SCOUT area_number=', len(areas)) 294 | # print('SCOUT target_number=', len(self._scout_base_target)) 295 | 296 | def get_view_scouts(self): 297 | valid_scouts = [] 298 | for scout in self._scouts.values(): 299 | if scout.is_doing_task and scout.snapshot_armys is not None: 300 | valid_scouts.append(scout) 301 | return valid_scouts 302 | --------------------------------------------------------------------------------