├── sc2learner ├── __init__.py ├── bin │ ├── __init__.py │ ├── play_vs_ppo_agent.py │ ├── evaluate.py │ ├── train_dqn.py │ ├── train_ppo.py │ └── train_ppo_selfplay.py ├── envs │ ├── __init__.py │ ├── actions │ │ ├── __init__.py │ │ ├── function.py │ │ ├── upgrade.py │ │ ├── produce.py │ │ ├── build.py │ │ ├── placer.py │ │ ├── resource.py │ │ ├── zerg_action_wrappers.py │ │ └── combat.py │ ├── common │ │ ├── __init__.py │ │ ├── utils.py │ │ ├── data_context.py │ │ └── const.py │ ├── rewards │ │ ├── __init__.py │ │ └── reward_wrappers.py │ ├── spaces │ │ ├── __init__.py │ │ ├── pysc2_raw.py │ │ └── mask_discrete.py │ ├── observations │ │ ├── __init__.py │ │ ├── spatial_features.py │ │ ├── nonspatial_features.py │ │ └── zerg_observation_wrappers.py │ ├── lan_raw_env.py │ ├── selfplay_raw_env.py │ └── raw_env.py ├── utils │ ├── __init__.py │ └── utils.py └── agents │ ├── __init__.py │ ├── random_agent.py │ ├── keyboard_agent.py │ ├── utils_tf.py │ ├── dqn_networks.py │ ├── ppo_policies.py │ ├── replay_memory.py │ ├── dqn_agent.py │ └── ppo_agent.py ├── docs └── images │ └── overview.png ├── setup.py └── README.md /sc2learner/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /sc2learner/bin/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /sc2learner/envs/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /sc2learner/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /sc2learner/agents/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /sc2learner/envs/actions/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /sc2learner/envs/common/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /sc2learner/envs/rewards/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /sc2learner/envs/spaces/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /sc2learner/envs/observations/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /docs/images/overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Tencent/TStarBot1/HEAD/docs/images/overview.png -------------------------------------------------------------------------------- /sc2learner/envs/actions/function.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | from collections import namedtuple 6 | 7 | 8 | Function = namedtuple('Function', ['name', 'function', 'is_valid']) 9 | -------------------------------------------------------------------------------- /sc2learner/envs/spaces/pysc2_raw.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import gym 6 | 7 | 8 | class PySC2RawAction(gym.Space): 9 | pass 10 | 11 | 12 | class PySC2RawObservation(gym.Space): 13 | 14 | def __init__(self, observation_spec_fn): 15 | self._feature_layers = observation_spec_fn() 16 | 17 | @property 18 | def space_attr(self): 19 | return self._feature_layers 20 | -------------------------------------------------------------------------------- /sc2learner/envs/spaces/mask_discrete.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import numpy as np 6 | from gym.spaces.discrete import Discrete 7 | 8 | 9 | class MaskDiscrete(Discrete): 10 | 11 | def sample(self, availables): 12 | x = np.random.choice(availables).item() 13 | assert self.contains(x, availables) 14 | return x 15 | 16 | def contains(self, x, availables): 17 | return super(MaskDiscrete, self).contains(x) and x in availables 18 | 19 | def __repr__(self): 20 | return "MaskDiscrete(%d)" % self.n 21 | -------------------------------------------------------------------------------- /sc2learner/agents/random_agent.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import numpy as np 6 | 7 | from sc2learner.envs.spaces.mask_discrete import MaskDiscrete 8 | from sc2learner.envs.spaces.pysc2_raw import PySC2RawAction 9 | 10 | 11 | class RandomAgent(object): 12 | '''Random agent.''' 13 | 14 | def __init__(self, action_space): 15 | self._action_space = action_space 16 | 17 | def act(self, observation, eps=0): 18 | if (isinstance(self._action_space, MaskDiscrete) or 19 | isinstance(self._action_space, PySC2RawAction)): 20 | action_mask = observation[-1] 21 | return self._action_space.sample(np.nonzero(action_mask)[0]) 22 | else: 23 | return self._action_space.sample() 24 | 25 | def reset(self): 26 | pass 27 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | from setuptools import setup 6 | 7 | 8 | description = """Macro-action-based StarCraft-II learning environment.""" 9 | 10 | setup( 11 | name='sc2learner', 12 | version='0.1', 13 | description='Macro-action-based StarCraft-II learning environment.', 14 | long_description=description, 15 | author='Tencent AI Lab', 16 | author_email='xinghaisun@tencent.com', 17 | keywords='sc2learner StarCraft AI TStarBot', 18 | packages=[ 19 | 'sc2learner', 20 | 'sc2learner.agents', 21 | 'sc2learner.envs', 22 | 'sc2learner.utils', 23 | 'sc2learner.bin', 24 | ], 25 | install_requires=[ 26 | 'gym==0.10.5', 27 | 'torch==0.4.0', 28 | 'tensorflow>=1.4.1', 29 | 'joblib', 30 | 'pyzmq' 31 | ] 32 | ) 33 | -------------------------------------------------------------------------------- /sc2learner/envs/common/utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | from pysc2.lib.unit_controls import Unit 6 | 7 | 8 | def distance(a, b): 9 | 10 | def l2_dist(pos_a, pos_b): 11 | return ((pos_a[0] - pos_b[0]) ** 2 + (pos_a[1] - pos_b[1]) ** 2) ** 0.5 12 | 13 | if isinstance(a, Unit) and isinstance(b, Unit): 14 | return l2_dist((a.float_attr.pos_x, a.float_attr.pos_y), 15 | (b.float_attr.pos_x, b.float_attr.pos_y)) 16 | elif not isinstance(a, Unit) and isinstance(b, Unit): 17 | return l2_dist(a, (b.float_attr.pos_x, b.float_attr.pos_y)) 18 | elif isinstance(a, Unit) and not isinstance(b, Unit): 19 | return l2_dist((a.float_attr.pos_x, a.float_attr.pos_y), b) 20 | else: 21 | return l2_dist(a, b) 22 | 23 | 24 | def closest_unit(unit, target_units): 25 | assert len(target_units) > 0 26 | return min(target_units, key=lambda u: distance(unit, u)) 27 | 28 | 29 | def closest_units(unit, target_units, num): 30 | assert len(target_units) > 0 31 | return sorted(target_units, key=lambda u: distance(unit, u))[:num] 32 | 33 | 34 | def closest_distance(unit, target_units): 35 | return min(distance(unit, u) for u in target_units) \ 36 | if len(target_units) > 0 else float('inf') 37 | 38 | 39 | def units_nearby(unit_center, target_units, max_distance): 40 | return [u for u in target_units if distance(unit_center, u) <= max_distance] 41 | 42 | 43 | def strongest_health(units): 44 | return max(u.float_attr.health / u.float_attr.health_max for u in units) 45 | -------------------------------------------------------------------------------- /sc2learner/envs/lan_raw_env.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import gym 6 | from pysc2.env import sc2_env 7 | from pysc2.env import lan_sc2_env 8 | 9 | from sc2learner.envs.spaces.pysc2_raw import PySC2RawAction 10 | from sc2learner.envs.spaces.pysc2_raw import PySC2RawObservation 11 | 12 | 13 | class LanSC2RawEnv(gym.Env): 14 | 15 | def __init__(self, 16 | host, 17 | config_port, 18 | agent_race, 19 | step_mul=8, 20 | resolution=32, 21 | visualize_feature_map=False): 22 | agent_interface_format=sc2_env.parse_agent_interface_format( 23 | feature_screen=resolution, feature_minimap=resolution) 24 | self._sc2_env = lan_sc2_env.LanSC2Env( 25 | host=host, 26 | config_port=config_port, 27 | race=sc2_env.Race[agent_race], 28 | step_mul=step_mul, 29 | agent_interface_format=agent_interface_format, 30 | visualize=visualize_feature_map) 31 | self.observation_space = PySC2RawObservation(self._sc2_env.observation_spec) 32 | self.action_space = PySC2RawAction() 33 | self._reseted = False 34 | 35 | def step(self, actions): 36 | assert self._reseted 37 | timestep = self._sc2_env.step([actions])[0] 38 | observation = timestep.observation 39 | reward = float(timestep.reward) 40 | done = timestep.last() 41 | if done: self._reseted = False 42 | info = {} 43 | return (observation, reward, done, info) 44 | 45 | def reset(self): 46 | timestep = self._sc2_env.reset()[0] 47 | observation = timestep.observation 48 | self._reseted = True 49 | return observation 50 | 51 | def close(self): 52 | self._sc2_env.close() 53 | -------------------------------------------------------------------------------- /sc2learner/agents/keyboard_agent.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import time 6 | import queue 7 | import threading 8 | 9 | from absl import logging 10 | import numpy as np 11 | 12 | from sc2learner.envs.spaces.mask_discrete import MaskDiscrete 13 | from sc2learner.envs.spaces.pysc2_raw import PySC2RawAction 14 | 15 | 16 | def add_input(action_queue, n): 17 | while True: 18 | if action_queue.empty(): 19 | cmds = input("Input Action ID: ") 20 | if not cmds.isdigit(): 21 | print("Input should be an interger. Skipped.") 22 | continue 23 | action = int(cmds) 24 | if action >=0 and action < n: action_queue.put(action) 25 | else: print("Invalid action. Skipped.") 26 | 27 | 28 | class KeyboardAgent(object): 29 | """A random agent for starcraft.""" 30 | def __init__(self, action_space): 31 | super(KeyboardAgent, self).__init__() 32 | logging.set_verbosity(logging.ERROR) 33 | self._action_space = action_space 34 | self._action_queue = queue.Queue() 35 | self._cmd_thread = threading.Thread( 36 | target=add_input, args=(self._action_queue, action_space.n)) 37 | self._cmd_thread.daemon = True 38 | self._cmd_thread.start() 39 | 40 | def act(self, observation, eps=0): 41 | time.sleep(0.1) 42 | if not self._action_queue.empty(): 43 | action = self._action_queue.get() 44 | if (isinstance(self._action_space, MaskDiscrete) or 45 | isinstance(self._action_space, PySC2RawAction)): 46 | action_mask = observation[-1] 47 | if action_mask[action] == 0: 48 | print("Action not available. Availables: %s" % 49 | np.nonzero(action_mask)) 50 | action = 0 51 | return action 52 | else: 53 | return 0 54 | 55 | def reset(self): 56 | pass 57 | -------------------------------------------------------------------------------- /sc2learner/utils/utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | from absl import flags 6 | from datetime import datetime 7 | 8 | 9 | def print_arguments(flags_FLAGS): 10 | arg_name_list = dir(flags.FLAGS) 11 | black_set = set(['alsologtostderr', 12 | 'log_dir', 13 | 'logtostderr', 14 | 'showprefixforinfo', 15 | 'stderrthreshold', 16 | 'v', 17 | 'verbosity', 18 | '?', 19 | 'use_cprofile_for_profiling', 20 | 'help', 21 | 'helpfull', 22 | 'helpshort', 23 | 'helpxml', 24 | 'profile_file', 25 | 'run_with_profiling', 26 | 'only_check_args', 27 | 'pdb_post_mortem', 28 | 'run_with_pdb']) 29 | print("--------------------- Configuration Arguments --------------------") 30 | for arg_name in arg_name_list: 31 | if not arg_name.startswith('sc2_') and arg_name not in black_set: 32 | print("%s: %s" % (arg_name, flags_FLAGS[arg_name].value)) 33 | print("-------------------------------------------------------------------") 34 | 35 | 36 | def tprint(x): 37 | print("[%s] %s" % (str(datetime.now().strftime('%Y-%m-%d %H:%M:%S')), x)) 38 | 39 | 40 | def print_actions(env): 41 | print("----------------------------- Actions -----------------------------") 42 | for action_id, action_name in enumerate(env.action_names): 43 | print("Action ID: %d Action Name: %s" % (action_id, action_name)) 44 | print("-------------------------------------------------------------------") 45 | 46 | 47 | def print_action_distribution(env, action_counts): 48 | print("----------------------- Action Distribution -----------------------") 49 | for action_id, action_name in enumerate(env.action_names): 50 | print("Action ID: %d Count: %d Name: %s" % 51 | (action_id, action_counts[action_id], action_name)) 52 | print("-------------------------------------------------------------------") 53 | -------------------------------------------------------------------------------- /sc2learner/envs/actions/upgrade.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import random 6 | 7 | from s2clientprotocol import sc2api_pb2 as sc_pb 8 | from pysc2.lib.tech_tree import TechTree 9 | 10 | from sc2learner.envs.actions.function import Function 11 | 12 | 13 | class UpgradeActions(object): 14 | 15 | def __init__(self, game_version='4.1.2'): 16 | self._tech_tree = TechTree() 17 | self._tech_tree.update_version(game_version) 18 | 19 | def action(self, func_name, upgrade_id): 20 | return Function(name=func_name, 21 | function=self._upgrade_unit(upgrade_id), 22 | is_valid=self._is_valid_upgrade_unit(upgrade_id)) 23 | 24 | def _upgrade_unit(self, upgrade_id): 25 | 26 | def act(dc): 27 | tech = self._tech_tree.getUpgradeData(upgrade_id) 28 | if len(dc.idle_units_of_types(tech.whatBuilds)) == 0: return [] 29 | upgrader = random.choice(dc.idle_units_of_types(tech.whatBuilds)) 30 | action = sc_pb.Action() 31 | action.action_raw.unit_command.unit_tags.append(upgrader.tag) 32 | action.action_raw.unit_command.ability_id = tech.buildAbility 33 | return [action] 34 | 35 | return act 36 | 37 | def _is_valid_upgrade_unit(self, upgrade_id): 38 | 39 | def is_valid(dc): 40 | tech = self._tech_tree.getUpgradeData(upgrade_id) 41 | has_required_units = any([len(dc.mature_units_of_type(u)) > 0 42 | for u in tech.requiredUnits]) \ 43 | if len(tech.requiredUnits) > 0 else True 44 | has_required_upgrades = all([t in dc.upgraded_techs 45 | for t in tech.requiredUpgrades]) 46 | if (has_required_units and 47 | has_required_upgrades and 48 | upgrade_id not in dc.upgraded_techs and 49 | len(dc.units_with_task(tech.buildAbility)) == 0 and 50 | dc.mineral_count >= tech.mineralCost and 51 | dc.gas_count >= tech.gasCost and 52 | dc.supply_count >= tech.supplyCost and 53 | len(dc.idle_units_of_types(tech.whatBuilds)) > 0): 54 | return True 55 | else: 56 | return False 57 | 58 | return is_valid 59 | -------------------------------------------------------------------------------- /sc2learner/envs/actions/produce.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import random 6 | 7 | from s2clientprotocol import sc2api_pb2 as sc_pb 8 | from pysc2.lib.tech_tree import TechTree 9 | 10 | from sc2learner.envs.actions.function import Function 11 | from sc2learner.envs.common.const import MAXIMUM_NUM 12 | 13 | 14 | class ProduceActions(object): 15 | 16 | def __init__(self, game_version='4.1.2'): 17 | self._tech_tree = TechTree() 18 | self._tech_tree.update_version(game_version) 19 | 20 | def action(self, func_name, type_id): 21 | return Function(name=func_name, 22 | function=self._produce_unit(type_id), 23 | is_valid=self._is_valid_produce_unit(type_id)) 24 | 25 | def _produce_unit(self, type_id): 26 | 27 | def act(dc): 28 | tech = self._tech_tree.getUnitData(type_id) 29 | if len(dc.idle_units_of_types(tech.whatBuilds)) == 0: return [] 30 | producer = random.choice(dc.idle_units_of_types(tech.whatBuilds)) 31 | action = sc_pb.Action() 32 | action.action_raw.unit_command.unit_tags.append(producer.tag) 33 | action.action_raw.unit_command.ability_id = tech.buildAbility 34 | return [action] 35 | 36 | return act 37 | 38 | def _is_valid_produce_unit(self, type_id): 39 | 40 | def is_valid(dc): 41 | tech = self._tech_tree.getUnitData(type_id) 42 | has_required_units = any([len(dc.mature_units_of_type(u)) > 0 43 | for u in tech.requiredUnits]) \ 44 | if len(tech.requiredUnits) > 0 else True 45 | has_required_upgrades = all([t in dc.upgraded_techs 46 | for t in tech.requiredUpgrades]) 47 | current_num = len(dc.units_of_type(type_id)) + \ 48 | len(dc.units_with_task(tech.buildAbility)) 49 | overquota = current_num >= MAXIMUM_NUM[type_id] \ 50 | if type_id in MAXIMUM_NUM else False 51 | if (has_required_units and 52 | has_required_upgrades and 53 | not overquota and 54 | dc.mineral_count >= tech.mineralCost and 55 | dc.gas_count >= tech.gasCost and 56 | dc.supply_count >= tech.supplyCost and 57 | len(dc.idle_units_of_types(tech.whatBuilds)) > 0): 58 | return True 59 | else: 60 | return False 61 | 62 | return is_valid 63 | -------------------------------------------------------------------------------- /sc2learner/envs/observations/spatial_features.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import numpy as np 6 | 7 | from sc2learner.envs.common.const import ALLY_TYPE 8 | from sc2learner.envs.common.const import MAP 9 | 10 | 11 | class UnitTypeCountMapFeature(object): 12 | 13 | def __init__(self, type_map, resolution): 14 | self._type_map = type_map 15 | self._resolution = resolution 16 | 17 | def features(self, observation, need_flip=False): 18 | self_units = [u for u in observation['units'] 19 | if u.int_attr.alliance == ALLY_TYPE.SELF.value] 20 | enemy_units = [u for u in observation['units'] 21 | if u.int_attr.alliance == ALLY_TYPE.ENEMY.value] 22 | self_features = self._generate_features(self_units) 23 | enemy_features = self._generate_features(enemy_units) 24 | features = np.concatenate((self_features, enemy_features)) 25 | if need_flip: features = np.flip(np.flip(features, axis=1), axis=2).copy() 26 | return features 27 | 28 | @property 29 | def num_channels(self): 30 | return (max(self._type_map.values()) + 1) * 2 31 | 32 | def _generate_features(self, units): 33 | num_channels = max(self._type_map.values()) + 1 34 | features = np.zeros((num_channels, self._resolution, self._resolution), 35 | dtype=np.float32) 36 | grid_width = (MAP.WIDTH - MAP.LEFT - MAP.RIGHT) / self._resolution 37 | grid_height = (MAP.HEIGHT - MAP.TOP - MAP.BOTTOM) / self._resolution 38 | for u in units: 39 | if u.unit_type in self._type_map: 40 | c = self._type_map[u.unit_type] 41 | x = (u.float_attr.pos_x - MAP.LEFT) // grid_width 42 | y = self._resolution - 1 - \ 43 | (u.float_attr.pos_y - MAP.BOTTOM) // grid_height 44 | features[c, int(y), int(x)] += 1.0 45 | return features / 5.0 46 | 47 | 48 | class AllianceCountMapFeature(object): 49 | 50 | def __init__(self, resolution): 51 | self._resolution = resolution 52 | 53 | def features(self, observation, need_flip=False): 54 | self_units = [u for u in observation['units'] 55 | if u.int_attr.alliance == ALLY_TYPE.SELF.value] 56 | enemy_units = [u for u in observation['units'] 57 | if u.int_attr.alliance == ALLY_TYPE.ENEMY.value] 58 | neutral_units = [u for u in observation['units'] 59 | if u.int_attr.alliance == ALLY_TYPE.NEUTRAL.value] 60 | self_features = self._generate_features(self_units) 61 | enemy_features = self._generate_features(enemy_units) 62 | neutral_features = self._generate_features(neutral_units) 63 | features = np.concatenate((self_features, enemy_features, neutral_features)) 64 | if need_flip: features = np.flip(np.flip(features, axis=1), axis=2).copy() 65 | return features 66 | 67 | @property 68 | def num_channels(self): 69 | return 3 70 | 71 | def _generate_features(self, units): 72 | features = np.zeros((1, self._resolution, self._resolution), 73 | dtype=np.float32) 74 | grid_width = (MAP.WIDTH - MAP.LEFT - MAP.RIGHT) / self._resolution 75 | grid_height = (MAP.HEIGHT - MAP.TOP - MAP.BOTTOM) / self._resolution 76 | for u in units: 77 | x = (u.float_attr.pos_x - MAP.LEFT) // grid_width 78 | y = self._resolution - 1 - \ 79 | (u.float_attr.pos_y - MAP.BOTTOM) // grid_height 80 | features[0, int(y), int(x)] += 1.0 81 | return features / 5.0 82 | -------------------------------------------------------------------------------- /sc2learner/envs/actions/build.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | from s2clientprotocol import sc2api_pb2 as sc_pb 6 | from pysc2.lib.tech_tree import TechTree 7 | from pysc2.lib.unit_controls import Unit 8 | from pysc2.lib.typeenums import UNIT_TYPEID as UNIT_TYPE 9 | from pysc2.lib.typeenums import ABILITY_ID as ABILITY 10 | 11 | from sc2learner.envs.actions.function import Function 12 | from sc2learner.envs.actions.placer import Placer 13 | import sc2learner.envs.common.utils as utils 14 | from sc2learner.envs.common.const import MAXIMUM_NUM 15 | 16 | 17 | class BuildActions(object): 18 | 19 | def __init__(self, game_version='4.1.2'): 20 | self._placer = Placer() 21 | self._tech_tree = TechTree() 22 | self._tech_tree.update_version(game_version) 23 | 24 | def action(self, func_name, type_id): 25 | return Function(name=func_name, 26 | function=self._build_unit(type_id), 27 | is_valid=self._is_valid_build_unit(type_id)) 28 | 29 | def _build_unit(self, type_id): 30 | 31 | def act(dc): 32 | tech = self._tech_tree.getUnitData(type_id) 33 | pos = self._placer.get_building_position(type_id, dc) 34 | if pos == None: return [] 35 | extractor_tags = set(u.tag for u in dc.units_of_type( 36 | UNIT_TYPE.ZERG_EXTRACTOR.value)) 37 | builders = dc.units_of_types(tech.whatBuilds) 38 | prefered_builders = [ 39 | u for u in builders 40 | if (u.unit_type != UNIT_TYPE.ZERG_DRONE.value or 41 | len(u.orders) == 0 or 42 | (u.orders[0].ability_id == ABILITY.HARVEST_GATHER_DRONE.value and 43 | u.orders[0].target_tag not in extractor_tags)) 44 | ] 45 | if len(prefered_builders) > 0: 46 | builder = utils.closest_unit(pos, prefered_builders) 47 | else: 48 | if len(builders) == 0: return [] 49 | builder = utils.closest_unit(pos, builders) 50 | action = sc_pb.Action() 51 | action.action_raw.unit_command.unit_tags.append(builder.tag) 52 | action.action_raw.unit_command.ability_id = tech.buildAbility 53 | if isinstance(pos, Unit): 54 | action.action_raw.unit_command.target_unit_tag = pos.tag 55 | else: 56 | action.action_raw.unit_command.target_world_space_pos.x = pos[0] 57 | action.action_raw.unit_command.target_world_space_pos.y = pos[1] 58 | return [action] 59 | 60 | return act 61 | 62 | def _is_valid_build_unit(self, type_id): 63 | 64 | def is_valid(dc): 65 | tech = self._tech_tree.getUnitData(type_id) 66 | has_required_units = any([len(dc.mature_units_of_type(u)) > 0 67 | for u in tech.requiredUnits]) \ 68 | if len(tech.requiredUnits) > 0 else True 69 | has_required_upgrades = all([t in dc.upgraded_techs 70 | for t in tech.requiredUpgrades]) 71 | current_num = len(dc.units_of_type(type_id)) + \ 72 | len(dc.units_with_task(tech.buildAbility)) 73 | overquota = current_num >= MAXIMUM_NUM[type_id] \ 74 | if type_id in MAXIMUM_NUM else False 75 | 76 | if (has_required_units and 77 | has_required_upgrades and 78 | not overquota and 79 | dc.mineral_count >= tech.mineralCost and 80 | dc.gas_count >= tech.gasCost and 81 | dc.supply_count >= tech.supplyCost and 82 | len(dc.units_of_types(tech.whatBuilds)) > 0 and 83 | len(dc.units_with_task(tech.buildAbility)) == 0 and 84 | self._placer.can_build(type_id, dc)): 85 | return True 86 | else: 87 | return False 88 | 89 | return is_valid 90 | -------------------------------------------------------------------------------- /sc2learner/agents/utils_tf.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import tensorflow as tf 6 | import numpy as np 7 | 8 | 9 | class Pd(object): 10 | 11 | def neglogp(self, x): 12 | raise NotImplementedError 13 | 14 | def entropy(self): 15 | raise NotImplementedError 16 | 17 | def sample(self): 18 | raise NotImplementedError 19 | 20 | @classmethod 21 | def fromlogits(cls, logits): 22 | return cls(logits) 23 | 24 | 25 | class CategoricalPd(Pd): 26 | 27 | def __init__(self, logits): 28 | self.logits = logits 29 | 30 | def neglogp(self, x): 31 | one_hot_actions = tf.one_hot(x, self.logits.get_shape().as_list()[-1]) 32 | return tf.nn.softmax_cross_entropy_with_logits(logits=self.logits, 33 | labels=one_hot_actions) 34 | 35 | def entropy(self): 36 | a = self.logits - tf.reduce_max(self.logits, axis=-1, keep_dims=True) 37 | ea = tf.exp(a) 38 | z = tf.reduce_sum(ea, axis=-1, keep_dims=True) 39 | p = ea / z 40 | return tf.reduce_sum(p * (tf.log(z) - a), axis=-1) 41 | 42 | def sample(self): 43 | u = tf.random_uniform(tf.shape(self.logits)) 44 | return tf.argmax(self.logits - tf.log(-tf.log(u)), axis=-1) 45 | 46 | 47 | def fc(x, scope, nh, init_scale=1.0, init_bias=0.0): 48 | with tf.variable_scope(scope): 49 | nin = x.get_shape()[1].value 50 | w = tf.get_variable("w", [nin, nh], initializer=ortho_init(init_scale)) 51 | b = tf.get_variable("b", [nh], 52 | initializer=tf.constant_initializer(init_bias)) 53 | return tf.matmul(x, w) + b 54 | 55 | 56 | def lstm(xs, ms, s, scope, nh, init_scale=1.0): 57 | nbatch, nin = [v.value for v in xs[0].get_shape()] 58 | nsteps = len(xs) 59 | with tf.variable_scope(scope): 60 | wx = tf.get_variable("wx", [nin, nh*4], initializer=ortho_init(init_scale)) 61 | wh = tf.get_variable("wh", [nh, nh*4], initializer=ortho_init(init_scale)) 62 | b = tf.get_variable("b", [nh*4], initializer=tf.constant_initializer(0.0)) 63 | 64 | c, h = tf.split(axis=1, num_or_size_splits=2, value=s) 65 | for idx, (x, m) in enumerate(zip(xs, ms)): 66 | c = c * (1 - m) 67 | h = h * (1 - m) 68 | z = tf.matmul(x, wx) + tf.matmul(h, wh) + b 69 | i, f, o, u = tf.split(axis=1, num_or_size_splits=4, value=z) 70 | i = tf.nn.sigmoid(i) 71 | f = tf.nn.sigmoid(f) 72 | o = tf.nn.sigmoid(o) 73 | u = tf.tanh(u) 74 | c = f * c + i * u 75 | h = o * tf.tanh(c) 76 | xs[idx] = h 77 | s = tf.concat(axis=1, values=[c, h]) 78 | return xs, s 79 | 80 | 81 | def batch_to_seq(h, nbatch, nsteps, flat=False): 82 | if flat: h = tf.reshape(h, [nbatch, nsteps]) 83 | else: h = tf.reshape(h, [nbatch, nsteps, -1]) 84 | return [tf.squeeze(v, [1]) 85 | for v in tf.split(axis=1, num_or_size_splits=nsteps, value=h)] 86 | 87 | 88 | def seq_to_batch(h, flat = False): 89 | shape = h[0].get_shape().as_list() 90 | if not flat: 91 | assert len(shape) > 1 92 | nh = h[0].get_shape()[-1].value 93 | return tf.reshape(tf.concat(axis=1, values=h), [-1, nh]) 94 | else: 95 | return tf.reshape(tf.stack(values=h, axis=1), [-1]) 96 | 97 | 98 | def ortho_init(scale=1.0): 99 | 100 | def _ortho_init(shape, dtype, partition_info=None): 101 | shape = tuple(shape) 102 | if len(shape) == 2: flat_shape = shape 103 | elif len(shape) == 4: flat_shape = (np.prod(shape[:-1]), shape[-1]) #NHWC 104 | else: raise NotImplementedError 105 | a = np.random.normal(0.0, 1.0, flat_shape) 106 | u, _, v = np.linalg.svd(a, full_matrices=False) 107 | q = u if u.shape == flat_shape else v 108 | q = q.reshape(shape) 109 | return (scale * q[:shape[0], :shape[1]]).astype(np.float32) 110 | 111 | return _ortho_init 112 | 113 | 114 | def explained_variance(ypred,y): 115 | assert y.ndim == 1 and ypred.ndim == 1 116 | var_y = np.var(y) 117 | return np.nan if var_y == 0 else 1 - np.var(y - ypred) / var_y 118 | -------------------------------------------------------------------------------- /sc2learner/envs/selfplay_raw_env.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import gym 6 | from pysc2.env import sc2_env 7 | 8 | from sc2learner.envs.spaces.pysc2_raw import PySC2RawAction 9 | from sc2learner.envs.spaces.pysc2_raw import PySC2RawObservation 10 | from sc2learner.utils.utils import tprint 11 | 12 | 13 | DIFFICULTIES= { 14 | "1": sc2_env.Difficulty.very_easy, 15 | "2": sc2_env.Difficulty.easy, 16 | "3": sc2_env.Difficulty.medium, 17 | "4": sc2_env.Difficulty.medium_hard, 18 | "5": sc2_env.Difficulty.hard, 19 | "6": sc2_env.Difficulty.hard, 20 | "7": sc2_env.Difficulty.very_hard, 21 | "8": sc2_env.Difficulty.cheat_vision, 22 | "9": sc2_env.Difficulty.cheat_money, 23 | "A": sc2_env.Difficulty.cheat_insane, 24 | } 25 | 26 | 27 | class SC2SelfplayRawEnv(gym.Env): 28 | 29 | def __init__(self, 30 | map_name, 31 | step_mul=8, 32 | resolution=32, 33 | disable_fog=False, 34 | agent_race='random', 35 | opponent_race='random', 36 | game_steps_per_episode=None, 37 | tie_to_lose=False, 38 | score_index=None, 39 | random_seed=None): 40 | self._map_name = map_name 41 | self._step_mul = step_mul 42 | self._resolution = resolution 43 | self._disable_fog = disable_fog 44 | self._agent_race = agent_race 45 | self._opponent_race = opponent_race 46 | self._game_steps_per_episode = game_steps_per_episode 47 | self._tie_to_lose = tie_to_lose 48 | self._score_index = score_index 49 | self._random_seed = random_seed 50 | self._reseted = False 51 | self._first_create = True 52 | 53 | self._sc2_env = self._safe_create_env() 54 | self.observation_space = PySC2RawObservation(self._sc2_env.observation_spec) 55 | self.action_space = PySC2RawAction() 56 | 57 | def step(self, actions): 58 | assert self._reseted 59 | assert len(actions) == 2 60 | timesteps = self._sc2_env.step(actions) 61 | observation = [timesteps[0].observation, timesteps[1].observation] 62 | reward = float(timesteps[0].reward) 63 | done = timesteps[0].last() 64 | if done: 65 | self._reseted = False 66 | if self._tie_to_lose and reward == 0: 67 | reward = -1.0 68 | tprint("Episode Done. Outcome %f" % reward) 69 | info = {} 70 | return (observation, reward, done, info) 71 | 72 | def reset(self): 73 | timesteps = self._safe_reset() 74 | self._reseted = True 75 | return [timesteps[0].observation, timesteps[1].observation] 76 | 77 | def _reset(self): 78 | if not self._first_create: 79 | self._sc2_env.close() 80 | self._sc2_env = self._create_env() 81 | self._first_create = False 82 | return self._sc2_env.reset() 83 | 84 | def _safe_reset(self, max_retry=10): 85 | for _ in range(max_retry - 1): 86 | try: return self._reset() 87 | except: pass 88 | return self._reset() 89 | 90 | def close(self): 91 | self._sc2_env.close() 92 | 93 | def _create_env(self): 94 | self._random_seed = (self._random_seed + 1) & 0xFFFFFFFF 95 | players=[sc2_env.Agent(sc2_env.Race[self._agent_race]), 96 | sc2_env.Agent(sc2_env.Race[self._opponent_race])] 97 | agent_interface_format=sc2_env.parse_agent_interface_format( 98 | feature_screen=self._resolution, feature_minimap=self._resolution) 99 | return sc2_env.SC2Env( 100 | map_name=self._map_name, 101 | step_mul=self._step_mul, 102 | players=players, 103 | agent_interface_format=agent_interface_format, 104 | disable_fog=self._disable_fog, 105 | game_steps_per_episode=self._game_steps_per_episode, 106 | visualize=False, 107 | score_index=self._score_index, 108 | random_seed=self._random_seed) 109 | 110 | def _safe_create_env(self, max_retry=10): 111 | for _ in range(max_retry - 1): 112 | try: return self._create_env() 113 | except: pass 114 | return self._create_env() 115 | -------------------------------------------------------------------------------- /sc2learner/envs/raw_env.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import gym 6 | from pysc2.env import sc2_env 7 | 8 | from sc2learner.envs.spaces.pysc2_raw import PySC2RawAction 9 | from sc2learner.envs.spaces.pysc2_raw import PySC2RawObservation 10 | from sc2learner.utils.utils import tprint 11 | 12 | 13 | DIFFICULTIES= { 14 | "1": sc2_env.Difficulty.very_easy, 15 | "2": sc2_env.Difficulty.easy, 16 | "3": sc2_env.Difficulty.medium, 17 | "4": sc2_env.Difficulty.medium_hard, 18 | "5": sc2_env.Difficulty.hard, 19 | "6": sc2_env.Difficulty.hard, 20 | "7": sc2_env.Difficulty.very_hard, 21 | "8": sc2_env.Difficulty.cheat_vision, 22 | "9": sc2_env.Difficulty.cheat_money, 23 | "A": sc2_env.Difficulty.cheat_insane, 24 | } 25 | 26 | 27 | class SC2RawEnv(gym.Env): 28 | 29 | def __init__(self, 30 | map_name, 31 | step_mul=8, 32 | resolution=32, 33 | disable_fog=False, 34 | agent_race='random', 35 | bot_race='random', 36 | difficulty='1', 37 | game_steps_per_episode=None, 38 | tie_to_lose=False, 39 | score_index=None, 40 | random_seed=None): 41 | self._map_name = map_name 42 | self._step_mul = step_mul 43 | self._resolution = resolution 44 | self._disable_fog = disable_fog 45 | self._agent_race = agent_race 46 | self._bot_race = bot_race 47 | self._difficulty = difficulty 48 | self._game_steps_per_episode = game_steps_per_episode 49 | self._tie_to_lose = tie_to_lose 50 | self._score_index = score_index 51 | self._random_seed = random_seed 52 | self._reseted = False 53 | self._first_create = True 54 | 55 | self._sc2_env = self._safe_create_env() 56 | self.observation_space = PySC2RawObservation(self._sc2_env.observation_spec) 57 | self.action_space = PySC2RawAction() 58 | 59 | def step(self, actions): 60 | assert self._reseted 61 | timestep = self._sc2_env.step([actions])[0] 62 | observation = timestep.observation 63 | reward = float(timestep.reward) 64 | done = timestep.last() 65 | if done: 66 | self._reseted = False 67 | if self._tie_to_lose and reward == 0: 68 | reward = -1.0 69 | tprint("Episode Done. Difficulty: %s Outcome %f" % 70 | (self._difficulty, reward)) 71 | info = {} 72 | return (observation, reward, done, info) 73 | 74 | def reset(self): 75 | timesteps = self._safe_reset() 76 | self._reseted = True 77 | return timesteps[0].observation 78 | 79 | def _reset(self): 80 | if not self._first_create: 81 | self._sc2_env.close() 82 | self._sc2_env = self._create_env() 83 | self._first_create = False 84 | return self._sc2_env.reset() 85 | 86 | def _safe_reset(self, max_retry=10): 87 | for _ in range(max_retry - 1): 88 | try: return self._reset() 89 | except: pass 90 | return self._reset() 91 | 92 | def close(self): 93 | self._sc2_env.close() 94 | 95 | def _create_env(self): 96 | self._random_seed = (self._random_seed + 11) & 0xFFFFFFFF 97 | players=[sc2_env.Agent(sc2_env.Race[self._agent_race]), 98 | sc2_env.Bot(sc2_env.Race[self._bot_race], 99 | DIFFICULTIES[self._difficulty])] 100 | agent_interface_format=sc2_env.parse_agent_interface_format( 101 | feature_screen=self._resolution, feature_minimap=self._resolution) 102 | tprint("Creating game with seed %d." % self._random_seed) 103 | return sc2_env.SC2Env( 104 | map_name=self._map_name, 105 | step_mul=self._step_mul, 106 | players=players, 107 | agent_interface_format=agent_interface_format, 108 | disable_fog=self._disable_fog, 109 | game_steps_per_episode=self._game_steps_per_episode, 110 | visualize=False, 111 | score_index=self._score_index, 112 | random_seed=self._random_seed) 113 | 114 | def _safe_create_env(self, max_retry=10): 115 | for _ in range(max_retry - 1): 116 | try: return self._create_env() 117 | except: pass 118 | return self._create_env() 119 | -------------------------------------------------------------------------------- /sc2learner/agents/dqn_networks.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | 10 | class DuelingQNet(nn.Module): 11 | 12 | def __init__(self, 13 | resolution, 14 | n_channels, 15 | n_dims, 16 | n_out, 17 | batchnorm=False): 18 | super(DuelingQNet, self).__init__() 19 | assert resolution == 16 20 | self.conv1 = nn.Conv2d(in_channels=n_channels, 21 | out_channels=32, 22 | kernel_size=5, 23 | stride=1, 24 | padding=2) 25 | self.conv2 = nn.Conv2d(in_channels=32, 26 | out_channels=32, 27 | kernel_size=3, 28 | stride=1, 29 | padding=1) 30 | self.conv3 = nn.Conv2d(in_channels=32, 31 | out_channels=16, 32 | kernel_size=3, 33 | stride=2, 34 | padding=1) 35 | if batchnorm: 36 | self.bn1 = nn.BatchNorm2d(32) 37 | self.bn2 = nn.BatchNorm2d(32) 38 | self.bn3 = nn.BatchNorm2d(16) 39 | 40 | self.value_sp_fc = nn.Linear(16 * 8 * 8, 256) 41 | self.value_nonsp_fc1 = nn.Linear(n_dims, 512) 42 | self.value_nonsp_fc2 = nn.Linear(512, 512) 43 | self.value_nonsp_fc3 = nn.Linear(512, 256) 44 | self.value_final_fc = nn.Linear(512, 1) 45 | 46 | self.adv_sp_fc = nn.Linear(16 * 8 * 8, 256) 47 | self.adv_nonsp_fc1 = nn.Linear(n_dims, 512) 48 | self.adv_nonsp_fc2 = nn.Linear(512, 512) 49 | self.adv_nonsp_fc3 = nn.Linear(512, 256) 50 | self.adv_final_fc = nn.Linear(512, n_out) 51 | self._batchnorm = batchnorm 52 | 53 | def forward(self, x): 54 | spatial, nonspatial = x 55 | if self._batchnorm: 56 | spatial = F.relu(self.bn1(self.conv1(spatial))) 57 | spatial = F.relu(self.bn2(self.conv2(spatial))) 58 | spatial = F.relu(self.bn3(self.conv3(spatial))) 59 | else: 60 | spatial = F.relu(self.conv1(spatial)) 61 | spatial = F.relu(self.conv2(spatial)) 62 | spatial = F.relu(self.conv3(spatial)) 63 | spatial = spatial.view(spatial.size(0), -1) 64 | 65 | value_sp_state = F.relu(self.value_sp_fc(spatial)) 66 | value_nonsp_state = F.relu(self.value_nonsp_fc1(nonspatial)) 67 | value_nonsp_state = F.relu(self.value_nonsp_fc2(value_nonsp_state)) 68 | value_nonsp_state = F.relu(self.value_nonsp_fc3(value_nonsp_state)) 69 | value_state = torch.cat((value_sp_state, value_nonsp_state), 1) 70 | value = self.value_final_fc(value_state) 71 | 72 | adv_sp_state = F.relu(self.adv_sp_fc(spatial)) 73 | adv_nonsp_state = F.relu(self.adv_nonsp_fc1(nonspatial)) 74 | adv_nonsp_state = F.relu(self.adv_nonsp_fc2(adv_nonsp_state)) 75 | adv_nonsp_state = F.relu(self.adv_nonsp_fc3(adv_nonsp_state)) 76 | adv_state = torch.cat((adv_sp_state, adv_nonsp_state), 1) 77 | adv = self.adv_final_fc(adv_state) 78 | adv_subtract = adv - adv.mean(dim=1, keepdim=True) 79 | return value + adv_subtract 80 | 81 | 82 | class NonspatialDuelingQNet(nn.Module): 83 | 84 | def __init__(self, n_dims, n_out): 85 | super(NonspatialDuelingQNet, self).__init__() 86 | self.value_fc1 = nn.Linear(n_dims, 512) 87 | self.value_fc2 = nn.Linear(512, 512) 88 | self.value_fc3 = nn.Linear(512, 256) 89 | self.value_fc4 = nn.Linear(256, 1) 90 | 91 | self.adv_fc1 = nn.Linear(n_dims, 512) 92 | self.adv_fc2 = nn.Linear(512, 512) 93 | self.adv_fc3 = nn.Linear(512, 256) 94 | self.adv_fc4 = nn.Linear(256, n_out) 95 | 96 | def forward(self, x): 97 | value = F.relu(self.value_fc1(x)) 98 | value = F.relu(self.value_fc2(value)) 99 | value = F.relu(self.value_fc3(value)) 100 | value = self.value_fc4(value) 101 | 102 | adv = F.relu(self.adv_fc1(x)) 103 | adv = F.relu(self.adv_fc2(adv)) 104 | adv = F.relu(self.adv_fc3(adv)) 105 | adv = self.adv_fc4(adv) 106 | 107 | adv_subtract = adv - adv.mean(dim=1, keepdim=True) 108 | return value + adv_subtract 109 | -------------------------------------------------------------------------------- /sc2learner/bin/play_vs_ppo_agent.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import sys 6 | import traceback 7 | import multiprocessing 8 | 9 | from absl import app 10 | from absl import flags 11 | from absl import logging 12 | import tensorflow as tf 13 | 14 | from sc2learner.envs.lan_raw_env import LanSC2RawEnv 15 | from sc2learner.envs.observations.zerg_observation_wrappers \ 16 | import ZergObservationWrapper 17 | from sc2learner.envs.actions.zerg_action_wrappers import ZergActionWrapper 18 | from sc2learner.utils.utils import print_arguments 19 | from sc2learner.agents.ppo_policies import LstmPolicy, MlpPolicy 20 | from sc2learner.agents.ppo_agent import PPOAgent 21 | 22 | 23 | FLAGS = flags.FLAGS 24 | flags.DEFINE_string("game_version", '4.6', "Game core version.") 25 | flags.DEFINE_string("model_path", None, "Filepath to load initial model.") 26 | flags.DEFINE_integer("step_mul", 32, "Game steps per agent step.") 27 | flags.DEFINE_string("host", "127.0.0.1", "Game Host. Can be 127.0.0.1 or ::1") 28 | flags.DEFINE_integer( 29 | "config_port", 14380, 30 | "Where to set/find the config port. The host starts a tcp server to share " 31 | "the config with the client, and to proxy udp traffic if played over an " 32 | "ssh tunnel. This sets that port, and is also the start of the range of " 33 | "ports used for LAN play.") 34 | flags.DEFINE_boolean("use_all_combat_actions", False, "Use all combat actions.") 35 | flags.DEFINE_boolean("use_region_features", False, "Use region features") 36 | flags.DEFINE_boolean("use_action_mask", True, "Use action mask or not.") 37 | flags.DEFINE_enum("policy", 'mlp', ['mlp', 'lstm'], "Job type.") 38 | 39 | 40 | def print_actions(env): 41 | print("----------------------------- Actions -----------------------------") 42 | for action_id, action_name in enumerate(env.action_names): 43 | print("Action ID: %d Action Name: %s" % (action_id, action_name)) 44 | print("-------------------------------------------------------------------") 45 | 46 | 47 | def print_action_distribution(env, action_counts): 48 | print("----------------------- Action Distribution -----------------------") 49 | for action_id, action_name in enumerate(env.action_names): 50 | print("Action ID: %d Count: %d Name: %s" % 51 | (action_id, action_counts[action_id], action_name)) 52 | print("-------------------------------------------------------------------") 53 | 54 | 55 | def tf_config(ncpu=None): 56 | if ncpu is None: 57 | ncpu = multiprocessing.cpu_count() 58 | if sys.platform == 'darwin': ncpu //= 2 59 | config = tf.ConfigProto(allow_soft_placement=True, 60 | intra_op_parallelism_threads=ncpu, 61 | inter_op_parallelism_threads=ncpu) 62 | config.gpu_options.allow_growth = True 63 | tf.Session(config=config).__enter__() 64 | 65 | 66 | def start_lan_agent(): 67 | """Run the agent, connecting to a host started independently.""" 68 | tf_config() 69 | env = LanSC2RawEnv(host=FLAGS.host, 70 | config_port=FLAGS.config_port, 71 | agent_race='zerg', 72 | step_mul=FLAGS.step_mul, 73 | visualize_feature_map=False) 74 | env = ZergActionWrapper(env, 75 | game_version=FLAGS.game_version, 76 | mask=FLAGS.use_action_mask, 77 | use_all_combat_actions=FLAGS.use_all_combat_actions) 78 | env = ZergObservationWrapper( 79 | env, 80 | use_spatial_features=False, 81 | use_game_progress=(not FLAGS.policy == 'lstm'), 82 | action_seq_len=1 if FLAGS.policy == 'lstm' else 8, 83 | use_regions=FLAGS.use_region_features) 84 | print_actions(env) 85 | policy = {'lstm': LstmPolicy, 86 | 'mlp': MlpPolicy}[FLAGS.policy] 87 | agent = PPOAgent(env=env, 88 | policy=policy, 89 | model_path=FLAGS.model_path) 90 | try: 91 | action_counts = [0] * env.action_space.n 92 | observation = env.reset() 93 | done, step_id = False, 0 94 | while not done: 95 | action = agent.act(observation) 96 | print("Step ID: %d Take Action: %d" % (step_id, action)) 97 | observation, reward, done, _ = env.step(action) 98 | action_counts[action] += 1 99 | step_id += 1 100 | print_action_distribution(env, action_counts) 101 | except KeyboardInterrupt: pass 102 | except: traceback.print_exc() 103 | env.close() 104 | 105 | 106 | def main(unused_argv): 107 | logging.set_verbosity(logging.ERROR) 108 | print_arguments(FLAGS) 109 | start_lan_agent() 110 | 111 | 112 | if __name__ == "__main__": 113 | app.run(main) 114 | -------------------------------------------------------------------------------- /sc2learner/agents/ppo_policies.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import numpy as np 6 | import tensorflow as tf 7 | 8 | from sc2learner.envs.spaces.mask_discrete import MaskDiscrete 9 | from sc2learner.agents.utils_tf import CategoricalPd 10 | from sc2learner.agents.utils_tf import fc, lstm, batch_to_seq, seq_to_batch 11 | 12 | 13 | class MlpPolicy(object): 14 | def __init__(self, sess, scope_name, ob_space, ac_space, nbatch, nsteps, 15 | reuse=False): 16 | if isinstance(ac_space, MaskDiscrete): 17 | ob_space, mask_space = ob_space.spaces 18 | 19 | X = tf.placeholder( 20 | shape=(nbatch,) + ob_space.shape, dtype=tf.float32, name="x_screen") 21 | if isinstance(ac_space, MaskDiscrete): 22 | MASK = tf.placeholder( 23 | shape=(nbatch,) + mask_space.shape, dtype=tf.float32, name="mask") 24 | 25 | with tf.variable_scope(scope_name, reuse=reuse): 26 | x = tf.layers.flatten(X) 27 | pi_h1 = tf.tanh(fc(x, 'pi_fc1', nh=512, init_scale=np.sqrt(2))) 28 | pi_h2 = tf.tanh(fc(pi_h1, 'pi_fc2', nh=512, init_scale=np.sqrt(2))) 29 | pi_h3 = tf.tanh(fc(pi_h2, 'pi_fc3', nh=512, init_scale=np.sqrt(2))) 30 | vf_h1 = tf.tanh(fc(x, 'vf_fc1', nh=512, init_scale=np.sqrt(2))) 31 | vf_h2 = tf.tanh(fc(vf_h1, 'vf_fc2', nh=512, init_scale=np.sqrt(2))) 32 | vf_h3 = tf.tanh(fc(vf_h2, 'vf_fc3', nh=512, init_scale=np.sqrt(2))) 33 | vf = fc(vf_h3, 'vf', 1)[:,0] 34 | pi_logit = fc(pi_h3, 'pi', ac_space.n, init_scale=0.01, init_bias=0.0) 35 | if isinstance(ac_space, MaskDiscrete): 36 | pi_logit -= (1 - MASK) * 1e30 37 | self.pd = CategoricalPd(pi_logit) 38 | 39 | action = self.pd.sample() 40 | neglogp = self.pd.neglogp(action) 41 | self.initial_state = None 42 | 43 | def step(ob, *_args, **_kwargs): 44 | if isinstance(ac_space, MaskDiscrete): 45 | a, v, nl = sess.run([action, vf, neglogp], {X:ob[0], MASK:ob[-1]}) 46 | else: 47 | a, v, nl = sess.run([action, vf, neglogp], {X:ob}) 48 | return a, v, self.initial_state, nl 49 | 50 | def value(ob, *_args, **_kwargs): 51 | if isinstance(ac_space, MaskDiscrete): 52 | return sess.run(vf, {X:ob[0], MASK:ob[-1]}) 53 | else: 54 | return sess.run(vf, {X:ob}) 55 | 56 | self.X = X 57 | if isinstance(ac_space, MaskDiscrete): 58 | self.MASK = MASK 59 | self.vf = vf 60 | self.step = step 61 | self.value = value 62 | 63 | 64 | class LstmPolicy(object): 65 | 66 | def __init__(self, sess, scope_name, ob_space, ac_space, nbatch, 67 | unroll_length, nlstm=512, reuse=False): 68 | nenv = nbatch // unroll_length 69 | if isinstance(ac_space, MaskDiscrete): 70 | ob_space, mask_space = ob_space.spaces 71 | 72 | DONE = tf.placeholder(tf.float32, [nbatch]) 73 | STATE = tf.placeholder(tf.float32, [nenv, nlstm * 2]) 74 | X = tf.placeholder( 75 | shape=(nbatch,) + ob_space.shape, dtype=tf.float32, name="x_screen") 76 | if isinstance(ac_space, MaskDiscrete): 77 | MASK = tf.placeholder( 78 | shape=(nbatch,) + mask_space.shape, dtype=tf.float32, name="mask") 79 | 80 | with tf.variable_scope(scope_name, reuse=reuse): 81 | x = tf.layers.flatten(X) 82 | fc1 = tf.nn.relu(fc(x, 'fc1', 512)) 83 | h = tf.nn.relu(fc(fc1, 'fc2', 512)) 84 | xs = batch_to_seq(h, nenv, unroll_length) 85 | ms = batch_to_seq(DONE, nenv, unroll_length) 86 | h5, snew = lstm(xs, ms, STATE, 'lstm1', nh=nlstm) 87 | h5 = seq_to_batch(h5) 88 | vf = fc(h5, 'v', 1) 89 | pi_logit = fc(h5, 'pi', ac_space.n, init_scale=1.0, init_bias=0.0) 90 | if isinstance(ac_space, MaskDiscrete): 91 | pi_logit -= (1 - MASK) * 1e30 92 | self.pd = CategoricalPd(pi_logit) 93 | 94 | action = self.pd.sample() 95 | neglogp = self.pd.neglogp(action) 96 | self.initial_state = np.zeros((nenv, nlstm*2), dtype=np.float32) 97 | 98 | def step(ob, state, done): 99 | if isinstance(ac_space, MaskDiscrete): 100 | return sess.run([action, vf, snew, neglogp], 101 | {X:ob[0], MASK:ob[-1], STATE:state, DONE:done}) 102 | else: 103 | return sess.run([action, vf, snew, neglogp], 104 | {X:ob, STATE:state, DONE:done}) 105 | 106 | def value(ob, state, done): 107 | if isinstance(ac_space, MaskDiscrete): 108 | return sess.run(vf, {X:ob[0], MASK:ob[-1], STATE:state, DONE:mask}) 109 | else: 110 | return sess.run(vf, {X:ob, STATE:state, DONE:mask}) 111 | 112 | self.X = X 113 | if isinstance(ac_space, MaskDiscrete): 114 | self.MASK = MASK 115 | self.DONE = DONE 116 | self.STATE = STATE 117 | self.vf = vf 118 | self.step = step 119 | self.value = value 120 | -------------------------------------------------------------------------------- /sc2learner/agents/replay_memory.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | from collections import namedtuple 6 | from collections import deque 7 | from threading import Thread 8 | import random 9 | import time 10 | 11 | import zmq 12 | 13 | 14 | Transition = namedtuple('Transition', 15 | ('observation', 'action', 'reward', 'next_observation', 16 | 'done', 'mc_return')) 17 | 18 | 19 | class LocalReplayMemory(object): 20 | def __init__(self, capacity): 21 | self._memory = deque(maxlen=capacity) 22 | self._total = 0 23 | 24 | def push(self, *args): 25 | self._memory.append(Transition(*args)) 26 | self._total += 1 27 | 28 | def sample(self, batch_size): 29 | return random.sample(self._memory, batch_size) 30 | 31 | @property 32 | def total(self): 33 | return self._total 34 | 35 | 36 | class RemoteReplayMemory(object): 37 | def __init__(self, 38 | is_server, 39 | memory_size, 40 | memory_warmup_size, 41 | block_size=128, 42 | send_freq=1.0, 43 | num_pull_threads=4, 44 | ports=("5700", "5701"), 45 | server_ip="localhost"): 46 | assert len(ports) == 2 47 | assert memory_warmup_size <= memory_size 48 | self._is_server = is_server 49 | self._memory_warmup_size = memory_warmup_size 50 | self._block_size = block_size 51 | 52 | if is_server: 53 | self._num_received, self._num_used, self._total = 0, 0, 0 54 | self._cache_blocks = deque(maxlen=memory_size // block_size) 55 | self._zmq_context = zmq.Context() 56 | 57 | self._receiver_threads = [Thread(target=self._server_proxy_worker, 58 | args=(self._zmq_context, ports,))] 59 | self._receiver_threads += [Thread(target=self._server_receiver_worker, 60 | args=(self._zmq_context, ports[1],)) 61 | for _ in range(num_pull_threads)] 62 | for thread in self._receiver_threads: thread.start() 63 | else: 64 | self._memory = LocalReplayMemory(memory_size) 65 | self._memory_total_last = 0 66 | self._send_interval = int(block_size / send_freq) 67 | 68 | self._zmq_context = zmq.Context() 69 | self._sender = self._zmq_context.socket(zmq.PUSH) 70 | self._sender.connect("tcp://%s:%s" % (server_ip, ports[0])) 71 | 72 | def push(self, *args): 73 | assert not self._is_server, "push() cannot be called when is_server=True." 74 | self._memory.push(*args) 75 | if (self._memory.total >= self._memory_warmup_size and 76 | self._memory.total >= self._block_size and 77 | self._memory.total % self._send_interval == 0): 78 | block = self._memory.sample(self._block_size) 79 | memory_total = self._memory.total 80 | memory_delta = memory_total - self._memory_total_last 81 | self._memory_total_last = memory_total 82 | self._sender.send_pyobj((block, memory_delta)) 83 | 84 | def sample(self, batch_size, reuse_ratio=1.0): 85 | assert self._is_server, "sample() cannot be called when is_server=False." 86 | while (self._num_used / reuse_ratio >= self._num_received or 87 | self._memory_warmup_size > len(self._cache_blocks) * self._block_size): 88 | time.sleep(0.001) 89 | batch = [random.choice(random.choice(self._cache_blocks)) 90 | for _ in range(batch_size)] 91 | self._num_used += batch_size 92 | return batch 93 | 94 | @property 95 | def total(self): 96 | if self._is_server: 97 | return self._total 98 | else: 99 | return self._memory.total 100 | 101 | def _server_receiver_worker(self, zmq_context, port): 102 | receiver = zmq_context.socket(zmq.PULL) 103 | receiver.connect("tcp://localhost:%s" % port) 104 | while True: 105 | block, delta = receiver.recv_pyobj() 106 | self._cache_blocks.append(block) 107 | self._total += delta 108 | self._num_received += len(block) 109 | 110 | def _server_proxy_worker(self, zmq_context, ports): 111 | assert len(ports) == 2 112 | frontend = zmq_context.socket(zmq.PULL) 113 | frontend.bind("tcp://*:%s" % ports[0]) 114 | backend = self._zmq_context.socket(zmq.PUSH) 115 | backend.bind("tcp://*:%s" % ports[1]) 116 | zmq.proxy(frontend, backend) 117 | 118 | 119 | if __name__ == '__main__': 120 | import sys 121 | import numpy as np 122 | 123 | job_name = sys.argv[1] 124 | if job_name == 'client': 125 | replay_memory = RemoteReplayMemory( 126 | is_server=False, 127 | memory_size=10000, 128 | memory_warmup_size=16) 129 | while True: 130 | obs, next_obs = np.array([1,2,3]), np.array([3,4,5]) 131 | action, reward, done, mc_return = 1, 0.5, False, 0.01 132 | replay_memory.push(obs, action, reward, next_obs, done, mc_return) 133 | else: 134 | replay_memory = RemoteReplayMemory( 135 | is_server=True, 136 | memory_size=10000, 137 | memory_warmup_size=10000) 138 | while True: 139 | print(replay_memory.sample(8)) 140 | -------------------------------------------------------------------------------- /sc2learner/envs/common/data_context.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import itertools 6 | 7 | from pysc2.lib.typeenums import UNIT_TYPEID as UNIT_TYPE 8 | 9 | from sc2learner.envs.common.const import ALLY_TYPE 10 | from sc2learner.envs.common.const import PLAYER_FEATURE 11 | from sc2learner.envs.common.const import COMBAT_TYPES 12 | import sc2learner.envs.common.utils as utils 13 | 14 | 15 | class DataContext(object): 16 | 17 | def __init__(self): 18 | self._units = [] 19 | self._player = None 20 | self._raw_data = None 21 | self._existed_tags = set() 22 | 23 | def update(self, observation): 24 | for u in self._units: 25 | self._existed_tags.add(u.tag) 26 | self._units = observation['units'] 27 | self._player = observation['player'] 28 | self._raw_data = observation['raw_data'] 29 | self._combat_units = self.units_of_types(COMBAT_TYPES) 30 | 31 | def reset(self, observation): 32 | self._existed_tags.clear() 33 | self.update(observation) 34 | init_base = self.units_of_type(UNIT_TYPE.ZERG_HATCHERY.value)[0] 35 | self._init_base_pos = (init_base.float_attr.pos_x, 36 | init_base.float_attr.pos_y) 37 | 38 | def units_of_alliance(self, ally): 39 | return [u for u in self._units if u.int_attr.alliance == ally] 40 | 41 | def units_of_type(self, type_id, ally=ALLY_TYPE.SELF.value): 42 | return [u for u in self.units_of_alliance(ally) if u.unit_type == type_id] 43 | 44 | def mature_units_of_type(self, type_id, ally=ALLY_TYPE.SELF.value): 45 | return [u for u in self.units_of_type(type_id, ally) 46 | if u.float_attr.build_progress >= 1.0] 47 | 48 | def idle_units_of_type(self, type_id, ally=ALLY_TYPE.SELF.value): 49 | return [u for u in self.mature_units_of_type(type_id, ally) 50 | if len(u.orders) == 0] 51 | 52 | def units_of_types(self, type_list, ally=ALLY_TYPE.SELF.value): 53 | type_set = set(type_list) 54 | return [u for u in self.units_of_alliance(ally) if u.unit_type in type_set] 55 | 56 | def mature_units_of_types(self, type_list, ally=ALLY_TYPE.SELF.value): 57 | return [u for u in self.units_of_types(type_list, ally) 58 | if u.float_attr.build_progress >= 1.0] 59 | 60 | def idle_units_of_types(self, type_list, ally=ALLY_TYPE.SELF.value): 61 | return [u for u in self.mature_units_of_types(type_list, ally) 62 | if len(u.orders) == 0] 63 | 64 | def units_with_task(self, ability_id, ally=ALLY_TYPE.SELF.value): 65 | return [u for u in self.units_of_alliance(ally) 66 | if ability_id in set([order.ability_id for order in u.orders])] 67 | 68 | def is_new_unit(self, unit): 69 | return unit.tag not in self._existed_tags 70 | 71 | @property 72 | def units(self): 73 | return self._units 74 | 75 | @property 76 | def combat_units(self): 77 | return self._combat_units 78 | 79 | @property 80 | def minerals(self): 81 | return [u for u in self._units 82 | if (u.unit_type == UNIT_TYPE.NEUTRAL_MINERALFIELD.value or 83 | u.unit_type == UNIT_TYPE.NEUTRAL_MINERALFIELD750.value)] 84 | 85 | @property 86 | def unexploited_minerals(self): 87 | self_bases = self.units_of_types([UNIT_TYPE.ZERG_HATCHERY.value, 88 | UNIT_TYPE.ZERG_LAIR.value, 89 | UNIT_TYPE.ZERG_HIVE.value]) 90 | enemy_bases = self.units_of_types([UNIT_TYPE.ZERG_HATCHERY.value, 91 | UNIT_TYPE.ZERG_LAIR.value, 92 | UNIT_TYPE.ZERG_HIVE.value], 93 | ALLY_TYPE.ENEMY.value) 94 | return [u for u in self.minerals 95 | if utils.closest_distance(u, self_bases + enemy_bases) > 15] 96 | 97 | @property 98 | def gas(self): 99 | return [u for u in self._units 100 | if u.unit_type == UNIT_TYPE.NEUTRAL_VESPENEGEYSER.value] 101 | 102 | @property 103 | def exploitable_gas(self): 104 | extractors = self.units_of_type(UNIT_TYPE.ZERG_EXTRACTOR.value) + \ 105 | self.units_of_type(UNIT_TYPE.ZERG_EXTRACTOR.value, ALLY_TYPE.ENEMY) 106 | bases = self.mature_units_of_types([UNIT_TYPE.ZERG_HATCHERY.value, 107 | UNIT_TYPE.ZERG_LAIR.value, 108 | UNIT_TYPE.ZERG_HIVE.value]) 109 | return [u for u in self.gas if (utils.closest_distance(u, bases) < 10 and 110 | utils.closest_distance(u, extractors) > 3)] 111 | 112 | @property 113 | def mineral_count(self): 114 | return self._player[PLAYER_FEATURE.MINERALS.value] 115 | 116 | @property 117 | def gas_count(self): 118 | return self._player[PLAYER_FEATURE.VESPENE.value] 119 | 120 | @property 121 | def supply_count(self): 122 | return self._player[PLAYER_FEATURE.FOOD_CAP.value] - \ 123 | self._player[PLAYER_FEATURE.FOOD_USED.value] 124 | 125 | @property 126 | def upgraded_techs(self): 127 | return set(self._raw_data.player.upgrade_ids) 128 | 129 | @property 130 | def init_base_pos(self): 131 | return self._init_base_pos 132 | -------------------------------------------------------------------------------- /sc2learner/bin/evaluate.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import sys 6 | import random 7 | 8 | from absl import app 9 | from absl import flags 10 | from absl import logging 11 | 12 | from sc2learner.envs.raw_env import SC2RawEnv 13 | from sc2learner.envs.actions.zerg_action_wrappers import ZergActionWrapper 14 | from sc2learner.envs.observations.zerg_observation_wrappers \ 15 | import ZergObservationWrapper 16 | from sc2learner.utils.utils import print_arguments 17 | from sc2learner.utils.utils import print_actions 18 | from sc2learner.utils.utils import print_action_distribution 19 | from sc2learner.agents.random_agent import RandomAgent 20 | from sc2learner.agents.keyboard_agent import KeyboardAgent 21 | 22 | 23 | FLAGS = flags.FLAGS 24 | flags.DEFINE_integer("num_episodes", 10, "Number of episodes to evaluate.") 25 | flags.DEFINE_enum("agent", 'ppo', ['ppo', 'dqn', 'random', 'keyboard'], 26 | "Agent name.") 27 | flags.DEFINE_enum("policy", 'mlp', ['mlp', 'lstm'], "Job type.") 28 | flags.DEFINE_string("game_version", '4.6', "Game core version.") 29 | flags.DEFINE_integer("step_mul", 32, "Game steps per agent step.") 30 | flags.DEFINE_enum("difficulty", '1', 31 | ['1', '2', '3', '4', '5', '6', '7', '8', '9', 'A'], 32 | "Bot's strength.") 33 | flags.DEFINE_string("model_path", None, "Filepath to load initial model.") 34 | flags.DEFINE_boolean("disable_fog", False, "Disable fog-of-war.") 35 | flags.DEFINE_boolean("use_all_combat_actions", False, "Use all combat actions.") 36 | flags.DEFINE_boolean("use_region_features", False, "Use region features") 37 | flags.DEFINE_boolean("use_action_mask", True, "Use action mask or not.") 38 | flags.FLAGS(sys.argv) 39 | 40 | 41 | def create_env(random_seed=None): 42 | env = SC2RawEnv(map_name='AbyssalReef', 43 | step_mul=FLAGS.step_mul, 44 | agent_race='zerg', 45 | bot_race='zerg', 46 | difficulty=FLAGS.difficulty, 47 | disable_fog=FLAGS.disable_fog, 48 | random_seed=random_seed) 49 | env = ZergActionWrapper(env, 50 | game_version=FLAGS.game_version, 51 | mask=FLAGS.use_action_mask, 52 | use_all_combat_actions=FLAGS.use_all_combat_actions) 53 | env = ZergObservationWrapper( 54 | env, 55 | use_spatial_features=False, 56 | use_game_progress=(not FLAGS.policy == 'lstm'), 57 | action_seq_len=1 if FLAGS.policy == 'lstm' else 8, 58 | use_regions=FLAGS.use_region_features) 59 | print_actions(env) 60 | return env 61 | 62 | 63 | def create_dqn_agent(env): 64 | from sc2learner.agents.dqn_agent import DQNAgent 65 | from sc2learner.agents.dqn_networks import NonspatialDuelingQNet 66 | 67 | assert FLAGS.policy == 'mlp' 68 | assert not FLAGS.use_action_mask 69 | network = NonspatialDuelingQNet(n_dims=env.observation_space.shape[0], 70 | n_out=env.action_space.n) 71 | agent = DQNAgent(network, env.action_space, FLAGS.model_path) 72 | return agent 73 | 74 | 75 | def create_ppo_agent(env): 76 | import tensorflow as tf 77 | import multiprocessing 78 | from sc2learner.agents.ppo_policies import LstmPolicy, MlpPolicy 79 | from sc2learner.agents.ppo_agent import PPOAgent 80 | 81 | ncpu = multiprocessing.cpu_count() 82 | if sys.platform == 'darwin': ncpu //= 2 83 | config = tf.ConfigProto(allow_soft_placement=True, 84 | intra_op_parallelism_threads=ncpu, 85 | inter_op_parallelism_threads=ncpu) 86 | config.gpu_options.allow_growth = True 87 | tf.Session(config=config).__enter__() 88 | 89 | policy = {'lstm': LstmPolicy, 'mlp': MlpPolicy}[FLAGS.policy] 90 | agent = PPOAgent(env=env, policy=policy, model_path=FLAGS.model_path) 91 | return agent 92 | 93 | 94 | def evaluate(): 95 | game_seed = random.randint(0, 2**32 - 1) 96 | print("Game Seed: %d" % game_seed) 97 | env = create_env(game_seed) 98 | 99 | if FLAGS.agent == 'ppo': 100 | agent = create_ppo_agent(env) 101 | elif FLAGS.agent == 'dqn': 102 | agent = create_dqn_agent(env) 103 | elif FLAGS.agent == 'random': 104 | agent = RandomAgent(action_space=env.action_space) 105 | elif FLAGS.agent == 'keyboard': 106 | agent = KeyboardAgent(action_space=env.action_space) 107 | else: 108 | raise NotImplementedError 109 | 110 | try: 111 | cum_return = 0.0 112 | action_counts = [0] * env.action_space.n 113 | for i in range(FLAGS.num_episodes): 114 | observation = env.reset() 115 | agent.reset() 116 | done, step_id = False, 0 117 | while not done: 118 | action = agent.act(observation) 119 | print("Step ID: %d Take Action: %d" % (step_id, action)) 120 | observation, reward, done, _ = env.step(action) 121 | action_counts[action] += 1 122 | cum_return += reward 123 | step_id += 1 124 | print_action_distribution(env, action_counts) 125 | print("Evaluated %d/%d Episodes Avg Return %f Avg Winning Rate %f" % ( 126 | i + 1, FLAGS.num_episodes, cum_return / (i + 1), 127 | ((cum_return / (i + 1)) + 1) / 2.0)) 128 | except KeyboardInterrupt: pass 129 | finally: env.close() 130 | 131 | 132 | def main(argv): 133 | logging.set_verbosity(logging.ERROR) 134 | print_arguments(FLAGS) 135 | evaluate() 136 | 137 | 138 | if __name__ == '__main__': 139 | app.run(main) 140 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SC2Learner (TStarBot1) - Macro-Action-Based StarCraft-II Reinforcement Learning Environment 2 | 3 | 4 |

5 | 6 |

7 | 8 | *[SC2Learner](https://github.com/Tencent/TStarBot1)* is a *macro-action*-based [StarCraft-II](https://en.wikipedia.org/wiki/StarCraft_II:_Wings_of_Liberty) reinforcement learning research platform. 9 | It exposes the re-designed StarCraft-II action space, which has more than one hundred discrete macro actions, based on the raw APIs exposed by DeepMind and Blizzard's [PySC2](https://github.com/deepmind/pysc2). 10 | The macro action space relieves the learning algorithms from a disastrous burden of directly handling a massive number of atomic keyboard and mouse operations, making learning more tractable. 11 | The environments and wrappers strictly follow the interface of [OpenAI Gym](https://github.com/openai/gym), making it easier to be adapted to many off-the-shelf reinforcement learning algorithms and implementations. 12 | 13 | [*TStartBot1*](https://arxiv.org/pdf/1809.07193.pdf), a reinforcement learning agent, is also released with two off-the-shelf reinforcement learning algorithms *Dueling Double Deep Q Network* (DDQN) and *Proximal Policy Optimization* (PPO), as examples. 14 | **Distributed** versions of both algorithms are released, enabling learners to scale up the rollout experience collection across thousands of CPU cores on a cluster of machines. 15 | *TStarBot1* is able to beat **level-9** built-in AI (cheating resources) with **97%** win-rate and **level-10** (cheating insane) with **81%** win-rate. 16 | 17 | A whitepaper of *TStarBots* is available at [here](https://arxiv.org/pdf/1809.07193.pdf). 18 | 19 | ## Table of Contents 20 | - [Installations](#installations) 21 | - [Getting Started](#getting-started) 22 | - [Run Random Agent](#run-random-agent) 23 | - [Train PPO Agent](#train-ppo-agent) 24 | - [Evaluate PPO Agent](#evaluate-ppo-agent) 25 | - [Play vs. PPO Agent](#play-vs.-ppo-agent) 26 | - [Train via Self-play](#train-via-self-play) 27 | - [Environments and Wrappers](environments-and-wrappers) 28 | - [Questions ans Help](questions-and-help) 29 | 30 | 31 | ## Installations 32 | 33 | ### Prerequisites 34 | - Python >= 3.5 required. 35 | - [PySC2 Extension](https://github.com/Tencent/PySC2TencentExtension) required. 36 | 37 | ### Setup 38 | Git clone this repository and then install it with 39 | ```bash 40 | pip3 install -e sc2learner 41 | ``` 42 | 43 | ## Getting Started 44 | 45 | ### Run Random Agent 46 | Run a random agent playing against a builtin AI of difficulty level 1. 47 | ```bash 48 | python3 -m sc2learner.bin.evaluate --agent random --difficulty '1' 49 | ``` 50 | 51 | ### Train PPO Agent 52 | 53 | To train an agent with PPO algorithm, actor workers and learner worker must be started respectively. 54 | They can run either locally or across separate machines (e.g. actors usually run in a CPU cluster consisting of hundreds of machines with tens of thousands of CPU cores, and a learner runs in a GPU machine). 55 | With the designated ports and learner's IP, rollout trajectories and model parameters are communicated between actors and learner. 56 | - Start 48 actor workers (run the same script in all actor machines) 57 | ```bash 58 | for i in $(seq 0 47); do 59 | CUDA_VISIBLE_DEVICES= python3 -m sc2learner.bin.train_ppo --job_name=actor --learner_ip localhost & 60 | done; 61 | ``` 62 | 63 | - Start a learner worker 64 | ```bash 65 | CUDA_VISIBLE_DEVICES=0 python3 -m sc2learner.bin.train_ppo --job_name learner 66 | ``` 67 | 68 | Similarly, DQN algorithm can be tried with `sc2learner.bin.train_dqn`. 69 | 70 | ### Evaluate PPO Agent 71 | After training, the agent's in-game performance can be observed by letting it play the game against a build-in AI of a certain difficulty level. 72 | Win-rate is also estimated meanwhile with multiple such games initialized with different game seeds. 73 | ```bash 74 | python3 -m sc2learner.bin.evaluate --agent ppo --difficulty 1 --model_path REPLACE_WITH_YOUR_OWN_MODLE_PATH 75 | ``` 76 | ### 77 | 78 | ### Play vs. PPO Agent 79 | We can also try ourselves playing against the learned agent by first starting a human player client and then a learned agent. 80 | They can run either locally or remotely. 81 | When run across two machines, `--remote` argument needs to be set for the human player side to create an SSH tunnel to the remote agent's machine and ssh keys must be used for authentication. 82 | 83 | - Start a human player client 84 | ```bash 85 | CUDA_VISIBLE_DEVICES= python3 -m pysc2.bin.play_vs_agent --human --map AbyssalReef --user_race zerg 86 | ``` 87 | 88 | - Start a PPO agent 89 | ```bash 90 | python3 -m sc2learner.bin.play_vs_ppo_agent --model_path REPLACE_WITH_YOUR_OWN_MODLE_PATH 91 | ``` 92 | 93 | ### Train via Self-play 94 | 95 | Besides, a self-play training (playing vs. past versions) is also provided to make learning more diversified strategies possible. 96 | 97 | - Start Actors 98 | ```bash 99 | for i in $(seq 0 48); do 100 | CUDA_VISIBLE_DEVICES= python3 -m sc2learner.bin.train_ppo_selfplay --job_name=actor --learner_ip localhost & 101 | done; 102 | ``` 103 | 104 | - Start Learner 105 | ```bash 106 | CUDA_VISIBLE_DEVICES=0 python3 -m sc2learner.bin.train_ppo_selfplay --job_name learner 107 | ``` 108 | 109 | ## Environments and Wrappers 110 | 111 | The environments and wrappers strictly follow the interface of [OpenAI Gym](https://github.com/openai/gym). 112 | The macro action space is defined in [`ZergActionWrapper`](https://github.com/Tencent/TStarBot1/blob/dev-open/sc2learner/envs/actions/zerg_action_wrappers.py#L26) and the observation space defined in [`ZergObservationWrapper`](https://github.com/Tencent/TStarBot1/blob/dev-open/sc2learner/envs/observations/zerg_observation_wrappers.py#L24), based on which users can easily make their own changes and restart the training to see what happens. 113 | 114 | ## Questions and Help 115 | You are welcome to submit questions and bug reports in [Github Issues](https://github.com/Tencent/TStarBot1/issues). 116 | You are also welcome to contribute to this project. 117 | -------------------------------------------------------------------------------- /sc2learner/envs/common/const.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | from enum import Enum 6 | from enum import unique 7 | from collections import namedtuple 8 | 9 | from pysc2.lib.typeenums import UNIT_TYPEID as UNIT_TYPE 10 | 11 | 12 | @unique 13 | class ALLY_TYPE(Enum): 14 | SELF = 1 15 | ALLY = 2 16 | NEUTRAL = 3 17 | ENEMY = 4 18 | 19 | 20 | @unique 21 | class PLAYER_FEATURE(Enum): 22 | PLAYER_ID = 0 23 | MINERALS = 1 24 | VESPENE = 2 25 | FOOD_USED = 3 26 | FOOD_CAP = 4 27 | FOOD_ARMY = 5 28 | FOOD_WORKER = 6 29 | IDLE_WORKER_COUNT = 7 30 | ARMY_COUNT = 8 31 | WARP_GATE_COUNT = 9 32 | LARVA_COUNT = 10 33 | 34 | 35 | NEUTRAL_DESTRUCTABLEROCKEX14X4 = 638 36 | NEUTRAL_UNBUILDABLEROCKSDESTRUCIBLE = 472 37 | 38 | 39 | PLACE_COLLISION_BUILDINGS = { 40 | UNIT_TYPE.NEUTRAL_MINERALFIELD.value, 41 | UNIT_TYPE.NEUTRAL_MINERALFIELD750.value, 42 | UNIT_TYPE.NEUTRAL_VESPENEGEYSER.value, 43 | UNIT_TYPE.NEUTRAL_DESTRUCTIBLEROCK6X6.value, 44 | UNIT_TYPE.NEUTRAL_DESTRUCTIBLEROCKEX1DIAGONALHUGEBLUR.value, 45 | NEUTRAL_DESTRUCTABLEROCKEX14X4, 46 | NEUTRAL_UNBUILDABLEROCKSDESTRUCIBLE, 47 | UNIT_TYPE.ZERG_EXTRACTOR.value, 48 | UNIT_TYPE.ZERG_SPAWNINGPOOL.value, 49 | UNIT_TYPE.ZERG_ROACHWARREN.value, 50 | UNIT_TYPE.ZERG_HYDRALISKDEN.value, 51 | UNIT_TYPE.ZERG_HATCHERY.value, 52 | UNIT_TYPE.ZERG_EVOLUTIONCHAMBER.value, 53 | UNIT_TYPE.ZERG_BANELINGNEST.value, 54 | UNIT_TYPE.ZERG_INFESTATIONPIT.value, 55 | UNIT_TYPE.ZERG_SPIRE.value, 56 | UNIT_TYPE.ZERG_ULTRALISKCAVERN.value, 57 | UNIT_TYPE.ZERG_NYDUSNETWORK.value, 58 | UNIT_TYPE.ZERG_SPINECRAWLER.value, 59 | UNIT_TYPE.ZERG_SPORECRAWLER.value, 60 | UNIT_TYPE.ZERG_LURKERDENMP.value, 61 | UNIT_TYPE.ZERG_LAIR.value, 62 | UNIT_TYPE.ZERG_HIVE.value, 63 | UNIT_TYPE.ZERG_GREATERSPIRE.value 64 | } 65 | 66 | 67 | MAXIMUM_NUM = { 68 | UNIT_TYPE.ZERG_SPAWNINGPOOL.value: 1, 69 | UNIT_TYPE.ZERG_ROACHWARREN.value: 1, 70 | UNIT_TYPE.ZERG_HYDRALISKDEN.value: 1, 71 | UNIT_TYPE.ZERG_HATCHERY.value: 6, 72 | UNIT_TYPE.ZERG_EVOLUTIONCHAMBER.value: 2, 73 | UNIT_TYPE.ZERG_BANELINGNEST.value: 1, 74 | UNIT_TYPE.ZERG_INFESTATIONPIT.value: 1, 75 | UNIT_TYPE.ZERG_SPIRE.value: 1, 76 | UNIT_TYPE.ZERG_ULTRALISKCAVERN.value: 1, 77 | UNIT_TYPE.ZERG_NYDUSNETWORK.value: 1, 78 | UNIT_TYPE.ZERG_LURKERDENMP.value: 1, 79 | UNIT_TYPE.ZERG_LAIR.value: 1, 80 | UNIT_TYPE.ZERG_HIVE.value: 1, 81 | UNIT_TYPE.ZERG_GREATERSPIRE.value: 1 82 | } 83 | 84 | 85 | AttackAttr = namedtuple('AttackAttr', ('can_attack_ground', 'can_attack_air')) 86 | 87 | 88 | ATTACK_FORCE = { 89 | UNIT_TYPE.ZERG_LARVA.value: AttackAttr(False, False), 90 | UNIT_TYPE.ZERG_DRONE.value: AttackAttr(True, False), 91 | UNIT_TYPE.ZERG_ZERGLING.value: AttackAttr(True, False), 92 | UNIT_TYPE.ZERG_BANELING.value: AttackAttr(True, False), 93 | UNIT_TYPE.ZERG_ROACH.value: AttackAttr(True, False), 94 | UNIT_TYPE.ZERG_ROACHBURROWED.value: AttackAttr(True, False), 95 | UNIT_TYPE.ZERG_RAVAGER.value: AttackAttr(True, False), 96 | UNIT_TYPE.ZERG_HYDRALISK.value: AttackAttr(True, True), 97 | UNIT_TYPE.ZERG_LURKERMP.value: AttackAttr(True, False), 98 | UNIT_TYPE.ZERG_LURKERMPBURROWED.value: AttackAttr(True, False), 99 | UNIT_TYPE.ZERG_VIPER.value: AttackAttr(False, False), 100 | UNIT_TYPE.ZERG_MUTALISK.value: AttackAttr(True, True), 101 | UNIT_TYPE.ZERG_CORRUPTOR.value: AttackAttr(False, True), 102 | UNIT_TYPE.ZERG_BROODLORD.value: AttackAttr(True, False), 103 | UNIT_TYPE.ZERG_SWARMHOSTMP.value: AttackAttr(False, False), 104 | UNIT_TYPE.ZERG_LOCUSTMP.value: AttackAttr(True, False), 105 | UNIT_TYPE.ZERG_INFESTOR.value: AttackAttr(False, False), 106 | UNIT_TYPE.ZERG_ULTRALISK.value: AttackAttr(True, False), 107 | UNIT_TYPE.ZERG_BROODLING.value: AttackAttr(True, False), 108 | UNIT_TYPE.ZERG_OVERLORD.value: AttackAttr(False, False), 109 | UNIT_TYPE.ZERG_OVERSEER.value: AttackAttr(False, False), 110 | UNIT_TYPE.ZERG_QUEEN.value: AttackAttr(True, True), 111 | UNIT_TYPE.ZERG_CHANGELING.value: AttackAttr(False, False), 112 | UNIT_TYPE.ZERG_SPINECRAWLER.value: AttackAttr(True, False), 113 | UNIT_TYPE.ZERG_SPORECRAWLER.value: AttackAttr(False, True), 114 | UNIT_TYPE.ZERG_NYDUSCANAL.value: AttackAttr(False, False) 115 | } 116 | 117 | 118 | PRIORITIZED_ATTACK = set([ 119 | UNIT_TYPE.ZERG_ZERGLING.value, 120 | UNIT_TYPE.ZERG_BANELING.value, 121 | UNIT_TYPE.ZERG_ROACH.value, 122 | UNIT_TYPE.ZERG_ROACHBURROWED.value, 123 | UNIT_TYPE.ZERG_RAVAGER.value, 124 | UNIT_TYPE.ZERG_HYDRALISK.value, 125 | UNIT_TYPE.ZERG_LURKERMP.value, 126 | UNIT_TYPE.ZERG_LURKERMPBURROWED.value, 127 | UNIT_TYPE.ZERG_VIPER.value, 128 | UNIT_TYPE.ZERG_MUTALISK.value, 129 | UNIT_TYPE.ZERG_CORRUPTOR.value, 130 | UNIT_TYPE.ZERG_BROODLORD.value, 131 | UNIT_TYPE.ZERG_SWARMHOSTMP.value, 132 | UNIT_TYPE.ZERG_LOCUSTMP.value, 133 | UNIT_TYPE.ZERG_INFESTOR.value, 134 | UNIT_TYPE.ZERG_ULTRALISK.value, 135 | UNIT_TYPE.ZERG_BROODLING.value, 136 | UNIT_TYPE.ZERG_QUEEN.value, 137 | UNIT_TYPE.ZERG_CHANGELING.value, 138 | UNIT_TYPE.ZERG_SPINECRAWLER.value, 139 | UNIT_TYPE.ZERG_SPORECRAWLER.value 140 | ]) 141 | 142 | 143 | COMBAT_TYPES = { 144 | UNIT_TYPE.ZERG_ZERGLING.value, 145 | UNIT_TYPE.ZERG_BANELING.value, 146 | UNIT_TYPE.ZERG_ROACH.value, 147 | UNIT_TYPE.ZERG_ROACHBURROWED.value, 148 | UNIT_TYPE.ZERG_RAVAGER.value, 149 | UNIT_TYPE.ZERG_HYDRALISK.value, 150 | UNIT_TYPE.ZERG_LURKERMP.value, 151 | UNIT_TYPE.ZERG_LURKERMPBURROWED.value, 152 | UNIT_TYPE.ZERG_MUTALISK.value, 153 | UNIT_TYPE.ZERG_CORRUPTOR.value, 154 | UNIT_TYPE.ZERG_BROODLORD.value, 155 | UNIT_TYPE.ZERG_ULTRALISK.value 156 | #UNIT_TYPE.ZERG_LOCUSTMP.value, 157 | #UNIT_TYPE.ZERG_BROODLING.value 158 | } 159 | 160 | 161 | class MAP(object): 162 | WIDTH = 200.0 163 | HEIGHT = 176.0 164 | LEFT = 24.0 165 | RIGHT = 24.0 166 | TOP = 37.0 167 | BOTTOM = 4.0 168 | -------------------------------------------------------------------------------- /sc2learner/envs/rewards/reward_wrappers.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import numpy as np 6 | import gym 7 | from pysc2.lib.typeenums import UNIT_TYPEID as UNIT_TYPE 8 | 9 | from sc2learner.envs.spaces.pysc2_raw import PySC2RawObservation 10 | from sc2learner.envs.common.const import ALLY_TYPE 11 | 12 | 13 | class RewardShapingWrapperV1(gym.Wrapper): 14 | 15 | def __init__(self, env): 16 | super(RewardShapingWrapperV1, self).__init__(env) 17 | assert isinstance(env.observation_space, PySC2RawObservation) 18 | self._combat_unit_types = set([UNIT_TYPE.ZERG_ZERGLING.value, 19 | UNIT_TYPE.ZERG_ROACH.value, 20 | UNIT_TYPE.ZERG_HYDRALISK.value]) 21 | self.reward_range = (-np.inf, np.inf) 22 | 23 | def step(self, action): 24 | observation, outcome, done, info = self.env.step(action) 25 | n_enemies, n_self_combats = self._get_unit_counts(observation) 26 | if n_self_combats - n_enemies > self._n_self_combats - self._n_enemies: 27 | reward = 1 28 | elif n_self_combats - n_enemies < self._n_self_combats - self._n_enemies: 29 | reward = -1 30 | else: 31 | reward = 0 32 | if not done: reward += outcome * 10 33 | else: reward = outcome * 10 34 | self._n_enemies = n_enemies 35 | self._n_self_combats = n_self_combats 36 | return observation, reward, done, info 37 | 38 | def reset(self, **kwargs): 39 | observation = self.env.reset() 40 | self._n_enemies, self._n_self_combats = self._get_unit_counts(observation) 41 | return observation 42 | 43 | @property 44 | def action_names(self): 45 | if not hasattr(self.env, 'action_names'): raise NotImplementedError 46 | return self.env.action_names 47 | 48 | @property 49 | def player_position(self): 50 | if not hasattr(self.env, 'player_position'): raise NotImplementedError 51 | return self.env.player_position 52 | 53 | def _get_unit_counts(self, observation): 54 | num_enemy_units, num_self_combat_units = 0, 0 55 | for u in observation['units']: 56 | if u.int_attr.alliance == ALLY_TYPE.ENEMY.value: 57 | num_enemy_units += 1 58 | elif u.int_attr.alliance == ALLY_TYPE.SELF.value: 59 | if u.unit_type in self._combat_unit_types: 60 | num_self_combat_units += 1 61 | return num_enemy_units, num_self_combat_units 62 | 63 | 64 | class RewardShapingWrapperV2(gym.Wrapper): 65 | 66 | def __init__(self, env): 67 | super(RewardShapingWrapperV2, self).__init__(env) 68 | assert isinstance(env.observation_space, PySC2RawObservation) 69 | self._combat_unit_types = set([UNIT_TYPE.ZERG_ZERGLING.value, 70 | UNIT_TYPE.ZERG_ROACH.value, 71 | UNIT_TYPE.ZERG_HYDRALISK.value, 72 | UNIT_TYPE.ZERG_RAVAGER.value, 73 | UNIT_TYPE.ZERG_BANELING.value, 74 | UNIT_TYPE.ZERG_BROODLING.value]) 75 | self.reward_range = (-np.inf, np.inf) 76 | 77 | def step(self, action): 78 | observation, reward, done, info = self.env.step(action) 79 | n_enemies, n_selves = self._get_unit_counts(observation) 80 | diff_selves = n_selves - self._n_selves 81 | diff_enemies = n_enemies - self._n_enemies 82 | if not done: reward += (diff_selves - diff_enemies) * 0.02 83 | self._n_enemies = n_enemies 84 | self._n_selves = n_selves 85 | return observation, reward, done, info 86 | 87 | def reset(self, **kwargs): 88 | observation = self.env.reset() 89 | self._n_enemies, self._n_selves = self._get_unit_counts(observation) 90 | return observation 91 | 92 | @property 93 | def action_names(self): 94 | if not hasattr(self.env, 'action_names'): raise NotImplementedError 95 | return self.env.action_names 96 | 97 | @property 98 | def player_position(self): 99 | if not hasattr(self.env, 'player_position'): raise NotImplementedError 100 | return self.env.player_position 101 | 102 | def _get_unit_counts(self, observation): 103 | num_enemy_units, num_self_units = 0, 0 104 | for u in observation['units']: 105 | if u.int_attr.alliance == ALLY_TYPE.ENEMY.value: 106 | if u.unit_type in self._combat_unit_types: 107 | num_enemy_units += 1 108 | elif u.int_attr.alliance == ALLY_TYPE.SELF.value: 109 | if u.unit_type in self._combat_unit_types: 110 | num_self_units += 1 111 | return num_enemy_units, num_self_units 112 | 113 | 114 | class KillingRewardWrapper(gym.Wrapper): 115 | 116 | def __init__(self, env): 117 | super(KillingRewardWrapper, self).__init__(env) 118 | assert isinstance(env.observation_space, PySC2RawObservation) 119 | self.reward_range = (-np.inf, np.inf) 120 | self._last_kill_value = 0 121 | 122 | def step(self, action): 123 | observation, reward, done, info = self.env.step(action) 124 | kill_value = observation.score_cumulative[5] + \ 125 | observation.score_cumulative[6] 126 | if not done: 127 | reward += (kill_value - self._last_kill_value) * 1e-5 128 | self._last_kill_value = kill_value 129 | return observation, reward, done, info 130 | 131 | def reset(self): 132 | observation = self.env.reset() 133 | kill_value = observation.score_cumulative[5] + \ 134 | observation.score_cumulative[6] 135 | self._last_kill_value = kill_value 136 | return observation 137 | 138 | @property 139 | def action_names(self): 140 | if not hasattr(self.env, 'action_names'): raise NotImplementedError 141 | return self.env.action_names 142 | 143 | @property 144 | def player_position(self): 145 | if not hasattr(self.env, 'player_position'): raise NotImplementedError 146 | return self.env.player_position 147 | observation = self.env.reset() 148 | return observation 149 | 150 | @property 151 | def action_names(self): 152 | if not hasattr(self.env, 'action_names'): raise NotImplementedError 153 | return self.env.action_names 154 | 155 | @property 156 | def player_position(self): 157 | if not hasattr(self.env, 'player_position'): raise NotImplementedError 158 | return self.env.player_position 159 | -------------------------------------------------------------------------------- /sc2learner/bin/train_dqn.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import sys 6 | import os 7 | import random 8 | import time 9 | 10 | import torch 11 | from absl import app 12 | from absl import flags 13 | from absl import logging 14 | 15 | from sc2learner.envs.raw_env import SC2RawEnv 16 | from sc2learner.envs.actions.zerg_action_wrappers import ZergActionWrapper 17 | from sc2learner.envs.observations.zerg_observation_wrappers \ 18 | import ZergObservationWrapper 19 | from sc2learner.agents.dqn_agent import DQNActor 20 | from sc2learner.agents.dqn_agent import DQNLearner 21 | from sc2learner.agents.dqn_networks import NonspatialDuelingQNet 22 | from sc2learner.utils.utils import print_arguments 23 | 24 | 25 | FLAGS = flags.FLAGS 26 | flags.DEFINE_enum("job_name", 'actor', ['actor', 'learner'], "Job type.") 27 | flags.DEFINE_string("learner_ip", "localhost", "Learner IP address.") 28 | flags.DEFINE_string("ports", "5700,5701,5702", 29 | "3 ports for distributed replay memory.") 30 | flags.DEFINE_integer("client_memory_size", 50000, 31 | "Total size of client memory.") 32 | flags.DEFINE_integer("client_memory_warmup_size", 2000, 33 | "Memory warmup size for client.") 34 | flags.DEFINE_integer("server_memory_size", 1000000, 35 | "Total size of server memory.") 36 | flags.DEFINE_integer("server_memory_warmup_size", 100000, 37 | "Memory warmup size for client.") 38 | flags.DEFINE_string("game_version", '4.6', "Game core version.") 39 | flags.DEFINE_float("discount", 0.995, "Discount factor.") 40 | flags.DEFINE_float("send_freq", 4.0, "Probability of a step being pushed.") 41 | flags.DEFINE_integer("step_mul", 32, "Game steps per agent step.") 42 | flags.DEFINE_string("difficulties", '1,2,4,6,9,A', "Bot's strengths.") 43 | flags.DEFINE_float("eps_start", 1.0, "Max greedy epsilon for exploration.") 44 | flags.DEFINE_float("eps_end", 0.1, "Min greedy epsilon for exploration.") 45 | flags.DEFINE_integer("eps_decay_steps", 1000000, "Greedy epsilon decay step.") 46 | flags.DEFINE_integer("eps_decay_steps2", 10000000, "Greedy epsilon decay step.") 47 | flags.DEFINE_float("learning_rate", 1e-6, "Learning rate.") 48 | flags.DEFINE_float("adam_eps", 1e-7, "Adam optimizer's epsilon.") 49 | flags.DEFINE_float("gradient_clipping", 10.0, "Gradient clipping threshold.") 50 | flags.DEFINE_integer("batch_size", 256, "Batch size.") 51 | flags.DEFINE_float("mmc_beta", 0.9, "Discount.") 52 | flags.DEFINE_integer("target_update_interval", 10000, 53 | "Target net update interval.") 54 | flags.DEFINE_string("init_model_path", None, "Checkpoint to initialize model.") 55 | flags.DEFINE_string("checkpoint_dir", "./checkpoints", "Dir to save models to") 56 | flags.DEFINE_integer("checkpoint_interval", 500000, "Model saving frequency.") 57 | flags.DEFINE_integer("print_interval", 10000, "Print train cost frequency.") 58 | flags.DEFINE_boolean("disable_fog", False, "Disable fog-of-war.") 59 | flags.DEFINE_boolean("use_all_combat_actions", False, "Use all combat actions.") 60 | flags.DEFINE_boolean("use_region_features", True, "Use region features") 61 | flags.FLAGS(sys.argv) 62 | 63 | 64 | def create_env(difficulty, random_seed=None): 65 | env = SC2RawEnv(map_name='AbyssalReef', 66 | step_mul=FLAGS.step_mul, 67 | resolution=16, 68 | agent_race='zerg', 69 | bot_race='zerg', 70 | difficulty=difficulty, 71 | disable_fog=FLAGS.disable_fog, 72 | random_seed=random_seed) 73 | env = ZergActionWrapper(env, 74 | game_version=FLAGS.game_version, 75 | mask=False, 76 | use_all_combat_actions=FLAGS.use_all_combat_actions) 77 | env = ZergObservationWrapper(env, 78 | use_spatial_features=False, 79 | use_regions=FLAGS.use_region_features) 80 | return env 81 | 82 | 83 | def create_network(env): 84 | return NonspatialDuelingQNet(n_dims=env.observation_space.shape[0], 85 | n_out=env.action_space.n) 86 | 87 | 88 | def start_actor_job(): 89 | random.seed(time.time()) 90 | difficulty = random.choice(FLAGS.difficulties.split(',')) 91 | game_seed = random.randint(0, 2**32 - 1) 92 | print("Game Seed: %d Difficulty: %s" % (game_seed, difficulty)) 93 | env = create_env(difficulty, game_seed) 94 | network = create_network(env) 95 | actor = DQNActor(memory_size=FLAGS.client_memory_size, 96 | memory_warmup_size=FLAGS.client_memory_warmup_size, 97 | env=env, 98 | network=network, 99 | discount=FLAGS.discount, 100 | send_freq=FLAGS.send_freq, 101 | ports=FLAGS.ports.split(','), 102 | learner_ip=FLAGS.learner_ip) 103 | actor.run() 104 | env.close() 105 | 106 | 107 | def start_learner_job(): 108 | if not os.path.exists(FLAGS.checkpoint_dir): 109 | os.makedirs(FLAGS.checkpoint_dir) 110 | 111 | env = create_env('1', 0) 112 | network = create_network(env) 113 | learner = DQNLearner(network=network, 114 | action_space=env.action_space, 115 | memory_size=FLAGS.server_memory_size, 116 | memory_warmup_size=FLAGS.server_memory_warmup_size, 117 | discount=FLAGS.discount, 118 | eps_start=FLAGS.eps_start, 119 | eps_end=FLAGS.eps_end, 120 | eps_decay_steps=FLAGS.eps_decay_steps, 121 | eps_decay_steps2=FLAGS.eps_decay_steps2, 122 | batch_size=FLAGS.batch_size, 123 | mmc_beta=FLAGS.mmc_beta, 124 | gradient_clipping=FLAGS.gradient_clipping, 125 | adam_eps=FLAGS.adam_eps, 126 | learning_rate=FLAGS.learning_rate, 127 | target_update_interval=FLAGS.target_update_interval, 128 | checkpoint_dir=FLAGS.checkpoint_dir, 129 | checkpoint_interval=FLAGS.checkpoint_interval, 130 | print_interval=FLAGS.print_interval, 131 | ports=FLAGS.ports.split(','), 132 | init_model_path=FLAGS.init_model_path) 133 | learner.run() 134 | env.close() 135 | 136 | 137 | def main(argv): 138 | logging.set_verbosity(logging.ERROR) 139 | print_arguments(FLAGS) 140 | if FLAGS.job_name == 'actor': start_actor_job() 141 | else: start_learner_job() 142 | 143 | 144 | if __name__ == '__main__': 145 | app.run(main) 146 | -------------------------------------------------------------------------------- /sc2learner/bin/train_ppo.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import sys 6 | from threading import Thread 7 | import os 8 | import multiprocessing 9 | import random 10 | import time 11 | 12 | from absl import app 13 | from absl import flags 14 | from absl import logging 15 | import tensorflow as tf 16 | 17 | from sc2learner.agents.ppo_policies import LstmPolicy, MlpPolicy 18 | from sc2learner.agents.ppo_agent import PPOActor, PPOLearner 19 | from sc2learner.envs.raw_env import SC2RawEnv 20 | from sc2learner.envs.rewards.reward_wrappers import KillingRewardWrapper 21 | from sc2learner.envs.actions.zerg_action_wrappers import ZergActionWrapper 22 | from sc2learner.envs.observations.zerg_observation_wrappers \ 23 | import ZergObservationWrapper 24 | from sc2learner.utils.utils import print_arguments 25 | 26 | 27 | FLAGS = flags.FLAGS 28 | flags.DEFINE_enum("job_name", 'actor', ['actor', 'learner'], "Job type.") 29 | flags.DEFINE_enum("policy", 'mlp', ['mlp', 'lstm'], "Job type.") 30 | flags.DEFINE_integer("unroll_length", 128, "Length of rollout steps.") 31 | flags.DEFINE_string("learner_ip", "localhost", "Learner IP address.") 32 | flags.DEFINE_string("port_A", "5700", "Port for transporting model.") 33 | flags.DEFINE_string("port_B", "5701", "Port for transporting data.") 34 | flags.DEFINE_string("game_version", '4.6', "Game core version.") 35 | flags.DEFINE_float("discount_gamma", 0.998, "Discount factor.") 36 | flags.DEFINE_float("lambda_return", 0.95, "Lambda return factor.") 37 | flags.DEFINE_float("clip_range", 0.1, "Clip range for PPO.") 38 | flags.DEFINE_float("ent_coef", 0.01, "Coefficient for the entropy term.") 39 | flags.DEFINE_float("vf_coef", 0.5, "Coefficient for the value loss.") 40 | flags.DEFINE_float("learn_act_speed_ratio", 0, "Maximum learner/actor ratio.") 41 | flags.DEFINE_integer("batch_size", 32, "Batch size.") 42 | flags.DEFINE_integer("game_steps_per_episode", 43200, "Maximum steps per episode.") 43 | flags.DEFINE_integer("learner_queue_size", 1024, "Size of learner's unroll queue.") 44 | flags.DEFINE_integer("step_mul", 32, "Game steps per agent step.") 45 | flags.DEFINE_string("difficulties", '1,2,4,6,9,A', "Bot's strengths.") 46 | flags.DEFINE_float("learning_rate", 1e-5, "Learning rate.") 47 | flags.DEFINE_string("init_model_path", None, "Initial model path.") 48 | flags.DEFINE_string("save_dir", "./checkpoints/", "Dir to save models to") 49 | flags.DEFINE_integer("save_interval", 50000, "Model saving frequency.") 50 | flags.DEFINE_integer("print_interval", 1000, "Print train cost frequency.") 51 | flags.DEFINE_boolean("disable_fog", False, "Disable fog-of-war.") 52 | flags.DEFINE_boolean("use_all_combat_actions", False, "Use all combat actions.") 53 | flags.DEFINE_boolean("use_region_features", False, "Use region features") 54 | flags.DEFINE_boolean("use_action_mask", True, "Use region-wise combat.") 55 | flags.DEFINE_boolean("use_reward_shaping", False, "Use reward shaping.") 56 | flags.FLAGS(sys.argv) 57 | 58 | 59 | def tf_config(ncpu=None): 60 | if ncpu is None: 61 | ncpu = multiprocessing.cpu_count() 62 | if sys.platform == 'darwin': ncpu //= 2 63 | config = tf.ConfigProto(allow_soft_placement=True, 64 | intra_op_parallelism_threads=ncpu, 65 | inter_op_parallelism_threads=ncpu) 66 | config.gpu_options.allow_growth = True 67 | tf.Session(config=config).__enter__() 68 | 69 | 70 | def create_env(difficulty, random_seed=None): 71 | env = SC2RawEnv(map_name='AbyssalReef', 72 | step_mul=FLAGS.step_mul, 73 | resolution=16, 74 | agent_race='zerg', 75 | bot_race='zerg', 76 | difficulty=difficulty, 77 | disable_fog=FLAGS.disable_fog, 78 | tie_to_lose=False, 79 | game_steps_per_episode=FLAGS.game_steps_per_episode, 80 | random_seed=random_seed) 81 | if FLAGS.use_reward_shaping: env = KillingRewardWrapper(env) 82 | env = ZergActionWrapper(env, 83 | game_version=FLAGS.game_version, 84 | mask=FLAGS.use_action_mask, 85 | use_all_combat_actions=FLAGS.use_all_combat_actions) 86 | env = ZergObservationWrapper(env, 87 | use_spatial_features=False, 88 | use_game_progress=(not FLAGS.policy == 'lstm'), 89 | action_seq_len=1 if FLAGS.policy == 'lstm' else 8, 90 | use_regions=FLAGS.use_region_features) 91 | print(env.observation_space, env.action_space) 92 | return env 93 | 94 | 95 | def start_actor(): 96 | tf_config(ncpu=2) 97 | random.seed(time.time()) 98 | difficulty = random.choice(FLAGS.difficulties.split(',')) 99 | game_seed = random.randint(0, 2**32 - 1) 100 | print("Game Seed: %d Difficulty: %s" % (game_seed, difficulty)) 101 | env = create_env(difficulty, game_seed) 102 | policy = {'lstm': LstmPolicy, 103 | 'mlp': MlpPolicy}[FLAGS.policy] 104 | actor = PPOActor(env=env, 105 | policy=policy, 106 | unroll_length=FLAGS.unroll_length, 107 | gamma=FLAGS.discount_gamma, 108 | lam=FLAGS.lambda_return, 109 | learner_ip=FLAGS.learner_ip, 110 | port_A=FLAGS.port_A, 111 | port_B=FLAGS.port_B) 112 | actor.run() 113 | env.close() 114 | 115 | 116 | def start_learner(): 117 | tf_config() 118 | env = create_env('1', 0) 119 | policy = {'lstm': LstmPolicy, 120 | 'mlp': MlpPolicy}[FLAGS.policy] 121 | learner = PPOLearner(env=env, 122 | policy=policy, 123 | unroll_length=FLAGS.unroll_length, 124 | lr=FLAGS.learning_rate, 125 | clip_range=FLAGS.clip_range, 126 | batch_size=FLAGS.batch_size, 127 | ent_coef=FLAGS.ent_coef, 128 | vf_coef=FLAGS.vf_coef, 129 | max_grad_norm=0.5, 130 | queue_size=FLAGS.learner_queue_size, 131 | print_interval=FLAGS.print_interval, 132 | save_interval=FLAGS.save_interval, 133 | learn_act_speed_ratio=FLAGS.learn_act_speed_ratio, 134 | save_dir=FLAGS.save_dir, 135 | init_model_path=FLAGS.init_model_path, 136 | port_A=FLAGS.port_A, 137 | port_B=FLAGS.port_B) 138 | learner.run() 139 | env.close() 140 | 141 | 142 | def main(argv): 143 | logging.set_verbosity(logging.ERROR) 144 | print_arguments(FLAGS) 145 | if FLAGS.job_name == 'actor': start_actor() 146 | else: start_learner() 147 | 148 | 149 | if __name__ == '__main__': 150 | app.run(main) 151 | -------------------------------------------------------------------------------- /sc2learner/envs/actions/placer.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import random 6 | import math 7 | 8 | import numpy as np 9 | from pysc2.lib.typeenums import UNIT_TYPEID as UNIT_TYPE 10 | 11 | import sc2learner.envs.common.utils as utils 12 | from sc2learner.envs.common.const import PLACE_COLLISION_BUILDINGS 13 | 14 | 15 | class Placer(object): 16 | 17 | def get_building_position(self, type_id, dc): 18 | if type_id == UNIT_TYPE.ZERG_HATCHERY.value: 19 | return self._next_base_place(dc) 20 | elif type_id == UNIT_TYPE.ZERG_EXTRACTOR.value: 21 | gas = dc.exploitable_gas 22 | return random.choice(gas) if len(gas) > 0 else None 23 | else: 24 | place = self._constructable_place(1.5, dc) 25 | return random.choice(place) if len(place) > 0 else None 26 | 27 | def can_build(self, type_id, dc): 28 | if type_id == UNIT_TYPE.ZERG_HATCHERY.value: 29 | return self._next_base_place(dc) is not None 30 | elif type_id == UNIT_TYPE.ZERG_EXTRACTOR.value: 31 | return len(dc.exploitable_gas) > 0 32 | else: 33 | place = self._constructable_place(1.5, dc) 34 | return len(place) > 0 35 | 36 | def _constructable_place(self, margin, dc): 37 | place = [] 38 | bases = dc.mature_units_of_types([UNIT_TYPE.ZERG_HATCHERY.value, 39 | UNIT_TYPE.ZERG_LAIR.value, 40 | UNIT_TYPE.ZERG_HIVE.value]) 41 | for base in bases: 42 | search_region = (base.float_attr.pos_x - 10.5, 43 | base.float_attr.pos_y - 10.5, 44 | 10.5 * 2, 45 | 10.5 * 2) 46 | place.extend(self._search_place(search_region, dc, margin=margin, 47 | remove_corner=True, expand_mineral=True)) 48 | return place 49 | 50 | def _next_base_place(self, dc): 51 | unexploited_minerals = dc.unexploited_minerals 52 | if len(unexploited_minerals) == 0: return None 53 | mineral_to_exploit = utils.closest_unit(dc.init_base_pos, 54 | unexploited_minerals) 55 | resources_nearby = utils.units_nearby(mineral_to_exploit, 56 | dc.minerals + dc.gas, 57 | max_distance=14) 58 | x_list = [u.float_attr.pos_x for u in resources_nearby] 59 | y_list = [u.float_attr.pos_y for u in resources_nearby] 60 | x_mean = sum(x_list) / len(x_list) 61 | y_mean = sum(y_list) / len(y_list) 62 | left = int(math.floor(min(x_list))) 63 | right = int(math.ceil(max(x_list))) 64 | bottom = int(math.floor(min(y_list))) 65 | top = int(math.ceil(max(y_list))) 66 | width = right - left + 1 67 | height = top - bottom + 1 68 | x_offset, y_offset = 0, 0 69 | if height - width >= 5: 70 | left_mid = (left, (bottom + top) / 2) 71 | right_mid = (right, (bottom + top) / 2) 72 | if utils.closest_distance(left_mid, resources_nearby) > \ 73 | utils.closest_distance(right_mid, resources_nearby): 74 | x_offset = width - height + 1 75 | width = height - 1 76 | elif height - width <= -5: 77 | top_mid = ((left + right) / 2, top) 78 | bottom_mid = ((left + right) / 2, bottom) 79 | if utils.closest_distance(top_mid, resources_nearby) < \ 80 | utils.closest_distance(bottom_mid, resources_nearby): 81 | y_offset = height - width + 1 82 | height = width - 1 83 | region = [left + x_offset, bottom + y_offset, width, height] 84 | place = self._search_place(region, dc, margin=5.5, shrink_mineral=True) 85 | return utils.closest_unit((x_mean, y_mean), place) \ 86 | if len(place) > 0 else None 87 | 88 | def _search_place(self, search_region, dc, margin=0, remove_corner=False, 89 | expand_mineral=False, shrink_mineral=False): 90 | bottomleft = tuple(map(int, search_region[:2])) 91 | size = tuple(map(int, search_region[2:])) 92 | grids = np.zeros(size).astype(np.int) 93 | if remove_corner: 94 | cx, cy = size[0] / 2.0, size[1] / 2.0 95 | r = max(size[0] / 2.0, size[1] / 2.0) 96 | r_sqrt = (r - 0.5) ** 2 97 | for x in range(size[0]): 98 | x_sqrt = (x + 0.5 - cx) ** 2 99 | for y in range(size[1]): 100 | y_sqrt = (y + 0.5 - cy) ** 2 101 | if x_sqrt + y_sqrt > r_sqrt: 102 | grids[x, y] = 1 103 | filter_range = (bottomleft[0] - 10, bottomleft[0] + size[0] + 10, 104 | bottomleft[1] - 10, bottomleft[1] + size[1] + 10) 105 | sitted_units = [u for u in dc.units 106 | if (u.unit_type in PLACE_COLLISION_BUILDINGS and 107 | u.float_attr.pos_x >= filter_range[0] and 108 | u.float_attr.pos_x <= filter_range[1] and 109 | u.float_attr.pos_y >= filter_range[2] and 110 | u.float_attr.pos_y <= filter_range[3])] 111 | sitted_units = [u for u in dc.units 112 | if u.unit_type in PLACE_COLLISION_BUILDINGS] 113 | margin_int = math.floor(margin) 114 | for u in sitted_units: 115 | if u.float_attr.radius <= 1.0: 116 | r_x = r_y = 1.0 117 | if u.float_attr.pos_x % 1 != 0: r_x -= 0.5 118 | if u.float_attr.pos_y % 1 != 0: r_y -= 0.5 119 | else: 120 | r_x = r_y = float(int(u.float_attr.radius)) 121 | if u.float_attr.pos_x % 1 != 0: r_x += 0.5 122 | if u.float_attr.pos_y % 1 != 0: r_y += 0.5 123 | if (shrink_mineral and 124 | u.unit_type in {UNIT_TYPE.NEUTRAL_MINERALFIELD.value, 125 | UNIT_TYPE.NEUTRAL_MINERALFIELD750.value}): 126 | if r_x == 1.5 and r_y == 1: r_x = 0.5 127 | elif r_x == 1 and r_y == 1.5: r_y = 0.5 128 | if (expand_mineral and 129 | u.unit_type in {UNIT_TYPE.NEUTRAL_MINERALFIELD.value, 130 | UNIT_TYPE.NEUTRAL_MINERALFIELD750.value}): 131 | r_x += 1 132 | r_y += 1 133 | r_x += margin_int 134 | r_y += margin_int 135 | xl = int(u.float_attr.pos_x - r_x - bottomleft[0]) 136 | xr = int(u.float_attr.pos_x + r_x - bottomleft[0]) 137 | yu = int(u.float_attr.pos_y - r_y - bottomleft[1]) 138 | yd = int(u.float_attr.pos_y + r_y - bottomleft[1]) 139 | grids[max(xl, 0) : max(min(xr, size[0]), 0), 140 | max(yu, 0) : max(min(yd, size[1]), 0)] = 1 141 | 142 | slopes = [(76.5, 90.5), (73.5, 86.5), (123.5, 52.5), (126.5, 56.5), 143 | (131.5, 36.5), (68.5, 106.5)] 144 | holes = [(124.5, 34.5), (76.5, 108.5), (154.5, 60.5), (45.5, 82.5)] 145 | r = 2.5 146 | for x, y in slopes + holes: 147 | xl = int(x - r - bottomleft[0]) 148 | xr = int(x + r - bottomleft[0]) 149 | yu = int(y - r - bottomleft[1]) 150 | yd = int(y + r - bottomleft[1]) 151 | grids[max(xl, 0) : max(min(xr, size[0]), 0), 152 | max(yu, 0) : max(min(yd, size[1]), 0)] = 1 153 | x, y = np.nonzero(1 - grids) 154 | #if remove_corner == True: 155 | #np.set_printoptions(threshold=np.nan, linewidth=300) 156 | #print(grids) 157 | return list(zip(x + bottomleft[0] + 0.5, y + bottomleft[1] + 0.5)) 158 | -------------------------------------------------------------------------------- /sc2learner/envs/actions/resource.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import random 6 | 7 | from s2clientprotocol import sc2api_pb2 as sc_pb 8 | from pysc2.lib.typeenums import UNIT_TYPEID as UNIT_TYPE 9 | from pysc2.lib.typeenums import ABILITY_ID as ABILITY 10 | 11 | from sc2learner.envs.actions.function import Function 12 | import sc2learner.envs.common.utils as utils 13 | 14 | 15 | class ResourceActions(object): 16 | 17 | @property 18 | def action_queens_inject_larva(self): 19 | return Function(name="queens_inject_larva", 20 | function=self._all_idle_queens_inject_larva, 21 | is_valid=self._is_valid_all_idle_queens_inject_larva) 22 | 23 | @property 24 | def action_idle_workers_gather_minerals(self): 25 | return Function(name="idle_workers_gather_minerals", 26 | function=self._all_idle_workers_gather_minerals, 27 | is_valid=self._is_valid_all_idle_workers_gather_minerals) 28 | 29 | @property 30 | def action_assign_workers_gather_gas(self): 31 | return Function(name="assign_workers_gather_gas", 32 | function=self._assign_workers_gather_gas, 33 | is_valid=self._is_valid_assign_workers_gather_gas) 34 | 35 | @property 36 | def action_assign_workers_gather_minerals(self): 37 | return Function(name="assign_workers_gather_minerals", 38 | function=self._assign_workers_gather_minerals, 39 | is_valid=self._is_valid_assign_workers_gather_minerals) 40 | 41 | def _all_idle_queens_inject_larva(self, dc): 42 | injectable_queens = [ 43 | # TODO: -->idle_units_of_type 44 | u for u in dc.units_of_type(UNIT_TYPE.ZERG_QUEEN.value) 45 | if u.float_attr.energy >= 25 46 | ] 47 | bases = dc.mature_units_of_types([UNIT_TYPE.ZERG_HATCHERY.value, 48 | UNIT_TYPE.ZERG_LAIR.value, 49 | UNIT_TYPE.ZERG_HIVE.value]) 50 | actions = [] 51 | for queen in injectable_queens: 52 | action = sc_pb.Action() 53 | action.action_raw.unit_command.unit_tags.append(queen.tag) 54 | action.action_raw.unit_command.ability_id = \ 55 | ABILITY.EFFECT_INJECTLARVA.value 56 | base = utils.closest_unit(queen, bases) 57 | action.action_raw.unit_command.target_unit_tag = base.tag 58 | actions.append(action) 59 | return actions 60 | 61 | def _is_valid_all_idle_queens_inject_larva(self, dc): 62 | injectable_queens = [ 63 | # TODO: -->idle_units_of_type 64 | u for u in dc.units_of_type(UNIT_TYPE.ZERG_QUEEN.value) 65 | if u.float_attr.energy >= 25 66 | ] 67 | bases = dc.mature_units_of_types([UNIT_TYPE.ZERG_HATCHERY.value, 68 | UNIT_TYPE.ZERG_LAIR.value, 69 | UNIT_TYPE.ZERG_HIVE.value]) 70 | if len(bases) > 0 and len(injectable_queens) > 0: return True 71 | else: return False 72 | 73 | def _all_idle_workers_gather_minerals(self, dc): 74 | idle_workers = dc.idle_units_of_type(UNIT_TYPE.ZERG_DRONE.value) 75 | actions = [] 76 | for worker in idle_workers: 77 | mineral = utils.closest_unit(worker, dc.minerals) 78 | action = sc_pb.Action() 79 | action.action_raw.unit_command.unit_tags.append(worker.tag) 80 | action.action_raw.unit_command.ability_id = \ 81 | ABILITY.HARVEST_GATHER_DRONE.value 82 | action.action_raw.unit_command.target_unit_tag = mineral.tag 83 | actions.append(action) 84 | return actions 85 | 86 | def _is_valid_all_idle_workers_gather_minerals(self, dc): 87 | if (len(dc.idle_units_of_type(UNIT_TYPE.ZERG_DRONE.value)) > 0 and 88 | len(dc.minerals) > 0): 89 | return True 90 | else: 91 | return False 92 | 93 | def _assign_workers_gather_gas(self, dc): 94 | idle_extractors = [ 95 | u for u in dc.units_of_type(UNIT_TYPE.ZERG_EXTRACTOR.value) 96 | if u.int_attr.ideal_harvesters - u.int_attr.assigned_harvesters > 0 97 | ] 98 | if len(idle_extractors) == 0: return [] 99 | extractor = random.choice(idle_extractors) 100 | num_workers_need = extractor.int_attr.ideal_harvesters - \ 101 | extractor.int_attr.assigned_harvesters 102 | extractor_tags = set(u.tag for u in dc.units_of_type( 103 | UNIT_TYPE.ZERG_EXTRACTOR.value)) 104 | workers = [ 105 | u for u in dc.units_of_type(UNIT_TYPE.ZERG_DRONE.value) 106 | if (len(u.orders) == 0 or 107 | (u.orders[0].ability_id == ABILITY.HARVEST_GATHER_DRONE.value and 108 | u.orders[0].target_tag not in extractor_tags)) 109 | ] 110 | if len(workers) == 0: return [] 111 | assigned_workers = utils.closest_units(extractor, workers, num_workers_need) 112 | action = sc_pb.Action() 113 | action.action_raw.unit_command.unit_tags.extend( 114 | [u.tag for u in assigned_workers]) 115 | action.action_raw.unit_command.ability_id = \ 116 | ABILITY.HARVEST_GATHER_DRONE.value 117 | action.action_raw.unit_command.target_unit_tag = extractor.tag 118 | return [action] 119 | 120 | def _is_valid_assign_workers_gather_gas(self, dc): 121 | idle_extractors = [ 122 | u for u in dc.units_of_type(UNIT_TYPE.ZERG_EXTRACTOR.value) 123 | if u.int_attr.ideal_harvesters - u.int_attr.assigned_harvesters > 0 124 | ] 125 | extractor_tags = set(u.tag for u in dc.units_of_type( 126 | UNIT_TYPE.ZERG_EXTRACTOR.value)) 127 | workers = [ 128 | u for u in dc.units_of_type(UNIT_TYPE.ZERG_DRONE.value) 129 | if (len(u.orders) == 0 or 130 | (u.orders[0].ability_id == ABILITY.HARVEST_GATHER_DRONE.value and 131 | u.orders[0].target_tag not in extractor_tags)) 132 | ] 133 | if len(idle_extractors) > 0 and len(workers) > 0: return True 134 | else: return False 135 | 136 | def _assign_workers_gather_minerals(self, dc): 137 | extractor_tags = set(u.tag for u in dc.units_of_type( 138 | UNIT_TYPE.ZERG_EXTRACTOR.value)) 139 | workers = [ 140 | u for u in dc.units_of_type(UNIT_TYPE.ZERG_DRONE.value) 141 | if (len(u.orders) == 0 or 142 | (u.orders[0].ability_id == ABILITY.HARVEST_GATHER_DRONE.value and 143 | u.orders[0].target_tag in extractor_tags)) 144 | ] 145 | actions = [] 146 | for worker in random.sample(workers, min(3, len(workers))): 147 | target_mineral = utils.closest_unit(worker, dc.minerals) 148 | action = sc_pb.Action() 149 | action.action_raw.unit_command.unit_tags.append(worker.tag) 150 | action.action_raw.unit_command.ability_id = \ 151 | ABILITY.HARVEST_GATHER_DRONE.value 152 | action.action_raw.unit_command.target_unit_tag = target_mineral.tag 153 | actions.append(action) 154 | return actions 155 | 156 | def _is_valid_assign_workers_gather_minerals(self, dc): 157 | extractor_tags = set(u.tag for u in dc.units_of_type( 158 | UNIT_TYPE.ZERG_EXTRACTOR.value)) 159 | workers = [ 160 | u for u in dc.units_of_type(UNIT_TYPE.ZERG_DRONE.value) 161 | if (len(u.orders) == 0 or 162 | (u.orders[0].ability_id == ABILITY.HARVEST_GATHER_DRONE.value and 163 | u.orders[0].target_tag in extractor_tags)) 164 | ] 165 | return len(workers) > 0 166 | -------------------------------------------------------------------------------- /sc2learner/envs/observations/nonspatial_features.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import numpy as np 6 | from pysc2.lib.typeenums import UNIT_TYPEID as UNIT_TYPE 7 | from pysc2.lib.typeenums import ABILITY_ID as ABILITY 8 | 9 | from sc2learner.envs.common.const import ALLY_TYPE 10 | from sc2learner.envs.common.const import COMBAT_TYPES 11 | 12 | 13 | class PlayerFeature(object): 14 | 15 | def features(self, observation): 16 | player_features = observation["player"][1:-1].astype(np.float32) 17 | food_unused = player_features[3] - player_features[2] 18 | player_features[-1] = food_unused if food_unused >= 0 else 0 19 | scale = np.array([2000, 2000, 20, 20, 20, 20, 20, 20, 20], np.float32) 20 | scaled_features = (player_features / scale).astype(np.float32) 21 | log_features = np.log10(player_features + 1).astype(np.float32) 22 | 23 | bins_food_unused = np.zeros(10, dtype=np.float32) 24 | bin_id = int((max(food_unused, 0) - 1) // 3 + 1) if food_unused <= 27 else 9 25 | bins_food_unused[bin_id] = 1 26 | return np.concatenate((scaled_features, log_features, bins_food_unused)) 27 | 28 | @property 29 | def num_dims(self): 30 | return 9 * 2 + 10 31 | 32 | 33 | class ScoreFeature(object): 34 | 35 | def features(self, observation): 36 | score_features = observation.score_cumulative[3:].astype(np.float32) 37 | score_features /= 3000.0 38 | log_features = np.log10(score_features + 1).astype(np.float32) 39 | return np.concatenate((score_features, log_features)) 40 | 41 | @property 42 | def num_dims(self): 43 | return 10 * 2 44 | 45 | 46 | class UnitTypeCountFeature(object): 47 | 48 | def __init__(self, type_list, use_regions=False): 49 | self._type_list = type_list 50 | if use_regions: 51 | self._regions = [(0, 0, 200, 176), 52 | (0, 88, 80, 176), 53 | (80, 88, 120, 176), 54 | (120, 88, 200, 176), 55 | (0, 55, 80, 88), 56 | (80, 55, 120, 88), 57 | (120, 55, 200, 88), 58 | (0, 0, 80, 55), 59 | (80, 0, 120, 55), 60 | (120, 0, 200, 55)] 61 | else: 62 | self._regions = [(0, 0, 200, 176)] 63 | self._regions_flipped = [self._regions[0]] + [ 64 | self._regions[10 - i] for i in range(1, len(self._regions))] 65 | 66 | def features(self, observation, need_flip=False): 67 | feature_list = [] 68 | for region in (self._regions if not need_flip else self._regions_flipped): 69 | units_in_region = [u for u in observation['units'] 70 | if self._is_in_region(u, region)] 71 | feature_list.append(self._generate_features(units_in_region)) 72 | return np.concatenate(feature_list) 73 | 74 | @property 75 | def num_dims(self): 76 | return len(self._type_list) * len(self._regions) * 2 * 2 77 | 78 | def _generate_features(self, units): 79 | self_units = [u for u in units 80 | if u.int_attr.alliance == ALLY_TYPE.SELF.value] 81 | enemy_units = [u for u in units 82 | if u.int_attr.alliance == ALLY_TYPE.ENEMY.value] 83 | self_features = self._get_counts(self_units) 84 | enemy_features = self._get_counts(enemy_units) 85 | features = np.concatenate((self_features, enemy_features)) 86 | 87 | scaled_features = features / 20 88 | log_features = np.log10(features + 1) 89 | 90 | return np.concatenate((scaled_features, log_features)) 91 | 92 | def _get_counts(self, units): 93 | count = {t: 0 for t in self._type_list} 94 | for u in units: 95 | if u.unit_type in count: 96 | count[u.unit_type] += 1 97 | return np.array([count[t] for t in self._type_list], dtype=np.float32) 98 | 99 | def _is_in_region(self, unit, region): 100 | return (unit.float_attr.pos_x >= region[0] and 101 | unit.float_attr.pos_x < region[2] and 102 | unit.float_attr.pos_y >= region[1] and 103 | unit.float_attr.pos_y < region[3]) 104 | 105 | 106 | class UnitStatCountFeature(object): 107 | 108 | def __init__(self, use_regions=False): 109 | if use_regions: 110 | self._regions = [(0, 0, 200, 176), 111 | (0, 88, 80, 176), 112 | (80, 88, 120, 176), 113 | (120, 88, 200, 176), 114 | (0, 55, 80, 88), 115 | (80, 55, 120, 88), 116 | (120, 55, 200, 88), 117 | (0, 0, 80, 55), 118 | (80, 0, 120, 55), 119 | (120, 0, 200, 55)] 120 | else: 121 | self._regions = [(0, 0, 200, 176)] 122 | self._regions_flipped = [self._regions[0]] + [ 123 | self._regions[10 - i] for i in range(1, len(self._regions))] 124 | 125 | def features(self, observation, need_flip=False): 126 | feature_list = [] 127 | for region in (self._regions if not need_flip else self._regions_flipped): 128 | units_in_region = [u for u in observation['units'] 129 | if self._is_in_region(u, region)] 130 | feature_list.append(self._generate_features(units_in_region)) 131 | return np.concatenate(feature_list) 132 | 133 | @property 134 | def num_dims(self): 135 | return len(self._regions) * 2 * 4 * 2 136 | 137 | def _generate_features(self, units): 138 | self_units = [u for u in units 139 | if u.int_attr.alliance == ALLY_TYPE.SELF.value] 140 | enemy_units = [u for u in units 141 | if u.int_attr.alliance == ALLY_TYPE.ENEMY.value] 142 | self_combats = [u for u in self_units if u.unit_type in COMBAT_TYPES] 143 | enemy_combats = [u for u in enemy_units if u.unit_type in COMBAT_TYPES] 144 | self_air_units = [u for u in self_units if u.bool_attr.is_flying] 145 | enemy_air_units = [u for u in enemy_units if u.bool_attr.is_flying] 146 | self_ground_units = [u for u in self_units if not u.bool_attr.is_flying] 147 | enemy_ground_units = [u for u in enemy_units if not u.bool_attr.is_flying] 148 | 149 | features = np.array([len(self_units), 150 | len(self_combats), 151 | len(self_ground_units), 152 | len(self_air_units), 153 | len(enemy_units), 154 | len(enemy_combats), 155 | len(enemy_ground_units), 156 | len(enemy_air_units)], dtype=np.float32) 157 | 158 | scaled_features = features / 20 159 | log_features = np.log10(features + 1) 160 | return np.concatenate((scaled_features, log_features)) 161 | 162 | def _is_in_region(self, unit, region): 163 | return (unit.float_attr.pos_x >= region[0] and 164 | unit.float_attr.pos_x < region[2] and 165 | unit.float_attr.pos_y >= region[1] and 166 | unit.float_attr.pos_y < region[3]) 167 | 168 | 169 | class GameProgressFeature(object): 170 | 171 | def features(self, observation): 172 | game_loop = observation["game_loop"][0] 173 | features_60 = self._onehot(game_loop, 60) 174 | features_20 = self._onehot(game_loop, 20) 175 | features_8 = self._onehot(game_loop, 8) 176 | features_4 = self._onehot(game_loop, 4) 177 | 178 | return np.concatenate([features_60, 179 | features_20, 180 | features_8, 181 | features_4]) 182 | 183 | def _onehot(self, value, n_bins): 184 | bin_width = 24000 // n_bins 185 | features = np.zeros(n_bins, dtype=np.float32) 186 | idx = int(value // bin_width) 187 | idx = n_bins - 1 if idx >= n_bins else idx 188 | features[idx] = 1.0 189 | return features 190 | 191 | @property 192 | def num_dims(self): 193 | return 60 + 20 + 8 + 4 194 | 195 | 196 | class ActionSeqFeature(object): 197 | 198 | def __init__(self, n_dims_action_space, seq_len): 199 | self._action_seq = [-1] * seq_len 200 | self._n_dims_action_space = n_dims_action_space 201 | 202 | def reset(self): 203 | self._action_seq = [-1] * len(self._action_seq) 204 | 205 | def push_action(self, action): 206 | self._action_seq.pop(0) 207 | self._action_seq.append(action) 208 | 209 | def features(self): 210 | features = np.zeros(self._n_dims_action_space * len(self._action_seq), 211 | dtype=np.float32) 212 | for i, action in enumerate(self._action_seq): 213 | assert action < self._n_dims_action_space 214 | if action >= 0: 215 | features[i * self._n_dims_action_space + action] = 1.0 216 | return features 217 | 218 | @property 219 | def num_dims(self): 220 | return self._n_dims_action_space * len(self._action_seq) 221 | 222 | 223 | class WorkerFeature(object): 224 | 225 | def features(self, dc): 226 | extractor_tags = set(u.tag for u in dc.units_of_type( 227 | UNIT_TYPE.ZERG_EXTRACTOR.value)) 228 | workers = dc.units_of_type(UNIT_TYPE.ZERG_DRONE.value) 229 | harvest_workers = [ 230 | u for u in workers 231 | if (len(u.orders) > 0 and 232 | u.orders[0].ability_id == ABILITY.HARVEST_GATHER_DRONE.value) 233 | ] 234 | gas_workers = [u for u in harvest_workers 235 | if u.orders[0].target_tag in extractor_tags] 236 | mineral_workers = [u for u in harvest_workers 237 | if u.orders[0].target_tag not in extractor_tags] 238 | return np.array([len(gas_workers), 239 | len(mineral_workers), 240 | len(workers) - len(gas_workers) - len(mineral_workers)], 241 | dtype=np.float32) / 20.0 242 | 243 | @property 244 | def num_dims(self): 245 | return 3 246 | -------------------------------------------------------------------------------- /sc2learner/bin/train_ppo_selfplay.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import sys 6 | from threading import Thread 7 | import os 8 | import multiprocessing 9 | import random 10 | import time 11 | 12 | from absl import app 13 | from absl import flags 14 | from absl import logging 15 | import tensorflow as tf 16 | 17 | from sc2learner.agents.ppo_policies import LstmPolicy, MlpPolicy 18 | from sc2learner.agents.ppo_agent import PPOActor, PPOLearner, PPOSelfplayActor 19 | from sc2learner.envs.raw_env import SC2RawEnv 20 | from sc2learner.envs.selfplay_raw_env import SC2SelfplayRawEnv 21 | from sc2learner.envs.actions.zerg_action_wrappers import ZergActionWrapper 22 | from sc2learner.envs.actions.zerg_action_wrappers import ZergPlayerActionWrapper 23 | from sc2learner.envs.observations.zerg_observation_wrappers \ 24 | import ZergObservationWrapper 25 | from sc2learner.envs.observations.zerg_observation_wrappers \ 26 | import ZergPlayerObservationWrapper 27 | from sc2learner.utils.utils import print_arguments 28 | 29 | 30 | FLAGS = flags.FLAGS 31 | flags.DEFINE_enum("job_name", 'actor', ['actor', 'learner', 'eval', 'eval_model'], 32 | "Job type.") 33 | flags.DEFINE_enum("policy", 'mlp', ['mlp', 'lstm'], "Job type.") 34 | flags.DEFINE_integer("unroll_length", 128, "Length of rollout steps.") 35 | flags.DEFINE_integer("model_cache_size", 300, "Opponent model cache size.") 36 | flags.DEFINE_float("model_cache_prob", 0.05, "Opponent model cache probability.") 37 | flags.DEFINE_string("learner_ip", "localhost", "Learner IP address.") 38 | flags.DEFINE_string("port_A", "5700", "Port for transporting model.") 39 | flags.DEFINE_string("port_B", "5701", "Port for transporting data.") 40 | flags.DEFINE_string("game_version", '4.6', "Game core version.") 41 | flags.DEFINE_float("discount_gamma", 0.998, "Discount factor.") 42 | flags.DEFINE_float("lambda_return", 0.95, "Lambda return factor.") 43 | flags.DEFINE_float("clip_range", 0.1, "Clip range for PPO.") 44 | flags.DEFINE_float("ent_coef", 0.01, "Coefficient for the entropy term.") 45 | flags.DEFINE_float("vf_coef", 0.5, "Coefficient for the value loss.") 46 | flags.DEFINE_float("learn_act_speed_ratio", 0, "Maximum learner/actor ratio.") 47 | flags.DEFINE_integer("game_steps_per_episode", 43200, "Maximum steps per episode.") 48 | flags.DEFINE_integer("batch_size", 32, "Batch size.") 49 | flags.DEFINE_integer("learner_queue_size", 1024, "Size of learner's unroll queue.") 50 | flags.DEFINE_integer("step_mul", 32, "Game steps per agent step.") 51 | flags.DEFINE_string("difficulties", '1,2,4,6,9,A', "Bot's strengths.") 52 | flags.DEFINE_float("learning_rate", 1e-5, "Learning rate.") 53 | flags.DEFINE_string("init_model_path", None, "Initial model path.") 54 | flags.DEFINE_string("init_oppo_pool_filelist", None, "Initial opponent model path.") 55 | flags.DEFINE_string("save_dir", "./checkpoints/", "Dir to save models to") 56 | flags.DEFINE_integer("save_interval", 50000, "Model saving frequency.") 57 | flags.DEFINE_integer("print_interval", 1000, "Print train cost frequency.") 58 | flags.DEFINE_boolean("disable_fog", False, "Disable fog-of-war.") 59 | flags.DEFINE_boolean("use_all_combat_actions", False, "Use all combat actions.") 60 | flags.DEFINE_boolean("use_region_features", False, "Use region features") 61 | flags.DEFINE_boolean("use_action_mask", True, "Use region-wise combat.") 62 | flags.FLAGS(sys.argv) 63 | 64 | 65 | def tf_config(ncpu=None): 66 | if ncpu is None: 67 | ncpu = multiprocessing.cpu_count() 68 | if sys.platform == 'darwin': ncpu //= 2 69 | config = tf.ConfigProto(allow_soft_placement=True, 70 | intra_op_parallelism_threads=ncpu, 71 | inter_op_parallelism_threads=ncpu) 72 | config.gpu_options.allow_growth = True 73 | tf.Session(config=config).__enter__() 74 | 75 | 76 | def create_env(difficulty, random_seed=None): 77 | env = SC2RawEnv(map_name='AbyssalReef', 78 | step_mul=FLAGS.step_mul, 79 | resolution=16, 80 | agent_race='zerg', 81 | bot_race='zerg', 82 | difficulty=difficulty, 83 | disable_fog=FLAGS.disable_fog, 84 | tie_to_lose=False, 85 | game_steps_per_episode=FLAGS.game_steps_per_episode, 86 | random_seed=random_seed) 87 | env = ZergActionWrapper(env, 88 | game_version=FLAGS.game_version, 89 | mask=FLAGS.use_action_mask, 90 | use_all_combat_actions=FLAGS.use_all_combat_actions) 91 | env = ZergObservationWrapper(env, 92 | use_spatial_features=False, 93 | use_game_progress=(not FLAGS.policy == 'lstm'), 94 | action_seq_len=1 if FLAGS.policy == 'lstm' else 8, 95 | use_regions=FLAGS.use_region_features) 96 | print(env.observation_space, env.action_space) 97 | return env 98 | 99 | 100 | def create_selfplay_env(random_seed=None): 101 | env = SC2SelfplayRawEnv(map_name='AbyssalReef', 102 | step_mul=FLAGS.step_mul, 103 | resolution=16, 104 | agent_race='zerg', 105 | opponent_race='zerg', 106 | tie_to_lose=False, 107 | disable_fog=FLAGS.disable_fog, 108 | game_steps_per_episode=FLAGS.game_steps_per_episode, 109 | random_seed=random_seed) 110 | env = ZergPlayerActionWrapper( 111 | player=0, 112 | env=env, 113 | game_version=FLAGS.game_version, 114 | mask=FLAGS.use_action_mask, 115 | use_all_combat_actions=FLAGS.use_all_combat_actions) 116 | env = ZergPlayerObservationWrapper( 117 | player=0, 118 | env=env, 119 | use_spatial_features=False, 120 | use_game_progress=(not FLAGS.policy == 'lstm'), 121 | action_seq_len=1 if FLAGS.policy == 'lstm' else 8, 122 | use_regions=FLAGS.use_region_features) 123 | 124 | env = ZergPlayerActionWrapper( 125 | player=1, 126 | env=env, 127 | game_version=FLAGS.game_version, 128 | mask=FLAGS.use_action_mask, 129 | use_all_combat_actions=FLAGS.use_all_combat_actions) 130 | env = ZergPlayerObservationWrapper( 131 | player=1, 132 | env=env, 133 | use_spatial_features=False, 134 | use_game_progress=(not FLAGS.policy == 'lstm'), 135 | action_seq_len=1 if FLAGS.policy == 'lstm' else 8, 136 | use_regions=FLAGS.use_region_features) 137 | print(env.observation_space, env.action_space) 138 | return env 139 | 140 | 141 | def start_actor(): 142 | tf_config(ncpu=2) 143 | random.seed(time.time()) 144 | game_seed = random.randint(0, 2**32 - 1) 145 | print("Game Seed: %d" % game_seed) 146 | env = create_selfplay_env(game_seed) 147 | policy = {'lstm': LstmPolicy, 148 | 'mlp': MlpPolicy}[FLAGS.policy] 149 | actor = PPOSelfplayActor( 150 | env=env, 151 | policy=policy, 152 | unroll_length=FLAGS.unroll_length, 153 | gamma=FLAGS.discount_gamma, 154 | lam=FLAGS.lambda_return, 155 | model_cache_size=FLAGS.model_cache_size, 156 | model_cache_prob=FLAGS.model_cache_prob, 157 | prob_latest_opponent=0.0, 158 | init_opponent_pool_filelist=FLAGS.init_oppo_pool_filelist, 159 | freeze_opponent_pool=False, 160 | learner_ip=FLAGS.learner_ip, 161 | port_A=FLAGS.port_A, 162 | port_B=FLAGS.port_B) 163 | actor.run() 164 | env.close() 165 | 166 | 167 | def start_learner(): 168 | tf_config() 169 | env = create_env('1', 0) 170 | policy = {'lstm': LstmPolicy, 171 | 'mlp': MlpPolicy}[FLAGS.policy] 172 | learner = PPOLearner(env=env, 173 | policy=policy, 174 | unroll_length=FLAGS.unroll_length, 175 | lr=FLAGS.learning_rate, 176 | clip_range=FLAGS.clip_range, 177 | batch_size=FLAGS.batch_size, 178 | ent_coef=FLAGS.ent_coef, 179 | vf_coef=FLAGS.vf_coef, 180 | max_grad_norm=0.5, 181 | queue_size=FLAGS.learner_queue_size, 182 | print_interval=FLAGS.print_interval, 183 | save_interval=FLAGS.save_interval, 184 | learn_act_speed_ratio=FLAGS.learn_act_speed_ratio, 185 | save_dir=FLAGS.save_dir, 186 | init_model_path=FLAGS.init_model_path, 187 | port_A=FLAGS.port_A, 188 | port_B=FLAGS.port_B) 189 | learner.run() 190 | env.close() 191 | 192 | 193 | def start_evaluator_against_builtin(): 194 | tf_config(ncpu=2) 195 | random.seed(time.time()) 196 | difficulty = random.choice(FLAGS.difficulties.split(',')) 197 | game_seed = random.randint(0, 2**32 - 1) 198 | print("Game Seed: %d" % game_seed) 199 | env = create_env(difficulty, game_seed) 200 | policy = {'lstm': LstmPolicy, 201 | 'mlp': MlpPolicy}[FLAGS.policy] 202 | actor = PPOActor(env=env, 203 | policy=policy, 204 | unroll_length=FLAGS.unroll_length, 205 | gamma=FLAGS.discount_gamma, 206 | lam=FLAGS.lambda_return, 207 | enable_push=False, 208 | learner_ip=FLAGS.learner_ip, 209 | port_A=FLAGS.port_A, 210 | port_B=FLAGS.port_B) 211 | actor.run() 212 | env.close() 213 | 214 | 215 | def start_evaluator_against_model(): 216 | tf_config(ncpu=2) 217 | random.seed(time.time()) 218 | game_seed = random.randint(0, 2**32 - 1) 219 | print("Game Seed: %d" % game_seed) 220 | env = create_selfplay_env(game_seed) 221 | policy = {'lstm': LstmPolicy, 222 | 'mlp': MlpPolicy}[FLAGS.policy] 223 | actor = PPOSelfplayActor( 224 | env=env, 225 | policy=policy, 226 | unroll_length=FLAGS.unroll_length, 227 | gamma=FLAGS.discount_gamma, 228 | lam=FLAGS.lambda_return, 229 | model_cache_size=1, 230 | model_cache_prob=FLAGS.model_cache_prob, 231 | enable_push=False, 232 | prob_latest_opponent=0.0, 233 | init_opponent_pool_filelist=FLAGS.init_oppo_pool_filelist, 234 | freeze_opponent_pool=True, 235 | learner_ip=FLAGS.learner_ip, 236 | port_A=FLAGS.port_A, 237 | port_B=FLAGS.port_B) 238 | actor.run() 239 | env.close() 240 | 241 | 242 | def main(argv): 243 | logging.set_verbosity(logging.ERROR) 244 | print_arguments(FLAGS) 245 | if FLAGS.job_name == 'actor': start_actor() 246 | elif FLAGS.job_name == 'learner': start_learner() 247 | elif FLAGS.job_name == 'eval': start_evaluator_against_builtin() 248 | else: start_evaluator_against_model() 249 | 250 | 251 | if __name__ == '__main__': 252 | app.run(main) 253 | -------------------------------------------------------------------------------- /sc2learner/envs/actions/zerg_action_wrappers.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import platform 6 | 7 | import numpy as np 8 | import gym 9 | from gym.spaces.discrete import Discrete 10 | from pysc2.lib.typeenums import UNIT_TYPEID as UNIT_TYPE 11 | from pysc2.lib.typeenums import UPGRADE_ID as UPGRADE 12 | from pysc2.lib import point 13 | from s2clientprotocol import sc2api_pb2 as sc_pb 14 | 15 | from sc2learner.envs.spaces.pysc2_raw import PySC2RawObservation 16 | from sc2learner.envs.spaces.mask_discrete import MaskDiscrete 17 | from sc2learner.envs.common.data_context import DataContext 18 | from sc2learner.envs.actions.function import Function 19 | from sc2learner.envs.actions.produce import ProduceActions 20 | from sc2learner.envs.actions.build import BuildActions 21 | from sc2learner.envs.actions.upgrade import UpgradeActions 22 | from sc2learner.envs.actions.resource import ResourceActions 23 | from sc2learner.envs.actions.combat import CombatActions 24 | 25 | 26 | class ZergActionWrapper(gym.Wrapper): 27 | 28 | def __init__(self, env, game_version='4.1.2', mask=False, 29 | use_all_combat_actions=False): 30 | super(ZergActionWrapper, self).__init__(env) 31 | # TODO: multiple observation space 32 | #assert isinstance(env.observation_space, PySC2RawObservation) 33 | 34 | self._dc = DataContext() 35 | self._build_mgr = BuildActions(game_version) 36 | self._produce_mgr = ProduceActions(game_version) 37 | self._upgrade_mgr = UpgradeActions(game_version) 38 | self._resource_mgr = ResourceActions() 39 | self._combat_mgr = CombatActions() 40 | 41 | self._actions = [ 42 | self._action_do_nothing(), 43 | self._build_mgr.action("build_extractor", UNIT_TYPE.ZERG_EXTRACTOR.value), 44 | self._build_mgr.action("build_spawning_pool", UNIT_TYPE.ZERG_SPAWNINGPOOL.value), 45 | self._build_mgr.action("build_roach_warren", UNIT_TYPE.ZERG_ROACHWARREN.value), 46 | self._build_mgr.action("build_hydraliskden", UNIT_TYPE.ZERG_HYDRALISKDEN.value), 47 | self._build_mgr.action("build_hatchery", UNIT_TYPE.ZERG_HATCHERY.value), 48 | self._build_mgr.action("build_evolution_chamber", UNIT_TYPE.ZERG_EVOLUTIONCHAMBER.value), 49 | self._build_mgr.action("build_baneling_nest", UNIT_TYPE.ZERG_BANELINGNEST.value), 50 | self._build_mgr.action("build_infestation_pit", UNIT_TYPE.ZERG_INFESTATIONPIT.value), 51 | self._build_mgr.action("build_spire", UNIT_TYPE.ZERG_SPIRE.value), 52 | self._build_mgr.action("build_ultralisk_cavern", UNIT_TYPE.ZERG_ULTRALISKCAVERN.value), 53 | #self._build_mgr.action("build_nydus_network", UNIT_TYPE.ZERG_NYDUSNETWORK.value), 54 | self._build_mgr.action("build_spine_crawler", UNIT_TYPE.ZERG_SPINECRAWLER.value), 55 | self._build_mgr.action("build_spore_crawler", UNIT_TYPE.ZERG_SPORECRAWLER.value), 56 | self._produce_mgr.action("morph_lurker_den", UNIT_TYPE.ZERG_LURKERDENMP.value), 57 | self._produce_mgr.action("morph_lair", UNIT_TYPE.ZERG_LAIR.value), 58 | self._produce_mgr.action("morph_hive", UNIT_TYPE.ZERG_HIVE.value), 59 | self._produce_mgr.action("morph_greater_spire", UNIT_TYPE.ZERG_GREATERSPIRE.value), 60 | self._produce_mgr.action("produce_drone", UNIT_TYPE.ZERG_DRONE.value), 61 | self._produce_mgr.action("produce_zergling", UNIT_TYPE.ZERG_ZERGLING.value), 62 | self._produce_mgr.action("morph_baneling", UNIT_TYPE.ZERG_BANELING.value), 63 | self._produce_mgr.action("produce_roach", UNIT_TYPE.ZERG_ROACH.value), 64 | self._produce_mgr.action("morph_ravager", UNIT_TYPE.ZERG_RAVAGER.value), 65 | self._produce_mgr.action("produce_hydralisk", UNIT_TYPE.ZERG_HYDRALISK.value), 66 | self._produce_mgr.action("morph_lurker", UNIT_TYPE.ZERG_LURKERMP.value), 67 | #self._produce_mgr.action("produce_viper", UNIT_TYPE.ZERG_VIPER.value), 68 | self._produce_mgr.action("produce_mutalisk", UNIT_TYPE.ZERG_MUTALISK.value), 69 | self._produce_mgr.action("produce_corruptor", UNIT_TYPE.ZERG_CORRUPTOR.value), 70 | self._produce_mgr.action("morph_broodlord", UNIT_TYPE.ZERG_BROODLORD.value), 71 | #self._produce_mgr.action("produce_swarmhost", UNIT_TYPE.ZERG_SWARMHOSTMP.value), 72 | #self._produce_mgr.action("produce_infestor", UNIT_TYPE.ZERG_INFESTOR.value), 73 | self._produce_mgr.action("produce_ultralisk", UNIT_TYPE.ZERG_ULTRALISK.value), 74 | self._produce_mgr.action("produce_overlord", UNIT_TYPE.ZERG_OVERLORD.value), 75 | self._produce_mgr.action("morph_overseer", UNIT_TYPE.ZERG_OVERSEER.value), 76 | self._produce_mgr.action("produce_queen", UNIT_TYPE.ZERG_QUEEN.value), 77 | #self._produce_mgr.action("produce_nydus_worm", UNIT_TYPE.ZERG_NYDUSCANAL.value), 78 | self._upgrade_mgr.action("upgrade_burrow", UPGRADE.BURROW.value), 79 | self._upgrade_mgr.action("upgrade_centrifical_hooks", UPGRADE.CENTRIFICALHOOKS.value), 80 | self._upgrade_mgr.action("upgrade_chitions_plating", UPGRADE.CHITINOUSPLATING.value), 81 | self._upgrade_mgr.action("upgrade_evolve_grooved_spines", UPGRADE.EVOLVEGROOVEDSPINES.value), 82 | self._upgrade_mgr.action("upgrade_evolve_muscular_augments", UPGRADE.EVOLVEMUSCULARAUGMENTS.value), 83 | self._upgrade_mgr.action("upgrade_gliare_constitution", UPGRADE.GLIALRECONSTITUTION.value), 84 | #self._upgrade_mgr.action("upgrade_infestor_energy_upgrade", UPGRADE.INFESTORENERGYUPGRADE.value), 85 | self._upgrade_mgr.action("upgrade_neural_parasite", UPGRADE.NEURALPARASITE.value), 86 | self._upgrade_mgr.action("upgrade_overlord_speed", UPGRADE.OVERLORDSPEED.value), 87 | self._upgrade_mgr.action("upgrade_tunneling_claws", UPGRADE.TUNNELINGCLAWS.value), 88 | self._upgrade_mgr.action("upgrade_flyer_armors_level1", UPGRADE.ZERGFLYERARMORSLEVEL1.value), 89 | self._upgrade_mgr.action("upgrade_flyer_armors_level2", UPGRADE.ZERGFLYERARMORSLEVEL2.value), 90 | self._upgrade_mgr.action("upgrade_flyer_armors_level3", UPGRADE.ZERGFLYERARMORSLEVEL3.value), 91 | self._upgrade_mgr.action("upgrade_flyer_weapons_level1", UPGRADE.ZERGFLYERWEAPONSLEVEL1.value), 92 | self._upgrade_mgr.action("upgrade_flyer_weapons_level2", UPGRADE.ZERGFLYERWEAPONSLEVEL2.value), 93 | self._upgrade_mgr.action("upgrade_flyer_weapons_level3", UPGRADE.ZERGFLYERWEAPONSLEVEL3.value), 94 | self._upgrade_mgr.action("upgrade_ground_armors_level1", UPGRADE.ZERGGROUNDARMORSLEVEL1.value), 95 | self._upgrade_mgr.action("upgrade_ground_armors_level2", UPGRADE.ZERGGROUNDARMORSLEVEL2.value), 96 | self._upgrade_mgr.action("upgrade_ground_armors_level3", UPGRADE.ZERGGROUNDARMORSLEVEL3.value), 97 | self._upgrade_mgr.action("upgrade_zergling_attack_speed", UPGRADE.ZERGLINGATTACKSPEED.value), 98 | self._upgrade_mgr.action("upgrade_zergling_moving_speed", UPGRADE.ZERGLINGMOVEMENTSPEED.value), 99 | self._upgrade_mgr.action("upgrade_melee_weapons_level1", UPGRADE.ZERGMELEEWEAPONSLEVEL1.value), 100 | self._upgrade_mgr.action("upgrade_melee_weapons_level2", UPGRADE.ZERGMELEEWEAPONSLEVEL2.value), 101 | self._upgrade_mgr.action("upgrade_melee_weapons_level3", UPGRADE.ZERGMELEEWEAPONSLEVEL3.value), 102 | self._upgrade_mgr.action("upgrade_missile_weapons_level1", UPGRADE.ZERGMISSILEWEAPONSLEVEL1.value), 103 | self._upgrade_mgr.action("upgrade_missile_weapons_level2", UPGRADE.ZERGMISSILEWEAPONSLEVEL2.value), 104 | self._upgrade_mgr.action("upgrade_missile_weapons_level3", UPGRADE.ZERGMISSILEWEAPONSLEVEL3.value), 105 | self._resource_mgr.action_assign_workers_gather_gas, 106 | self._resource_mgr.action_assign_workers_gather_minerals, 107 | # ZERG_LOCUST, ZERG_CHANGELING not included 108 | ] + ([ 109 | self._combat_mgr.action(0, 0), 110 | self._combat_mgr.action(9, 4), 111 | self._combat_mgr.action(4, 1) 112 | ] if not use_all_combat_actions else [ 113 | self._combat_mgr.action(0, target_region_id) 114 | for target_region_id in range(self._combat_mgr.num_regions) 115 | #self._combat_mgr.action(source_region_id, target_region_id) 116 | #for source_region_id in range(self._combat_mgr.num_regions) 117 | #for target_region_id in range(self._combat_mgr.num_regions) 118 | ]) 119 | 120 | self._required_pre_actions = [ 121 | self._resource_mgr.action_idle_workers_gather_minerals, 122 | self._resource_mgr.action_queens_inject_larva 123 | ] 124 | self._required_post_actions = [ 125 | self._combat_mgr.action_rally_new_combat_units, 126 | self._combat_mgr.action_framewise_rally_and_attack 127 | ] 128 | 129 | if mask: self.action_space = MaskDiscrete(len(self._actions)) 130 | else: self.action_space = Discrete(len(self._actions)) 131 | 132 | def step(self, action): 133 | actions = self._actions[action].function(self._dc) 134 | pre_actions, post_actions = self._required_actions() 135 | observation, reward, done, info = self.env.step( 136 | pre_actions + actions + post_actions) 137 | self._dc.update(observation) 138 | if isinstance(self.action_space, MaskDiscrete): 139 | observation['action_mask'] = self._get_valid_action_mask() 140 | return observation, reward, done, info 141 | 142 | def reset(self, **kwargs): 143 | self._combat_mgr.reset() 144 | observation = self.env.reset() 145 | self._dc.reset(observation) 146 | if isinstance(self.action_space, MaskDiscrete): 147 | observation['action_mask'] = self._get_valid_action_mask() 148 | return observation 149 | 150 | @property 151 | def action_names(self): 152 | return [action.name for action in self._actions] 153 | 154 | @property 155 | def player_position(self): 156 | if self._dc.init_base_pos[0] < 100: return 0 157 | else: return 1 158 | 159 | def _required_actions(self): 160 | pre_actions = [] 161 | for fn in self._required_pre_actions: 162 | if fn.is_valid(self._dc): 163 | pre_actions.extend(fn.function(self._dc)) 164 | 165 | post_actions = [] 166 | for fn in self._required_post_actions: 167 | if fn.is_valid(self._dc): 168 | post_actions.extend(fn.function(self._dc)) 169 | 170 | return pre_actions, post_actions 171 | 172 | def _get_valid_action_mask(self): 173 | ids = [i for i, action in enumerate(self._actions) 174 | if action.is_valid(self._dc)] 175 | mask = np.zeros(self.action_space.n) 176 | mask[ids] = 1 177 | return mask 178 | 179 | def _action_do_nothing(self): 180 | return Function(name='do_nothing', 181 | function=lambda dc: [], 182 | is_valid=lambda dc: True) 183 | 184 | 185 | class ZergPlayerActionWrapper(ZergActionWrapper): 186 | 187 | def __init__(self, player, **kwargs): 188 | self._warn_double_wrap = lambda *args: None 189 | self._player = player 190 | super(ZergPlayerActionWrapper, self).__init__(**kwargs) 191 | 192 | def step(self, action): 193 | actions = self._actions[action[self._player]].function(self._dc) 194 | pre_actions, post_actions = self._required_actions() 195 | action[self._player] = pre_actions + actions + post_actions 196 | observation, reward, done, info = self.env.step(action) 197 | self._dc.update(observation[self._player]) 198 | if isinstance(self.action_space, MaskDiscrete): 199 | observation[self._player]['action_mask'] = self._get_valid_action_mask() 200 | return observation, reward, done, info 201 | 202 | def reset(self, **kwargs): 203 | self._combat_mgr.reset() 204 | observation = self.env.reset() 205 | self._dc.reset(observation[self._player]) 206 | if isinstance(self.action_space, MaskDiscrete): 207 | observation[self._player]['action_mask'] = self._get_valid_action_mask() 208 | return observation 209 | -------------------------------------------------------------------------------- /sc2learner/envs/observations/zerg_observation_wrappers.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import numpy as np 6 | import gym 7 | from gym import spaces 8 | from pysc2.lib.typeenums import UNIT_TYPEID as UNIT_TYPE 9 | 10 | from sc2learner.envs.spaces.pysc2_raw import PySC2RawObservation 11 | from sc2learner.envs.spaces.mask_discrete import MaskDiscrete 12 | from sc2learner.envs.common.data_context import DataContext 13 | from sc2learner.envs.observations.spatial_features import UnitTypeCountMapFeature 14 | from sc2learner.envs.observations.spatial_features import AllianceCountMapFeature 15 | from sc2learner.envs.observations.nonspatial_features import PlayerFeature 16 | from sc2learner.envs.observations.nonspatial_features import ScoreFeature 17 | from sc2learner.envs.observations.nonspatial_features import WorkerFeature 18 | from sc2learner.envs.observations.nonspatial_features import UnitTypeCountFeature 19 | from sc2learner.envs.observations.nonspatial_features import UnitStatCountFeature 20 | from sc2learner.envs.observations.nonspatial_features import GameProgressFeature 21 | from sc2learner.envs.observations.nonspatial_features import ActionSeqFeature 22 | 23 | 24 | class ZergObservationWrapper(gym.Wrapper): 25 | 26 | def __init__(self, env, use_spatial_features=False, use_game_progress=True, 27 | action_seq_len=8, use_regions=False): 28 | super(ZergObservationWrapper, self).__init__(env) 29 | # TODO: multiple observation space 30 | #assert isinstance(env.observation_space, PySC2RawObservation) 31 | self._use_spatial_features = use_spatial_features 32 | self._use_game_progress = use_game_progress 33 | self._dc = DataContext() 34 | 35 | # nonspatial features 36 | self._unit_count_feature = UnitTypeCountFeature( 37 | type_list=[UNIT_TYPE.ZERG_LARVA.value, 38 | UNIT_TYPE.ZERG_DRONE.value, 39 | UNIT_TYPE.ZERG_ZERGLING.value, 40 | UNIT_TYPE.ZERG_BANELING.value, 41 | UNIT_TYPE.ZERG_ROACH.value, 42 | UNIT_TYPE.ZERG_ROACHBURROWED.value, 43 | UNIT_TYPE.ZERG_RAVAGER.value, 44 | UNIT_TYPE.ZERG_HYDRALISK.value, 45 | UNIT_TYPE.ZERG_LURKERMP.value, 46 | UNIT_TYPE.ZERG_LURKERMPBURROWED.value, 47 | #UNIT_TYPE.ZERG_VIPER.value, 48 | UNIT_TYPE.ZERG_MUTALISK.value, 49 | UNIT_TYPE.ZERG_CORRUPTOR.value, 50 | UNIT_TYPE.ZERG_BROODLORD.value, 51 | #UNIT_TYPE.ZERG_SWARMHOSTMP.value, 52 | UNIT_TYPE.ZERG_LOCUSTMP.value, 53 | #UNIT_TYPE.ZERG_INFESTOR.value, 54 | UNIT_TYPE.ZERG_ULTRALISK.value, 55 | UNIT_TYPE.ZERG_BROODLING.value, 56 | UNIT_TYPE.ZERG_OVERLORD.value, 57 | UNIT_TYPE.ZERG_OVERSEER.value, 58 | #UNIT_TYPE.ZERG_CHANGELING.value, 59 | UNIT_TYPE.ZERG_QUEEN.value], 60 | use_regions=use_regions 61 | ) 62 | self._building_count_feature = UnitTypeCountFeature( 63 | type_list=[UNIT_TYPE.ZERG_SPINECRAWLER.value, 64 | UNIT_TYPE.ZERG_SPORECRAWLER.value, 65 | #UNIT_TYPE.ZERG_NYDUSCANAL.value, 66 | UNIT_TYPE.ZERG_EXTRACTOR.value, 67 | UNIT_TYPE.ZERG_SPAWNINGPOOL.value, 68 | UNIT_TYPE.ZERG_ROACHWARREN.value, 69 | UNIT_TYPE.ZERG_HYDRALISKDEN.value, 70 | UNIT_TYPE.ZERG_HATCHERY.value, 71 | UNIT_TYPE.ZERG_EVOLUTIONCHAMBER.value, 72 | UNIT_TYPE.ZERG_BANELINGNEST.value, 73 | UNIT_TYPE.ZERG_INFESTATIONPIT.value, 74 | UNIT_TYPE.ZERG_SPIRE.value, 75 | UNIT_TYPE.ZERG_ULTRALISKCAVERN.value, 76 | #UNIT_TYPE.ZERG_NYDUSNETWORK.value, 77 | UNIT_TYPE.ZERG_LURKERDENMP.value, 78 | UNIT_TYPE.ZERG_LAIR.value, 79 | UNIT_TYPE.ZERG_HIVE.value, 80 | UNIT_TYPE.ZERG_GREATERSPIRE.value], 81 | use_regions=False 82 | ) 83 | self._unit_stat_count_feature = UnitStatCountFeature( 84 | use_regions=use_regions) 85 | self._player_feature = PlayerFeature() 86 | self._score_feature = ScoreFeature() 87 | self._worker_feature = WorkerFeature() 88 | if use_game_progress: 89 | self._game_progress_feature = GameProgressFeature() 90 | self._action_seq_feature = ActionSeqFeature(self.action_space.n, 91 | action_seq_len) 92 | n_dims = sum([ 93 | self._unit_stat_count_feature.num_dims, 94 | self._unit_count_feature.num_dims, 95 | self._building_count_feature.num_dims, 96 | self._player_feature.num_dims, 97 | self._score_feature.num_dims, 98 | self._worker_feature.num_dims, 99 | self._action_seq_feature.num_dims, 100 | self._game_progress_feature.num_dims if use_game_progress else 0, 101 | self.env.action_space.n if isinstance(self.env.action_space, 102 | MaskDiscrete) else 0 103 | ]) 104 | 105 | # spatial features 106 | if use_spatial_features: 107 | resolution = self.env.observation_space.space_attr["minimap"][1] 108 | self._unit_type_count_map_feature = UnitTypeCountMapFeature( 109 | type_map={UNIT_TYPE.ZERG_DRONE.value: 0, 110 | UNIT_TYPE.ZERG_ZERGLING.value: 1, 111 | UNIT_TYPE.ZERG_ROACH.value: 2, 112 | UNIT_TYPE.ZERG_ROACHBURROWED.value: 2, 113 | UNIT_TYPE.ZERG_HYDRALISK.value: 3, 114 | UNIT_TYPE.ZERG_OVERLORD.value: 4, 115 | UNIT_TYPE.ZERG_OVERSEER.value: 4, 116 | UNIT_TYPE.ZERG_HATCHERY.value: 5, 117 | UNIT_TYPE.ZERG_LAIR.value: 5, 118 | UNIT_TYPE.ZERG_HIVE.value: 5, 119 | UNIT_TYPE.ZERG_EXTRACTOR.value: 6, 120 | UNIT_TYPE.ZERG_QUEEN.value: 7, 121 | UNIT_TYPE.ZERG_RAVAGER.value: 8, 122 | UNIT_TYPE.ZERG_BANELING.value: 9, 123 | UNIT_TYPE.ZERG_LURKERMP.value: 10, 124 | UNIT_TYPE.ZERG_LURKERMPBURROWED.value: 10, 125 | UNIT_TYPE.ZERG_VIPER.value: 11, 126 | UNIT_TYPE.ZERG_MUTALISK.value: 12, 127 | UNIT_TYPE.ZERG_CORRUPTOR.value: 13, 128 | UNIT_TYPE.ZERG_BROODLORD.value: 14, 129 | UNIT_TYPE.ZERG_SWARMHOSTMP.value: 15, 130 | UNIT_TYPE.ZERG_INFESTOR.value: 16, 131 | UNIT_TYPE.ZERG_ULTRALISK.value: 17, 132 | UNIT_TYPE.ZERG_CHANGELING.value: 18, 133 | UNIT_TYPE.ZERG_SPINECRAWLER.value: 19, 134 | UNIT_TYPE.ZERG_SPORECRAWLER.value: 20}, 135 | resolution=resolution, 136 | ) 137 | self._alliance_count_map_feature = AllianceCountMapFeature(resolution) 138 | n_channels = sum([self._unit_type_count_map_feature.num_channels, 139 | self._alliance_count_map_feature.num_channels]) 140 | 141 | if use_spatial_features: 142 | if isinstance(self.env.action_space, MaskDiscrete): 143 | self.observation_space = spaces.Tuple([ 144 | spaces.Box(0.0, float('inf'), [n_channels, resolution, resolution], 145 | dtype=np.float32), 146 | spaces.Box(0.0, float('inf'), [n_dims], dtype=np.float32), 147 | spaces.Box(0.0, 1.0, [self.env.action_space.n], dtype=np.float32) 148 | ]) 149 | else: 150 | self.observation_space = spaces.Tuple([ 151 | spaces.Box(0.0, float('inf'), [n_channels, resolution, resolution], 152 | dtype=np.float32), 153 | spaces.Box(0.0, float('inf'), [n_dims], dtype=np.float32) 154 | ]) 155 | else: 156 | if isinstance(self.env.action_space, MaskDiscrete): 157 | self.observation_space = spaces.Tuple([ 158 | spaces.Box(0.0, float('inf'), [n_dims], dtype=np.float32), 159 | spaces.Box(0.0, 1.0, [self.env.action_space.n], dtype=np.float32) 160 | ]) 161 | else: 162 | self.observation_space = spaces.Box(0.0, float('inf'), [n_dims], 163 | dtype=np.float32) 164 | 165 | def step(self, action): 166 | self._action_seq_feature.push_action(action) 167 | observation, reward, done, info = self.env.step(action) 168 | self._dc.update(observation) 169 | return self._observation(observation), reward, done, info 170 | 171 | def reset(self, **kwargs): 172 | observation = self.env.reset() 173 | self._dc.reset(observation) 174 | self._action_seq_feature.reset() 175 | return self._observation(observation) 176 | 177 | @property 178 | def action_names(self): 179 | if not hasattr(self.env, 'action_names'): 180 | raise NotImplementedError 181 | return self.env.action_names 182 | 183 | @property 184 | def player_position(self): 185 | if not hasattr(self.env, 'player_position'): 186 | raise NotImplementedError 187 | return self.env.player_position 188 | 189 | def _observation(self, observation): 190 | need_flip = True if self.env.player_position == 0 else False 191 | 192 | # nonspatial features 193 | unit_type_feat = self._unit_count_feature.features(observation, need_flip) 194 | building_type_feat = self._building_count_feature.features(observation, 195 | need_flip) 196 | unit_stat_feat = self._unit_stat_count_feature.features(observation, 197 | need_flip) 198 | player_feat = self._player_feature.features(observation) 199 | score_feat = self._score_feature.features(observation) 200 | worker_feat = self._worker_feature.features(self._dc) 201 | if self._use_game_progress: 202 | game_progress_feat = self._game_progress_feature.features(observation) 203 | action_seq_feat = self._action_seq_feature.features() 204 | nonspatial_feat = np.concatenate([ 205 | unit_type_feat, 206 | building_type_feat, 207 | unit_stat_feat, 208 | player_feat, 209 | score_feat, 210 | worker_feat, 211 | action_seq_feat, 212 | game_progress_feat if self._use_game_progress \ 213 | else np.array([], dtype=np.float32), 214 | np.array(observation['action_mask'], dtype=np.float32) \ 215 | if isinstance(self.env.action_space, MaskDiscrete) \ 216 | else np.array([], dtype=np.float32) 217 | ]) 218 | 219 | # spatial features 220 | if self._use_spatial_features: 221 | ally_map_feat = self._alliance_count_map_feature.features( 222 | observation, need_flip) 223 | type_map_feat = self._unit_type_count_map_feature.features( 224 | observation, need_flip) 225 | spatial_feat = np.concatenate([ally_map_feat, type_map_feat]) 226 | 227 | # return features 228 | if self._use_spatial_features: 229 | if isinstance(self.env.action_space, MaskDiscrete): 230 | return (spatial_feat, nonspatial_feat, observation['action_mask']) 231 | else: 232 | return (spatial_feat, nonspatial_feat) 233 | else: 234 | if isinstance(self.env.action_space, MaskDiscrete): 235 | return (nonspatial_feat, observation['action_mask']) 236 | else: 237 | return nonspatial_feat 238 | 239 | 240 | class ZergPlayerObservationWrapper(ZergObservationWrapper): 241 | 242 | def __init__(self, player, **kwargs): 243 | self._warn_double_wrap = lambda *args: None 244 | self._player = player 245 | super(ZergPlayerObservationWrapper, self).__init__(**kwargs) 246 | 247 | def step(self, action): 248 | self._action_seq_feature.push_action(action[self._player]) 249 | observation, reward, done, info = self.env.step(action) 250 | self._dc.update(observation[self._player]) 251 | observation[self._player] = self._observation(observation[self._player]) 252 | return observation, reward, done, info 253 | 254 | def reset(self, **kwargs): 255 | observation = self.env.reset() 256 | self._dc.reset(observation[self._player]) 257 | self._action_seq_feature.reset() 258 | observation[self._player] = self._observation(observation[self._player]) 259 | return observation 260 | -------------------------------------------------------------------------------- /sc2learner/agents/dqn_agent.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import os 6 | import sys 7 | import time 8 | import struct 9 | import random 10 | import math 11 | from copy import deepcopy 12 | import queue 13 | from threading import Thread 14 | from collections import deque 15 | import io 16 | import zmq 17 | 18 | import numpy as np 19 | import torch 20 | import torch.nn as nn 21 | import torch.nn.functional as F 22 | from torch.autograd import Variable 23 | import torch.optim as optim 24 | from gym.spaces import prng 25 | from gym.spaces.discrete import Discrete 26 | from gym import spaces 27 | 28 | from sc2learner.agents.replay_memory import Transition 29 | from sc2learner.agents.replay_memory import RemoteReplayMemory 30 | from sc2learner.utils.utils import tprint 31 | 32 | 33 | class DQNAgent(object): 34 | 35 | def __init__(self, network, action_space, init_model_path=None): 36 | assert type(action_space) == spaces.Discrete 37 | self._action_space = action_space 38 | self._network = network 39 | if init_model_path is not None: 40 | self.load_params(torch.load(init_model_path, 41 | map_location=lambda storage, loc: storage)) 42 | if torch.cuda.device_count() > 1: 43 | self._network = nn.DataParallel(self._network) 44 | if torch.cuda.is_available(): self._network.cuda() 45 | self._optimizer = None 46 | self._target_network = None 47 | self._num_optim_steps = 0 48 | 49 | def act(self, observation, eps=0): 50 | self._network.eval() 51 | if random.uniform(0, 1) >= eps: 52 | observation = torch.from_numpy(np.expand_dims(observation, 0)) 53 | if torch.cuda.is_available(): 54 | observation = observation.pin_memory().cuda(non_blocking=True) 55 | with torch.no_grad(): 56 | q = self._network(observation) 57 | action = q.data.max(1)[1].item() 58 | return action 59 | else: 60 | return self._action_space.sample() 61 | 62 | def optimize_step(self, 63 | obs_batch, 64 | next_obs_batch, 65 | action_batch, 66 | reward_batch, 67 | done_batch, 68 | mc_return_batch, 69 | discount, 70 | mmc_beta, 71 | gradient_clipping, 72 | adam_eps, 73 | learning_rate, 74 | target_update_interval): 75 | # create optimizer 76 | if self._optimizer is None: 77 | self._optimizer = optim.Adam(self._network.parameters(), 78 | eps=adam_eps, 79 | lr=learning_rate) 80 | # create target network 81 | if self._target_network is None: 82 | self._target_network = deepcopy(self._network) 83 | if torch.cuda.is_available(): self._target_network.cuda() 84 | self._target_network.eval() 85 | 86 | # update target network 87 | if self._num_optim_steps % target_update_interval == 0: 88 | self._target_network.load_state_dict(self._network.state_dict()) 89 | 90 | # move to gpu 91 | if torch.cuda.is_available(): 92 | obs_batch = obs_batch.cuda(non_blocking=True) 93 | next_obs_batch = next_obs_batch.cuda(non_blocking=True) 94 | action_batch = action_batch.cuda(non_blocking=True) 95 | reward_batch = reward_batch.cuda(non_blocking=True) 96 | mc_return_batch = mc_return_batch.cuda(non_blocking=True) 97 | done_batch = done_batch.cuda(non_blocking=True) 98 | 99 | # compute max-q target 100 | self._network.eval() 101 | with torch.no_grad(): 102 | q_next_target = self._target_network(next_obs_batch) 103 | q_next = self._network(next_obs_batch) 104 | futures = q_next_target.gather( 105 | 1, q_next.max(dim=1)[1].view(-1, 1)).squeeze() 106 | futures = futures * (1 - done_batch) 107 | target_q = reward_batch + discount * futures 108 | target_q = target_q * mmc_beta + (1.0 - mmc_beta) * mc_return_batch 109 | 110 | # define loss 111 | self._network.train() 112 | q = self._network(obs_batch).gather(1, action_batch.view(-1, 1)).squeeze() 113 | loss = F.mse_loss(q, target_q.detach()) 114 | 115 | # compute gradient and update parameters 116 | self._optimizer.zero_grad() 117 | loss.backward() 118 | for param in self._network.parameters(): 119 | param.grad.data.clamp_(-gradient_clipping, gradient_clipping) 120 | self._optimizer.step() 121 | self._num_optim_steps += 1 122 | return loss.data.item() 123 | 124 | def reset(self): 125 | pass 126 | 127 | def load_params(self, state_dict): 128 | self._network.load_state_dict(state_dict) 129 | 130 | def read_params(self): 131 | if torch.cuda.device_count() > 1: 132 | return self._network.module.state_dict() 133 | else: 134 | return self._network.state_dict() 135 | 136 | 137 | class DQNActor(object): 138 | 139 | def __init__(self, 140 | memory_size, 141 | memory_warmup_size, 142 | env, 143 | network, 144 | discount, 145 | send_freq=4.0, 146 | ports=("5700", "5701", "5702"), 147 | learner_ip="localhost"): 148 | assert type(env.action_space) == spaces.Discrete 149 | assert len(ports) == 3 150 | self._env = env 151 | self._discount = discount 152 | self._epsilon = 1.0 153 | 154 | self._agent = DQNAgent(network, env.action_space) 155 | self._replay_memory = RemoteReplayMemory( 156 | is_server=False, 157 | memory_size=memory_size, 158 | memory_warmup_size=memory_warmup_size, 159 | send_freq=send_freq, 160 | ports=ports[:2], 161 | server_ip=learner_ip) 162 | 163 | self._zmq_context = zmq.Context() 164 | self._model_requestor = self._zmq_context.socket(zmq.REQ) 165 | self._model_requestor.connect("tcp://%s:%s" % (learner_ip, ports[2])) 166 | 167 | def run(self): 168 | while True: 169 | # fetch model 170 | t = time.time() 171 | self._update_model() 172 | tprint("Update model time: %f eps: %f" % (time.time() - t, self._epsilon)) 173 | # rollout 174 | t = time.time() 175 | self._rollout() 176 | tprint("Rollout time: %f" % (time.time() - t)) 177 | 178 | def _rollout(self): 179 | rollout, done = [], False 180 | observation = self._env.reset() 181 | while not done: 182 | action = self._agent.act(observation, eps=self._epsilon) 183 | next_observation, reward, done, info = self._env.step(action) 184 | rollout.append( 185 | (observation, action, reward, next_observation, done)) 186 | observation = next_observation 187 | 188 | discounted_return = 0 189 | for transition in reversed(rollout): 190 | reward = transition[2] 191 | discounted_return = discounted_return * self._discount + reward 192 | self._replay_memory.push(*transition, discounted_return) 193 | 194 | def _update_model(self): 195 | self._model_requestor.send_string("request model") 196 | file_object = io.BytesIO(self._model_requestor.recv_pyobj()) 197 | self._agent.load_params( 198 | torch.load(file_object, map_location=lambda storage, loc: storage)) 199 | self._epsilon = self._model_requestor.recv_pyobj() 200 | 201 | 202 | class DQNLearner(object): 203 | 204 | def __init__(self, 205 | network, 206 | action_space, 207 | memory_size, 208 | memory_warmup_size, 209 | discount, 210 | eps_start, 211 | eps_end, 212 | eps_decay_steps, 213 | eps_decay_steps2, 214 | batch_size, 215 | mmc_beta, 216 | gradient_clipping, 217 | adam_eps, 218 | learning_rate, 219 | target_update_interval, 220 | checkpoint_dir, 221 | checkpoint_interval, 222 | print_interval, 223 | ports=("5700", "5701", "5702"), 224 | init_model_path=None): 225 | assert type(action_space) == spaces.Discrete 226 | self._agent = DQNAgent(network, action_space) 227 | self._replay_memory = RemoteReplayMemory( 228 | is_server=True, 229 | memory_size=memory_size, 230 | memory_warmup_size=memory_warmup_size, 231 | ports=ports[:2]) 232 | if init_model_path is not None: 233 | self._agent.load_params( 234 | torch.load(init_model_path, 235 | map_location=lambda storage, loc: storage)) 236 | self._model_params = self._agent.read_params() 237 | 238 | self._batch_size = batch_size 239 | self._mmc_beta = mmc_beta 240 | self._gradient_clipping = gradient_clipping 241 | self._adam_eps = adam_eps 242 | self._learning_rate = learning_rate 243 | self._target_update_interval = target_update_interval 244 | self._checkpoint_dir = checkpoint_dir 245 | self._checkpoint_interval = checkpoint_interval 246 | self._print_interval = print_interval 247 | self._discount = discount 248 | self._eps_start = eps_start 249 | self._eps_end = eps_end 250 | self._eps_decay_steps = eps_decay_steps 251 | self._eps_decay_steps2 = eps_decay_steps2 252 | self._epsilon = eps_start 253 | 254 | self._zmq_context = zmq.Context() 255 | self._reply_model_thread = Thread( 256 | target=self._reply_model, args=(self._zmq_context, ports[2])) 257 | self._reply_model_thread.start() 258 | 259 | def run(self): 260 | batch_queue = queue.Queue(8) 261 | batch_thread = Thread(target=self._prepare_batch, 262 | args=(batch_queue, self._batch_size,)) 263 | batch_thread.start() 264 | 265 | updates, loss, total_rollout_frames = 0, [], 0 266 | time_start = time.time() 267 | while True: 268 | updates += 1 269 | observation, next_observation, action, reward, done, mc_return = \ 270 | batch_queue.get() 271 | self._epsilon = self._schedule_epsilon(updates) 272 | loss.append(self._agent.optimize_step( 273 | obs_batch=observation, 274 | next_obs_batch=next_observation, 275 | action_batch=action, 276 | reward_batch=reward, 277 | done_batch=done, 278 | mc_return_batch=mc_return, 279 | discount=self._discount, 280 | mmc_beta=self._mmc_beta, 281 | gradient_clipping=self._gradient_clipping, 282 | adam_eps=self._adam_eps, 283 | learning_rate=self._learning_rate, 284 | target_update_interval=self._target_update_interval)) 285 | self._model_params = self._agent.read_params() 286 | if updates % self._checkpoint_interval == 0: 287 | ckpt_path = os.path.join(self._checkpoint_dir, 288 | 'checkpoint-%d' % updates) 289 | self._save_checkpoint(ckpt_path) 290 | if updates % self._print_interval == 0: 291 | time_elapsed = time.time() - time_start 292 | train_fps = self._print_interval * self._batch_size / time_elapsed 293 | rollout_fps = (self._replay_memory.total - total_rollout_frames) \ 294 | / time_elapsed 295 | loss_mean = np.mean(loss) 296 | tprint("Update: %d Train-fps: %.1f Rollout-fps: %.1f " 297 | "Loss: %.5f Epsilon: %.5f Time: %.1f" % (updates, train_fps, 298 | rollout_fps, loss_mean, self._epsilon, time_elapsed)) 299 | time_start, loss = time.time(), [] 300 | total_rollout_frames = self._replay_memory.total 301 | 302 | def _prepare_batch(self, batch_queue, batch_size): 303 | while True: 304 | transitions = self._replay_memory.sample(batch_size) 305 | batch = self._transitions_to_batch(transitions) 306 | batch_queue.put(batch) 307 | 308 | def _transitions_to_batch(self, transitions): 309 | batch = Transition(*zip(*transitions)) 310 | observation = torch.from_numpy(np.stack(batch.observation)) 311 | next_observation = torch.from_numpy(np.stack(batch.next_observation)) 312 | reward = torch.FloatTensor(batch.reward) 313 | action = torch.LongTensor(batch.action) 314 | done = torch.Tensor(batch.done) 315 | mc_return = torch.FloatTensor(batch.mc_return) 316 | 317 | if torch.cuda.is_available(): 318 | observation = observation.pin_memory() 319 | next_observation = next_observation.pin_memory() 320 | action = action.pin_memory() 321 | reward = reward.pin_memory() 322 | mc_return = mc_return.pin_memory() 323 | done = done.pin_memory() 324 | 325 | return observation, next_observation, action, reward, done, mc_return 326 | 327 | def _save_checkpoint(self, checkpoint_path): 328 | torch.save(self._model_params, checkpoint_path) 329 | 330 | def _schedule_epsilon(self, steps): 331 | if steps < self._eps_decay_steps: 332 | return self._eps_start - (self._eps_start - self._eps_end) * \ 333 | steps / self._eps_decay_steps 334 | elif steps < self._eps_decay_steps2: 335 | return self._eps_end - (self._eps_end - 0.01) * \ 336 | (steps - self._eps_decay_steps) / self._eps_decay_steps2 337 | else: 338 | return 0.01 339 | 340 | def _reply_model(self, zmq_context, port): 341 | receiver = zmq_context.socket(zmq.REP) 342 | receiver.bind("tcp://*:%s" % port) 343 | while True: 344 | assert receiver.recv_string() == "request model" 345 | f = io.BytesIO() 346 | torch.save(self._model_params, f) 347 | receiver.send_pyobj(f.getvalue(), zmq.SNDMORE) 348 | receiver.send_pyobj(self._epsilon) 349 | -------------------------------------------------------------------------------- /sc2learner/envs/actions/combat.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | from collections import namedtuple 6 | 7 | from s2clientprotocol import sc2api_pb2 as sc_pb 8 | from pysc2.lib.typeenums import UNIT_TYPEID as UNIT_TYPE 9 | from pysc2.lib.typeenums import ABILITY_ID as ABILITY 10 | from pysc2.lib.typeenums import UPGRADE_ID as UPGRADE 11 | 12 | from sc2learner.envs.actions.function import Function 13 | import sc2learner.envs.common.utils as utils 14 | from sc2learner.envs.common.const import ATTACK_FORCE 15 | from sc2learner.envs.common.const import ALLY_TYPE 16 | from sc2learner.envs.common.const import PRIORITIZED_ATTACK 17 | 18 | 19 | Region = namedtuple('Region', ('ranges', 'rally_point_a', 'rally_point_b')) 20 | 21 | 22 | class CombatActions(object): 23 | 24 | def __init__(self): 25 | self._regions = [ 26 | Region([(0, 0, 200, 176)], (161.5, 21.5), (38.5, 122.5)), 27 | Region([(0, 88, 80, 176)], (68, 108), (68, 108)), 28 | Region([(80, 88, 120, 176)], (100, 113.5), (100, 113.5)), 29 | Region([(120, 88, 200, 176)], (147.5, 113.5), (147.5, 113.5)), 30 | Region([(0, 55, 80, 88)], (36.5, 76.5), (36.5, 76.5)), 31 | Region([(80, 55, 120, 88)], (100, 71.5), (100, 71.5)), 32 | Region([(120, 55, 200, 88)], (163.5, 66.5), (163.5, 66.5)), 33 | Region([(0, 0, 80, 55)], (52.5, 30), (52.5, 30)), 34 | Region([(80, 0, 120, 55)], (100, 30), (100, 30)), 35 | Region([(120, 0, 200, 55)], (133, 36), (133, 36)) 36 | ] 37 | self._flip_region = lambda r_id: 10 - r_id if r_id > 0 else r_id 38 | 39 | self._attack_tasks = {} 40 | 41 | def reset(self): 42 | self._attack_tasks.clear() 43 | 44 | def action(self, source_region_id, target_region_id): 45 | assert source_region_id < len(self._regions) 46 | assert target_region_id < len(self._regions) 47 | return Function( 48 | name=("combats_in_region_%d_attack_region_%d" % 49 | (source_region_id, target_region_id)), 50 | function=self._attack_region(source_region_id, target_region_id), 51 | is_valid=self._is_valid_attack_region(source_region_id, target_region_id) 52 | ) 53 | 54 | @property 55 | def num_regions(self): 56 | return len(self._regions) 57 | 58 | @property 59 | def action_rally_new_combat_units(self): 60 | return Function(name="rally_new_combat_units", 61 | function=self._rally_new_combat_units, 62 | is_valid=self._is_valid_rally_new_combat_units) 63 | 64 | @property 65 | def action_framewise_rally_and_attack(self): 66 | return Function(name="framewise_rally_and_attack", 67 | function=self._framewise_rally_and_attack, 68 | is_valid=lambda dc: True) 69 | 70 | def _attack_region(self, source_region_id, target_region_id): 71 | 72 | def act(dc): 73 | flip = True if self._player_position(dc) == 0 else False 74 | src_id = self._flip_region(source_region_id) if flip else source_region_id 75 | tgt_id = self._flip_region(target_region_id) if flip else target_region_id 76 | combat_unit = [u for u in dc.combat_units if self._is_in_region(u, src_id)] 77 | self._set_attack_task(combat_unit, tgt_id) 78 | return [] 79 | 80 | return act 81 | 82 | def _is_valid_attack_region(self, source_region_id, target_region_id): 83 | 84 | def is_valid(dc): 85 | flip = True if self._player_position(dc) == 0 else False 86 | src_id = self._flip_region(source_region_id) if flip else source_region_id 87 | combat_unit = [u for u in dc.combat_units if self._is_in_region(u, src_id)] 88 | return len(combat_unit) >= 3 89 | 90 | return is_valid 91 | 92 | def _rally_new_combat_units(self, dc): 93 | new_combat_units = [u for u in dc.combat_units if dc.is_new_unit(u)] 94 | if self._player_position(dc) == 0: 95 | self._set_attack_task(new_combat_units, 1) 96 | else: 97 | self._set_attack_task(new_combat_units, 9) 98 | return [] 99 | 100 | def _is_valid_rally_new_combat_units(self, dc): 101 | new_combat_units = [u for u in dc.combat_units if dc.is_new_unit(u)] 102 | if len(new_combat_units) > 0: return True 103 | else: return False 104 | 105 | def _framewise_rally_and_attack(self, dc): 106 | actions = [] 107 | for region_id in range(len(self._regions)): 108 | units_with_task = [u for u in dc.combat_units 109 | if (u.tag in self._attack_tasks and 110 | self._attack_tasks[u.tag] == region_id)] 111 | if len(units_with_task) > 0: 112 | target_enemies = [ 113 | u for u in dc.units_of_alliance(ALLY_TYPE.ENEMY.value) 114 | if self._is_in_region(u, region_id) 115 | ] 116 | if len(target_enemies) > 0: 117 | actions.extend( 118 | self._micro_attack(units_with_task, target_enemies, dc)) 119 | else: 120 | if self._player_position(dc) == 0: 121 | rally_point = self._regions[region_id].rally_point_a 122 | else: 123 | rally_point = self._regions[region_id].rally_point_b 124 | actions.extend(self._micro_rally(units_with_task, rally_point, dc)) 125 | return actions 126 | 127 | def _micro_attack(self, combat_units, enemy_units, dc): 128 | 129 | def prioritized_attack(unit, target_units): 130 | assert len(target_units) > 0 131 | prioritized_target_units = [u for u in target_units 132 | if u.unit_type in PRIORITIZED_ATTACK] 133 | if len(prioritized_target_units) > 0: 134 | closest_target = utils.closest_unit(unit, prioritized_target_units) 135 | else: 136 | closest_target = utils.closest_unit(unit, target_units) 137 | target_pos = (closest_target.float_attr.pos_x, 138 | closest_target.float_attr.pos_y) 139 | return self._unit_attack(unit, target_pos, dc) 140 | 141 | def flee_or_fight(unit, target_units): 142 | assert len(target_units) > 0 143 | closest_target = utils.closest_unit(unit, target_units) 144 | closest_dist = utils.closest_distance(unit, enemy_units) 145 | strongest_health = utils.strongest_health(combat_units) 146 | if (closest_dist < 5.0 and 147 | unit.float_attr.health / unit.float_attr.health_max < 0.3 and 148 | strongest_health > 0.9): 149 | x = unit.float_attr.pos_x + (unit.float_attr.pos_x - \ 150 | closest_target.float_attr.pos_x) * 0.2 151 | y = unit.float_attr.pos_y + (unit.float_attr.pos_y - \ 152 | closest_target.float_attr.pos_y) * 0.2 153 | target_pos = (x, y) 154 | return self._unit_move(unit, target_pos, dc) 155 | else: 156 | target_pos = (closest_target.float_attr.pos_x, 157 | closest_target.float_attr.pos_y) 158 | return self._unit_attack(unit, target_pos, dc) 159 | 160 | air_combat_units = [ 161 | u for u in combat_units 162 | if (ATTACK_FORCE[u.unit_type].can_attack_air and 163 | not ATTACK_FORCE[u.unit_type].can_attack_ground) 164 | ] 165 | ground_combat_units = [ 166 | u for u in combat_units 167 | if (not ATTACK_FORCE[u.unit_type].can_attack_air and 168 | ATTACK_FORCE[u.unit_type].can_attack_ground) 169 | ] 170 | air_ground_combat_units = [ 171 | u for u in combat_units 172 | if (ATTACK_FORCE[u.unit_type].can_attack_air and 173 | ATTACK_FORCE[u.unit_type].can_attack_ground) 174 | ] 175 | air_enemy_units = [u for u in enemy_units if u.bool_attr.is_flying] 176 | ground_enemy_units = [u for u in enemy_units if not u.bool_attr.is_flying] 177 | actions = [] 178 | for unit in air_combat_units: 179 | if len(air_enemy_units) > 0: 180 | actions.extend(prioritized_attack(unit, air_enemy_units)) 181 | for unit in ground_combat_units: 182 | if len(ground_enemy_units) > 0: 183 | actions.extend(prioritized_attack(unit, ground_enemy_units)) 184 | for unit in air_ground_combat_units: 185 | if len(enemy_units) > 0: 186 | actions.extend(prioritized_attack(unit, enemy_units)) 187 | return actions 188 | 189 | def _micro_rally(self, units, rally_point, dc): 190 | actions = [] 191 | for unit in units: 192 | actions.extend(self._unit_attack(unit, rally_point, dc)) 193 | return actions 194 | 195 | def _unit_attack(self, unit, target_pos, dc): 196 | # move with attack 197 | if unit.unit_type == UNIT_TYPE.ZERG_RAVAGER.value: 198 | return self._ravager_unit_attack(unit, target_pos, dc) 199 | #elif (unit.unit_type == UNIT_TYPE.ZERG_ROACH.value or 200 | #unit.unit_type == UNIT_TYPE.ZERG_ROACHBURROWED.value): 201 | #return self._roach_unit_attack(unit, target_pos, dc) 202 | elif (unit.unit_type == UNIT_TYPE.ZERG_LURKERMP.value or 203 | unit.unit_type == UNIT_TYPE.ZERG_LURKERMPBURROWED.value): 204 | return self._lurker_unit_attack(unit, target_pos, dc) 205 | else: 206 | return self._normal_unit_attack(unit, target_pos) 207 | 208 | def _unit_move(self, unit, target_pos, dc): 209 | # move without attack 210 | if unit.unit_type == UNIT_TYPE.ZERG_LURKERMPBURROWED.value: 211 | return self._lurker_unit_move(unit, target_pos) 212 | #elif unit.unit_type == UNIT_TYPE.ZERG_ROACHBURROWED.value: 213 | #return self._roach_unit_move(unit, target_pos, dc) 214 | else: 215 | return self._normal_unit_move(unit, target_pos) 216 | 217 | def _normal_unit_attack(self, unit, target_pos): 218 | action = sc_pb.Action() 219 | action.action_raw.unit_command.unit_tags.append(unit.tag) 220 | action.action_raw.unit_command.ability_id = ABILITY.ATTACK_ATTACK.value 221 | action.action_raw.unit_command.target_world_space_pos.x = target_pos[0] 222 | action.action_raw.unit_command.target_world_space_pos.y = target_pos[1] 223 | return [action] 224 | 225 | def _normal_unit_move(self, unit, target_pos): 226 | action = sc_pb.Action() 227 | action.action_raw.unit_command.unit_tags.append(unit.tag) 228 | action.action_raw.unit_command.ability_id = ABILITY.MOVE.value 229 | action.action_raw.unit_command.target_world_space_pos.x = target_pos[0] 230 | action.action_raw.unit_command.target_world_space_pos.y = target_pos[1] 231 | return [action] 232 | 233 | def _roach_unit_attack(self, unit, target_pos, dc): 234 | actions = [] 235 | ground_enemies = [u for u in dc.units_of_alliance(ALLY_TYPE.ENEMY.value) 236 | if not u.bool_attr.is_flying] 237 | if len(utils.units_nearby(unit, ground_enemies, max_distance=4)) > 0: 238 | if unit.unit_type == UNIT_TYPE.ZERG_ROACHBURROWED.value: 239 | action = sc_pb.Action() 240 | action.action_raw.unit_command.unit_tags.append(unit.tag) 241 | action.action_raw.unit_command.ability_id = ABILITY.BURROWUP_ROACH.value 242 | actions.append(action) 243 | actions.extend(self._normal_unit_attack(unit, target_pos)) 244 | else: 245 | actions.extend(self._roach_unit_move(unit, target_pos, dc)) 246 | return actions 247 | 248 | def _roach_unit_move(self, unit, target_pos, dc): 249 | actions = [] 250 | if (UPGRADE.TUNNELINGCLAWS.value in dc.upgraded_techs and 251 | UPGRADE.BURROW.value in dc.upgraded_techs and 252 | unit.unit_type == UNIT_TYPE.ZERG_ROACH.value): 253 | action = sc_pb.Action() 254 | action.action_raw.unit_command.unit_tags.append(unit.tag) 255 | action.action_raw.unit_command.ability_id = ABILITY.BURROWDOWN_ROACH.value 256 | actions.append(action) 257 | actions.extend(self._normal_unit_move(unit, target_pos)) 258 | return actions 259 | 260 | def _lurker_unit_attack(self, unit, target_pos, dc): 261 | actions = [] 262 | ground_enemies = [u for u in dc.units_of_alliance(ALLY_TYPE.ENEMY.value) 263 | if not u.bool_attr.is_flying] 264 | if len(utils.units_nearby(unit, ground_enemies, max_distance=8)) > 0: 265 | if unit.unit_type == UNIT_TYPE.ZERG_LURKERMP.value: 266 | action = sc_pb.Action() 267 | action.action_raw.unit_command.unit_tags.append(unit.tag) 268 | action.action_raw.unit_command.ability_id = \ 269 | ABILITY.BURROWDOWN_LURKER.value 270 | actions.append(action) 271 | else: 272 | actions.extend(self._lurker_unit_move(unit, target_pos)) 273 | return actions 274 | 275 | def _lurker_unit_move(self, unit, target_pos): 276 | actions = [] 277 | if unit.unit_type == UNIT_TYPE.ZERG_LURKERMPBURROWED.value: 278 | action = sc_pb.Action() 279 | action.action_raw.unit_command.unit_tags.append(unit.tag) 280 | action.action_raw.unit_command.ability_id = ABILITY.BURROWUP_LURKER.value 281 | actions.append(action) 282 | actions.extend(self._normal_unit_move(unit, target_pos)) 283 | return actions 284 | 285 | def _ravager_unit_attack(self, unit, target_pos, dc): 286 | actions = [] 287 | ground_units = [u for u in dc.units_of_alliance(ALLY_TYPE.SELF.value) 288 | if not u.bool_attr.is_flying] 289 | if len(utils.units_nearby(target_pos, ground_units, max_distance=2)) == 0: 290 | action = sc_pb.Action() 291 | action.action_raw.unit_command.unit_tags.append(unit.tag) 292 | action.action_raw.unit_command.ability_id = \ 293 | ABILITY.EFFECT_CORROSIVEBILE.value 294 | action.action_raw.unit_command.target_world_space_pos.x = target_pos[0] 295 | action.action_raw.unit_command.target_world_space_pos.y = target_pos[1] 296 | actions.append(action) 297 | actions.extend(self._normal_unit_attack(unit, target_pos)) 298 | return actions 299 | 300 | def _set_attack_task(self, units, target_region_id): 301 | for u in units: 302 | self._attack_tasks[u.tag] = target_region_id 303 | 304 | def _is_in_region(self, unit, region_id): 305 | return any([(unit.float_attr.pos_x >= r[0] and 306 | unit.float_attr.pos_x < r[2] and 307 | unit.float_attr.pos_y >= r[1] and 308 | unit.float_attr.pos_y < r[3]) 309 | for r in self._regions[region_id].ranges]) 310 | 311 | def _player_position(self, dc): 312 | if dc.init_base_pos[0] < 100: return 0 313 | else: return 1 314 | -------------------------------------------------------------------------------- /sc2learner/agents/ppo_agent.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import os 6 | import time 7 | from collections import deque 8 | from queue import Queue 9 | import queue 10 | from threading import Thread 11 | import time 12 | import random 13 | import joblib 14 | 15 | import numpy as np 16 | import tensorflow as tf 17 | import zmq 18 | from gym import spaces 19 | 20 | from sc2learner.envs.spaces.mask_discrete import MaskDiscrete 21 | from sc2learner.agents.utils_tf import explained_variance 22 | from sc2learner.utils.utils import tprint 23 | 24 | 25 | class Model(object): 26 | def __init__(self, *, policy, ob_space, ac_space, nbatch_act, nbatch_train, 27 | unroll_length, ent_coef, vf_coef, max_grad_norm, scope_name, 28 | value_clip=False): 29 | sess = tf.get_default_session() 30 | 31 | act_model = policy(sess, scope_name, ob_space, ac_space, nbatch_act, 1, 32 | reuse=False) 33 | train_model = policy(sess, scope_name, ob_space, ac_space, nbatch_train, 34 | unroll_length, reuse=True) 35 | 36 | A = tf.placeholder(shape=(nbatch_train,), dtype=tf.int32) 37 | ADV = tf.placeholder(tf.float32, [None]) 38 | R = tf.placeholder(tf.float32, [None]) 39 | OLDNEGLOGPAC = tf.placeholder(tf.float32, [None]) 40 | OLDVPRED = tf.placeholder(tf.float32, [None]) 41 | LR = tf.placeholder(tf.float32, []) 42 | CLIPRANGE = tf.placeholder(tf.float32, []) 43 | 44 | neglogpac = train_model.pd.neglogp(A) 45 | entropy = tf.reduce_mean(train_model.pd.entropy()) 46 | 47 | vpred = train_model.vf 48 | vpredclipped = OLDVPRED + tf.clip_by_value(train_model.vf - OLDVPRED, 49 | -CLIPRANGE, CLIPRANGE) 50 | vf_losses1 = tf.square(vpred - R) 51 | if value_clip: 52 | vf_losses2 = tf.square(vpredclipped - R) 53 | vf_loss = .5 * tf.reduce_mean(tf.maximum(vf_losses1, vf_losses2)) 54 | else: 55 | vf_loss = .5 * tf.reduce_mean(vf_losses1) 56 | ratio = tf.exp(OLDNEGLOGPAC - neglogpac) 57 | pg_losses = -ADV * ratio 58 | pg_losses2 = -ADV * tf.clip_by_value(ratio, 1.0 - CLIPRANGE, 59 | 1.0 + CLIPRANGE) 60 | pg_loss = tf.reduce_mean(tf.maximum(pg_losses, pg_losses2)) 61 | approxkl = .5 * tf.reduce_mean(tf.square(neglogpac - OLDNEGLOGPAC)) 62 | clipfrac = tf.reduce_mean( 63 | tf.to_float(tf.greater(tf.abs(ratio - 1.0), CLIPRANGE))) 64 | loss = pg_loss - entropy * ent_coef + vf_loss * vf_coef 65 | params = tf.trainable_variables(scope=scope_name) 66 | grads = tf.gradients(loss, params) 67 | if max_grad_norm is not None: 68 | grads, _grad_norm = tf.clip_by_global_norm(grads, max_grad_norm) 69 | grads = list(zip(grads, params)) 70 | trainer = tf.train.AdamOptimizer(learning_rate=LR, epsilon=1e-5) 71 | _train = trainer.apply_gradients(grads) 72 | new_params = [tf.placeholder(p.dtype, shape=p.get_shape()) for p in params] 73 | param_assign_ops = [p.assign(new_p) for p, new_p in zip(params, new_params)] 74 | 75 | def train(lr, cliprange, obs, returns, dones, actions, values, neglogpacs, 76 | states=None): 77 | advs = returns - values 78 | advs = (advs - advs.mean()) / (advs.std() + 1e-8) 79 | if isinstance(ac_space, MaskDiscrete): 80 | td_map = {train_model.X:obs[0], train_model.MASK:obs[-1], A:actions, 81 | ADV:advs, R:returns, LR:lr, CLIPRANGE:cliprange, 82 | OLDNEGLOGPAC:neglogpacs, OLDVPRED:values} 83 | else: 84 | td_map = {train_model.X:obs, A:actions, ADV:advs, R:returns, LR:lr, 85 | CLIPRANGE:cliprange, OLDNEGLOGPAC:neglogpacs, OLDVPRED:values} 86 | if states is not None: 87 | td_map[train_model.STATE] = states 88 | td_map[train_model.DONE] = dones 89 | return sess.run( 90 | [pg_loss, vf_loss, entropy, approxkl, clipfrac, _train], 91 | td_map 92 | )[:-1] 93 | self.loss_names = ['policy_loss', 'value_loss', 'policy_entropy', 94 | 'approxkl', 'clipfrac'] 95 | 96 | def save(save_path): 97 | joblib.dump(read_params(), save_path) 98 | 99 | def load(load_path): 100 | loaded_params = joblib.load(load_path) 101 | load_params(loaded_params) 102 | 103 | def read_params(): 104 | return sess.run(params) 105 | 106 | def load_params(loaded_params): 107 | sess.run(param_assign_ops, 108 | feed_dict={p : v for p, v in zip(new_params, loaded_params)}) 109 | 110 | self.train = train 111 | self.train_model = train_model 112 | self.act_model = act_model 113 | self.step = act_model.step 114 | self.value = act_model.value 115 | self.initial_state = act_model.initial_state 116 | self.save = save 117 | self.load = load 118 | self.read_params = read_params 119 | self.load_params = load_params 120 | 121 | tf.global_variables_initializer().run(session=sess) 122 | 123 | 124 | class PPOActor(object): 125 | 126 | def __init__(self, env, policy, unroll_length, gamma, lam, queue_size=1, 127 | enable_push=True, learner_ip="localhost", port_A="5700", 128 | port_B="5701"): 129 | self._env = env 130 | self._unroll_length = unroll_length 131 | self._lam = lam 132 | self._gamma = gamma 133 | self._enable_push = enable_push 134 | 135 | self._model = Model(policy=policy, 136 | scope_name="model", 137 | ob_space=env.observation_space, 138 | ac_space=env.action_space, 139 | nbatch_act=1, 140 | nbatch_train=unroll_length, 141 | unroll_length=unroll_length, 142 | ent_coef=0.01, 143 | vf_coef=0.5, 144 | max_grad_norm=0.5) 145 | self._obs = env.reset() 146 | self._state = self._model.initial_state 147 | self._done = False 148 | self._cum_reward = 0 149 | 150 | self._zmq_context = zmq.Context() 151 | self._model_requestor = self._zmq_context.socket(zmq.REQ) 152 | self._model_requestor.connect("tcp://%s:%s" % (learner_ip, port_A)) 153 | if enable_push: 154 | self._data_queue = Queue(queue_size) 155 | self._push_thread = Thread(target=self._push_data, args=( 156 | self._zmq_context, learner_ip, port_B, self._data_queue)) 157 | self._push_thread.start() 158 | 159 | def run(self): 160 | while True: 161 | # fetch model 162 | t = time.time() 163 | self._update_model() 164 | tprint("Update model time: %f" % (time.time() - t)) 165 | t = time.time() 166 | # rollout 167 | unroll = self._nstep_rollout() 168 | if self._enable_push: 169 | if self._data_queue.full(): tprint("[WARN]: Actor's queue is full.") 170 | self._data_queue.put(unroll) 171 | tprint("Rollout time: %f" % (time.time() - t)) 172 | 173 | def _nstep_rollout(self): 174 | mb_obs, mb_rewards, mb_actions, mb_values, mb_dones, mb_neglogpacs = \ 175 | [],[],[],[],[],[] 176 | mb_states, episode_infos = self._state, [] 177 | for _ in range(self._unroll_length): 178 | action, value, self._state, neglogpac = self._model.step( 179 | transform_tuple(self._obs, lambda x: np.expand_dims(x, 0)), 180 | self._state, 181 | np.expand_dims(self._done, 0)) 182 | mb_obs.append(transform_tuple(self._obs, lambda x: x.copy())) 183 | mb_actions.append(action[0]) 184 | mb_values.append(value[0]) 185 | mb_neglogpacs.append(neglogpac[0]) 186 | mb_dones.append(self._done) 187 | self._obs, reward, self._done, info = self._env.step(action[0]) 188 | self._cum_reward += reward 189 | if self._done: 190 | self._obs = self._env.reset() 191 | self._state = self._model.initial_state 192 | episode_infos.append({'r': self._cum_reward}) 193 | self._cum_reward = 0 194 | mb_rewards.append(reward) 195 | if isinstance(self._obs, tuple): 196 | mb_obs = tuple(np.asarray(obs, dtype=self._obs[0].dtype) 197 | for obs in zip(*mb_obs)) 198 | else: 199 | mb_obs = np.asarray(mb_obs, dtype=self._obs.dtype) 200 | mb_rewards = np.asarray(mb_rewards, dtype=np.float32) 201 | mb_actions = np.asarray(mb_actions) 202 | mb_values = np.asarray(mb_values, dtype=np.float32) 203 | mb_neglogpacs = np.asarray(mb_neglogpacs, dtype=np.float32) 204 | mb_dones = np.asarray(mb_dones, dtype=np.bool) 205 | last_values = self._model.value( 206 | transform_tuple(self._obs, lambda x: np.expand_dims(x, 0)), 207 | self._state, 208 | np.expand_dims(self._done, 0)) 209 | mb_returns = np.zeros_like(mb_rewards) 210 | mb_advs = np.zeros_like(mb_rewards) 211 | last_gae_lam = 0 212 | for t in reversed(range(self._unroll_length)): 213 | if t == self._unroll_length - 1: 214 | next_nonterminal = 1.0 - self._done 215 | next_values = last_values[0] 216 | else: 217 | next_nonterminal = 1.0 - mb_dones[t + 1] 218 | next_values = mb_values[t + 1] 219 | delta = mb_rewards[t] + self._gamma * next_values * next_nonterminal - \ 220 | mb_values[t] 221 | mb_advs[t] = last_gae_lam = delta + self._gamma * self._lam * \ 222 | next_nonterminal * last_gae_lam 223 | mb_returns = mb_advs + mb_values 224 | return (mb_obs, mb_returns, mb_dones, mb_actions, mb_values, mb_neglogpacs, 225 | mb_states, episode_infos) 226 | 227 | def _push_data(self, zmq_context, learner_ip, port_B, data_queue): 228 | sender = zmq_context.socket(zmq.PUSH) 229 | sender.setsockopt(zmq.SNDHWM, 1) 230 | sender.setsockopt(zmq.RCVHWM, 1) 231 | sender.connect("tcp://%s:%s" % (learner_ip, port_B)) 232 | while True: 233 | data = data_queue.get() 234 | sender.send_pyobj(data) 235 | 236 | def _update_model(self): 237 | self._model_requestor.send_string("request model") 238 | self._model.load_params(self._model_requestor.recv_pyobj()) 239 | 240 | 241 | class PPOLearner(object): 242 | 243 | def __init__(self, env, policy, unroll_length, lr, clip_range, batch_size, 244 | ent_coef=0.01, vf_coef=0.5, max_grad_norm=0.5, queue_size=8, 245 | print_interval=100, save_interval=10000, learn_act_speed_ratio=0, 246 | unroll_split=8, save_dir=None, init_model_path=None, 247 | port_A="5700", port_B="5701"): 248 | assert isinstance(env.action_space, spaces.Discrete) 249 | if isinstance(lr, float): lr = constfn(lr) 250 | else: assert callable(lr) 251 | if isinstance(clip_range, float): clip_range = constfn(clip_range) 252 | else: assert callable(clip_range) 253 | self._lr = lr 254 | self._clip_range=clip_range 255 | self._batch_size = batch_size 256 | self._unroll_length = unroll_length 257 | self._print_interval = print_interval 258 | self._save_interval = save_interval 259 | self._learn_act_speed_ratio = learn_act_speed_ratio 260 | self._save_dir = save_dir 261 | 262 | self._model = Model(policy=policy, 263 | scope_name="model", 264 | ob_space=env.observation_space, 265 | ac_space=env.action_space, 266 | nbatch_act=1, 267 | nbatch_train=batch_size * unroll_length, 268 | unroll_length=unroll_length, 269 | ent_coef=ent_coef, 270 | vf_coef=vf_coef, 271 | max_grad_norm=max_grad_norm) 272 | if init_model_path is not None: self._model.load(init_model_path) 273 | self._model_params = self._model.read_params() 274 | self._unroll_split = unroll_split if self._model.initial_state is None else 1 275 | assert self._unroll_length % self._unroll_split == 0 276 | self._data_queue = deque(maxlen=queue_size * self._unroll_split) 277 | self._data_timesteps = deque(maxlen=200) 278 | self._episode_infos = deque(maxlen=5000) 279 | self._num_unrolls = 0 280 | 281 | self._zmq_context = zmq.Context() 282 | self._pull_data_thread = Thread( 283 | target=self._pull_data, 284 | args=(self._zmq_context, self._data_queue, self._episode_infos, 285 | self._unroll_split, port_B) 286 | ) 287 | self._pull_data_thread.start() 288 | self._reply_model_thread = Thread( 289 | target=self._reply_model, args=(self._zmq_context, port_A)) 290 | self._reply_model_thread.start() 291 | 292 | def run(self): 293 | #while len(self._data_queue) < self._data_queue.maxlen: time.sleep(1) 294 | while len(self._episode_infos) < self._episode_infos.maxlen / 2: 295 | time.sleep(1) 296 | 297 | batch_queue = Queue(4) 298 | batch_threads = [ 299 | Thread(target=self._prepare_batch, 300 | args=(self._data_queue, batch_queue, 301 | self._batch_size * self._unroll_split)) 302 | for _ in range(8) 303 | ] 304 | for thread in batch_threads: 305 | thread.start() 306 | 307 | updates, loss = 0, [] 308 | time_start = time.time() 309 | while True: 310 | while (self._learn_act_speed_ratio > 0 and 311 | updates * self._batch_size >= \ 312 | self._num_unrolls * self._learn_act_speed_ratio): 313 | time.sleep(0.001) 314 | updates += 1 315 | lr_now = self._lr(updates) 316 | clip_range_now = self._clip_range(updates) 317 | 318 | batch = batch_queue.get() 319 | obs, returns, dones, actions, values, neglogpacs, states = batch 320 | loss.append(self._model.train(lr_now, clip_range_now, obs, returns, dones, 321 | actions, values, neglogpacs, states)) 322 | self._model_params = self._model.read_params() 323 | 324 | if updates % self._print_interval == 0: 325 | loss_mean = np.mean(loss, axis=0) 326 | batch_steps = self._batch_size * self._unroll_length 327 | time_elapsed = time.time() - time_start 328 | train_fps = self._print_interval * batch_steps / time_elapsed 329 | rollout_fps = len(self._data_timesteps) * self._unroll_length / \ 330 | (time.time() - self._data_timesteps[0]) 331 | var = explained_variance(values, returns) 332 | avg_reward = safemean([info['r'] for info in self._episode_infos]) 333 | tprint("Update: %d Train-fps: %.1f Rollout-fps: %.1f " 334 | "Explained-var: %.5f Avg-reward %.2f Policy-loss: %.5f " 335 | "Value-loss: %.5f Policy-entropy: %.5f Approx-KL: %.5f " 336 | "Clip-frac: %.3f Time: %.1f" % (updates, train_fps, rollout_fps, 337 | var, avg_reward, *loss_mean[:5], time_elapsed)) 338 | time_start, loss = time.time(), [] 339 | 340 | if self._save_dir is not None and updates % self._save_interval == 0: 341 | os.makedirs(self._save_dir, exist_ok=True) 342 | save_path = os.path.join(self._save_dir, 'checkpoint-%d' % updates) 343 | self._model.save(save_path) 344 | tprint('Saved to %s.' % save_path) 345 | 346 | def _prepare_batch(self, data_queue, batch_queue, batch_size): 347 | while True: 348 | batch = random.sample(data_queue, batch_size) 349 | obs, returns, dones, actions, values, neglogpacs, states = zip(*batch) 350 | if isinstance(obs[0], tuple): 351 | obs = tuple(np.concatenate(ob) for ob in zip(*obs)) 352 | else: 353 | obs = np.concatenate(obs) 354 | returns = np.concatenate(returns) 355 | dones = np.concatenate(dones) 356 | actions = np.concatenate(actions) 357 | values = np.concatenate(values) 358 | neglogpacs = np.concatenate(neglogpacs) 359 | states = np.concatenate(states) if states[0] is not None else None 360 | batch_queue.put((obs, returns, dones, actions, values, neglogpacs, states)) 361 | 362 | def _pull_data(self, zmq_context, data_queue, episode_infos, unroll_split, 363 | port_B): 364 | receiver = zmq_context.socket(zmq.PULL) 365 | receiver.setsockopt(zmq.RCVHWM, 1) 366 | receiver.setsockopt(zmq.SNDHWM, 1) 367 | receiver.bind("tcp://*:%s" % port_B) 368 | while True: 369 | data = receiver.recv_pyobj() 370 | if unroll_split > 1: 371 | data_queue.extend(list(zip(*( 372 | [list(zip(*transform_tuple( 373 | data[0], lambda x: np.split(x, unroll_split))))] + \ 374 | [np.split(arr, unroll_split) for arr in data[1:-2]] + \ 375 | [[data[-2] for _ in range(unroll_split)]] 376 | )))) 377 | else: 378 | data_queue.append(data[:-1]) 379 | episode_infos.extend(data[-1]) 380 | self._data_timesteps.append(time.time()) 381 | self._num_unrolls += 1 382 | 383 | def _reply_model(self, zmq_context, port_A): 384 | receiver = zmq_context.socket(zmq.REP) 385 | receiver.bind("tcp://*:%s" % port_A) 386 | while True: 387 | msg = receiver.recv_string() 388 | assert msg == "request model" 389 | receiver.send_pyobj(self._model_params) 390 | 391 | 392 | class PPOAgent(object): 393 | 394 | def __init__(self, env, policy, model_path=None): 395 | assert isinstance(env.action_space, spaces.Discrete) 396 | self._model = Model(policy=policy, 397 | scope_name="model", 398 | ob_space=env.observation_space, 399 | ac_space=env.action_space, 400 | nbatch_act=1, 401 | nbatch_train=1, 402 | unroll_length=1, 403 | ent_coef=0.01, 404 | vf_coef=0.5, 405 | max_grad_norm=0.5) 406 | if model_path is not None: 407 | self._model.load(model_path) 408 | self._state = self._model.initial_state 409 | self._done = False 410 | 411 | def act(self, observation): 412 | action, value, self._state, _ = self._model.step( 413 | transform_tuple(observation, lambda x: np.expand_dims(x, 0)), 414 | self._state, 415 | np.expand_dims(self._done, 0)) 416 | return action[0] 417 | 418 | def reset(self): 419 | self._state = self._model.initial_state 420 | 421 | 422 | class PPOSelfplayActor(object): 423 | 424 | def __init__(self, env, policy, unroll_length, gamma, lam, model_cache_size, 425 | model_cache_prob, queue_size=1, prob_latest_opponent=0.0, 426 | init_opponent_pool_filelist=None, freeze_opponent_pool=False, 427 | enable_push=True, learner_ip="localhost", port_A="5700", 428 | port_B="5701"): 429 | assert isinstance(env.action_space, spaces.Discrete) 430 | self._env = env 431 | self._unroll_length = unroll_length 432 | self._lam = lam 433 | self._gamma = gamma 434 | self._prob_latest_opponent = prob_latest_opponent 435 | self._freeze_opponent_pool = freeze_opponent_pool 436 | self._enable_push = enable_push 437 | self._model_cache_prob = model_cache_prob 438 | 439 | self._model = Model(policy=policy, 440 | scope_name="model", 441 | ob_space=env.observation_space, 442 | ac_space=env.action_space, 443 | nbatch_act=1, 444 | nbatch_train=unroll_length, 445 | unroll_length=unroll_length, 446 | ent_coef=0.01, 447 | vf_coef=0.5, 448 | max_grad_norm=0.5) 449 | self._oppo_model = Model(policy=policy, 450 | scope_name="oppo_model", 451 | ob_space=env.observation_space, 452 | ac_space=env.action_space, 453 | nbatch_act=1, 454 | nbatch_train=unroll_length, 455 | unroll_length=unroll_length, 456 | ent_coef=0.01, 457 | vf_coef=0.5, 458 | max_grad_norm=0.5) 459 | self._obs, self._oppo_obs = env.reset() 460 | self._state = self._model.initial_state 461 | self._oppo_state = self._oppo_model.initial_state 462 | self._done = False 463 | self._cum_reward = 0 464 | 465 | self._model_cache = deque(maxlen=model_cache_size) 466 | if init_opponent_pool_filelist is not None: 467 | with open(init_opponent_pool_filelist, 'r') as f: 468 | for model_path in f.readlines(): 469 | print(model_path) 470 | self._model_cache.append(joblib.load(model_path.strip())) 471 | self._latest_model = self._oppo_model.read_params() 472 | if len(self._model_cache) == 0: 473 | self._model_cache.append(self._latest_model) 474 | self._update_opponent() 475 | 476 | self._zmq_context = zmq.Context() 477 | self._model_requestor = self._zmq_context.socket(zmq.REQ) 478 | self._model_requestor.connect("tcp://%s:%s" % (learner_ip, port_A)) 479 | if enable_push: 480 | self._data_queue = Queue(queue_size) 481 | self._push_thread = Thread(target=self._push_data, args=( 482 | self._zmq_context, learner_ip, port_B, self._data_queue)) 483 | self._push_thread.start() 484 | 485 | def run(self): 486 | while True: 487 | t = time.time() 488 | self._update_model() 489 | tprint("Time update model: %f" % (time.time() - t)) 490 | t = time.time() 491 | unroll = self._nstep_rollout() 492 | if self._enable_push: 493 | if self._data_queue.full(): tprint("[WARN]: Actor's queue is full.") 494 | self._data_queue.put(unroll) 495 | tprint("Time rollout: %f" % (time.time() - t)) 496 | 497 | def _nstep_rollout(self): 498 | mb_obs, mb_rewards, mb_actions, mb_values, mb_dones, mb_neglogpacs = \ 499 | [],[],[],[],[],[] 500 | mb_states, episode_infos = self._state, [] 501 | for _ in range(self._unroll_length): 502 | action, value, self._state, neglogpac = self._model.step( 503 | transform_tuple(self._obs, lambda x: np.expand_dims(x, 0)), 504 | self._state, 505 | np.expand_dims(self._done, 0)) 506 | oppo_action, _, self._oppo_state, _ = self._oppo_model.step( 507 | transform_tuple(self._oppo_obs, lambda x: np.expand_dims(x, 0)), 508 | self._oppo_state, 509 | np.expand_dims(self._done, 0)) 510 | mb_obs.append(transform_tuple(self._obs, lambda x: x.copy())) 511 | mb_actions.append(action[0]) 512 | mb_values.append(value[0]) 513 | mb_neglogpacs.append(neglogpac[0]) 514 | mb_dones.append(self._done) 515 | (self._obs, self._oppo_obs), reward, self._done, info = self._env.step( 516 | [action[0], oppo_action[0]]) 517 | self._cum_reward += reward 518 | if self._done: 519 | self._obs, self._oppo_obs = self._env.reset() 520 | self._state = self._model.initial_state 521 | self._oppo_state = self._oppo_model.initial_state 522 | self._update_opponent() 523 | episode_infos.append({'r': self._cum_reward}) 524 | self._cum_reward = 0 525 | mb_rewards.append(reward) 526 | if isinstance(self._obs, tuple): 527 | mb_obs = tuple(np.asarray(obs, dtype=self._obs[0].dtype) 528 | for obs in zip(*mb_obs)) 529 | else: 530 | mb_obs = np.asarray(mb_obs, dtype=self._obs.dtype) 531 | mb_rewards = np.asarray(mb_rewards, dtype=np.float32) 532 | mb_actions = np.asarray(mb_actions) 533 | mb_values = np.asarray(mb_values, dtype=np.float32) 534 | mb_neglogpacs = np.asarray(mb_neglogpacs, dtype=np.float32) 535 | mb_dones = np.asarray(mb_dones, dtype=np.bool) 536 | last_values = self._model.value( 537 | transform_tuple(self._obs, lambda x: np.expand_dims(x, 0)), 538 | self._state, 539 | np.expand_dims(self._done, 0)) 540 | mb_returns = np.zeros_like(mb_rewards) 541 | mb_advs = np.zeros_like(mb_rewards) 542 | last_gae_lam = 0 543 | for t in reversed(range(self._unroll_length)): 544 | if t == self._unroll_length - 1: 545 | next_nonterminal = 1.0 - self._done 546 | next_values = last_values[0] 547 | else: 548 | next_nonterminal = 1.0 - mb_dones[t + 1] 549 | next_values = mb_values[t + 1] 550 | delta = mb_rewards[t] + self._gamma * next_values * next_nonterminal - \ 551 | mb_values[t] 552 | mb_advs[t] = last_gae_lam = delta + self._gamma * self._lam * \ 553 | next_nonterminal * last_gae_lam 554 | mb_returns = mb_advs + mb_values 555 | return (mb_obs, mb_returns, mb_dones, mb_actions, mb_values, mb_neglogpacs, 556 | mb_states, episode_infos) 557 | 558 | def _push_data(self, zmq_context, learner_ip, port_B, data_queue): 559 | sender = zmq_context.socket(zmq.PUSH) 560 | sender.setsockopt(zmq.SNDHWM, 1) 561 | sender.setsockopt(zmq.RCVHWM, 1) 562 | sender.connect("tcp://%s:%s" % (learner_ip, port_B)) 563 | while True: 564 | data = data_queue.get() 565 | sender.send_pyobj(data) 566 | 567 | def _update_model(self): 568 | self._model_requestor.send_string("request model") 569 | model_params = self._model_requestor.recv_pyobj() 570 | self._model.load_params(model_params) 571 | if (not self._freeze_opponent_pool and 572 | random.uniform(0, 1.0) < self._model_cache_prob): 573 | self._model_cache.append(model_params) 574 | self._latest_model = model_params 575 | 576 | def _update_opponent(self): 577 | if (random.uniform(0, 1.0) < self._prob_latest_opponent or 578 | len(self._model_cache) == 0): 579 | self._oppo_model.load_params(self._latest_model) 580 | tprint("Opponent updated with the current model.") 581 | else: 582 | model_params = random.choice(self._model_cache) 583 | self._oppo_model.load_params(model_params) 584 | tprint("Opponent updated with the previous model. %d models cached." % 585 | len(self._model_cache)) 586 | 587 | 588 | def constfn(val): 589 | def f(_): 590 | return val 591 | return f 592 | 593 | 594 | def safemean(xs): 595 | return np.nan if len(xs) == 0 else np.mean(xs) 596 | 597 | 598 | def transform_tuple(x, transformer): 599 | if isinstance(x, tuple): 600 | return tuple(transformer(a) for a in x) 601 | else: 602 | return transformer(x) 603 | --------------------------------------------------------------------------------