├── sc2learner
├── __init__.py
├── bin
│ ├── __init__.py
│ ├── play_vs_ppo_agent.py
│ ├── evaluate.py
│ ├── train_dqn.py
│ ├── train_ppo.py
│ └── train_ppo_selfplay.py
├── envs
│ ├── __init__.py
│ ├── actions
│ │ ├── __init__.py
│ │ ├── function.py
│ │ ├── upgrade.py
│ │ ├── produce.py
│ │ ├── build.py
│ │ ├── placer.py
│ │ ├── resource.py
│ │ ├── zerg_action_wrappers.py
│ │ └── combat.py
│ ├── common
│ │ ├── __init__.py
│ │ ├── utils.py
│ │ ├── data_context.py
│ │ └── const.py
│ ├── rewards
│ │ ├── __init__.py
│ │ └── reward_wrappers.py
│ ├── spaces
│ │ ├── __init__.py
│ │ ├── pysc2_raw.py
│ │ └── mask_discrete.py
│ ├── observations
│ │ ├── __init__.py
│ │ ├── spatial_features.py
│ │ ├── nonspatial_features.py
│ │ └── zerg_observation_wrappers.py
│ ├── lan_raw_env.py
│ ├── selfplay_raw_env.py
│ └── raw_env.py
├── utils
│ ├── __init__.py
│ └── utils.py
└── agents
│ ├── __init__.py
│ ├── random_agent.py
│ ├── keyboard_agent.py
│ ├── utils_tf.py
│ ├── dqn_networks.py
│ ├── ppo_policies.py
│ ├── replay_memory.py
│ ├── dqn_agent.py
│ └── ppo_agent.py
├── docs
└── images
│ └── overview.png
├── setup.py
└── README.md
/sc2learner/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/sc2learner/bin/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/sc2learner/envs/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/sc2learner/utils/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/sc2learner/agents/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/sc2learner/envs/actions/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/sc2learner/envs/common/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/sc2learner/envs/rewards/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/sc2learner/envs/spaces/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/sc2learner/envs/observations/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/docs/images/overview.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Tencent/TStarBot1/HEAD/docs/images/overview.png
--------------------------------------------------------------------------------
/sc2learner/envs/actions/function.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | from collections import namedtuple
6 |
7 |
8 | Function = namedtuple('Function', ['name', 'function', 'is_valid'])
9 |
--------------------------------------------------------------------------------
/sc2learner/envs/spaces/pysc2_raw.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | import gym
6 |
7 |
8 | class PySC2RawAction(gym.Space):
9 | pass
10 |
11 |
12 | class PySC2RawObservation(gym.Space):
13 |
14 | def __init__(self, observation_spec_fn):
15 | self._feature_layers = observation_spec_fn()
16 |
17 | @property
18 | def space_attr(self):
19 | return self._feature_layers
20 |
--------------------------------------------------------------------------------
/sc2learner/envs/spaces/mask_discrete.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | import numpy as np
6 | from gym.spaces.discrete import Discrete
7 |
8 |
9 | class MaskDiscrete(Discrete):
10 |
11 | def sample(self, availables):
12 | x = np.random.choice(availables).item()
13 | assert self.contains(x, availables)
14 | return x
15 |
16 | def contains(self, x, availables):
17 | return super(MaskDiscrete, self).contains(x) and x in availables
18 |
19 | def __repr__(self):
20 | return "MaskDiscrete(%d)" % self.n
21 |
--------------------------------------------------------------------------------
/sc2learner/agents/random_agent.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | import numpy as np
6 |
7 | from sc2learner.envs.spaces.mask_discrete import MaskDiscrete
8 | from sc2learner.envs.spaces.pysc2_raw import PySC2RawAction
9 |
10 |
11 | class RandomAgent(object):
12 | '''Random agent.'''
13 |
14 | def __init__(self, action_space):
15 | self._action_space = action_space
16 |
17 | def act(self, observation, eps=0):
18 | if (isinstance(self._action_space, MaskDiscrete) or
19 | isinstance(self._action_space, PySC2RawAction)):
20 | action_mask = observation[-1]
21 | return self._action_space.sample(np.nonzero(action_mask)[0])
22 | else:
23 | return self._action_space.sample()
24 |
25 | def reset(self):
26 | pass
27 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | from setuptools import setup
6 |
7 |
8 | description = """Macro-action-based StarCraft-II learning environment."""
9 |
10 | setup(
11 | name='sc2learner',
12 | version='0.1',
13 | description='Macro-action-based StarCraft-II learning environment.',
14 | long_description=description,
15 | author='Tencent AI Lab',
16 | author_email='xinghaisun@tencent.com',
17 | keywords='sc2learner StarCraft AI TStarBot',
18 | packages=[
19 | 'sc2learner',
20 | 'sc2learner.agents',
21 | 'sc2learner.envs',
22 | 'sc2learner.utils',
23 | 'sc2learner.bin',
24 | ],
25 | install_requires=[
26 | 'gym==0.10.5',
27 | 'torch==0.4.0',
28 | 'tensorflow>=1.4.1',
29 | 'joblib',
30 | 'pyzmq'
31 | ]
32 | )
33 |
--------------------------------------------------------------------------------
/sc2learner/envs/common/utils.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | from pysc2.lib.unit_controls import Unit
6 |
7 |
8 | def distance(a, b):
9 |
10 | def l2_dist(pos_a, pos_b):
11 | return ((pos_a[0] - pos_b[0]) ** 2 + (pos_a[1] - pos_b[1]) ** 2) ** 0.5
12 |
13 | if isinstance(a, Unit) and isinstance(b, Unit):
14 | return l2_dist((a.float_attr.pos_x, a.float_attr.pos_y),
15 | (b.float_attr.pos_x, b.float_attr.pos_y))
16 | elif not isinstance(a, Unit) and isinstance(b, Unit):
17 | return l2_dist(a, (b.float_attr.pos_x, b.float_attr.pos_y))
18 | elif isinstance(a, Unit) and not isinstance(b, Unit):
19 | return l2_dist((a.float_attr.pos_x, a.float_attr.pos_y), b)
20 | else:
21 | return l2_dist(a, b)
22 |
23 |
24 | def closest_unit(unit, target_units):
25 | assert len(target_units) > 0
26 | return min(target_units, key=lambda u: distance(unit, u))
27 |
28 |
29 | def closest_units(unit, target_units, num):
30 | assert len(target_units) > 0
31 | return sorted(target_units, key=lambda u: distance(unit, u))[:num]
32 |
33 |
34 | def closest_distance(unit, target_units):
35 | return min(distance(unit, u) for u in target_units) \
36 | if len(target_units) > 0 else float('inf')
37 |
38 |
39 | def units_nearby(unit_center, target_units, max_distance):
40 | return [u for u in target_units if distance(unit_center, u) <= max_distance]
41 |
42 |
43 | def strongest_health(units):
44 | return max(u.float_attr.health / u.float_attr.health_max for u in units)
45 |
--------------------------------------------------------------------------------
/sc2learner/envs/lan_raw_env.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | import gym
6 | from pysc2.env import sc2_env
7 | from pysc2.env import lan_sc2_env
8 |
9 | from sc2learner.envs.spaces.pysc2_raw import PySC2RawAction
10 | from sc2learner.envs.spaces.pysc2_raw import PySC2RawObservation
11 |
12 |
13 | class LanSC2RawEnv(gym.Env):
14 |
15 | def __init__(self,
16 | host,
17 | config_port,
18 | agent_race,
19 | step_mul=8,
20 | resolution=32,
21 | visualize_feature_map=False):
22 | agent_interface_format=sc2_env.parse_agent_interface_format(
23 | feature_screen=resolution, feature_minimap=resolution)
24 | self._sc2_env = lan_sc2_env.LanSC2Env(
25 | host=host,
26 | config_port=config_port,
27 | race=sc2_env.Race[agent_race],
28 | step_mul=step_mul,
29 | agent_interface_format=agent_interface_format,
30 | visualize=visualize_feature_map)
31 | self.observation_space = PySC2RawObservation(self._sc2_env.observation_spec)
32 | self.action_space = PySC2RawAction()
33 | self._reseted = False
34 |
35 | def step(self, actions):
36 | assert self._reseted
37 | timestep = self._sc2_env.step([actions])[0]
38 | observation = timestep.observation
39 | reward = float(timestep.reward)
40 | done = timestep.last()
41 | if done: self._reseted = False
42 | info = {}
43 | return (observation, reward, done, info)
44 |
45 | def reset(self):
46 | timestep = self._sc2_env.reset()[0]
47 | observation = timestep.observation
48 | self._reseted = True
49 | return observation
50 |
51 | def close(self):
52 | self._sc2_env.close()
53 |
--------------------------------------------------------------------------------
/sc2learner/agents/keyboard_agent.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | import time
6 | import queue
7 | import threading
8 |
9 | from absl import logging
10 | import numpy as np
11 |
12 | from sc2learner.envs.spaces.mask_discrete import MaskDiscrete
13 | from sc2learner.envs.spaces.pysc2_raw import PySC2RawAction
14 |
15 |
16 | def add_input(action_queue, n):
17 | while True:
18 | if action_queue.empty():
19 | cmds = input("Input Action ID: ")
20 | if not cmds.isdigit():
21 | print("Input should be an interger. Skipped.")
22 | continue
23 | action = int(cmds)
24 | if action >=0 and action < n: action_queue.put(action)
25 | else: print("Invalid action. Skipped.")
26 |
27 |
28 | class KeyboardAgent(object):
29 | """A random agent for starcraft."""
30 | def __init__(self, action_space):
31 | super(KeyboardAgent, self).__init__()
32 | logging.set_verbosity(logging.ERROR)
33 | self._action_space = action_space
34 | self._action_queue = queue.Queue()
35 | self._cmd_thread = threading.Thread(
36 | target=add_input, args=(self._action_queue, action_space.n))
37 | self._cmd_thread.daemon = True
38 | self._cmd_thread.start()
39 |
40 | def act(self, observation, eps=0):
41 | time.sleep(0.1)
42 | if not self._action_queue.empty():
43 | action = self._action_queue.get()
44 | if (isinstance(self._action_space, MaskDiscrete) or
45 | isinstance(self._action_space, PySC2RawAction)):
46 | action_mask = observation[-1]
47 | if action_mask[action] == 0:
48 | print("Action not available. Availables: %s" %
49 | np.nonzero(action_mask))
50 | action = 0
51 | return action
52 | else:
53 | return 0
54 |
55 | def reset(self):
56 | pass
57 |
--------------------------------------------------------------------------------
/sc2learner/utils/utils.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | from absl import flags
6 | from datetime import datetime
7 |
8 |
9 | def print_arguments(flags_FLAGS):
10 | arg_name_list = dir(flags.FLAGS)
11 | black_set = set(['alsologtostderr',
12 | 'log_dir',
13 | 'logtostderr',
14 | 'showprefixforinfo',
15 | 'stderrthreshold',
16 | 'v',
17 | 'verbosity',
18 | '?',
19 | 'use_cprofile_for_profiling',
20 | 'help',
21 | 'helpfull',
22 | 'helpshort',
23 | 'helpxml',
24 | 'profile_file',
25 | 'run_with_profiling',
26 | 'only_check_args',
27 | 'pdb_post_mortem',
28 | 'run_with_pdb'])
29 | print("--------------------- Configuration Arguments --------------------")
30 | for arg_name in arg_name_list:
31 | if not arg_name.startswith('sc2_') and arg_name not in black_set:
32 | print("%s: %s" % (arg_name, flags_FLAGS[arg_name].value))
33 | print("-------------------------------------------------------------------")
34 |
35 |
36 | def tprint(x):
37 | print("[%s] %s" % (str(datetime.now().strftime('%Y-%m-%d %H:%M:%S')), x))
38 |
39 |
40 | def print_actions(env):
41 | print("----------------------------- Actions -----------------------------")
42 | for action_id, action_name in enumerate(env.action_names):
43 | print("Action ID: %d Action Name: %s" % (action_id, action_name))
44 | print("-------------------------------------------------------------------")
45 |
46 |
47 | def print_action_distribution(env, action_counts):
48 | print("----------------------- Action Distribution -----------------------")
49 | for action_id, action_name in enumerate(env.action_names):
50 | print("Action ID: %d Count: %d Name: %s" %
51 | (action_id, action_counts[action_id], action_name))
52 | print("-------------------------------------------------------------------")
53 |
--------------------------------------------------------------------------------
/sc2learner/envs/actions/upgrade.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | import random
6 |
7 | from s2clientprotocol import sc2api_pb2 as sc_pb
8 | from pysc2.lib.tech_tree import TechTree
9 |
10 | from sc2learner.envs.actions.function import Function
11 |
12 |
13 | class UpgradeActions(object):
14 |
15 | def __init__(self, game_version='4.1.2'):
16 | self._tech_tree = TechTree()
17 | self._tech_tree.update_version(game_version)
18 |
19 | def action(self, func_name, upgrade_id):
20 | return Function(name=func_name,
21 | function=self._upgrade_unit(upgrade_id),
22 | is_valid=self._is_valid_upgrade_unit(upgrade_id))
23 |
24 | def _upgrade_unit(self, upgrade_id):
25 |
26 | def act(dc):
27 | tech = self._tech_tree.getUpgradeData(upgrade_id)
28 | if len(dc.idle_units_of_types(tech.whatBuilds)) == 0: return []
29 | upgrader = random.choice(dc.idle_units_of_types(tech.whatBuilds))
30 | action = sc_pb.Action()
31 | action.action_raw.unit_command.unit_tags.append(upgrader.tag)
32 | action.action_raw.unit_command.ability_id = tech.buildAbility
33 | return [action]
34 |
35 | return act
36 |
37 | def _is_valid_upgrade_unit(self, upgrade_id):
38 |
39 | def is_valid(dc):
40 | tech = self._tech_tree.getUpgradeData(upgrade_id)
41 | has_required_units = any([len(dc.mature_units_of_type(u)) > 0
42 | for u in tech.requiredUnits]) \
43 | if len(tech.requiredUnits) > 0 else True
44 | has_required_upgrades = all([t in dc.upgraded_techs
45 | for t in tech.requiredUpgrades])
46 | if (has_required_units and
47 | has_required_upgrades and
48 | upgrade_id not in dc.upgraded_techs and
49 | len(dc.units_with_task(tech.buildAbility)) == 0 and
50 | dc.mineral_count >= tech.mineralCost and
51 | dc.gas_count >= tech.gasCost and
52 | dc.supply_count >= tech.supplyCost and
53 | len(dc.idle_units_of_types(tech.whatBuilds)) > 0):
54 | return True
55 | else:
56 | return False
57 |
58 | return is_valid
59 |
--------------------------------------------------------------------------------
/sc2learner/envs/actions/produce.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | import random
6 |
7 | from s2clientprotocol import sc2api_pb2 as sc_pb
8 | from pysc2.lib.tech_tree import TechTree
9 |
10 | from sc2learner.envs.actions.function import Function
11 | from sc2learner.envs.common.const import MAXIMUM_NUM
12 |
13 |
14 | class ProduceActions(object):
15 |
16 | def __init__(self, game_version='4.1.2'):
17 | self._tech_tree = TechTree()
18 | self._tech_tree.update_version(game_version)
19 |
20 | def action(self, func_name, type_id):
21 | return Function(name=func_name,
22 | function=self._produce_unit(type_id),
23 | is_valid=self._is_valid_produce_unit(type_id))
24 |
25 | def _produce_unit(self, type_id):
26 |
27 | def act(dc):
28 | tech = self._tech_tree.getUnitData(type_id)
29 | if len(dc.idle_units_of_types(tech.whatBuilds)) == 0: return []
30 | producer = random.choice(dc.idle_units_of_types(tech.whatBuilds))
31 | action = sc_pb.Action()
32 | action.action_raw.unit_command.unit_tags.append(producer.tag)
33 | action.action_raw.unit_command.ability_id = tech.buildAbility
34 | return [action]
35 |
36 | return act
37 |
38 | def _is_valid_produce_unit(self, type_id):
39 |
40 | def is_valid(dc):
41 | tech = self._tech_tree.getUnitData(type_id)
42 | has_required_units = any([len(dc.mature_units_of_type(u)) > 0
43 | for u in tech.requiredUnits]) \
44 | if len(tech.requiredUnits) > 0 else True
45 | has_required_upgrades = all([t in dc.upgraded_techs
46 | for t in tech.requiredUpgrades])
47 | current_num = len(dc.units_of_type(type_id)) + \
48 | len(dc.units_with_task(tech.buildAbility))
49 | overquota = current_num >= MAXIMUM_NUM[type_id] \
50 | if type_id in MAXIMUM_NUM else False
51 | if (has_required_units and
52 | has_required_upgrades and
53 | not overquota and
54 | dc.mineral_count >= tech.mineralCost and
55 | dc.gas_count >= tech.gasCost and
56 | dc.supply_count >= tech.supplyCost and
57 | len(dc.idle_units_of_types(tech.whatBuilds)) > 0):
58 | return True
59 | else:
60 | return False
61 |
62 | return is_valid
63 |
--------------------------------------------------------------------------------
/sc2learner/envs/observations/spatial_features.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | import numpy as np
6 |
7 | from sc2learner.envs.common.const import ALLY_TYPE
8 | from sc2learner.envs.common.const import MAP
9 |
10 |
11 | class UnitTypeCountMapFeature(object):
12 |
13 | def __init__(self, type_map, resolution):
14 | self._type_map = type_map
15 | self._resolution = resolution
16 |
17 | def features(self, observation, need_flip=False):
18 | self_units = [u for u in observation['units']
19 | if u.int_attr.alliance == ALLY_TYPE.SELF.value]
20 | enemy_units = [u for u in observation['units']
21 | if u.int_attr.alliance == ALLY_TYPE.ENEMY.value]
22 | self_features = self._generate_features(self_units)
23 | enemy_features = self._generate_features(enemy_units)
24 | features = np.concatenate((self_features, enemy_features))
25 | if need_flip: features = np.flip(np.flip(features, axis=1), axis=2).copy()
26 | return features
27 |
28 | @property
29 | def num_channels(self):
30 | return (max(self._type_map.values()) + 1) * 2
31 |
32 | def _generate_features(self, units):
33 | num_channels = max(self._type_map.values()) + 1
34 | features = np.zeros((num_channels, self._resolution, self._resolution),
35 | dtype=np.float32)
36 | grid_width = (MAP.WIDTH - MAP.LEFT - MAP.RIGHT) / self._resolution
37 | grid_height = (MAP.HEIGHT - MAP.TOP - MAP.BOTTOM) / self._resolution
38 | for u in units:
39 | if u.unit_type in self._type_map:
40 | c = self._type_map[u.unit_type]
41 | x = (u.float_attr.pos_x - MAP.LEFT) // grid_width
42 | y = self._resolution - 1 - \
43 | (u.float_attr.pos_y - MAP.BOTTOM) // grid_height
44 | features[c, int(y), int(x)] += 1.0
45 | return features / 5.0
46 |
47 |
48 | class AllianceCountMapFeature(object):
49 |
50 | def __init__(self, resolution):
51 | self._resolution = resolution
52 |
53 | def features(self, observation, need_flip=False):
54 | self_units = [u for u in observation['units']
55 | if u.int_attr.alliance == ALLY_TYPE.SELF.value]
56 | enemy_units = [u for u in observation['units']
57 | if u.int_attr.alliance == ALLY_TYPE.ENEMY.value]
58 | neutral_units = [u for u in observation['units']
59 | if u.int_attr.alliance == ALLY_TYPE.NEUTRAL.value]
60 | self_features = self._generate_features(self_units)
61 | enemy_features = self._generate_features(enemy_units)
62 | neutral_features = self._generate_features(neutral_units)
63 | features = np.concatenate((self_features, enemy_features, neutral_features))
64 | if need_flip: features = np.flip(np.flip(features, axis=1), axis=2).copy()
65 | return features
66 |
67 | @property
68 | def num_channels(self):
69 | return 3
70 |
71 | def _generate_features(self, units):
72 | features = np.zeros((1, self._resolution, self._resolution),
73 | dtype=np.float32)
74 | grid_width = (MAP.WIDTH - MAP.LEFT - MAP.RIGHT) / self._resolution
75 | grid_height = (MAP.HEIGHT - MAP.TOP - MAP.BOTTOM) / self._resolution
76 | for u in units:
77 | x = (u.float_attr.pos_x - MAP.LEFT) // grid_width
78 | y = self._resolution - 1 - \
79 | (u.float_attr.pos_y - MAP.BOTTOM) // grid_height
80 | features[0, int(y), int(x)] += 1.0
81 | return features / 5.0
82 |
--------------------------------------------------------------------------------
/sc2learner/envs/actions/build.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | from s2clientprotocol import sc2api_pb2 as sc_pb
6 | from pysc2.lib.tech_tree import TechTree
7 | from pysc2.lib.unit_controls import Unit
8 | from pysc2.lib.typeenums import UNIT_TYPEID as UNIT_TYPE
9 | from pysc2.lib.typeenums import ABILITY_ID as ABILITY
10 |
11 | from sc2learner.envs.actions.function import Function
12 | from sc2learner.envs.actions.placer import Placer
13 | import sc2learner.envs.common.utils as utils
14 | from sc2learner.envs.common.const import MAXIMUM_NUM
15 |
16 |
17 | class BuildActions(object):
18 |
19 | def __init__(self, game_version='4.1.2'):
20 | self._placer = Placer()
21 | self._tech_tree = TechTree()
22 | self._tech_tree.update_version(game_version)
23 |
24 | def action(self, func_name, type_id):
25 | return Function(name=func_name,
26 | function=self._build_unit(type_id),
27 | is_valid=self._is_valid_build_unit(type_id))
28 |
29 | def _build_unit(self, type_id):
30 |
31 | def act(dc):
32 | tech = self._tech_tree.getUnitData(type_id)
33 | pos = self._placer.get_building_position(type_id, dc)
34 | if pos == None: return []
35 | extractor_tags = set(u.tag for u in dc.units_of_type(
36 | UNIT_TYPE.ZERG_EXTRACTOR.value))
37 | builders = dc.units_of_types(tech.whatBuilds)
38 | prefered_builders = [
39 | u for u in builders
40 | if (u.unit_type != UNIT_TYPE.ZERG_DRONE.value or
41 | len(u.orders) == 0 or
42 | (u.orders[0].ability_id == ABILITY.HARVEST_GATHER_DRONE.value and
43 | u.orders[0].target_tag not in extractor_tags))
44 | ]
45 | if len(prefered_builders) > 0:
46 | builder = utils.closest_unit(pos, prefered_builders)
47 | else:
48 | if len(builders) == 0: return []
49 | builder = utils.closest_unit(pos, builders)
50 | action = sc_pb.Action()
51 | action.action_raw.unit_command.unit_tags.append(builder.tag)
52 | action.action_raw.unit_command.ability_id = tech.buildAbility
53 | if isinstance(pos, Unit):
54 | action.action_raw.unit_command.target_unit_tag = pos.tag
55 | else:
56 | action.action_raw.unit_command.target_world_space_pos.x = pos[0]
57 | action.action_raw.unit_command.target_world_space_pos.y = pos[1]
58 | return [action]
59 |
60 | return act
61 |
62 | def _is_valid_build_unit(self, type_id):
63 |
64 | def is_valid(dc):
65 | tech = self._tech_tree.getUnitData(type_id)
66 | has_required_units = any([len(dc.mature_units_of_type(u)) > 0
67 | for u in tech.requiredUnits]) \
68 | if len(tech.requiredUnits) > 0 else True
69 | has_required_upgrades = all([t in dc.upgraded_techs
70 | for t in tech.requiredUpgrades])
71 | current_num = len(dc.units_of_type(type_id)) + \
72 | len(dc.units_with_task(tech.buildAbility))
73 | overquota = current_num >= MAXIMUM_NUM[type_id] \
74 | if type_id in MAXIMUM_NUM else False
75 |
76 | if (has_required_units and
77 | has_required_upgrades and
78 | not overquota and
79 | dc.mineral_count >= tech.mineralCost and
80 | dc.gas_count >= tech.gasCost and
81 | dc.supply_count >= tech.supplyCost and
82 | len(dc.units_of_types(tech.whatBuilds)) > 0 and
83 | len(dc.units_with_task(tech.buildAbility)) == 0 and
84 | self._placer.can_build(type_id, dc)):
85 | return True
86 | else:
87 | return False
88 |
89 | return is_valid
90 |
--------------------------------------------------------------------------------
/sc2learner/agents/utils_tf.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | import tensorflow as tf
6 | import numpy as np
7 |
8 |
9 | class Pd(object):
10 |
11 | def neglogp(self, x):
12 | raise NotImplementedError
13 |
14 | def entropy(self):
15 | raise NotImplementedError
16 |
17 | def sample(self):
18 | raise NotImplementedError
19 |
20 | @classmethod
21 | def fromlogits(cls, logits):
22 | return cls(logits)
23 |
24 |
25 | class CategoricalPd(Pd):
26 |
27 | def __init__(self, logits):
28 | self.logits = logits
29 |
30 | def neglogp(self, x):
31 | one_hot_actions = tf.one_hot(x, self.logits.get_shape().as_list()[-1])
32 | return tf.nn.softmax_cross_entropy_with_logits(logits=self.logits,
33 | labels=one_hot_actions)
34 |
35 | def entropy(self):
36 | a = self.logits - tf.reduce_max(self.logits, axis=-1, keep_dims=True)
37 | ea = tf.exp(a)
38 | z = tf.reduce_sum(ea, axis=-1, keep_dims=True)
39 | p = ea / z
40 | return tf.reduce_sum(p * (tf.log(z) - a), axis=-1)
41 |
42 | def sample(self):
43 | u = tf.random_uniform(tf.shape(self.logits))
44 | return tf.argmax(self.logits - tf.log(-tf.log(u)), axis=-1)
45 |
46 |
47 | def fc(x, scope, nh, init_scale=1.0, init_bias=0.0):
48 | with tf.variable_scope(scope):
49 | nin = x.get_shape()[1].value
50 | w = tf.get_variable("w", [nin, nh], initializer=ortho_init(init_scale))
51 | b = tf.get_variable("b", [nh],
52 | initializer=tf.constant_initializer(init_bias))
53 | return tf.matmul(x, w) + b
54 |
55 |
56 | def lstm(xs, ms, s, scope, nh, init_scale=1.0):
57 | nbatch, nin = [v.value for v in xs[0].get_shape()]
58 | nsteps = len(xs)
59 | with tf.variable_scope(scope):
60 | wx = tf.get_variable("wx", [nin, nh*4], initializer=ortho_init(init_scale))
61 | wh = tf.get_variable("wh", [nh, nh*4], initializer=ortho_init(init_scale))
62 | b = tf.get_variable("b", [nh*4], initializer=tf.constant_initializer(0.0))
63 |
64 | c, h = tf.split(axis=1, num_or_size_splits=2, value=s)
65 | for idx, (x, m) in enumerate(zip(xs, ms)):
66 | c = c * (1 - m)
67 | h = h * (1 - m)
68 | z = tf.matmul(x, wx) + tf.matmul(h, wh) + b
69 | i, f, o, u = tf.split(axis=1, num_or_size_splits=4, value=z)
70 | i = tf.nn.sigmoid(i)
71 | f = tf.nn.sigmoid(f)
72 | o = tf.nn.sigmoid(o)
73 | u = tf.tanh(u)
74 | c = f * c + i * u
75 | h = o * tf.tanh(c)
76 | xs[idx] = h
77 | s = tf.concat(axis=1, values=[c, h])
78 | return xs, s
79 |
80 |
81 | def batch_to_seq(h, nbatch, nsteps, flat=False):
82 | if flat: h = tf.reshape(h, [nbatch, nsteps])
83 | else: h = tf.reshape(h, [nbatch, nsteps, -1])
84 | return [tf.squeeze(v, [1])
85 | for v in tf.split(axis=1, num_or_size_splits=nsteps, value=h)]
86 |
87 |
88 | def seq_to_batch(h, flat = False):
89 | shape = h[0].get_shape().as_list()
90 | if not flat:
91 | assert len(shape) > 1
92 | nh = h[0].get_shape()[-1].value
93 | return tf.reshape(tf.concat(axis=1, values=h), [-1, nh])
94 | else:
95 | return tf.reshape(tf.stack(values=h, axis=1), [-1])
96 |
97 |
98 | def ortho_init(scale=1.0):
99 |
100 | def _ortho_init(shape, dtype, partition_info=None):
101 | shape = tuple(shape)
102 | if len(shape) == 2: flat_shape = shape
103 | elif len(shape) == 4: flat_shape = (np.prod(shape[:-1]), shape[-1]) #NHWC
104 | else: raise NotImplementedError
105 | a = np.random.normal(0.0, 1.0, flat_shape)
106 | u, _, v = np.linalg.svd(a, full_matrices=False)
107 | q = u if u.shape == flat_shape else v
108 | q = q.reshape(shape)
109 | return (scale * q[:shape[0], :shape[1]]).astype(np.float32)
110 |
111 | return _ortho_init
112 |
113 |
114 | def explained_variance(ypred,y):
115 | assert y.ndim == 1 and ypred.ndim == 1
116 | var_y = np.var(y)
117 | return np.nan if var_y == 0 else 1 - np.var(y - ypred) / var_y
118 |
--------------------------------------------------------------------------------
/sc2learner/envs/selfplay_raw_env.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | import gym
6 | from pysc2.env import sc2_env
7 |
8 | from sc2learner.envs.spaces.pysc2_raw import PySC2RawAction
9 | from sc2learner.envs.spaces.pysc2_raw import PySC2RawObservation
10 | from sc2learner.utils.utils import tprint
11 |
12 |
13 | DIFFICULTIES= {
14 | "1": sc2_env.Difficulty.very_easy,
15 | "2": sc2_env.Difficulty.easy,
16 | "3": sc2_env.Difficulty.medium,
17 | "4": sc2_env.Difficulty.medium_hard,
18 | "5": sc2_env.Difficulty.hard,
19 | "6": sc2_env.Difficulty.hard,
20 | "7": sc2_env.Difficulty.very_hard,
21 | "8": sc2_env.Difficulty.cheat_vision,
22 | "9": sc2_env.Difficulty.cheat_money,
23 | "A": sc2_env.Difficulty.cheat_insane,
24 | }
25 |
26 |
27 | class SC2SelfplayRawEnv(gym.Env):
28 |
29 | def __init__(self,
30 | map_name,
31 | step_mul=8,
32 | resolution=32,
33 | disable_fog=False,
34 | agent_race='random',
35 | opponent_race='random',
36 | game_steps_per_episode=None,
37 | tie_to_lose=False,
38 | score_index=None,
39 | random_seed=None):
40 | self._map_name = map_name
41 | self._step_mul = step_mul
42 | self._resolution = resolution
43 | self._disable_fog = disable_fog
44 | self._agent_race = agent_race
45 | self._opponent_race = opponent_race
46 | self._game_steps_per_episode = game_steps_per_episode
47 | self._tie_to_lose = tie_to_lose
48 | self._score_index = score_index
49 | self._random_seed = random_seed
50 | self._reseted = False
51 | self._first_create = True
52 |
53 | self._sc2_env = self._safe_create_env()
54 | self.observation_space = PySC2RawObservation(self._sc2_env.observation_spec)
55 | self.action_space = PySC2RawAction()
56 |
57 | def step(self, actions):
58 | assert self._reseted
59 | assert len(actions) == 2
60 | timesteps = self._sc2_env.step(actions)
61 | observation = [timesteps[0].observation, timesteps[1].observation]
62 | reward = float(timesteps[0].reward)
63 | done = timesteps[0].last()
64 | if done:
65 | self._reseted = False
66 | if self._tie_to_lose and reward == 0:
67 | reward = -1.0
68 | tprint("Episode Done. Outcome %f" % reward)
69 | info = {}
70 | return (observation, reward, done, info)
71 |
72 | def reset(self):
73 | timesteps = self._safe_reset()
74 | self._reseted = True
75 | return [timesteps[0].observation, timesteps[1].observation]
76 |
77 | def _reset(self):
78 | if not self._first_create:
79 | self._sc2_env.close()
80 | self._sc2_env = self._create_env()
81 | self._first_create = False
82 | return self._sc2_env.reset()
83 |
84 | def _safe_reset(self, max_retry=10):
85 | for _ in range(max_retry - 1):
86 | try: return self._reset()
87 | except: pass
88 | return self._reset()
89 |
90 | def close(self):
91 | self._sc2_env.close()
92 |
93 | def _create_env(self):
94 | self._random_seed = (self._random_seed + 1) & 0xFFFFFFFF
95 | players=[sc2_env.Agent(sc2_env.Race[self._agent_race]),
96 | sc2_env.Agent(sc2_env.Race[self._opponent_race])]
97 | agent_interface_format=sc2_env.parse_agent_interface_format(
98 | feature_screen=self._resolution, feature_minimap=self._resolution)
99 | return sc2_env.SC2Env(
100 | map_name=self._map_name,
101 | step_mul=self._step_mul,
102 | players=players,
103 | agent_interface_format=agent_interface_format,
104 | disable_fog=self._disable_fog,
105 | game_steps_per_episode=self._game_steps_per_episode,
106 | visualize=False,
107 | score_index=self._score_index,
108 | random_seed=self._random_seed)
109 |
110 | def _safe_create_env(self, max_retry=10):
111 | for _ in range(max_retry - 1):
112 | try: return self._create_env()
113 | except: pass
114 | return self._create_env()
115 |
--------------------------------------------------------------------------------
/sc2learner/envs/raw_env.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | import gym
6 | from pysc2.env import sc2_env
7 |
8 | from sc2learner.envs.spaces.pysc2_raw import PySC2RawAction
9 | from sc2learner.envs.spaces.pysc2_raw import PySC2RawObservation
10 | from sc2learner.utils.utils import tprint
11 |
12 |
13 | DIFFICULTIES= {
14 | "1": sc2_env.Difficulty.very_easy,
15 | "2": sc2_env.Difficulty.easy,
16 | "3": sc2_env.Difficulty.medium,
17 | "4": sc2_env.Difficulty.medium_hard,
18 | "5": sc2_env.Difficulty.hard,
19 | "6": sc2_env.Difficulty.hard,
20 | "7": sc2_env.Difficulty.very_hard,
21 | "8": sc2_env.Difficulty.cheat_vision,
22 | "9": sc2_env.Difficulty.cheat_money,
23 | "A": sc2_env.Difficulty.cheat_insane,
24 | }
25 |
26 |
27 | class SC2RawEnv(gym.Env):
28 |
29 | def __init__(self,
30 | map_name,
31 | step_mul=8,
32 | resolution=32,
33 | disable_fog=False,
34 | agent_race='random',
35 | bot_race='random',
36 | difficulty='1',
37 | game_steps_per_episode=None,
38 | tie_to_lose=False,
39 | score_index=None,
40 | random_seed=None):
41 | self._map_name = map_name
42 | self._step_mul = step_mul
43 | self._resolution = resolution
44 | self._disable_fog = disable_fog
45 | self._agent_race = agent_race
46 | self._bot_race = bot_race
47 | self._difficulty = difficulty
48 | self._game_steps_per_episode = game_steps_per_episode
49 | self._tie_to_lose = tie_to_lose
50 | self._score_index = score_index
51 | self._random_seed = random_seed
52 | self._reseted = False
53 | self._first_create = True
54 |
55 | self._sc2_env = self._safe_create_env()
56 | self.observation_space = PySC2RawObservation(self._sc2_env.observation_spec)
57 | self.action_space = PySC2RawAction()
58 |
59 | def step(self, actions):
60 | assert self._reseted
61 | timestep = self._sc2_env.step([actions])[0]
62 | observation = timestep.observation
63 | reward = float(timestep.reward)
64 | done = timestep.last()
65 | if done:
66 | self._reseted = False
67 | if self._tie_to_lose and reward == 0:
68 | reward = -1.0
69 | tprint("Episode Done. Difficulty: %s Outcome %f" %
70 | (self._difficulty, reward))
71 | info = {}
72 | return (observation, reward, done, info)
73 |
74 | def reset(self):
75 | timesteps = self._safe_reset()
76 | self._reseted = True
77 | return timesteps[0].observation
78 |
79 | def _reset(self):
80 | if not self._first_create:
81 | self._sc2_env.close()
82 | self._sc2_env = self._create_env()
83 | self._first_create = False
84 | return self._sc2_env.reset()
85 |
86 | def _safe_reset(self, max_retry=10):
87 | for _ in range(max_retry - 1):
88 | try: return self._reset()
89 | except: pass
90 | return self._reset()
91 |
92 | def close(self):
93 | self._sc2_env.close()
94 |
95 | def _create_env(self):
96 | self._random_seed = (self._random_seed + 11) & 0xFFFFFFFF
97 | players=[sc2_env.Agent(sc2_env.Race[self._agent_race]),
98 | sc2_env.Bot(sc2_env.Race[self._bot_race],
99 | DIFFICULTIES[self._difficulty])]
100 | agent_interface_format=sc2_env.parse_agent_interface_format(
101 | feature_screen=self._resolution, feature_minimap=self._resolution)
102 | tprint("Creating game with seed %d." % self._random_seed)
103 | return sc2_env.SC2Env(
104 | map_name=self._map_name,
105 | step_mul=self._step_mul,
106 | players=players,
107 | agent_interface_format=agent_interface_format,
108 | disable_fog=self._disable_fog,
109 | game_steps_per_episode=self._game_steps_per_episode,
110 | visualize=False,
111 | score_index=self._score_index,
112 | random_seed=self._random_seed)
113 |
114 | def _safe_create_env(self, max_retry=10):
115 | for _ in range(max_retry - 1):
116 | try: return self._create_env()
117 | except: pass
118 | return self._create_env()
119 |
--------------------------------------------------------------------------------
/sc2learner/agents/dqn_networks.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | import torch
6 | import torch.nn as nn
7 | import torch.nn.functional as F
8 |
9 |
10 | class DuelingQNet(nn.Module):
11 |
12 | def __init__(self,
13 | resolution,
14 | n_channels,
15 | n_dims,
16 | n_out,
17 | batchnorm=False):
18 | super(DuelingQNet, self).__init__()
19 | assert resolution == 16
20 | self.conv1 = nn.Conv2d(in_channels=n_channels,
21 | out_channels=32,
22 | kernel_size=5,
23 | stride=1,
24 | padding=2)
25 | self.conv2 = nn.Conv2d(in_channels=32,
26 | out_channels=32,
27 | kernel_size=3,
28 | stride=1,
29 | padding=1)
30 | self.conv3 = nn.Conv2d(in_channels=32,
31 | out_channels=16,
32 | kernel_size=3,
33 | stride=2,
34 | padding=1)
35 | if batchnorm:
36 | self.bn1 = nn.BatchNorm2d(32)
37 | self.bn2 = nn.BatchNorm2d(32)
38 | self.bn3 = nn.BatchNorm2d(16)
39 |
40 | self.value_sp_fc = nn.Linear(16 * 8 * 8, 256)
41 | self.value_nonsp_fc1 = nn.Linear(n_dims, 512)
42 | self.value_nonsp_fc2 = nn.Linear(512, 512)
43 | self.value_nonsp_fc3 = nn.Linear(512, 256)
44 | self.value_final_fc = nn.Linear(512, 1)
45 |
46 | self.adv_sp_fc = nn.Linear(16 * 8 * 8, 256)
47 | self.adv_nonsp_fc1 = nn.Linear(n_dims, 512)
48 | self.adv_nonsp_fc2 = nn.Linear(512, 512)
49 | self.adv_nonsp_fc3 = nn.Linear(512, 256)
50 | self.adv_final_fc = nn.Linear(512, n_out)
51 | self._batchnorm = batchnorm
52 |
53 | def forward(self, x):
54 | spatial, nonspatial = x
55 | if self._batchnorm:
56 | spatial = F.relu(self.bn1(self.conv1(spatial)))
57 | spatial = F.relu(self.bn2(self.conv2(spatial)))
58 | spatial = F.relu(self.bn3(self.conv3(spatial)))
59 | else:
60 | spatial = F.relu(self.conv1(spatial))
61 | spatial = F.relu(self.conv2(spatial))
62 | spatial = F.relu(self.conv3(spatial))
63 | spatial = spatial.view(spatial.size(0), -1)
64 |
65 | value_sp_state = F.relu(self.value_sp_fc(spatial))
66 | value_nonsp_state = F.relu(self.value_nonsp_fc1(nonspatial))
67 | value_nonsp_state = F.relu(self.value_nonsp_fc2(value_nonsp_state))
68 | value_nonsp_state = F.relu(self.value_nonsp_fc3(value_nonsp_state))
69 | value_state = torch.cat((value_sp_state, value_nonsp_state), 1)
70 | value = self.value_final_fc(value_state)
71 |
72 | adv_sp_state = F.relu(self.adv_sp_fc(spatial))
73 | adv_nonsp_state = F.relu(self.adv_nonsp_fc1(nonspatial))
74 | adv_nonsp_state = F.relu(self.adv_nonsp_fc2(adv_nonsp_state))
75 | adv_nonsp_state = F.relu(self.adv_nonsp_fc3(adv_nonsp_state))
76 | adv_state = torch.cat((adv_sp_state, adv_nonsp_state), 1)
77 | adv = self.adv_final_fc(adv_state)
78 | adv_subtract = adv - adv.mean(dim=1, keepdim=True)
79 | return value + adv_subtract
80 |
81 |
82 | class NonspatialDuelingQNet(nn.Module):
83 |
84 | def __init__(self, n_dims, n_out):
85 | super(NonspatialDuelingQNet, self).__init__()
86 | self.value_fc1 = nn.Linear(n_dims, 512)
87 | self.value_fc2 = nn.Linear(512, 512)
88 | self.value_fc3 = nn.Linear(512, 256)
89 | self.value_fc4 = nn.Linear(256, 1)
90 |
91 | self.adv_fc1 = nn.Linear(n_dims, 512)
92 | self.adv_fc2 = nn.Linear(512, 512)
93 | self.adv_fc3 = nn.Linear(512, 256)
94 | self.adv_fc4 = nn.Linear(256, n_out)
95 |
96 | def forward(self, x):
97 | value = F.relu(self.value_fc1(x))
98 | value = F.relu(self.value_fc2(value))
99 | value = F.relu(self.value_fc3(value))
100 | value = self.value_fc4(value)
101 |
102 | adv = F.relu(self.adv_fc1(x))
103 | adv = F.relu(self.adv_fc2(adv))
104 | adv = F.relu(self.adv_fc3(adv))
105 | adv = self.adv_fc4(adv)
106 |
107 | adv_subtract = adv - adv.mean(dim=1, keepdim=True)
108 | return value + adv_subtract
109 |
--------------------------------------------------------------------------------
/sc2learner/bin/play_vs_ppo_agent.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | import sys
6 | import traceback
7 | import multiprocessing
8 |
9 | from absl import app
10 | from absl import flags
11 | from absl import logging
12 | import tensorflow as tf
13 |
14 | from sc2learner.envs.lan_raw_env import LanSC2RawEnv
15 | from sc2learner.envs.observations.zerg_observation_wrappers \
16 | import ZergObservationWrapper
17 | from sc2learner.envs.actions.zerg_action_wrappers import ZergActionWrapper
18 | from sc2learner.utils.utils import print_arguments
19 | from sc2learner.agents.ppo_policies import LstmPolicy, MlpPolicy
20 | from sc2learner.agents.ppo_agent import PPOAgent
21 |
22 |
23 | FLAGS = flags.FLAGS
24 | flags.DEFINE_string("game_version", '4.6', "Game core version.")
25 | flags.DEFINE_string("model_path", None, "Filepath to load initial model.")
26 | flags.DEFINE_integer("step_mul", 32, "Game steps per agent step.")
27 | flags.DEFINE_string("host", "127.0.0.1", "Game Host. Can be 127.0.0.1 or ::1")
28 | flags.DEFINE_integer(
29 | "config_port", 14380,
30 | "Where to set/find the config port. The host starts a tcp server to share "
31 | "the config with the client, and to proxy udp traffic if played over an "
32 | "ssh tunnel. This sets that port, and is also the start of the range of "
33 | "ports used for LAN play.")
34 | flags.DEFINE_boolean("use_all_combat_actions", False, "Use all combat actions.")
35 | flags.DEFINE_boolean("use_region_features", False, "Use region features")
36 | flags.DEFINE_boolean("use_action_mask", True, "Use action mask or not.")
37 | flags.DEFINE_enum("policy", 'mlp', ['mlp', 'lstm'], "Job type.")
38 |
39 |
40 | def print_actions(env):
41 | print("----------------------------- Actions -----------------------------")
42 | for action_id, action_name in enumerate(env.action_names):
43 | print("Action ID: %d Action Name: %s" % (action_id, action_name))
44 | print("-------------------------------------------------------------------")
45 |
46 |
47 | def print_action_distribution(env, action_counts):
48 | print("----------------------- Action Distribution -----------------------")
49 | for action_id, action_name in enumerate(env.action_names):
50 | print("Action ID: %d Count: %d Name: %s" %
51 | (action_id, action_counts[action_id], action_name))
52 | print("-------------------------------------------------------------------")
53 |
54 |
55 | def tf_config(ncpu=None):
56 | if ncpu is None:
57 | ncpu = multiprocessing.cpu_count()
58 | if sys.platform == 'darwin': ncpu //= 2
59 | config = tf.ConfigProto(allow_soft_placement=True,
60 | intra_op_parallelism_threads=ncpu,
61 | inter_op_parallelism_threads=ncpu)
62 | config.gpu_options.allow_growth = True
63 | tf.Session(config=config).__enter__()
64 |
65 |
66 | def start_lan_agent():
67 | """Run the agent, connecting to a host started independently."""
68 | tf_config()
69 | env = LanSC2RawEnv(host=FLAGS.host,
70 | config_port=FLAGS.config_port,
71 | agent_race='zerg',
72 | step_mul=FLAGS.step_mul,
73 | visualize_feature_map=False)
74 | env = ZergActionWrapper(env,
75 | game_version=FLAGS.game_version,
76 | mask=FLAGS.use_action_mask,
77 | use_all_combat_actions=FLAGS.use_all_combat_actions)
78 | env = ZergObservationWrapper(
79 | env,
80 | use_spatial_features=False,
81 | use_game_progress=(not FLAGS.policy == 'lstm'),
82 | action_seq_len=1 if FLAGS.policy == 'lstm' else 8,
83 | use_regions=FLAGS.use_region_features)
84 | print_actions(env)
85 | policy = {'lstm': LstmPolicy,
86 | 'mlp': MlpPolicy}[FLAGS.policy]
87 | agent = PPOAgent(env=env,
88 | policy=policy,
89 | model_path=FLAGS.model_path)
90 | try:
91 | action_counts = [0] * env.action_space.n
92 | observation = env.reset()
93 | done, step_id = False, 0
94 | while not done:
95 | action = agent.act(observation)
96 | print("Step ID: %d Take Action: %d" % (step_id, action))
97 | observation, reward, done, _ = env.step(action)
98 | action_counts[action] += 1
99 | step_id += 1
100 | print_action_distribution(env, action_counts)
101 | except KeyboardInterrupt: pass
102 | except: traceback.print_exc()
103 | env.close()
104 |
105 |
106 | def main(unused_argv):
107 | logging.set_verbosity(logging.ERROR)
108 | print_arguments(FLAGS)
109 | start_lan_agent()
110 |
111 |
112 | if __name__ == "__main__":
113 | app.run(main)
114 |
--------------------------------------------------------------------------------
/sc2learner/agents/ppo_policies.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | import numpy as np
6 | import tensorflow as tf
7 |
8 | from sc2learner.envs.spaces.mask_discrete import MaskDiscrete
9 | from sc2learner.agents.utils_tf import CategoricalPd
10 | from sc2learner.agents.utils_tf import fc, lstm, batch_to_seq, seq_to_batch
11 |
12 |
13 | class MlpPolicy(object):
14 | def __init__(self, sess, scope_name, ob_space, ac_space, nbatch, nsteps,
15 | reuse=False):
16 | if isinstance(ac_space, MaskDiscrete):
17 | ob_space, mask_space = ob_space.spaces
18 |
19 | X = tf.placeholder(
20 | shape=(nbatch,) + ob_space.shape, dtype=tf.float32, name="x_screen")
21 | if isinstance(ac_space, MaskDiscrete):
22 | MASK = tf.placeholder(
23 | shape=(nbatch,) + mask_space.shape, dtype=tf.float32, name="mask")
24 |
25 | with tf.variable_scope(scope_name, reuse=reuse):
26 | x = tf.layers.flatten(X)
27 | pi_h1 = tf.tanh(fc(x, 'pi_fc1', nh=512, init_scale=np.sqrt(2)))
28 | pi_h2 = tf.tanh(fc(pi_h1, 'pi_fc2', nh=512, init_scale=np.sqrt(2)))
29 | pi_h3 = tf.tanh(fc(pi_h2, 'pi_fc3', nh=512, init_scale=np.sqrt(2)))
30 | vf_h1 = tf.tanh(fc(x, 'vf_fc1', nh=512, init_scale=np.sqrt(2)))
31 | vf_h2 = tf.tanh(fc(vf_h1, 'vf_fc2', nh=512, init_scale=np.sqrt(2)))
32 | vf_h3 = tf.tanh(fc(vf_h2, 'vf_fc3', nh=512, init_scale=np.sqrt(2)))
33 | vf = fc(vf_h3, 'vf', 1)[:,0]
34 | pi_logit = fc(pi_h3, 'pi', ac_space.n, init_scale=0.01, init_bias=0.0)
35 | if isinstance(ac_space, MaskDiscrete):
36 | pi_logit -= (1 - MASK) * 1e30
37 | self.pd = CategoricalPd(pi_logit)
38 |
39 | action = self.pd.sample()
40 | neglogp = self.pd.neglogp(action)
41 | self.initial_state = None
42 |
43 | def step(ob, *_args, **_kwargs):
44 | if isinstance(ac_space, MaskDiscrete):
45 | a, v, nl = sess.run([action, vf, neglogp], {X:ob[0], MASK:ob[-1]})
46 | else:
47 | a, v, nl = sess.run([action, vf, neglogp], {X:ob})
48 | return a, v, self.initial_state, nl
49 |
50 | def value(ob, *_args, **_kwargs):
51 | if isinstance(ac_space, MaskDiscrete):
52 | return sess.run(vf, {X:ob[0], MASK:ob[-1]})
53 | else:
54 | return sess.run(vf, {X:ob})
55 |
56 | self.X = X
57 | if isinstance(ac_space, MaskDiscrete):
58 | self.MASK = MASK
59 | self.vf = vf
60 | self.step = step
61 | self.value = value
62 |
63 |
64 | class LstmPolicy(object):
65 |
66 | def __init__(self, sess, scope_name, ob_space, ac_space, nbatch,
67 | unroll_length, nlstm=512, reuse=False):
68 | nenv = nbatch // unroll_length
69 | if isinstance(ac_space, MaskDiscrete):
70 | ob_space, mask_space = ob_space.spaces
71 |
72 | DONE = tf.placeholder(tf.float32, [nbatch])
73 | STATE = tf.placeholder(tf.float32, [nenv, nlstm * 2])
74 | X = tf.placeholder(
75 | shape=(nbatch,) + ob_space.shape, dtype=tf.float32, name="x_screen")
76 | if isinstance(ac_space, MaskDiscrete):
77 | MASK = tf.placeholder(
78 | shape=(nbatch,) + mask_space.shape, dtype=tf.float32, name="mask")
79 |
80 | with tf.variable_scope(scope_name, reuse=reuse):
81 | x = tf.layers.flatten(X)
82 | fc1 = tf.nn.relu(fc(x, 'fc1', 512))
83 | h = tf.nn.relu(fc(fc1, 'fc2', 512))
84 | xs = batch_to_seq(h, nenv, unroll_length)
85 | ms = batch_to_seq(DONE, nenv, unroll_length)
86 | h5, snew = lstm(xs, ms, STATE, 'lstm1', nh=nlstm)
87 | h5 = seq_to_batch(h5)
88 | vf = fc(h5, 'v', 1)
89 | pi_logit = fc(h5, 'pi', ac_space.n, init_scale=1.0, init_bias=0.0)
90 | if isinstance(ac_space, MaskDiscrete):
91 | pi_logit -= (1 - MASK) * 1e30
92 | self.pd = CategoricalPd(pi_logit)
93 |
94 | action = self.pd.sample()
95 | neglogp = self.pd.neglogp(action)
96 | self.initial_state = np.zeros((nenv, nlstm*2), dtype=np.float32)
97 |
98 | def step(ob, state, done):
99 | if isinstance(ac_space, MaskDiscrete):
100 | return sess.run([action, vf, snew, neglogp],
101 | {X:ob[0], MASK:ob[-1], STATE:state, DONE:done})
102 | else:
103 | return sess.run([action, vf, snew, neglogp],
104 | {X:ob, STATE:state, DONE:done})
105 |
106 | def value(ob, state, done):
107 | if isinstance(ac_space, MaskDiscrete):
108 | return sess.run(vf, {X:ob[0], MASK:ob[-1], STATE:state, DONE:mask})
109 | else:
110 | return sess.run(vf, {X:ob, STATE:state, DONE:mask})
111 |
112 | self.X = X
113 | if isinstance(ac_space, MaskDiscrete):
114 | self.MASK = MASK
115 | self.DONE = DONE
116 | self.STATE = STATE
117 | self.vf = vf
118 | self.step = step
119 | self.value = value
120 |
--------------------------------------------------------------------------------
/sc2learner/agents/replay_memory.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | from collections import namedtuple
6 | from collections import deque
7 | from threading import Thread
8 | import random
9 | import time
10 |
11 | import zmq
12 |
13 |
14 | Transition = namedtuple('Transition',
15 | ('observation', 'action', 'reward', 'next_observation',
16 | 'done', 'mc_return'))
17 |
18 |
19 | class LocalReplayMemory(object):
20 | def __init__(self, capacity):
21 | self._memory = deque(maxlen=capacity)
22 | self._total = 0
23 |
24 | def push(self, *args):
25 | self._memory.append(Transition(*args))
26 | self._total += 1
27 |
28 | def sample(self, batch_size):
29 | return random.sample(self._memory, batch_size)
30 |
31 | @property
32 | def total(self):
33 | return self._total
34 |
35 |
36 | class RemoteReplayMemory(object):
37 | def __init__(self,
38 | is_server,
39 | memory_size,
40 | memory_warmup_size,
41 | block_size=128,
42 | send_freq=1.0,
43 | num_pull_threads=4,
44 | ports=("5700", "5701"),
45 | server_ip="localhost"):
46 | assert len(ports) == 2
47 | assert memory_warmup_size <= memory_size
48 | self._is_server = is_server
49 | self._memory_warmup_size = memory_warmup_size
50 | self._block_size = block_size
51 |
52 | if is_server:
53 | self._num_received, self._num_used, self._total = 0, 0, 0
54 | self._cache_blocks = deque(maxlen=memory_size // block_size)
55 | self._zmq_context = zmq.Context()
56 |
57 | self._receiver_threads = [Thread(target=self._server_proxy_worker,
58 | args=(self._zmq_context, ports,))]
59 | self._receiver_threads += [Thread(target=self._server_receiver_worker,
60 | args=(self._zmq_context, ports[1],))
61 | for _ in range(num_pull_threads)]
62 | for thread in self._receiver_threads: thread.start()
63 | else:
64 | self._memory = LocalReplayMemory(memory_size)
65 | self._memory_total_last = 0
66 | self._send_interval = int(block_size / send_freq)
67 |
68 | self._zmq_context = zmq.Context()
69 | self._sender = self._zmq_context.socket(zmq.PUSH)
70 | self._sender.connect("tcp://%s:%s" % (server_ip, ports[0]))
71 |
72 | def push(self, *args):
73 | assert not self._is_server, "push() cannot be called when is_server=True."
74 | self._memory.push(*args)
75 | if (self._memory.total >= self._memory_warmup_size and
76 | self._memory.total >= self._block_size and
77 | self._memory.total % self._send_interval == 0):
78 | block = self._memory.sample(self._block_size)
79 | memory_total = self._memory.total
80 | memory_delta = memory_total - self._memory_total_last
81 | self._memory_total_last = memory_total
82 | self._sender.send_pyobj((block, memory_delta))
83 |
84 | def sample(self, batch_size, reuse_ratio=1.0):
85 | assert self._is_server, "sample() cannot be called when is_server=False."
86 | while (self._num_used / reuse_ratio >= self._num_received or
87 | self._memory_warmup_size > len(self._cache_blocks) * self._block_size):
88 | time.sleep(0.001)
89 | batch = [random.choice(random.choice(self._cache_blocks))
90 | for _ in range(batch_size)]
91 | self._num_used += batch_size
92 | return batch
93 |
94 | @property
95 | def total(self):
96 | if self._is_server:
97 | return self._total
98 | else:
99 | return self._memory.total
100 |
101 | def _server_receiver_worker(self, zmq_context, port):
102 | receiver = zmq_context.socket(zmq.PULL)
103 | receiver.connect("tcp://localhost:%s" % port)
104 | while True:
105 | block, delta = receiver.recv_pyobj()
106 | self._cache_blocks.append(block)
107 | self._total += delta
108 | self._num_received += len(block)
109 |
110 | def _server_proxy_worker(self, zmq_context, ports):
111 | assert len(ports) == 2
112 | frontend = zmq_context.socket(zmq.PULL)
113 | frontend.bind("tcp://*:%s" % ports[0])
114 | backend = self._zmq_context.socket(zmq.PUSH)
115 | backend.bind("tcp://*:%s" % ports[1])
116 | zmq.proxy(frontend, backend)
117 |
118 |
119 | if __name__ == '__main__':
120 | import sys
121 | import numpy as np
122 |
123 | job_name = sys.argv[1]
124 | if job_name == 'client':
125 | replay_memory = RemoteReplayMemory(
126 | is_server=False,
127 | memory_size=10000,
128 | memory_warmup_size=16)
129 | while True:
130 | obs, next_obs = np.array([1,2,3]), np.array([3,4,5])
131 | action, reward, done, mc_return = 1, 0.5, False, 0.01
132 | replay_memory.push(obs, action, reward, next_obs, done, mc_return)
133 | else:
134 | replay_memory = RemoteReplayMemory(
135 | is_server=True,
136 | memory_size=10000,
137 | memory_warmup_size=10000)
138 | while True:
139 | print(replay_memory.sample(8))
140 |
--------------------------------------------------------------------------------
/sc2learner/envs/common/data_context.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | import itertools
6 |
7 | from pysc2.lib.typeenums import UNIT_TYPEID as UNIT_TYPE
8 |
9 | from sc2learner.envs.common.const import ALLY_TYPE
10 | from sc2learner.envs.common.const import PLAYER_FEATURE
11 | from sc2learner.envs.common.const import COMBAT_TYPES
12 | import sc2learner.envs.common.utils as utils
13 |
14 |
15 | class DataContext(object):
16 |
17 | def __init__(self):
18 | self._units = []
19 | self._player = None
20 | self._raw_data = None
21 | self._existed_tags = set()
22 |
23 | def update(self, observation):
24 | for u in self._units:
25 | self._existed_tags.add(u.tag)
26 | self._units = observation['units']
27 | self._player = observation['player']
28 | self._raw_data = observation['raw_data']
29 | self._combat_units = self.units_of_types(COMBAT_TYPES)
30 |
31 | def reset(self, observation):
32 | self._existed_tags.clear()
33 | self.update(observation)
34 | init_base = self.units_of_type(UNIT_TYPE.ZERG_HATCHERY.value)[0]
35 | self._init_base_pos = (init_base.float_attr.pos_x,
36 | init_base.float_attr.pos_y)
37 |
38 | def units_of_alliance(self, ally):
39 | return [u for u in self._units if u.int_attr.alliance == ally]
40 |
41 | def units_of_type(self, type_id, ally=ALLY_TYPE.SELF.value):
42 | return [u for u in self.units_of_alliance(ally) if u.unit_type == type_id]
43 |
44 | def mature_units_of_type(self, type_id, ally=ALLY_TYPE.SELF.value):
45 | return [u for u in self.units_of_type(type_id, ally)
46 | if u.float_attr.build_progress >= 1.0]
47 |
48 | def idle_units_of_type(self, type_id, ally=ALLY_TYPE.SELF.value):
49 | return [u for u in self.mature_units_of_type(type_id, ally)
50 | if len(u.orders) == 0]
51 |
52 | def units_of_types(self, type_list, ally=ALLY_TYPE.SELF.value):
53 | type_set = set(type_list)
54 | return [u for u in self.units_of_alliance(ally) if u.unit_type in type_set]
55 |
56 | def mature_units_of_types(self, type_list, ally=ALLY_TYPE.SELF.value):
57 | return [u for u in self.units_of_types(type_list, ally)
58 | if u.float_attr.build_progress >= 1.0]
59 |
60 | def idle_units_of_types(self, type_list, ally=ALLY_TYPE.SELF.value):
61 | return [u for u in self.mature_units_of_types(type_list, ally)
62 | if len(u.orders) == 0]
63 |
64 | def units_with_task(self, ability_id, ally=ALLY_TYPE.SELF.value):
65 | return [u for u in self.units_of_alliance(ally)
66 | if ability_id in set([order.ability_id for order in u.orders])]
67 |
68 | def is_new_unit(self, unit):
69 | return unit.tag not in self._existed_tags
70 |
71 | @property
72 | def units(self):
73 | return self._units
74 |
75 | @property
76 | def combat_units(self):
77 | return self._combat_units
78 |
79 | @property
80 | def minerals(self):
81 | return [u for u in self._units
82 | if (u.unit_type == UNIT_TYPE.NEUTRAL_MINERALFIELD.value or
83 | u.unit_type == UNIT_TYPE.NEUTRAL_MINERALFIELD750.value)]
84 |
85 | @property
86 | def unexploited_minerals(self):
87 | self_bases = self.units_of_types([UNIT_TYPE.ZERG_HATCHERY.value,
88 | UNIT_TYPE.ZERG_LAIR.value,
89 | UNIT_TYPE.ZERG_HIVE.value])
90 | enemy_bases = self.units_of_types([UNIT_TYPE.ZERG_HATCHERY.value,
91 | UNIT_TYPE.ZERG_LAIR.value,
92 | UNIT_TYPE.ZERG_HIVE.value],
93 | ALLY_TYPE.ENEMY.value)
94 | return [u for u in self.minerals
95 | if utils.closest_distance(u, self_bases + enemy_bases) > 15]
96 |
97 | @property
98 | def gas(self):
99 | return [u for u in self._units
100 | if u.unit_type == UNIT_TYPE.NEUTRAL_VESPENEGEYSER.value]
101 |
102 | @property
103 | def exploitable_gas(self):
104 | extractors = self.units_of_type(UNIT_TYPE.ZERG_EXTRACTOR.value) + \
105 | self.units_of_type(UNIT_TYPE.ZERG_EXTRACTOR.value, ALLY_TYPE.ENEMY)
106 | bases = self.mature_units_of_types([UNIT_TYPE.ZERG_HATCHERY.value,
107 | UNIT_TYPE.ZERG_LAIR.value,
108 | UNIT_TYPE.ZERG_HIVE.value])
109 | return [u for u in self.gas if (utils.closest_distance(u, bases) < 10 and
110 | utils.closest_distance(u, extractors) > 3)]
111 |
112 | @property
113 | def mineral_count(self):
114 | return self._player[PLAYER_FEATURE.MINERALS.value]
115 |
116 | @property
117 | def gas_count(self):
118 | return self._player[PLAYER_FEATURE.VESPENE.value]
119 |
120 | @property
121 | def supply_count(self):
122 | return self._player[PLAYER_FEATURE.FOOD_CAP.value] - \
123 | self._player[PLAYER_FEATURE.FOOD_USED.value]
124 |
125 | @property
126 | def upgraded_techs(self):
127 | return set(self._raw_data.player.upgrade_ids)
128 |
129 | @property
130 | def init_base_pos(self):
131 | return self._init_base_pos
132 |
--------------------------------------------------------------------------------
/sc2learner/bin/evaluate.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | import sys
6 | import random
7 |
8 | from absl import app
9 | from absl import flags
10 | from absl import logging
11 |
12 | from sc2learner.envs.raw_env import SC2RawEnv
13 | from sc2learner.envs.actions.zerg_action_wrappers import ZergActionWrapper
14 | from sc2learner.envs.observations.zerg_observation_wrappers \
15 | import ZergObservationWrapper
16 | from sc2learner.utils.utils import print_arguments
17 | from sc2learner.utils.utils import print_actions
18 | from sc2learner.utils.utils import print_action_distribution
19 | from sc2learner.agents.random_agent import RandomAgent
20 | from sc2learner.agents.keyboard_agent import KeyboardAgent
21 |
22 |
23 | FLAGS = flags.FLAGS
24 | flags.DEFINE_integer("num_episodes", 10, "Number of episodes to evaluate.")
25 | flags.DEFINE_enum("agent", 'ppo', ['ppo', 'dqn', 'random', 'keyboard'],
26 | "Agent name.")
27 | flags.DEFINE_enum("policy", 'mlp', ['mlp', 'lstm'], "Job type.")
28 | flags.DEFINE_string("game_version", '4.6', "Game core version.")
29 | flags.DEFINE_integer("step_mul", 32, "Game steps per agent step.")
30 | flags.DEFINE_enum("difficulty", '1',
31 | ['1', '2', '3', '4', '5', '6', '7', '8', '9', 'A'],
32 | "Bot's strength.")
33 | flags.DEFINE_string("model_path", None, "Filepath to load initial model.")
34 | flags.DEFINE_boolean("disable_fog", False, "Disable fog-of-war.")
35 | flags.DEFINE_boolean("use_all_combat_actions", False, "Use all combat actions.")
36 | flags.DEFINE_boolean("use_region_features", False, "Use region features")
37 | flags.DEFINE_boolean("use_action_mask", True, "Use action mask or not.")
38 | flags.FLAGS(sys.argv)
39 |
40 |
41 | def create_env(random_seed=None):
42 | env = SC2RawEnv(map_name='AbyssalReef',
43 | step_mul=FLAGS.step_mul,
44 | agent_race='zerg',
45 | bot_race='zerg',
46 | difficulty=FLAGS.difficulty,
47 | disable_fog=FLAGS.disable_fog,
48 | random_seed=random_seed)
49 | env = ZergActionWrapper(env,
50 | game_version=FLAGS.game_version,
51 | mask=FLAGS.use_action_mask,
52 | use_all_combat_actions=FLAGS.use_all_combat_actions)
53 | env = ZergObservationWrapper(
54 | env,
55 | use_spatial_features=False,
56 | use_game_progress=(not FLAGS.policy == 'lstm'),
57 | action_seq_len=1 if FLAGS.policy == 'lstm' else 8,
58 | use_regions=FLAGS.use_region_features)
59 | print_actions(env)
60 | return env
61 |
62 |
63 | def create_dqn_agent(env):
64 | from sc2learner.agents.dqn_agent import DQNAgent
65 | from sc2learner.agents.dqn_networks import NonspatialDuelingQNet
66 |
67 | assert FLAGS.policy == 'mlp'
68 | assert not FLAGS.use_action_mask
69 | network = NonspatialDuelingQNet(n_dims=env.observation_space.shape[0],
70 | n_out=env.action_space.n)
71 | agent = DQNAgent(network, env.action_space, FLAGS.model_path)
72 | return agent
73 |
74 |
75 | def create_ppo_agent(env):
76 | import tensorflow as tf
77 | import multiprocessing
78 | from sc2learner.agents.ppo_policies import LstmPolicy, MlpPolicy
79 | from sc2learner.agents.ppo_agent import PPOAgent
80 |
81 | ncpu = multiprocessing.cpu_count()
82 | if sys.platform == 'darwin': ncpu //= 2
83 | config = tf.ConfigProto(allow_soft_placement=True,
84 | intra_op_parallelism_threads=ncpu,
85 | inter_op_parallelism_threads=ncpu)
86 | config.gpu_options.allow_growth = True
87 | tf.Session(config=config).__enter__()
88 |
89 | policy = {'lstm': LstmPolicy, 'mlp': MlpPolicy}[FLAGS.policy]
90 | agent = PPOAgent(env=env, policy=policy, model_path=FLAGS.model_path)
91 | return agent
92 |
93 |
94 | def evaluate():
95 | game_seed = random.randint(0, 2**32 - 1)
96 | print("Game Seed: %d" % game_seed)
97 | env = create_env(game_seed)
98 |
99 | if FLAGS.agent == 'ppo':
100 | agent = create_ppo_agent(env)
101 | elif FLAGS.agent == 'dqn':
102 | agent = create_dqn_agent(env)
103 | elif FLAGS.agent == 'random':
104 | agent = RandomAgent(action_space=env.action_space)
105 | elif FLAGS.agent == 'keyboard':
106 | agent = KeyboardAgent(action_space=env.action_space)
107 | else:
108 | raise NotImplementedError
109 |
110 | try:
111 | cum_return = 0.0
112 | action_counts = [0] * env.action_space.n
113 | for i in range(FLAGS.num_episodes):
114 | observation = env.reset()
115 | agent.reset()
116 | done, step_id = False, 0
117 | while not done:
118 | action = agent.act(observation)
119 | print("Step ID: %d Take Action: %d" % (step_id, action))
120 | observation, reward, done, _ = env.step(action)
121 | action_counts[action] += 1
122 | cum_return += reward
123 | step_id += 1
124 | print_action_distribution(env, action_counts)
125 | print("Evaluated %d/%d Episodes Avg Return %f Avg Winning Rate %f" % (
126 | i + 1, FLAGS.num_episodes, cum_return / (i + 1),
127 | ((cum_return / (i + 1)) + 1) / 2.0))
128 | except KeyboardInterrupt: pass
129 | finally: env.close()
130 |
131 |
132 | def main(argv):
133 | logging.set_verbosity(logging.ERROR)
134 | print_arguments(FLAGS)
135 | evaluate()
136 |
137 |
138 | if __name__ == '__main__':
139 | app.run(main)
140 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # SC2Learner (TStarBot1) - Macro-Action-Based StarCraft-II Reinforcement Learning Environment
2 |
3 |
4 |
5 |
6 |
7 |
8 | *[SC2Learner](https://github.com/Tencent/TStarBot1)* is a *macro-action*-based [StarCraft-II](https://en.wikipedia.org/wiki/StarCraft_II:_Wings_of_Liberty) reinforcement learning research platform.
9 | It exposes the re-designed StarCraft-II action space, which has more than one hundred discrete macro actions, based on the raw APIs exposed by DeepMind and Blizzard's [PySC2](https://github.com/deepmind/pysc2).
10 | The macro action space relieves the learning algorithms from a disastrous burden of directly handling a massive number of atomic keyboard and mouse operations, making learning more tractable.
11 | The environments and wrappers strictly follow the interface of [OpenAI Gym](https://github.com/openai/gym), making it easier to be adapted to many off-the-shelf reinforcement learning algorithms and implementations.
12 |
13 | [*TStartBot1*](https://arxiv.org/pdf/1809.07193.pdf), a reinforcement learning agent, is also released with two off-the-shelf reinforcement learning algorithms *Dueling Double Deep Q Network* (DDQN) and *Proximal Policy Optimization* (PPO), as examples.
14 | **Distributed** versions of both algorithms are released, enabling learners to scale up the rollout experience collection across thousands of CPU cores on a cluster of machines.
15 | *TStarBot1* is able to beat **level-9** built-in AI (cheating resources) with **97%** win-rate and **level-10** (cheating insane) with **81%** win-rate.
16 |
17 | A whitepaper of *TStarBots* is available at [here](https://arxiv.org/pdf/1809.07193.pdf).
18 |
19 | ## Table of Contents
20 | - [Installations](#installations)
21 | - [Getting Started](#getting-started)
22 | - [Run Random Agent](#run-random-agent)
23 | - [Train PPO Agent](#train-ppo-agent)
24 | - [Evaluate PPO Agent](#evaluate-ppo-agent)
25 | - [Play vs. PPO Agent](#play-vs.-ppo-agent)
26 | - [Train via Self-play](#train-via-self-play)
27 | - [Environments and Wrappers](environments-and-wrappers)
28 | - [Questions ans Help](questions-and-help)
29 |
30 |
31 | ## Installations
32 |
33 | ### Prerequisites
34 | - Python >= 3.5 required.
35 | - [PySC2 Extension](https://github.com/Tencent/PySC2TencentExtension) required.
36 |
37 | ### Setup
38 | Git clone this repository and then install it with
39 | ```bash
40 | pip3 install -e sc2learner
41 | ```
42 |
43 | ## Getting Started
44 |
45 | ### Run Random Agent
46 | Run a random agent playing against a builtin AI of difficulty level 1.
47 | ```bash
48 | python3 -m sc2learner.bin.evaluate --agent random --difficulty '1'
49 | ```
50 |
51 | ### Train PPO Agent
52 |
53 | To train an agent with PPO algorithm, actor workers and learner worker must be started respectively.
54 | They can run either locally or across separate machines (e.g. actors usually run in a CPU cluster consisting of hundreds of machines with tens of thousands of CPU cores, and a learner runs in a GPU machine).
55 | With the designated ports and learner's IP, rollout trajectories and model parameters are communicated between actors and learner.
56 | - Start 48 actor workers (run the same script in all actor machines)
57 | ```bash
58 | for i in $(seq 0 47); do
59 | CUDA_VISIBLE_DEVICES= python3 -m sc2learner.bin.train_ppo --job_name=actor --learner_ip localhost &
60 | done;
61 | ```
62 |
63 | - Start a learner worker
64 | ```bash
65 | CUDA_VISIBLE_DEVICES=0 python3 -m sc2learner.bin.train_ppo --job_name learner
66 | ```
67 |
68 | Similarly, DQN algorithm can be tried with `sc2learner.bin.train_dqn`.
69 |
70 | ### Evaluate PPO Agent
71 | After training, the agent's in-game performance can be observed by letting it play the game against a build-in AI of a certain difficulty level.
72 | Win-rate is also estimated meanwhile with multiple such games initialized with different game seeds.
73 | ```bash
74 | python3 -m sc2learner.bin.evaluate --agent ppo --difficulty 1 --model_path REPLACE_WITH_YOUR_OWN_MODLE_PATH
75 | ```
76 | ###
77 |
78 | ### Play vs. PPO Agent
79 | We can also try ourselves playing against the learned agent by first starting a human player client and then a learned agent.
80 | They can run either locally or remotely.
81 | When run across two machines, `--remote` argument needs to be set for the human player side to create an SSH tunnel to the remote agent's machine and ssh keys must be used for authentication.
82 |
83 | - Start a human player client
84 | ```bash
85 | CUDA_VISIBLE_DEVICES= python3 -m pysc2.bin.play_vs_agent --human --map AbyssalReef --user_race zerg
86 | ```
87 |
88 | - Start a PPO agent
89 | ```bash
90 | python3 -m sc2learner.bin.play_vs_ppo_agent --model_path REPLACE_WITH_YOUR_OWN_MODLE_PATH
91 | ```
92 |
93 | ### Train via Self-play
94 |
95 | Besides, a self-play training (playing vs. past versions) is also provided to make learning more diversified strategies possible.
96 |
97 | - Start Actors
98 | ```bash
99 | for i in $(seq 0 48); do
100 | CUDA_VISIBLE_DEVICES= python3 -m sc2learner.bin.train_ppo_selfplay --job_name=actor --learner_ip localhost &
101 | done;
102 | ```
103 |
104 | - Start Learner
105 | ```bash
106 | CUDA_VISIBLE_DEVICES=0 python3 -m sc2learner.bin.train_ppo_selfplay --job_name learner
107 | ```
108 |
109 | ## Environments and Wrappers
110 |
111 | The environments and wrappers strictly follow the interface of [OpenAI Gym](https://github.com/openai/gym).
112 | The macro action space is defined in [`ZergActionWrapper`](https://github.com/Tencent/TStarBot1/blob/dev-open/sc2learner/envs/actions/zerg_action_wrappers.py#L26) and the observation space defined in [`ZergObservationWrapper`](https://github.com/Tencent/TStarBot1/blob/dev-open/sc2learner/envs/observations/zerg_observation_wrappers.py#L24), based on which users can easily make their own changes and restart the training to see what happens.
113 |
114 | ## Questions and Help
115 | You are welcome to submit questions and bug reports in [Github Issues](https://github.com/Tencent/TStarBot1/issues).
116 | You are also welcome to contribute to this project.
117 |
--------------------------------------------------------------------------------
/sc2learner/envs/common/const.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | from enum import Enum
6 | from enum import unique
7 | from collections import namedtuple
8 |
9 | from pysc2.lib.typeenums import UNIT_TYPEID as UNIT_TYPE
10 |
11 |
12 | @unique
13 | class ALLY_TYPE(Enum):
14 | SELF = 1
15 | ALLY = 2
16 | NEUTRAL = 3
17 | ENEMY = 4
18 |
19 |
20 | @unique
21 | class PLAYER_FEATURE(Enum):
22 | PLAYER_ID = 0
23 | MINERALS = 1
24 | VESPENE = 2
25 | FOOD_USED = 3
26 | FOOD_CAP = 4
27 | FOOD_ARMY = 5
28 | FOOD_WORKER = 6
29 | IDLE_WORKER_COUNT = 7
30 | ARMY_COUNT = 8
31 | WARP_GATE_COUNT = 9
32 | LARVA_COUNT = 10
33 |
34 |
35 | NEUTRAL_DESTRUCTABLEROCKEX14X4 = 638
36 | NEUTRAL_UNBUILDABLEROCKSDESTRUCIBLE = 472
37 |
38 |
39 | PLACE_COLLISION_BUILDINGS = {
40 | UNIT_TYPE.NEUTRAL_MINERALFIELD.value,
41 | UNIT_TYPE.NEUTRAL_MINERALFIELD750.value,
42 | UNIT_TYPE.NEUTRAL_VESPENEGEYSER.value,
43 | UNIT_TYPE.NEUTRAL_DESTRUCTIBLEROCK6X6.value,
44 | UNIT_TYPE.NEUTRAL_DESTRUCTIBLEROCKEX1DIAGONALHUGEBLUR.value,
45 | NEUTRAL_DESTRUCTABLEROCKEX14X4,
46 | NEUTRAL_UNBUILDABLEROCKSDESTRUCIBLE,
47 | UNIT_TYPE.ZERG_EXTRACTOR.value,
48 | UNIT_TYPE.ZERG_SPAWNINGPOOL.value,
49 | UNIT_TYPE.ZERG_ROACHWARREN.value,
50 | UNIT_TYPE.ZERG_HYDRALISKDEN.value,
51 | UNIT_TYPE.ZERG_HATCHERY.value,
52 | UNIT_TYPE.ZERG_EVOLUTIONCHAMBER.value,
53 | UNIT_TYPE.ZERG_BANELINGNEST.value,
54 | UNIT_TYPE.ZERG_INFESTATIONPIT.value,
55 | UNIT_TYPE.ZERG_SPIRE.value,
56 | UNIT_TYPE.ZERG_ULTRALISKCAVERN.value,
57 | UNIT_TYPE.ZERG_NYDUSNETWORK.value,
58 | UNIT_TYPE.ZERG_SPINECRAWLER.value,
59 | UNIT_TYPE.ZERG_SPORECRAWLER.value,
60 | UNIT_TYPE.ZERG_LURKERDENMP.value,
61 | UNIT_TYPE.ZERG_LAIR.value,
62 | UNIT_TYPE.ZERG_HIVE.value,
63 | UNIT_TYPE.ZERG_GREATERSPIRE.value
64 | }
65 |
66 |
67 | MAXIMUM_NUM = {
68 | UNIT_TYPE.ZERG_SPAWNINGPOOL.value: 1,
69 | UNIT_TYPE.ZERG_ROACHWARREN.value: 1,
70 | UNIT_TYPE.ZERG_HYDRALISKDEN.value: 1,
71 | UNIT_TYPE.ZERG_HATCHERY.value: 6,
72 | UNIT_TYPE.ZERG_EVOLUTIONCHAMBER.value: 2,
73 | UNIT_TYPE.ZERG_BANELINGNEST.value: 1,
74 | UNIT_TYPE.ZERG_INFESTATIONPIT.value: 1,
75 | UNIT_TYPE.ZERG_SPIRE.value: 1,
76 | UNIT_TYPE.ZERG_ULTRALISKCAVERN.value: 1,
77 | UNIT_TYPE.ZERG_NYDUSNETWORK.value: 1,
78 | UNIT_TYPE.ZERG_LURKERDENMP.value: 1,
79 | UNIT_TYPE.ZERG_LAIR.value: 1,
80 | UNIT_TYPE.ZERG_HIVE.value: 1,
81 | UNIT_TYPE.ZERG_GREATERSPIRE.value: 1
82 | }
83 |
84 |
85 | AttackAttr = namedtuple('AttackAttr', ('can_attack_ground', 'can_attack_air'))
86 |
87 |
88 | ATTACK_FORCE = {
89 | UNIT_TYPE.ZERG_LARVA.value: AttackAttr(False, False),
90 | UNIT_TYPE.ZERG_DRONE.value: AttackAttr(True, False),
91 | UNIT_TYPE.ZERG_ZERGLING.value: AttackAttr(True, False),
92 | UNIT_TYPE.ZERG_BANELING.value: AttackAttr(True, False),
93 | UNIT_TYPE.ZERG_ROACH.value: AttackAttr(True, False),
94 | UNIT_TYPE.ZERG_ROACHBURROWED.value: AttackAttr(True, False),
95 | UNIT_TYPE.ZERG_RAVAGER.value: AttackAttr(True, False),
96 | UNIT_TYPE.ZERG_HYDRALISK.value: AttackAttr(True, True),
97 | UNIT_TYPE.ZERG_LURKERMP.value: AttackAttr(True, False),
98 | UNIT_TYPE.ZERG_LURKERMPBURROWED.value: AttackAttr(True, False),
99 | UNIT_TYPE.ZERG_VIPER.value: AttackAttr(False, False),
100 | UNIT_TYPE.ZERG_MUTALISK.value: AttackAttr(True, True),
101 | UNIT_TYPE.ZERG_CORRUPTOR.value: AttackAttr(False, True),
102 | UNIT_TYPE.ZERG_BROODLORD.value: AttackAttr(True, False),
103 | UNIT_TYPE.ZERG_SWARMHOSTMP.value: AttackAttr(False, False),
104 | UNIT_TYPE.ZERG_LOCUSTMP.value: AttackAttr(True, False),
105 | UNIT_TYPE.ZERG_INFESTOR.value: AttackAttr(False, False),
106 | UNIT_TYPE.ZERG_ULTRALISK.value: AttackAttr(True, False),
107 | UNIT_TYPE.ZERG_BROODLING.value: AttackAttr(True, False),
108 | UNIT_TYPE.ZERG_OVERLORD.value: AttackAttr(False, False),
109 | UNIT_TYPE.ZERG_OVERSEER.value: AttackAttr(False, False),
110 | UNIT_TYPE.ZERG_QUEEN.value: AttackAttr(True, True),
111 | UNIT_TYPE.ZERG_CHANGELING.value: AttackAttr(False, False),
112 | UNIT_TYPE.ZERG_SPINECRAWLER.value: AttackAttr(True, False),
113 | UNIT_TYPE.ZERG_SPORECRAWLER.value: AttackAttr(False, True),
114 | UNIT_TYPE.ZERG_NYDUSCANAL.value: AttackAttr(False, False)
115 | }
116 |
117 |
118 | PRIORITIZED_ATTACK = set([
119 | UNIT_TYPE.ZERG_ZERGLING.value,
120 | UNIT_TYPE.ZERG_BANELING.value,
121 | UNIT_TYPE.ZERG_ROACH.value,
122 | UNIT_TYPE.ZERG_ROACHBURROWED.value,
123 | UNIT_TYPE.ZERG_RAVAGER.value,
124 | UNIT_TYPE.ZERG_HYDRALISK.value,
125 | UNIT_TYPE.ZERG_LURKERMP.value,
126 | UNIT_TYPE.ZERG_LURKERMPBURROWED.value,
127 | UNIT_TYPE.ZERG_VIPER.value,
128 | UNIT_TYPE.ZERG_MUTALISK.value,
129 | UNIT_TYPE.ZERG_CORRUPTOR.value,
130 | UNIT_TYPE.ZERG_BROODLORD.value,
131 | UNIT_TYPE.ZERG_SWARMHOSTMP.value,
132 | UNIT_TYPE.ZERG_LOCUSTMP.value,
133 | UNIT_TYPE.ZERG_INFESTOR.value,
134 | UNIT_TYPE.ZERG_ULTRALISK.value,
135 | UNIT_TYPE.ZERG_BROODLING.value,
136 | UNIT_TYPE.ZERG_QUEEN.value,
137 | UNIT_TYPE.ZERG_CHANGELING.value,
138 | UNIT_TYPE.ZERG_SPINECRAWLER.value,
139 | UNIT_TYPE.ZERG_SPORECRAWLER.value
140 | ])
141 |
142 |
143 | COMBAT_TYPES = {
144 | UNIT_TYPE.ZERG_ZERGLING.value,
145 | UNIT_TYPE.ZERG_BANELING.value,
146 | UNIT_TYPE.ZERG_ROACH.value,
147 | UNIT_TYPE.ZERG_ROACHBURROWED.value,
148 | UNIT_TYPE.ZERG_RAVAGER.value,
149 | UNIT_TYPE.ZERG_HYDRALISK.value,
150 | UNIT_TYPE.ZERG_LURKERMP.value,
151 | UNIT_TYPE.ZERG_LURKERMPBURROWED.value,
152 | UNIT_TYPE.ZERG_MUTALISK.value,
153 | UNIT_TYPE.ZERG_CORRUPTOR.value,
154 | UNIT_TYPE.ZERG_BROODLORD.value,
155 | UNIT_TYPE.ZERG_ULTRALISK.value
156 | #UNIT_TYPE.ZERG_LOCUSTMP.value,
157 | #UNIT_TYPE.ZERG_BROODLING.value
158 | }
159 |
160 |
161 | class MAP(object):
162 | WIDTH = 200.0
163 | HEIGHT = 176.0
164 | LEFT = 24.0
165 | RIGHT = 24.0
166 | TOP = 37.0
167 | BOTTOM = 4.0
168 |
--------------------------------------------------------------------------------
/sc2learner/envs/rewards/reward_wrappers.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | import numpy as np
6 | import gym
7 | from pysc2.lib.typeenums import UNIT_TYPEID as UNIT_TYPE
8 |
9 | from sc2learner.envs.spaces.pysc2_raw import PySC2RawObservation
10 | from sc2learner.envs.common.const import ALLY_TYPE
11 |
12 |
13 | class RewardShapingWrapperV1(gym.Wrapper):
14 |
15 | def __init__(self, env):
16 | super(RewardShapingWrapperV1, self).__init__(env)
17 | assert isinstance(env.observation_space, PySC2RawObservation)
18 | self._combat_unit_types = set([UNIT_TYPE.ZERG_ZERGLING.value,
19 | UNIT_TYPE.ZERG_ROACH.value,
20 | UNIT_TYPE.ZERG_HYDRALISK.value])
21 | self.reward_range = (-np.inf, np.inf)
22 |
23 | def step(self, action):
24 | observation, outcome, done, info = self.env.step(action)
25 | n_enemies, n_self_combats = self._get_unit_counts(observation)
26 | if n_self_combats - n_enemies > self._n_self_combats - self._n_enemies:
27 | reward = 1
28 | elif n_self_combats - n_enemies < self._n_self_combats - self._n_enemies:
29 | reward = -1
30 | else:
31 | reward = 0
32 | if not done: reward += outcome * 10
33 | else: reward = outcome * 10
34 | self._n_enemies = n_enemies
35 | self._n_self_combats = n_self_combats
36 | return observation, reward, done, info
37 |
38 | def reset(self, **kwargs):
39 | observation = self.env.reset()
40 | self._n_enemies, self._n_self_combats = self._get_unit_counts(observation)
41 | return observation
42 |
43 | @property
44 | def action_names(self):
45 | if not hasattr(self.env, 'action_names'): raise NotImplementedError
46 | return self.env.action_names
47 |
48 | @property
49 | def player_position(self):
50 | if not hasattr(self.env, 'player_position'): raise NotImplementedError
51 | return self.env.player_position
52 |
53 | def _get_unit_counts(self, observation):
54 | num_enemy_units, num_self_combat_units = 0, 0
55 | for u in observation['units']:
56 | if u.int_attr.alliance == ALLY_TYPE.ENEMY.value:
57 | num_enemy_units += 1
58 | elif u.int_attr.alliance == ALLY_TYPE.SELF.value:
59 | if u.unit_type in self._combat_unit_types:
60 | num_self_combat_units += 1
61 | return num_enemy_units, num_self_combat_units
62 |
63 |
64 | class RewardShapingWrapperV2(gym.Wrapper):
65 |
66 | def __init__(self, env):
67 | super(RewardShapingWrapperV2, self).__init__(env)
68 | assert isinstance(env.observation_space, PySC2RawObservation)
69 | self._combat_unit_types = set([UNIT_TYPE.ZERG_ZERGLING.value,
70 | UNIT_TYPE.ZERG_ROACH.value,
71 | UNIT_TYPE.ZERG_HYDRALISK.value,
72 | UNIT_TYPE.ZERG_RAVAGER.value,
73 | UNIT_TYPE.ZERG_BANELING.value,
74 | UNIT_TYPE.ZERG_BROODLING.value])
75 | self.reward_range = (-np.inf, np.inf)
76 |
77 | def step(self, action):
78 | observation, reward, done, info = self.env.step(action)
79 | n_enemies, n_selves = self._get_unit_counts(observation)
80 | diff_selves = n_selves - self._n_selves
81 | diff_enemies = n_enemies - self._n_enemies
82 | if not done: reward += (diff_selves - diff_enemies) * 0.02
83 | self._n_enemies = n_enemies
84 | self._n_selves = n_selves
85 | return observation, reward, done, info
86 |
87 | def reset(self, **kwargs):
88 | observation = self.env.reset()
89 | self._n_enemies, self._n_selves = self._get_unit_counts(observation)
90 | return observation
91 |
92 | @property
93 | def action_names(self):
94 | if not hasattr(self.env, 'action_names'): raise NotImplementedError
95 | return self.env.action_names
96 |
97 | @property
98 | def player_position(self):
99 | if not hasattr(self.env, 'player_position'): raise NotImplementedError
100 | return self.env.player_position
101 |
102 | def _get_unit_counts(self, observation):
103 | num_enemy_units, num_self_units = 0, 0
104 | for u in observation['units']:
105 | if u.int_attr.alliance == ALLY_TYPE.ENEMY.value:
106 | if u.unit_type in self._combat_unit_types:
107 | num_enemy_units += 1
108 | elif u.int_attr.alliance == ALLY_TYPE.SELF.value:
109 | if u.unit_type in self._combat_unit_types:
110 | num_self_units += 1
111 | return num_enemy_units, num_self_units
112 |
113 |
114 | class KillingRewardWrapper(gym.Wrapper):
115 |
116 | def __init__(self, env):
117 | super(KillingRewardWrapper, self).__init__(env)
118 | assert isinstance(env.observation_space, PySC2RawObservation)
119 | self.reward_range = (-np.inf, np.inf)
120 | self._last_kill_value = 0
121 |
122 | def step(self, action):
123 | observation, reward, done, info = self.env.step(action)
124 | kill_value = observation.score_cumulative[5] + \
125 | observation.score_cumulative[6]
126 | if not done:
127 | reward += (kill_value - self._last_kill_value) * 1e-5
128 | self._last_kill_value = kill_value
129 | return observation, reward, done, info
130 |
131 | def reset(self):
132 | observation = self.env.reset()
133 | kill_value = observation.score_cumulative[5] + \
134 | observation.score_cumulative[6]
135 | self._last_kill_value = kill_value
136 | return observation
137 |
138 | @property
139 | def action_names(self):
140 | if not hasattr(self.env, 'action_names'): raise NotImplementedError
141 | return self.env.action_names
142 |
143 | @property
144 | def player_position(self):
145 | if not hasattr(self.env, 'player_position'): raise NotImplementedError
146 | return self.env.player_position
147 | observation = self.env.reset()
148 | return observation
149 |
150 | @property
151 | def action_names(self):
152 | if not hasattr(self.env, 'action_names'): raise NotImplementedError
153 | return self.env.action_names
154 |
155 | @property
156 | def player_position(self):
157 | if not hasattr(self.env, 'player_position'): raise NotImplementedError
158 | return self.env.player_position
159 |
--------------------------------------------------------------------------------
/sc2learner/bin/train_dqn.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | import sys
6 | import os
7 | import random
8 | import time
9 |
10 | import torch
11 | from absl import app
12 | from absl import flags
13 | from absl import logging
14 |
15 | from sc2learner.envs.raw_env import SC2RawEnv
16 | from sc2learner.envs.actions.zerg_action_wrappers import ZergActionWrapper
17 | from sc2learner.envs.observations.zerg_observation_wrappers \
18 | import ZergObservationWrapper
19 | from sc2learner.agents.dqn_agent import DQNActor
20 | from sc2learner.agents.dqn_agent import DQNLearner
21 | from sc2learner.agents.dqn_networks import NonspatialDuelingQNet
22 | from sc2learner.utils.utils import print_arguments
23 |
24 |
25 | FLAGS = flags.FLAGS
26 | flags.DEFINE_enum("job_name", 'actor', ['actor', 'learner'], "Job type.")
27 | flags.DEFINE_string("learner_ip", "localhost", "Learner IP address.")
28 | flags.DEFINE_string("ports", "5700,5701,5702",
29 | "3 ports for distributed replay memory.")
30 | flags.DEFINE_integer("client_memory_size", 50000,
31 | "Total size of client memory.")
32 | flags.DEFINE_integer("client_memory_warmup_size", 2000,
33 | "Memory warmup size for client.")
34 | flags.DEFINE_integer("server_memory_size", 1000000,
35 | "Total size of server memory.")
36 | flags.DEFINE_integer("server_memory_warmup_size", 100000,
37 | "Memory warmup size for client.")
38 | flags.DEFINE_string("game_version", '4.6', "Game core version.")
39 | flags.DEFINE_float("discount", 0.995, "Discount factor.")
40 | flags.DEFINE_float("send_freq", 4.0, "Probability of a step being pushed.")
41 | flags.DEFINE_integer("step_mul", 32, "Game steps per agent step.")
42 | flags.DEFINE_string("difficulties", '1,2,4,6,9,A', "Bot's strengths.")
43 | flags.DEFINE_float("eps_start", 1.0, "Max greedy epsilon for exploration.")
44 | flags.DEFINE_float("eps_end", 0.1, "Min greedy epsilon for exploration.")
45 | flags.DEFINE_integer("eps_decay_steps", 1000000, "Greedy epsilon decay step.")
46 | flags.DEFINE_integer("eps_decay_steps2", 10000000, "Greedy epsilon decay step.")
47 | flags.DEFINE_float("learning_rate", 1e-6, "Learning rate.")
48 | flags.DEFINE_float("adam_eps", 1e-7, "Adam optimizer's epsilon.")
49 | flags.DEFINE_float("gradient_clipping", 10.0, "Gradient clipping threshold.")
50 | flags.DEFINE_integer("batch_size", 256, "Batch size.")
51 | flags.DEFINE_float("mmc_beta", 0.9, "Discount.")
52 | flags.DEFINE_integer("target_update_interval", 10000,
53 | "Target net update interval.")
54 | flags.DEFINE_string("init_model_path", None, "Checkpoint to initialize model.")
55 | flags.DEFINE_string("checkpoint_dir", "./checkpoints", "Dir to save models to")
56 | flags.DEFINE_integer("checkpoint_interval", 500000, "Model saving frequency.")
57 | flags.DEFINE_integer("print_interval", 10000, "Print train cost frequency.")
58 | flags.DEFINE_boolean("disable_fog", False, "Disable fog-of-war.")
59 | flags.DEFINE_boolean("use_all_combat_actions", False, "Use all combat actions.")
60 | flags.DEFINE_boolean("use_region_features", True, "Use region features")
61 | flags.FLAGS(sys.argv)
62 |
63 |
64 | def create_env(difficulty, random_seed=None):
65 | env = SC2RawEnv(map_name='AbyssalReef',
66 | step_mul=FLAGS.step_mul,
67 | resolution=16,
68 | agent_race='zerg',
69 | bot_race='zerg',
70 | difficulty=difficulty,
71 | disable_fog=FLAGS.disable_fog,
72 | random_seed=random_seed)
73 | env = ZergActionWrapper(env,
74 | game_version=FLAGS.game_version,
75 | mask=False,
76 | use_all_combat_actions=FLAGS.use_all_combat_actions)
77 | env = ZergObservationWrapper(env,
78 | use_spatial_features=False,
79 | use_regions=FLAGS.use_region_features)
80 | return env
81 |
82 |
83 | def create_network(env):
84 | return NonspatialDuelingQNet(n_dims=env.observation_space.shape[0],
85 | n_out=env.action_space.n)
86 |
87 |
88 | def start_actor_job():
89 | random.seed(time.time())
90 | difficulty = random.choice(FLAGS.difficulties.split(','))
91 | game_seed = random.randint(0, 2**32 - 1)
92 | print("Game Seed: %d Difficulty: %s" % (game_seed, difficulty))
93 | env = create_env(difficulty, game_seed)
94 | network = create_network(env)
95 | actor = DQNActor(memory_size=FLAGS.client_memory_size,
96 | memory_warmup_size=FLAGS.client_memory_warmup_size,
97 | env=env,
98 | network=network,
99 | discount=FLAGS.discount,
100 | send_freq=FLAGS.send_freq,
101 | ports=FLAGS.ports.split(','),
102 | learner_ip=FLAGS.learner_ip)
103 | actor.run()
104 | env.close()
105 |
106 |
107 | def start_learner_job():
108 | if not os.path.exists(FLAGS.checkpoint_dir):
109 | os.makedirs(FLAGS.checkpoint_dir)
110 |
111 | env = create_env('1', 0)
112 | network = create_network(env)
113 | learner = DQNLearner(network=network,
114 | action_space=env.action_space,
115 | memory_size=FLAGS.server_memory_size,
116 | memory_warmup_size=FLAGS.server_memory_warmup_size,
117 | discount=FLAGS.discount,
118 | eps_start=FLAGS.eps_start,
119 | eps_end=FLAGS.eps_end,
120 | eps_decay_steps=FLAGS.eps_decay_steps,
121 | eps_decay_steps2=FLAGS.eps_decay_steps2,
122 | batch_size=FLAGS.batch_size,
123 | mmc_beta=FLAGS.mmc_beta,
124 | gradient_clipping=FLAGS.gradient_clipping,
125 | adam_eps=FLAGS.adam_eps,
126 | learning_rate=FLAGS.learning_rate,
127 | target_update_interval=FLAGS.target_update_interval,
128 | checkpoint_dir=FLAGS.checkpoint_dir,
129 | checkpoint_interval=FLAGS.checkpoint_interval,
130 | print_interval=FLAGS.print_interval,
131 | ports=FLAGS.ports.split(','),
132 | init_model_path=FLAGS.init_model_path)
133 | learner.run()
134 | env.close()
135 |
136 |
137 | def main(argv):
138 | logging.set_verbosity(logging.ERROR)
139 | print_arguments(FLAGS)
140 | if FLAGS.job_name == 'actor': start_actor_job()
141 | else: start_learner_job()
142 |
143 |
144 | if __name__ == '__main__':
145 | app.run(main)
146 |
--------------------------------------------------------------------------------
/sc2learner/bin/train_ppo.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | import sys
6 | from threading import Thread
7 | import os
8 | import multiprocessing
9 | import random
10 | import time
11 |
12 | from absl import app
13 | from absl import flags
14 | from absl import logging
15 | import tensorflow as tf
16 |
17 | from sc2learner.agents.ppo_policies import LstmPolicy, MlpPolicy
18 | from sc2learner.agents.ppo_agent import PPOActor, PPOLearner
19 | from sc2learner.envs.raw_env import SC2RawEnv
20 | from sc2learner.envs.rewards.reward_wrappers import KillingRewardWrapper
21 | from sc2learner.envs.actions.zerg_action_wrappers import ZergActionWrapper
22 | from sc2learner.envs.observations.zerg_observation_wrappers \
23 | import ZergObservationWrapper
24 | from sc2learner.utils.utils import print_arguments
25 |
26 |
27 | FLAGS = flags.FLAGS
28 | flags.DEFINE_enum("job_name", 'actor', ['actor', 'learner'], "Job type.")
29 | flags.DEFINE_enum("policy", 'mlp', ['mlp', 'lstm'], "Job type.")
30 | flags.DEFINE_integer("unroll_length", 128, "Length of rollout steps.")
31 | flags.DEFINE_string("learner_ip", "localhost", "Learner IP address.")
32 | flags.DEFINE_string("port_A", "5700", "Port for transporting model.")
33 | flags.DEFINE_string("port_B", "5701", "Port for transporting data.")
34 | flags.DEFINE_string("game_version", '4.6', "Game core version.")
35 | flags.DEFINE_float("discount_gamma", 0.998, "Discount factor.")
36 | flags.DEFINE_float("lambda_return", 0.95, "Lambda return factor.")
37 | flags.DEFINE_float("clip_range", 0.1, "Clip range for PPO.")
38 | flags.DEFINE_float("ent_coef", 0.01, "Coefficient for the entropy term.")
39 | flags.DEFINE_float("vf_coef", 0.5, "Coefficient for the value loss.")
40 | flags.DEFINE_float("learn_act_speed_ratio", 0, "Maximum learner/actor ratio.")
41 | flags.DEFINE_integer("batch_size", 32, "Batch size.")
42 | flags.DEFINE_integer("game_steps_per_episode", 43200, "Maximum steps per episode.")
43 | flags.DEFINE_integer("learner_queue_size", 1024, "Size of learner's unroll queue.")
44 | flags.DEFINE_integer("step_mul", 32, "Game steps per agent step.")
45 | flags.DEFINE_string("difficulties", '1,2,4,6,9,A', "Bot's strengths.")
46 | flags.DEFINE_float("learning_rate", 1e-5, "Learning rate.")
47 | flags.DEFINE_string("init_model_path", None, "Initial model path.")
48 | flags.DEFINE_string("save_dir", "./checkpoints/", "Dir to save models to")
49 | flags.DEFINE_integer("save_interval", 50000, "Model saving frequency.")
50 | flags.DEFINE_integer("print_interval", 1000, "Print train cost frequency.")
51 | flags.DEFINE_boolean("disable_fog", False, "Disable fog-of-war.")
52 | flags.DEFINE_boolean("use_all_combat_actions", False, "Use all combat actions.")
53 | flags.DEFINE_boolean("use_region_features", False, "Use region features")
54 | flags.DEFINE_boolean("use_action_mask", True, "Use region-wise combat.")
55 | flags.DEFINE_boolean("use_reward_shaping", False, "Use reward shaping.")
56 | flags.FLAGS(sys.argv)
57 |
58 |
59 | def tf_config(ncpu=None):
60 | if ncpu is None:
61 | ncpu = multiprocessing.cpu_count()
62 | if sys.platform == 'darwin': ncpu //= 2
63 | config = tf.ConfigProto(allow_soft_placement=True,
64 | intra_op_parallelism_threads=ncpu,
65 | inter_op_parallelism_threads=ncpu)
66 | config.gpu_options.allow_growth = True
67 | tf.Session(config=config).__enter__()
68 |
69 |
70 | def create_env(difficulty, random_seed=None):
71 | env = SC2RawEnv(map_name='AbyssalReef',
72 | step_mul=FLAGS.step_mul,
73 | resolution=16,
74 | agent_race='zerg',
75 | bot_race='zerg',
76 | difficulty=difficulty,
77 | disable_fog=FLAGS.disable_fog,
78 | tie_to_lose=False,
79 | game_steps_per_episode=FLAGS.game_steps_per_episode,
80 | random_seed=random_seed)
81 | if FLAGS.use_reward_shaping: env = KillingRewardWrapper(env)
82 | env = ZergActionWrapper(env,
83 | game_version=FLAGS.game_version,
84 | mask=FLAGS.use_action_mask,
85 | use_all_combat_actions=FLAGS.use_all_combat_actions)
86 | env = ZergObservationWrapper(env,
87 | use_spatial_features=False,
88 | use_game_progress=(not FLAGS.policy == 'lstm'),
89 | action_seq_len=1 if FLAGS.policy == 'lstm' else 8,
90 | use_regions=FLAGS.use_region_features)
91 | print(env.observation_space, env.action_space)
92 | return env
93 |
94 |
95 | def start_actor():
96 | tf_config(ncpu=2)
97 | random.seed(time.time())
98 | difficulty = random.choice(FLAGS.difficulties.split(','))
99 | game_seed = random.randint(0, 2**32 - 1)
100 | print("Game Seed: %d Difficulty: %s" % (game_seed, difficulty))
101 | env = create_env(difficulty, game_seed)
102 | policy = {'lstm': LstmPolicy,
103 | 'mlp': MlpPolicy}[FLAGS.policy]
104 | actor = PPOActor(env=env,
105 | policy=policy,
106 | unroll_length=FLAGS.unroll_length,
107 | gamma=FLAGS.discount_gamma,
108 | lam=FLAGS.lambda_return,
109 | learner_ip=FLAGS.learner_ip,
110 | port_A=FLAGS.port_A,
111 | port_B=FLAGS.port_B)
112 | actor.run()
113 | env.close()
114 |
115 |
116 | def start_learner():
117 | tf_config()
118 | env = create_env('1', 0)
119 | policy = {'lstm': LstmPolicy,
120 | 'mlp': MlpPolicy}[FLAGS.policy]
121 | learner = PPOLearner(env=env,
122 | policy=policy,
123 | unroll_length=FLAGS.unroll_length,
124 | lr=FLAGS.learning_rate,
125 | clip_range=FLAGS.clip_range,
126 | batch_size=FLAGS.batch_size,
127 | ent_coef=FLAGS.ent_coef,
128 | vf_coef=FLAGS.vf_coef,
129 | max_grad_norm=0.5,
130 | queue_size=FLAGS.learner_queue_size,
131 | print_interval=FLAGS.print_interval,
132 | save_interval=FLAGS.save_interval,
133 | learn_act_speed_ratio=FLAGS.learn_act_speed_ratio,
134 | save_dir=FLAGS.save_dir,
135 | init_model_path=FLAGS.init_model_path,
136 | port_A=FLAGS.port_A,
137 | port_B=FLAGS.port_B)
138 | learner.run()
139 | env.close()
140 |
141 |
142 | def main(argv):
143 | logging.set_verbosity(logging.ERROR)
144 | print_arguments(FLAGS)
145 | if FLAGS.job_name == 'actor': start_actor()
146 | else: start_learner()
147 |
148 |
149 | if __name__ == '__main__':
150 | app.run(main)
151 |
--------------------------------------------------------------------------------
/sc2learner/envs/actions/placer.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | import random
6 | import math
7 |
8 | import numpy as np
9 | from pysc2.lib.typeenums import UNIT_TYPEID as UNIT_TYPE
10 |
11 | import sc2learner.envs.common.utils as utils
12 | from sc2learner.envs.common.const import PLACE_COLLISION_BUILDINGS
13 |
14 |
15 | class Placer(object):
16 |
17 | def get_building_position(self, type_id, dc):
18 | if type_id == UNIT_TYPE.ZERG_HATCHERY.value:
19 | return self._next_base_place(dc)
20 | elif type_id == UNIT_TYPE.ZERG_EXTRACTOR.value:
21 | gas = dc.exploitable_gas
22 | return random.choice(gas) if len(gas) > 0 else None
23 | else:
24 | place = self._constructable_place(1.5, dc)
25 | return random.choice(place) if len(place) > 0 else None
26 |
27 | def can_build(self, type_id, dc):
28 | if type_id == UNIT_TYPE.ZERG_HATCHERY.value:
29 | return self._next_base_place(dc) is not None
30 | elif type_id == UNIT_TYPE.ZERG_EXTRACTOR.value:
31 | return len(dc.exploitable_gas) > 0
32 | else:
33 | place = self._constructable_place(1.5, dc)
34 | return len(place) > 0
35 |
36 | def _constructable_place(self, margin, dc):
37 | place = []
38 | bases = dc.mature_units_of_types([UNIT_TYPE.ZERG_HATCHERY.value,
39 | UNIT_TYPE.ZERG_LAIR.value,
40 | UNIT_TYPE.ZERG_HIVE.value])
41 | for base in bases:
42 | search_region = (base.float_attr.pos_x - 10.5,
43 | base.float_attr.pos_y - 10.5,
44 | 10.5 * 2,
45 | 10.5 * 2)
46 | place.extend(self._search_place(search_region, dc, margin=margin,
47 | remove_corner=True, expand_mineral=True))
48 | return place
49 |
50 | def _next_base_place(self, dc):
51 | unexploited_minerals = dc.unexploited_minerals
52 | if len(unexploited_minerals) == 0: return None
53 | mineral_to_exploit = utils.closest_unit(dc.init_base_pos,
54 | unexploited_minerals)
55 | resources_nearby = utils.units_nearby(mineral_to_exploit,
56 | dc.minerals + dc.gas,
57 | max_distance=14)
58 | x_list = [u.float_attr.pos_x for u in resources_nearby]
59 | y_list = [u.float_attr.pos_y for u in resources_nearby]
60 | x_mean = sum(x_list) / len(x_list)
61 | y_mean = sum(y_list) / len(y_list)
62 | left = int(math.floor(min(x_list)))
63 | right = int(math.ceil(max(x_list)))
64 | bottom = int(math.floor(min(y_list)))
65 | top = int(math.ceil(max(y_list)))
66 | width = right - left + 1
67 | height = top - bottom + 1
68 | x_offset, y_offset = 0, 0
69 | if height - width >= 5:
70 | left_mid = (left, (bottom + top) / 2)
71 | right_mid = (right, (bottom + top) / 2)
72 | if utils.closest_distance(left_mid, resources_nearby) > \
73 | utils.closest_distance(right_mid, resources_nearby):
74 | x_offset = width - height + 1
75 | width = height - 1
76 | elif height - width <= -5:
77 | top_mid = ((left + right) / 2, top)
78 | bottom_mid = ((left + right) / 2, bottom)
79 | if utils.closest_distance(top_mid, resources_nearby) < \
80 | utils.closest_distance(bottom_mid, resources_nearby):
81 | y_offset = height - width + 1
82 | height = width - 1
83 | region = [left + x_offset, bottom + y_offset, width, height]
84 | place = self._search_place(region, dc, margin=5.5, shrink_mineral=True)
85 | return utils.closest_unit((x_mean, y_mean), place) \
86 | if len(place) > 0 else None
87 |
88 | def _search_place(self, search_region, dc, margin=0, remove_corner=False,
89 | expand_mineral=False, shrink_mineral=False):
90 | bottomleft = tuple(map(int, search_region[:2]))
91 | size = tuple(map(int, search_region[2:]))
92 | grids = np.zeros(size).astype(np.int)
93 | if remove_corner:
94 | cx, cy = size[0] / 2.0, size[1] / 2.0
95 | r = max(size[0] / 2.0, size[1] / 2.0)
96 | r_sqrt = (r - 0.5) ** 2
97 | for x in range(size[0]):
98 | x_sqrt = (x + 0.5 - cx) ** 2
99 | for y in range(size[1]):
100 | y_sqrt = (y + 0.5 - cy) ** 2
101 | if x_sqrt + y_sqrt > r_sqrt:
102 | grids[x, y] = 1
103 | filter_range = (bottomleft[0] - 10, bottomleft[0] + size[0] + 10,
104 | bottomleft[1] - 10, bottomleft[1] + size[1] + 10)
105 | sitted_units = [u for u in dc.units
106 | if (u.unit_type in PLACE_COLLISION_BUILDINGS and
107 | u.float_attr.pos_x >= filter_range[0] and
108 | u.float_attr.pos_x <= filter_range[1] and
109 | u.float_attr.pos_y >= filter_range[2] and
110 | u.float_attr.pos_y <= filter_range[3])]
111 | sitted_units = [u for u in dc.units
112 | if u.unit_type in PLACE_COLLISION_BUILDINGS]
113 | margin_int = math.floor(margin)
114 | for u in sitted_units:
115 | if u.float_attr.radius <= 1.0:
116 | r_x = r_y = 1.0
117 | if u.float_attr.pos_x % 1 != 0: r_x -= 0.5
118 | if u.float_attr.pos_y % 1 != 0: r_y -= 0.5
119 | else:
120 | r_x = r_y = float(int(u.float_attr.radius))
121 | if u.float_attr.pos_x % 1 != 0: r_x += 0.5
122 | if u.float_attr.pos_y % 1 != 0: r_y += 0.5
123 | if (shrink_mineral and
124 | u.unit_type in {UNIT_TYPE.NEUTRAL_MINERALFIELD.value,
125 | UNIT_TYPE.NEUTRAL_MINERALFIELD750.value}):
126 | if r_x == 1.5 and r_y == 1: r_x = 0.5
127 | elif r_x == 1 and r_y == 1.5: r_y = 0.5
128 | if (expand_mineral and
129 | u.unit_type in {UNIT_TYPE.NEUTRAL_MINERALFIELD.value,
130 | UNIT_TYPE.NEUTRAL_MINERALFIELD750.value}):
131 | r_x += 1
132 | r_y += 1
133 | r_x += margin_int
134 | r_y += margin_int
135 | xl = int(u.float_attr.pos_x - r_x - bottomleft[0])
136 | xr = int(u.float_attr.pos_x + r_x - bottomleft[0])
137 | yu = int(u.float_attr.pos_y - r_y - bottomleft[1])
138 | yd = int(u.float_attr.pos_y + r_y - bottomleft[1])
139 | grids[max(xl, 0) : max(min(xr, size[0]), 0),
140 | max(yu, 0) : max(min(yd, size[1]), 0)] = 1
141 |
142 | slopes = [(76.5, 90.5), (73.5, 86.5), (123.5, 52.5), (126.5, 56.5),
143 | (131.5, 36.5), (68.5, 106.5)]
144 | holes = [(124.5, 34.5), (76.5, 108.5), (154.5, 60.5), (45.5, 82.5)]
145 | r = 2.5
146 | for x, y in slopes + holes:
147 | xl = int(x - r - bottomleft[0])
148 | xr = int(x + r - bottomleft[0])
149 | yu = int(y - r - bottomleft[1])
150 | yd = int(y + r - bottomleft[1])
151 | grids[max(xl, 0) : max(min(xr, size[0]), 0),
152 | max(yu, 0) : max(min(yd, size[1]), 0)] = 1
153 | x, y = np.nonzero(1 - grids)
154 | #if remove_corner == True:
155 | #np.set_printoptions(threshold=np.nan, linewidth=300)
156 | #print(grids)
157 | return list(zip(x + bottomleft[0] + 0.5, y + bottomleft[1] + 0.5))
158 |
--------------------------------------------------------------------------------
/sc2learner/envs/actions/resource.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | import random
6 |
7 | from s2clientprotocol import sc2api_pb2 as sc_pb
8 | from pysc2.lib.typeenums import UNIT_TYPEID as UNIT_TYPE
9 | from pysc2.lib.typeenums import ABILITY_ID as ABILITY
10 |
11 | from sc2learner.envs.actions.function import Function
12 | import sc2learner.envs.common.utils as utils
13 |
14 |
15 | class ResourceActions(object):
16 |
17 | @property
18 | def action_queens_inject_larva(self):
19 | return Function(name="queens_inject_larva",
20 | function=self._all_idle_queens_inject_larva,
21 | is_valid=self._is_valid_all_idle_queens_inject_larva)
22 |
23 | @property
24 | def action_idle_workers_gather_minerals(self):
25 | return Function(name="idle_workers_gather_minerals",
26 | function=self._all_idle_workers_gather_minerals,
27 | is_valid=self._is_valid_all_idle_workers_gather_minerals)
28 |
29 | @property
30 | def action_assign_workers_gather_gas(self):
31 | return Function(name="assign_workers_gather_gas",
32 | function=self._assign_workers_gather_gas,
33 | is_valid=self._is_valid_assign_workers_gather_gas)
34 |
35 | @property
36 | def action_assign_workers_gather_minerals(self):
37 | return Function(name="assign_workers_gather_minerals",
38 | function=self._assign_workers_gather_minerals,
39 | is_valid=self._is_valid_assign_workers_gather_minerals)
40 |
41 | def _all_idle_queens_inject_larva(self, dc):
42 | injectable_queens = [
43 | # TODO: -->idle_units_of_type
44 | u for u in dc.units_of_type(UNIT_TYPE.ZERG_QUEEN.value)
45 | if u.float_attr.energy >= 25
46 | ]
47 | bases = dc.mature_units_of_types([UNIT_TYPE.ZERG_HATCHERY.value,
48 | UNIT_TYPE.ZERG_LAIR.value,
49 | UNIT_TYPE.ZERG_HIVE.value])
50 | actions = []
51 | for queen in injectable_queens:
52 | action = sc_pb.Action()
53 | action.action_raw.unit_command.unit_tags.append(queen.tag)
54 | action.action_raw.unit_command.ability_id = \
55 | ABILITY.EFFECT_INJECTLARVA.value
56 | base = utils.closest_unit(queen, bases)
57 | action.action_raw.unit_command.target_unit_tag = base.tag
58 | actions.append(action)
59 | return actions
60 |
61 | def _is_valid_all_idle_queens_inject_larva(self, dc):
62 | injectable_queens = [
63 | # TODO: -->idle_units_of_type
64 | u for u in dc.units_of_type(UNIT_TYPE.ZERG_QUEEN.value)
65 | if u.float_attr.energy >= 25
66 | ]
67 | bases = dc.mature_units_of_types([UNIT_TYPE.ZERG_HATCHERY.value,
68 | UNIT_TYPE.ZERG_LAIR.value,
69 | UNIT_TYPE.ZERG_HIVE.value])
70 | if len(bases) > 0 and len(injectable_queens) > 0: return True
71 | else: return False
72 |
73 | def _all_idle_workers_gather_minerals(self, dc):
74 | idle_workers = dc.idle_units_of_type(UNIT_TYPE.ZERG_DRONE.value)
75 | actions = []
76 | for worker in idle_workers:
77 | mineral = utils.closest_unit(worker, dc.minerals)
78 | action = sc_pb.Action()
79 | action.action_raw.unit_command.unit_tags.append(worker.tag)
80 | action.action_raw.unit_command.ability_id = \
81 | ABILITY.HARVEST_GATHER_DRONE.value
82 | action.action_raw.unit_command.target_unit_tag = mineral.tag
83 | actions.append(action)
84 | return actions
85 |
86 | def _is_valid_all_idle_workers_gather_minerals(self, dc):
87 | if (len(dc.idle_units_of_type(UNIT_TYPE.ZERG_DRONE.value)) > 0 and
88 | len(dc.minerals) > 0):
89 | return True
90 | else:
91 | return False
92 |
93 | def _assign_workers_gather_gas(self, dc):
94 | idle_extractors = [
95 | u for u in dc.units_of_type(UNIT_TYPE.ZERG_EXTRACTOR.value)
96 | if u.int_attr.ideal_harvesters - u.int_attr.assigned_harvesters > 0
97 | ]
98 | if len(idle_extractors) == 0: return []
99 | extractor = random.choice(idle_extractors)
100 | num_workers_need = extractor.int_attr.ideal_harvesters - \
101 | extractor.int_attr.assigned_harvesters
102 | extractor_tags = set(u.tag for u in dc.units_of_type(
103 | UNIT_TYPE.ZERG_EXTRACTOR.value))
104 | workers = [
105 | u for u in dc.units_of_type(UNIT_TYPE.ZERG_DRONE.value)
106 | if (len(u.orders) == 0 or
107 | (u.orders[0].ability_id == ABILITY.HARVEST_GATHER_DRONE.value and
108 | u.orders[0].target_tag not in extractor_tags))
109 | ]
110 | if len(workers) == 0: return []
111 | assigned_workers = utils.closest_units(extractor, workers, num_workers_need)
112 | action = sc_pb.Action()
113 | action.action_raw.unit_command.unit_tags.extend(
114 | [u.tag for u in assigned_workers])
115 | action.action_raw.unit_command.ability_id = \
116 | ABILITY.HARVEST_GATHER_DRONE.value
117 | action.action_raw.unit_command.target_unit_tag = extractor.tag
118 | return [action]
119 |
120 | def _is_valid_assign_workers_gather_gas(self, dc):
121 | idle_extractors = [
122 | u for u in dc.units_of_type(UNIT_TYPE.ZERG_EXTRACTOR.value)
123 | if u.int_attr.ideal_harvesters - u.int_attr.assigned_harvesters > 0
124 | ]
125 | extractor_tags = set(u.tag for u in dc.units_of_type(
126 | UNIT_TYPE.ZERG_EXTRACTOR.value))
127 | workers = [
128 | u for u in dc.units_of_type(UNIT_TYPE.ZERG_DRONE.value)
129 | if (len(u.orders) == 0 or
130 | (u.orders[0].ability_id == ABILITY.HARVEST_GATHER_DRONE.value and
131 | u.orders[0].target_tag not in extractor_tags))
132 | ]
133 | if len(idle_extractors) > 0 and len(workers) > 0: return True
134 | else: return False
135 |
136 | def _assign_workers_gather_minerals(self, dc):
137 | extractor_tags = set(u.tag for u in dc.units_of_type(
138 | UNIT_TYPE.ZERG_EXTRACTOR.value))
139 | workers = [
140 | u for u in dc.units_of_type(UNIT_TYPE.ZERG_DRONE.value)
141 | if (len(u.orders) == 0 or
142 | (u.orders[0].ability_id == ABILITY.HARVEST_GATHER_DRONE.value and
143 | u.orders[0].target_tag in extractor_tags))
144 | ]
145 | actions = []
146 | for worker in random.sample(workers, min(3, len(workers))):
147 | target_mineral = utils.closest_unit(worker, dc.minerals)
148 | action = sc_pb.Action()
149 | action.action_raw.unit_command.unit_tags.append(worker.tag)
150 | action.action_raw.unit_command.ability_id = \
151 | ABILITY.HARVEST_GATHER_DRONE.value
152 | action.action_raw.unit_command.target_unit_tag = target_mineral.tag
153 | actions.append(action)
154 | return actions
155 |
156 | def _is_valid_assign_workers_gather_minerals(self, dc):
157 | extractor_tags = set(u.tag for u in dc.units_of_type(
158 | UNIT_TYPE.ZERG_EXTRACTOR.value))
159 | workers = [
160 | u for u in dc.units_of_type(UNIT_TYPE.ZERG_DRONE.value)
161 | if (len(u.orders) == 0 or
162 | (u.orders[0].ability_id == ABILITY.HARVEST_GATHER_DRONE.value and
163 | u.orders[0].target_tag in extractor_tags))
164 | ]
165 | return len(workers) > 0
166 |
--------------------------------------------------------------------------------
/sc2learner/envs/observations/nonspatial_features.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | import numpy as np
6 | from pysc2.lib.typeenums import UNIT_TYPEID as UNIT_TYPE
7 | from pysc2.lib.typeenums import ABILITY_ID as ABILITY
8 |
9 | from sc2learner.envs.common.const import ALLY_TYPE
10 | from sc2learner.envs.common.const import COMBAT_TYPES
11 |
12 |
13 | class PlayerFeature(object):
14 |
15 | def features(self, observation):
16 | player_features = observation["player"][1:-1].astype(np.float32)
17 | food_unused = player_features[3] - player_features[2]
18 | player_features[-1] = food_unused if food_unused >= 0 else 0
19 | scale = np.array([2000, 2000, 20, 20, 20, 20, 20, 20, 20], np.float32)
20 | scaled_features = (player_features / scale).astype(np.float32)
21 | log_features = np.log10(player_features + 1).astype(np.float32)
22 |
23 | bins_food_unused = np.zeros(10, dtype=np.float32)
24 | bin_id = int((max(food_unused, 0) - 1) // 3 + 1) if food_unused <= 27 else 9
25 | bins_food_unused[bin_id] = 1
26 | return np.concatenate((scaled_features, log_features, bins_food_unused))
27 |
28 | @property
29 | def num_dims(self):
30 | return 9 * 2 + 10
31 |
32 |
33 | class ScoreFeature(object):
34 |
35 | def features(self, observation):
36 | score_features = observation.score_cumulative[3:].astype(np.float32)
37 | score_features /= 3000.0
38 | log_features = np.log10(score_features + 1).astype(np.float32)
39 | return np.concatenate((score_features, log_features))
40 |
41 | @property
42 | def num_dims(self):
43 | return 10 * 2
44 |
45 |
46 | class UnitTypeCountFeature(object):
47 |
48 | def __init__(self, type_list, use_regions=False):
49 | self._type_list = type_list
50 | if use_regions:
51 | self._regions = [(0, 0, 200, 176),
52 | (0, 88, 80, 176),
53 | (80, 88, 120, 176),
54 | (120, 88, 200, 176),
55 | (0, 55, 80, 88),
56 | (80, 55, 120, 88),
57 | (120, 55, 200, 88),
58 | (0, 0, 80, 55),
59 | (80, 0, 120, 55),
60 | (120, 0, 200, 55)]
61 | else:
62 | self._regions = [(0, 0, 200, 176)]
63 | self._regions_flipped = [self._regions[0]] + [
64 | self._regions[10 - i] for i in range(1, len(self._regions))]
65 |
66 | def features(self, observation, need_flip=False):
67 | feature_list = []
68 | for region in (self._regions if not need_flip else self._regions_flipped):
69 | units_in_region = [u for u in observation['units']
70 | if self._is_in_region(u, region)]
71 | feature_list.append(self._generate_features(units_in_region))
72 | return np.concatenate(feature_list)
73 |
74 | @property
75 | def num_dims(self):
76 | return len(self._type_list) * len(self._regions) * 2 * 2
77 |
78 | def _generate_features(self, units):
79 | self_units = [u for u in units
80 | if u.int_attr.alliance == ALLY_TYPE.SELF.value]
81 | enemy_units = [u for u in units
82 | if u.int_attr.alliance == ALLY_TYPE.ENEMY.value]
83 | self_features = self._get_counts(self_units)
84 | enemy_features = self._get_counts(enemy_units)
85 | features = np.concatenate((self_features, enemy_features))
86 |
87 | scaled_features = features / 20
88 | log_features = np.log10(features + 1)
89 |
90 | return np.concatenate((scaled_features, log_features))
91 |
92 | def _get_counts(self, units):
93 | count = {t: 0 for t in self._type_list}
94 | for u in units:
95 | if u.unit_type in count:
96 | count[u.unit_type] += 1
97 | return np.array([count[t] for t in self._type_list], dtype=np.float32)
98 |
99 | def _is_in_region(self, unit, region):
100 | return (unit.float_attr.pos_x >= region[0] and
101 | unit.float_attr.pos_x < region[2] and
102 | unit.float_attr.pos_y >= region[1] and
103 | unit.float_attr.pos_y < region[3])
104 |
105 |
106 | class UnitStatCountFeature(object):
107 |
108 | def __init__(self, use_regions=False):
109 | if use_regions:
110 | self._regions = [(0, 0, 200, 176),
111 | (0, 88, 80, 176),
112 | (80, 88, 120, 176),
113 | (120, 88, 200, 176),
114 | (0, 55, 80, 88),
115 | (80, 55, 120, 88),
116 | (120, 55, 200, 88),
117 | (0, 0, 80, 55),
118 | (80, 0, 120, 55),
119 | (120, 0, 200, 55)]
120 | else:
121 | self._regions = [(0, 0, 200, 176)]
122 | self._regions_flipped = [self._regions[0]] + [
123 | self._regions[10 - i] for i in range(1, len(self._regions))]
124 |
125 | def features(self, observation, need_flip=False):
126 | feature_list = []
127 | for region in (self._regions if not need_flip else self._regions_flipped):
128 | units_in_region = [u for u in observation['units']
129 | if self._is_in_region(u, region)]
130 | feature_list.append(self._generate_features(units_in_region))
131 | return np.concatenate(feature_list)
132 |
133 | @property
134 | def num_dims(self):
135 | return len(self._regions) * 2 * 4 * 2
136 |
137 | def _generate_features(self, units):
138 | self_units = [u for u in units
139 | if u.int_attr.alliance == ALLY_TYPE.SELF.value]
140 | enemy_units = [u for u in units
141 | if u.int_attr.alliance == ALLY_TYPE.ENEMY.value]
142 | self_combats = [u for u in self_units if u.unit_type in COMBAT_TYPES]
143 | enemy_combats = [u for u in enemy_units if u.unit_type in COMBAT_TYPES]
144 | self_air_units = [u for u in self_units if u.bool_attr.is_flying]
145 | enemy_air_units = [u for u in enemy_units if u.bool_attr.is_flying]
146 | self_ground_units = [u for u in self_units if not u.bool_attr.is_flying]
147 | enemy_ground_units = [u for u in enemy_units if not u.bool_attr.is_flying]
148 |
149 | features = np.array([len(self_units),
150 | len(self_combats),
151 | len(self_ground_units),
152 | len(self_air_units),
153 | len(enemy_units),
154 | len(enemy_combats),
155 | len(enemy_ground_units),
156 | len(enemy_air_units)], dtype=np.float32)
157 |
158 | scaled_features = features / 20
159 | log_features = np.log10(features + 1)
160 | return np.concatenate((scaled_features, log_features))
161 |
162 | def _is_in_region(self, unit, region):
163 | return (unit.float_attr.pos_x >= region[0] and
164 | unit.float_attr.pos_x < region[2] and
165 | unit.float_attr.pos_y >= region[1] and
166 | unit.float_attr.pos_y < region[3])
167 |
168 |
169 | class GameProgressFeature(object):
170 |
171 | def features(self, observation):
172 | game_loop = observation["game_loop"][0]
173 | features_60 = self._onehot(game_loop, 60)
174 | features_20 = self._onehot(game_loop, 20)
175 | features_8 = self._onehot(game_loop, 8)
176 | features_4 = self._onehot(game_loop, 4)
177 |
178 | return np.concatenate([features_60,
179 | features_20,
180 | features_8,
181 | features_4])
182 |
183 | def _onehot(self, value, n_bins):
184 | bin_width = 24000 // n_bins
185 | features = np.zeros(n_bins, dtype=np.float32)
186 | idx = int(value // bin_width)
187 | idx = n_bins - 1 if idx >= n_bins else idx
188 | features[idx] = 1.0
189 | return features
190 |
191 | @property
192 | def num_dims(self):
193 | return 60 + 20 + 8 + 4
194 |
195 |
196 | class ActionSeqFeature(object):
197 |
198 | def __init__(self, n_dims_action_space, seq_len):
199 | self._action_seq = [-1] * seq_len
200 | self._n_dims_action_space = n_dims_action_space
201 |
202 | def reset(self):
203 | self._action_seq = [-1] * len(self._action_seq)
204 |
205 | def push_action(self, action):
206 | self._action_seq.pop(0)
207 | self._action_seq.append(action)
208 |
209 | def features(self):
210 | features = np.zeros(self._n_dims_action_space * len(self._action_seq),
211 | dtype=np.float32)
212 | for i, action in enumerate(self._action_seq):
213 | assert action < self._n_dims_action_space
214 | if action >= 0:
215 | features[i * self._n_dims_action_space + action] = 1.0
216 | return features
217 |
218 | @property
219 | def num_dims(self):
220 | return self._n_dims_action_space * len(self._action_seq)
221 |
222 |
223 | class WorkerFeature(object):
224 |
225 | def features(self, dc):
226 | extractor_tags = set(u.tag for u in dc.units_of_type(
227 | UNIT_TYPE.ZERG_EXTRACTOR.value))
228 | workers = dc.units_of_type(UNIT_TYPE.ZERG_DRONE.value)
229 | harvest_workers = [
230 | u for u in workers
231 | if (len(u.orders) > 0 and
232 | u.orders[0].ability_id == ABILITY.HARVEST_GATHER_DRONE.value)
233 | ]
234 | gas_workers = [u for u in harvest_workers
235 | if u.orders[0].target_tag in extractor_tags]
236 | mineral_workers = [u for u in harvest_workers
237 | if u.orders[0].target_tag not in extractor_tags]
238 | return np.array([len(gas_workers),
239 | len(mineral_workers),
240 | len(workers) - len(gas_workers) - len(mineral_workers)],
241 | dtype=np.float32) / 20.0
242 |
243 | @property
244 | def num_dims(self):
245 | return 3
246 |
--------------------------------------------------------------------------------
/sc2learner/bin/train_ppo_selfplay.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | import sys
6 | from threading import Thread
7 | import os
8 | import multiprocessing
9 | import random
10 | import time
11 |
12 | from absl import app
13 | from absl import flags
14 | from absl import logging
15 | import tensorflow as tf
16 |
17 | from sc2learner.agents.ppo_policies import LstmPolicy, MlpPolicy
18 | from sc2learner.agents.ppo_agent import PPOActor, PPOLearner, PPOSelfplayActor
19 | from sc2learner.envs.raw_env import SC2RawEnv
20 | from sc2learner.envs.selfplay_raw_env import SC2SelfplayRawEnv
21 | from sc2learner.envs.actions.zerg_action_wrappers import ZergActionWrapper
22 | from sc2learner.envs.actions.zerg_action_wrappers import ZergPlayerActionWrapper
23 | from sc2learner.envs.observations.zerg_observation_wrappers \
24 | import ZergObservationWrapper
25 | from sc2learner.envs.observations.zerg_observation_wrappers \
26 | import ZergPlayerObservationWrapper
27 | from sc2learner.utils.utils import print_arguments
28 |
29 |
30 | FLAGS = flags.FLAGS
31 | flags.DEFINE_enum("job_name", 'actor', ['actor', 'learner', 'eval', 'eval_model'],
32 | "Job type.")
33 | flags.DEFINE_enum("policy", 'mlp', ['mlp', 'lstm'], "Job type.")
34 | flags.DEFINE_integer("unroll_length", 128, "Length of rollout steps.")
35 | flags.DEFINE_integer("model_cache_size", 300, "Opponent model cache size.")
36 | flags.DEFINE_float("model_cache_prob", 0.05, "Opponent model cache probability.")
37 | flags.DEFINE_string("learner_ip", "localhost", "Learner IP address.")
38 | flags.DEFINE_string("port_A", "5700", "Port for transporting model.")
39 | flags.DEFINE_string("port_B", "5701", "Port for transporting data.")
40 | flags.DEFINE_string("game_version", '4.6', "Game core version.")
41 | flags.DEFINE_float("discount_gamma", 0.998, "Discount factor.")
42 | flags.DEFINE_float("lambda_return", 0.95, "Lambda return factor.")
43 | flags.DEFINE_float("clip_range", 0.1, "Clip range for PPO.")
44 | flags.DEFINE_float("ent_coef", 0.01, "Coefficient for the entropy term.")
45 | flags.DEFINE_float("vf_coef", 0.5, "Coefficient for the value loss.")
46 | flags.DEFINE_float("learn_act_speed_ratio", 0, "Maximum learner/actor ratio.")
47 | flags.DEFINE_integer("game_steps_per_episode", 43200, "Maximum steps per episode.")
48 | flags.DEFINE_integer("batch_size", 32, "Batch size.")
49 | flags.DEFINE_integer("learner_queue_size", 1024, "Size of learner's unroll queue.")
50 | flags.DEFINE_integer("step_mul", 32, "Game steps per agent step.")
51 | flags.DEFINE_string("difficulties", '1,2,4,6,9,A', "Bot's strengths.")
52 | flags.DEFINE_float("learning_rate", 1e-5, "Learning rate.")
53 | flags.DEFINE_string("init_model_path", None, "Initial model path.")
54 | flags.DEFINE_string("init_oppo_pool_filelist", None, "Initial opponent model path.")
55 | flags.DEFINE_string("save_dir", "./checkpoints/", "Dir to save models to")
56 | flags.DEFINE_integer("save_interval", 50000, "Model saving frequency.")
57 | flags.DEFINE_integer("print_interval", 1000, "Print train cost frequency.")
58 | flags.DEFINE_boolean("disable_fog", False, "Disable fog-of-war.")
59 | flags.DEFINE_boolean("use_all_combat_actions", False, "Use all combat actions.")
60 | flags.DEFINE_boolean("use_region_features", False, "Use region features")
61 | flags.DEFINE_boolean("use_action_mask", True, "Use region-wise combat.")
62 | flags.FLAGS(sys.argv)
63 |
64 |
65 | def tf_config(ncpu=None):
66 | if ncpu is None:
67 | ncpu = multiprocessing.cpu_count()
68 | if sys.platform == 'darwin': ncpu //= 2
69 | config = tf.ConfigProto(allow_soft_placement=True,
70 | intra_op_parallelism_threads=ncpu,
71 | inter_op_parallelism_threads=ncpu)
72 | config.gpu_options.allow_growth = True
73 | tf.Session(config=config).__enter__()
74 |
75 |
76 | def create_env(difficulty, random_seed=None):
77 | env = SC2RawEnv(map_name='AbyssalReef',
78 | step_mul=FLAGS.step_mul,
79 | resolution=16,
80 | agent_race='zerg',
81 | bot_race='zerg',
82 | difficulty=difficulty,
83 | disable_fog=FLAGS.disable_fog,
84 | tie_to_lose=False,
85 | game_steps_per_episode=FLAGS.game_steps_per_episode,
86 | random_seed=random_seed)
87 | env = ZergActionWrapper(env,
88 | game_version=FLAGS.game_version,
89 | mask=FLAGS.use_action_mask,
90 | use_all_combat_actions=FLAGS.use_all_combat_actions)
91 | env = ZergObservationWrapper(env,
92 | use_spatial_features=False,
93 | use_game_progress=(not FLAGS.policy == 'lstm'),
94 | action_seq_len=1 if FLAGS.policy == 'lstm' else 8,
95 | use_regions=FLAGS.use_region_features)
96 | print(env.observation_space, env.action_space)
97 | return env
98 |
99 |
100 | def create_selfplay_env(random_seed=None):
101 | env = SC2SelfplayRawEnv(map_name='AbyssalReef',
102 | step_mul=FLAGS.step_mul,
103 | resolution=16,
104 | agent_race='zerg',
105 | opponent_race='zerg',
106 | tie_to_lose=False,
107 | disable_fog=FLAGS.disable_fog,
108 | game_steps_per_episode=FLAGS.game_steps_per_episode,
109 | random_seed=random_seed)
110 | env = ZergPlayerActionWrapper(
111 | player=0,
112 | env=env,
113 | game_version=FLAGS.game_version,
114 | mask=FLAGS.use_action_mask,
115 | use_all_combat_actions=FLAGS.use_all_combat_actions)
116 | env = ZergPlayerObservationWrapper(
117 | player=0,
118 | env=env,
119 | use_spatial_features=False,
120 | use_game_progress=(not FLAGS.policy == 'lstm'),
121 | action_seq_len=1 if FLAGS.policy == 'lstm' else 8,
122 | use_regions=FLAGS.use_region_features)
123 |
124 | env = ZergPlayerActionWrapper(
125 | player=1,
126 | env=env,
127 | game_version=FLAGS.game_version,
128 | mask=FLAGS.use_action_mask,
129 | use_all_combat_actions=FLAGS.use_all_combat_actions)
130 | env = ZergPlayerObservationWrapper(
131 | player=1,
132 | env=env,
133 | use_spatial_features=False,
134 | use_game_progress=(not FLAGS.policy == 'lstm'),
135 | action_seq_len=1 if FLAGS.policy == 'lstm' else 8,
136 | use_regions=FLAGS.use_region_features)
137 | print(env.observation_space, env.action_space)
138 | return env
139 |
140 |
141 | def start_actor():
142 | tf_config(ncpu=2)
143 | random.seed(time.time())
144 | game_seed = random.randint(0, 2**32 - 1)
145 | print("Game Seed: %d" % game_seed)
146 | env = create_selfplay_env(game_seed)
147 | policy = {'lstm': LstmPolicy,
148 | 'mlp': MlpPolicy}[FLAGS.policy]
149 | actor = PPOSelfplayActor(
150 | env=env,
151 | policy=policy,
152 | unroll_length=FLAGS.unroll_length,
153 | gamma=FLAGS.discount_gamma,
154 | lam=FLAGS.lambda_return,
155 | model_cache_size=FLAGS.model_cache_size,
156 | model_cache_prob=FLAGS.model_cache_prob,
157 | prob_latest_opponent=0.0,
158 | init_opponent_pool_filelist=FLAGS.init_oppo_pool_filelist,
159 | freeze_opponent_pool=False,
160 | learner_ip=FLAGS.learner_ip,
161 | port_A=FLAGS.port_A,
162 | port_B=FLAGS.port_B)
163 | actor.run()
164 | env.close()
165 |
166 |
167 | def start_learner():
168 | tf_config()
169 | env = create_env('1', 0)
170 | policy = {'lstm': LstmPolicy,
171 | 'mlp': MlpPolicy}[FLAGS.policy]
172 | learner = PPOLearner(env=env,
173 | policy=policy,
174 | unroll_length=FLAGS.unroll_length,
175 | lr=FLAGS.learning_rate,
176 | clip_range=FLAGS.clip_range,
177 | batch_size=FLAGS.batch_size,
178 | ent_coef=FLAGS.ent_coef,
179 | vf_coef=FLAGS.vf_coef,
180 | max_grad_norm=0.5,
181 | queue_size=FLAGS.learner_queue_size,
182 | print_interval=FLAGS.print_interval,
183 | save_interval=FLAGS.save_interval,
184 | learn_act_speed_ratio=FLAGS.learn_act_speed_ratio,
185 | save_dir=FLAGS.save_dir,
186 | init_model_path=FLAGS.init_model_path,
187 | port_A=FLAGS.port_A,
188 | port_B=FLAGS.port_B)
189 | learner.run()
190 | env.close()
191 |
192 |
193 | def start_evaluator_against_builtin():
194 | tf_config(ncpu=2)
195 | random.seed(time.time())
196 | difficulty = random.choice(FLAGS.difficulties.split(','))
197 | game_seed = random.randint(0, 2**32 - 1)
198 | print("Game Seed: %d" % game_seed)
199 | env = create_env(difficulty, game_seed)
200 | policy = {'lstm': LstmPolicy,
201 | 'mlp': MlpPolicy}[FLAGS.policy]
202 | actor = PPOActor(env=env,
203 | policy=policy,
204 | unroll_length=FLAGS.unroll_length,
205 | gamma=FLAGS.discount_gamma,
206 | lam=FLAGS.lambda_return,
207 | enable_push=False,
208 | learner_ip=FLAGS.learner_ip,
209 | port_A=FLAGS.port_A,
210 | port_B=FLAGS.port_B)
211 | actor.run()
212 | env.close()
213 |
214 |
215 | def start_evaluator_against_model():
216 | tf_config(ncpu=2)
217 | random.seed(time.time())
218 | game_seed = random.randint(0, 2**32 - 1)
219 | print("Game Seed: %d" % game_seed)
220 | env = create_selfplay_env(game_seed)
221 | policy = {'lstm': LstmPolicy,
222 | 'mlp': MlpPolicy}[FLAGS.policy]
223 | actor = PPOSelfplayActor(
224 | env=env,
225 | policy=policy,
226 | unroll_length=FLAGS.unroll_length,
227 | gamma=FLAGS.discount_gamma,
228 | lam=FLAGS.lambda_return,
229 | model_cache_size=1,
230 | model_cache_prob=FLAGS.model_cache_prob,
231 | enable_push=False,
232 | prob_latest_opponent=0.0,
233 | init_opponent_pool_filelist=FLAGS.init_oppo_pool_filelist,
234 | freeze_opponent_pool=True,
235 | learner_ip=FLAGS.learner_ip,
236 | port_A=FLAGS.port_A,
237 | port_B=FLAGS.port_B)
238 | actor.run()
239 | env.close()
240 |
241 |
242 | def main(argv):
243 | logging.set_verbosity(logging.ERROR)
244 | print_arguments(FLAGS)
245 | if FLAGS.job_name == 'actor': start_actor()
246 | elif FLAGS.job_name == 'learner': start_learner()
247 | elif FLAGS.job_name == 'eval': start_evaluator_against_builtin()
248 | else: start_evaluator_against_model()
249 |
250 |
251 | if __name__ == '__main__':
252 | app.run(main)
253 |
--------------------------------------------------------------------------------
/sc2learner/envs/actions/zerg_action_wrappers.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | import platform
6 |
7 | import numpy as np
8 | import gym
9 | from gym.spaces.discrete import Discrete
10 | from pysc2.lib.typeenums import UNIT_TYPEID as UNIT_TYPE
11 | from pysc2.lib.typeenums import UPGRADE_ID as UPGRADE
12 | from pysc2.lib import point
13 | from s2clientprotocol import sc2api_pb2 as sc_pb
14 |
15 | from sc2learner.envs.spaces.pysc2_raw import PySC2RawObservation
16 | from sc2learner.envs.spaces.mask_discrete import MaskDiscrete
17 | from sc2learner.envs.common.data_context import DataContext
18 | from sc2learner.envs.actions.function import Function
19 | from sc2learner.envs.actions.produce import ProduceActions
20 | from sc2learner.envs.actions.build import BuildActions
21 | from sc2learner.envs.actions.upgrade import UpgradeActions
22 | from sc2learner.envs.actions.resource import ResourceActions
23 | from sc2learner.envs.actions.combat import CombatActions
24 |
25 |
26 | class ZergActionWrapper(gym.Wrapper):
27 |
28 | def __init__(self, env, game_version='4.1.2', mask=False,
29 | use_all_combat_actions=False):
30 | super(ZergActionWrapper, self).__init__(env)
31 | # TODO: multiple observation space
32 | #assert isinstance(env.observation_space, PySC2RawObservation)
33 |
34 | self._dc = DataContext()
35 | self._build_mgr = BuildActions(game_version)
36 | self._produce_mgr = ProduceActions(game_version)
37 | self._upgrade_mgr = UpgradeActions(game_version)
38 | self._resource_mgr = ResourceActions()
39 | self._combat_mgr = CombatActions()
40 |
41 | self._actions = [
42 | self._action_do_nothing(),
43 | self._build_mgr.action("build_extractor", UNIT_TYPE.ZERG_EXTRACTOR.value),
44 | self._build_mgr.action("build_spawning_pool", UNIT_TYPE.ZERG_SPAWNINGPOOL.value),
45 | self._build_mgr.action("build_roach_warren", UNIT_TYPE.ZERG_ROACHWARREN.value),
46 | self._build_mgr.action("build_hydraliskden", UNIT_TYPE.ZERG_HYDRALISKDEN.value),
47 | self._build_mgr.action("build_hatchery", UNIT_TYPE.ZERG_HATCHERY.value),
48 | self._build_mgr.action("build_evolution_chamber", UNIT_TYPE.ZERG_EVOLUTIONCHAMBER.value),
49 | self._build_mgr.action("build_baneling_nest", UNIT_TYPE.ZERG_BANELINGNEST.value),
50 | self._build_mgr.action("build_infestation_pit", UNIT_TYPE.ZERG_INFESTATIONPIT.value),
51 | self._build_mgr.action("build_spire", UNIT_TYPE.ZERG_SPIRE.value),
52 | self._build_mgr.action("build_ultralisk_cavern", UNIT_TYPE.ZERG_ULTRALISKCAVERN.value),
53 | #self._build_mgr.action("build_nydus_network", UNIT_TYPE.ZERG_NYDUSNETWORK.value),
54 | self._build_mgr.action("build_spine_crawler", UNIT_TYPE.ZERG_SPINECRAWLER.value),
55 | self._build_mgr.action("build_spore_crawler", UNIT_TYPE.ZERG_SPORECRAWLER.value),
56 | self._produce_mgr.action("morph_lurker_den", UNIT_TYPE.ZERG_LURKERDENMP.value),
57 | self._produce_mgr.action("morph_lair", UNIT_TYPE.ZERG_LAIR.value),
58 | self._produce_mgr.action("morph_hive", UNIT_TYPE.ZERG_HIVE.value),
59 | self._produce_mgr.action("morph_greater_spire", UNIT_TYPE.ZERG_GREATERSPIRE.value),
60 | self._produce_mgr.action("produce_drone", UNIT_TYPE.ZERG_DRONE.value),
61 | self._produce_mgr.action("produce_zergling", UNIT_TYPE.ZERG_ZERGLING.value),
62 | self._produce_mgr.action("morph_baneling", UNIT_TYPE.ZERG_BANELING.value),
63 | self._produce_mgr.action("produce_roach", UNIT_TYPE.ZERG_ROACH.value),
64 | self._produce_mgr.action("morph_ravager", UNIT_TYPE.ZERG_RAVAGER.value),
65 | self._produce_mgr.action("produce_hydralisk", UNIT_TYPE.ZERG_HYDRALISK.value),
66 | self._produce_mgr.action("morph_lurker", UNIT_TYPE.ZERG_LURKERMP.value),
67 | #self._produce_mgr.action("produce_viper", UNIT_TYPE.ZERG_VIPER.value),
68 | self._produce_mgr.action("produce_mutalisk", UNIT_TYPE.ZERG_MUTALISK.value),
69 | self._produce_mgr.action("produce_corruptor", UNIT_TYPE.ZERG_CORRUPTOR.value),
70 | self._produce_mgr.action("morph_broodlord", UNIT_TYPE.ZERG_BROODLORD.value),
71 | #self._produce_mgr.action("produce_swarmhost", UNIT_TYPE.ZERG_SWARMHOSTMP.value),
72 | #self._produce_mgr.action("produce_infestor", UNIT_TYPE.ZERG_INFESTOR.value),
73 | self._produce_mgr.action("produce_ultralisk", UNIT_TYPE.ZERG_ULTRALISK.value),
74 | self._produce_mgr.action("produce_overlord", UNIT_TYPE.ZERG_OVERLORD.value),
75 | self._produce_mgr.action("morph_overseer", UNIT_TYPE.ZERG_OVERSEER.value),
76 | self._produce_mgr.action("produce_queen", UNIT_TYPE.ZERG_QUEEN.value),
77 | #self._produce_mgr.action("produce_nydus_worm", UNIT_TYPE.ZERG_NYDUSCANAL.value),
78 | self._upgrade_mgr.action("upgrade_burrow", UPGRADE.BURROW.value),
79 | self._upgrade_mgr.action("upgrade_centrifical_hooks", UPGRADE.CENTRIFICALHOOKS.value),
80 | self._upgrade_mgr.action("upgrade_chitions_plating", UPGRADE.CHITINOUSPLATING.value),
81 | self._upgrade_mgr.action("upgrade_evolve_grooved_spines", UPGRADE.EVOLVEGROOVEDSPINES.value),
82 | self._upgrade_mgr.action("upgrade_evolve_muscular_augments", UPGRADE.EVOLVEMUSCULARAUGMENTS.value),
83 | self._upgrade_mgr.action("upgrade_gliare_constitution", UPGRADE.GLIALRECONSTITUTION.value),
84 | #self._upgrade_mgr.action("upgrade_infestor_energy_upgrade", UPGRADE.INFESTORENERGYUPGRADE.value),
85 | self._upgrade_mgr.action("upgrade_neural_parasite", UPGRADE.NEURALPARASITE.value),
86 | self._upgrade_mgr.action("upgrade_overlord_speed", UPGRADE.OVERLORDSPEED.value),
87 | self._upgrade_mgr.action("upgrade_tunneling_claws", UPGRADE.TUNNELINGCLAWS.value),
88 | self._upgrade_mgr.action("upgrade_flyer_armors_level1", UPGRADE.ZERGFLYERARMORSLEVEL1.value),
89 | self._upgrade_mgr.action("upgrade_flyer_armors_level2", UPGRADE.ZERGFLYERARMORSLEVEL2.value),
90 | self._upgrade_mgr.action("upgrade_flyer_armors_level3", UPGRADE.ZERGFLYERARMORSLEVEL3.value),
91 | self._upgrade_mgr.action("upgrade_flyer_weapons_level1", UPGRADE.ZERGFLYERWEAPONSLEVEL1.value),
92 | self._upgrade_mgr.action("upgrade_flyer_weapons_level2", UPGRADE.ZERGFLYERWEAPONSLEVEL2.value),
93 | self._upgrade_mgr.action("upgrade_flyer_weapons_level3", UPGRADE.ZERGFLYERWEAPONSLEVEL3.value),
94 | self._upgrade_mgr.action("upgrade_ground_armors_level1", UPGRADE.ZERGGROUNDARMORSLEVEL1.value),
95 | self._upgrade_mgr.action("upgrade_ground_armors_level2", UPGRADE.ZERGGROUNDARMORSLEVEL2.value),
96 | self._upgrade_mgr.action("upgrade_ground_armors_level3", UPGRADE.ZERGGROUNDARMORSLEVEL3.value),
97 | self._upgrade_mgr.action("upgrade_zergling_attack_speed", UPGRADE.ZERGLINGATTACKSPEED.value),
98 | self._upgrade_mgr.action("upgrade_zergling_moving_speed", UPGRADE.ZERGLINGMOVEMENTSPEED.value),
99 | self._upgrade_mgr.action("upgrade_melee_weapons_level1", UPGRADE.ZERGMELEEWEAPONSLEVEL1.value),
100 | self._upgrade_mgr.action("upgrade_melee_weapons_level2", UPGRADE.ZERGMELEEWEAPONSLEVEL2.value),
101 | self._upgrade_mgr.action("upgrade_melee_weapons_level3", UPGRADE.ZERGMELEEWEAPONSLEVEL3.value),
102 | self._upgrade_mgr.action("upgrade_missile_weapons_level1", UPGRADE.ZERGMISSILEWEAPONSLEVEL1.value),
103 | self._upgrade_mgr.action("upgrade_missile_weapons_level2", UPGRADE.ZERGMISSILEWEAPONSLEVEL2.value),
104 | self._upgrade_mgr.action("upgrade_missile_weapons_level3", UPGRADE.ZERGMISSILEWEAPONSLEVEL3.value),
105 | self._resource_mgr.action_assign_workers_gather_gas,
106 | self._resource_mgr.action_assign_workers_gather_minerals,
107 | # ZERG_LOCUST, ZERG_CHANGELING not included
108 | ] + ([
109 | self._combat_mgr.action(0, 0),
110 | self._combat_mgr.action(9, 4),
111 | self._combat_mgr.action(4, 1)
112 | ] if not use_all_combat_actions else [
113 | self._combat_mgr.action(0, target_region_id)
114 | for target_region_id in range(self._combat_mgr.num_regions)
115 | #self._combat_mgr.action(source_region_id, target_region_id)
116 | #for source_region_id in range(self._combat_mgr.num_regions)
117 | #for target_region_id in range(self._combat_mgr.num_regions)
118 | ])
119 |
120 | self._required_pre_actions = [
121 | self._resource_mgr.action_idle_workers_gather_minerals,
122 | self._resource_mgr.action_queens_inject_larva
123 | ]
124 | self._required_post_actions = [
125 | self._combat_mgr.action_rally_new_combat_units,
126 | self._combat_mgr.action_framewise_rally_and_attack
127 | ]
128 |
129 | if mask: self.action_space = MaskDiscrete(len(self._actions))
130 | else: self.action_space = Discrete(len(self._actions))
131 |
132 | def step(self, action):
133 | actions = self._actions[action].function(self._dc)
134 | pre_actions, post_actions = self._required_actions()
135 | observation, reward, done, info = self.env.step(
136 | pre_actions + actions + post_actions)
137 | self._dc.update(observation)
138 | if isinstance(self.action_space, MaskDiscrete):
139 | observation['action_mask'] = self._get_valid_action_mask()
140 | return observation, reward, done, info
141 |
142 | def reset(self, **kwargs):
143 | self._combat_mgr.reset()
144 | observation = self.env.reset()
145 | self._dc.reset(observation)
146 | if isinstance(self.action_space, MaskDiscrete):
147 | observation['action_mask'] = self._get_valid_action_mask()
148 | return observation
149 |
150 | @property
151 | def action_names(self):
152 | return [action.name for action in self._actions]
153 |
154 | @property
155 | def player_position(self):
156 | if self._dc.init_base_pos[0] < 100: return 0
157 | else: return 1
158 |
159 | def _required_actions(self):
160 | pre_actions = []
161 | for fn in self._required_pre_actions:
162 | if fn.is_valid(self._dc):
163 | pre_actions.extend(fn.function(self._dc))
164 |
165 | post_actions = []
166 | for fn in self._required_post_actions:
167 | if fn.is_valid(self._dc):
168 | post_actions.extend(fn.function(self._dc))
169 |
170 | return pre_actions, post_actions
171 |
172 | def _get_valid_action_mask(self):
173 | ids = [i for i, action in enumerate(self._actions)
174 | if action.is_valid(self._dc)]
175 | mask = np.zeros(self.action_space.n)
176 | mask[ids] = 1
177 | return mask
178 |
179 | def _action_do_nothing(self):
180 | return Function(name='do_nothing',
181 | function=lambda dc: [],
182 | is_valid=lambda dc: True)
183 |
184 |
185 | class ZergPlayerActionWrapper(ZergActionWrapper):
186 |
187 | def __init__(self, player, **kwargs):
188 | self._warn_double_wrap = lambda *args: None
189 | self._player = player
190 | super(ZergPlayerActionWrapper, self).__init__(**kwargs)
191 |
192 | def step(self, action):
193 | actions = self._actions[action[self._player]].function(self._dc)
194 | pre_actions, post_actions = self._required_actions()
195 | action[self._player] = pre_actions + actions + post_actions
196 | observation, reward, done, info = self.env.step(action)
197 | self._dc.update(observation[self._player])
198 | if isinstance(self.action_space, MaskDiscrete):
199 | observation[self._player]['action_mask'] = self._get_valid_action_mask()
200 | return observation, reward, done, info
201 |
202 | def reset(self, **kwargs):
203 | self._combat_mgr.reset()
204 | observation = self.env.reset()
205 | self._dc.reset(observation[self._player])
206 | if isinstance(self.action_space, MaskDiscrete):
207 | observation[self._player]['action_mask'] = self._get_valid_action_mask()
208 | return observation
209 |
--------------------------------------------------------------------------------
/sc2learner/envs/observations/zerg_observation_wrappers.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | import numpy as np
6 | import gym
7 | from gym import spaces
8 | from pysc2.lib.typeenums import UNIT_TYPEID as UNIT_TYPE
9 |
10 | from sc2learner.envs.spaces.pysc2_raw import PySC2RawObservation
11 | from sc2learner.envs.spaces.mask_discrete import MaskDiscrete
12 | from sc2learner.envs.common.data_context import DataContext
13 | from sc2learner.envs.observations.spatial_features import UnitTypeCountMapFeature
14 | from sc2learner.envs.observations.spatial_features import AllianceCountMapFeature
15 | from sc2learner.envs.observations.nonspatial_features import PlayerFeature
16 | from sc2learner.envs.observations.nonspatial_features import ScoreFeature
17 | from sc2learner.envs.observations.nonspatial_features import WorkerFeature
18 | from sc2learner.envs.observations.nonspatial_features import UnitTypeCountFeature
19 | from sc2learner.envs.observations.nonspatial_features import UnitStatCountFeature
20 | from sc2learner.envs.observations.nonspatial_features import GameProgressFeature
21 | from sc2learner.envs.observations.nonspatial_features import ActionSeqFeature
22 |
23 |
24 | class ZergObservationWrapper(gym.Wrapper):
25 |
26 | def __init__(self, env, use_spatial_features=False, use_game_progress=True,
27 | action_seq_len=8, use_regions=False):
28 | super(ZergObservationWrapper, self).__init__(env)
29 | # TODO: multiple observation space
30 | #assert isinstance(env.observation_space, PySC2RawObservation)
31 | self._use_spatial_features = use_spatial_features
32 | self._use_game_progress = use_game_progress
33 | self._dc = DataContext()
34 |
35 | # nonspatial features
36 | self._unit_count_feature = UnitTypeCountFeature(
37 | type_list=[UNIT_TYPE.ZERG_LARVA.value,
38 | UNIT_TYPE.ZERG_DRONE.value,
39 | UNIT_TYPE.ZERG_ZERGLING.value,
40 | UNIT_TYPE.ZERG_BANELING.value,
41 | UNIT_TYPE.ZERG_ROACH.value,
42 | UNIT_TYPE.ZERG_ROACHBURROWED.value,
43 | UNIT_TYPE.ZERG_RAVAGER.value,
44 | UNIT_TYPE.ZERG_HYDRALISK.value,
45 | UNIT_TYPE.ZERG_LURKERMP.value,
46 | UNIT_TYPE.ZERG_LURKERMPBURROWED.value,
47 | #UNIT_TYPE.ZERG_VIPER.value,
48 | UNIT_TYPE.ZERG_MUTALISK.value,
49 | UNIT_TYPE.ZERG_CORRUPTOR.value,
50 | UNIT_TYPE.ZERG_BROODLORD.value,
51 | #UNIT_TYPE.ZERG_SWARMHOSTMP.value,
52 | UNIT_TYPE.ZERG_LOCUSTMP.value,
53 | #UNIT_TYPE.ZERG_INFESTOR.value,
54 | UNIT_TYPE.ZERG_ULTRALISK.value,
55 | UNIT_TYPE.ZERG_BROODLING.value,
56 | UNIT_TYPE.ZERG_OVERLORD.value,
57 | UNIT_TYPE.ZERG_OVERSEER.value,
58 | #UNIT_TYPE.ZERG_CHANGELING.value,
59 | UNIT_TYPE.ZERG_QUEEN.value],
60 | use_regions=use_regions
61 | )
62 | self._building_count_feature = UnitTypeCountFeature(
63 | type_list=[UNIT_TYPE.ZERG_SPINECRAWLER.value,
64 | UNIT_TYPE.ZERG_SPORECRAWLER.value,
65 | #UNIT_TYPE.ZERG_NYDUSCANAL.value,
66 | UNIT_TYPE.ZERG_EXTRACTOR.value,
67 | UNIT_TYPE.ZERG_SPAWNINGPOOL.value,
68 | UNIT_TYPE.ZERG_ROACHWARREN.value,
69 | UNIT_TYPE.ZERG_HYDRALISKDEN.value,
70 | UNIT_TYPE.ZERG_HATCHERY.value,
71 | UNIT_TYPE.ZERG_EVOLUTIONCHAMBER.value,
72 | UNIT_TYPE.ZERG_BANELINGNEST.value,
73 | UNIT_TYPE.ZERG_INFESTATIONPIT.value,
74 | UNIT_TYPE.ZERG_SPIRE.value,
75 | UNIT_TYPE.ZERG_ULTRALISKCAVERN.value,
76 | #UNIT_TYPE.ZERG_NYDUSNETWORK.value,
77 | UNIT_TYPE.ZERG_LURKERDENMP.value,
78 | UNIT_TYPE.ZERG_LAIR.value,
79 | UNIT_TYPE.ZERG_HIVE.value,
80 | UNIT_TYPE.ZERG_GREATERSPIRE.value],
81 | use_regions=False
82 | )
83 | self._unit_stat_count_feature = UnitStatCountFeature(
84 | use_regions=use_regions)
85 | self._player_feature = PlayerFeature()
86 | self._score_feature = ScoreFeature()
87 | self._worker_feature = WorkerFeature()
88 | if use_game_progress:
89 | self._game_progress_feature = GameProgressFeature()
90 | self._action_seq_feature = ActionSeqFeature(self.action_space.n,
91 | action_seq_len)
92 | n_dims = sum([
93 | self._unit_stat_count_feature.num_dims,
94 | self._unit_count_feature.num_dims,
95 | self._building_count_feature.num_dims,
96 | self._player_feature.num_dims,
97 | self._score_feature.num_dims,
98 | self._worker_feature.num_dims,
99 | self._action_seq_feature.num_dims,
100 | self._game_progress_feature.num_dims if use_game_progress else 0,
101 | self.env.action_space.n if isinstance(self.env.action_space,
102 | MaskDiscrete) else 0
103 | ])
104 |
105 | # spatial features
106 | if use_spatial_features:
107 | resolution = self.env.observation_space.space_attr["minimap"][1]
108 | self._unit_type_count_map_feature = UnitTypeCountMapFeature(
109 | type_map={UNIT_TYPE.ZERG_DRONE.value: 0,
110 | UNIT_TYPE.ZERG_ZERGLING.value: 1,
111 | UNIT_TYPE.ZERG_ROACH.value: 2,
112 | UNIT_TYPE.ZERG_ROACHBURROWED.value: 2,
113 | UNIT_TYPE.ZERG_HYDRALISK.value: 3,
114 | UNIT_TYPE.ZERG_OVERLORD.value: 4,
115 | UNIT_TYPE.ZERG_OVERSEER.value: 4,
116 | UNIT_TYPE.ZERG_HATCHERY.value: 5,
117 | UNIT_TYPE.ZERG_LAIR.value: 5,
118 | UNIT_TYPE.ZERG_HIVE.value: 5,
119 | UNIT_TYPE.ZERG_EXTRACTOR.value: 6,
120 | UNIT_TYPE.ZERG_QUEEN.value: 7,
121 | UNIT_TYPE.ZERG_RAVAGER.value: 8,
122 | UNIT_TYPE.ZERG_BANELING.value: 9,
123 | UNIT_TYPE.ZERG_LURKERMP.value: 10,
124 | UNIT_TYPE.ZERG_LURKERMPBURROWED.value: 10,
125 | UNIT_TYPE.ZERG_VIPER.value: 11,
126 | UNIT_TYPE.ZERG_MUTALISK.value: 12,
127 | UNIT_TYPE.ZERG_CORRUPTOR.value: 13,
128 | UNIT_TYPE.ZERG_BROODLORD.value: 14,
129 | UNIT_TYPE.ZERG_SWARMHOSTMP.value: 15,
130 | UNIT_TYPE.ZERG_INFESTOR.value: 16,
131 | UNIT_TYPE.ZERG_ULTRALISK.value: 17,
132 | UNIT_TYPE.ZERG_CHANGELING.value: 18,
133 | UNIT_TYPE.ZERG_SPINECRAWLER.value: 19,
134 | UNIT_TYPE.ZERG_SPORECRAWLER.value: 20},
135 | resolution=resolution,
136 | )
137 | self._alliance_count_map_feature = AllianceCountMapFeature(resolution)
138 | n_channels = sum([self._unit_type_count_map_feature.num_channels,
139 | self._alliance_count_map_feature.num_channels])
140 |
141 | if use_spatial_features:
142 | if isinstance(self.env.action_space, MaskDiscrete):
143 | self.observation_space = spaces.Tuple([
144 | spaces.Box(0.0, float('inf'), [n_channels, resolution, resolution],
145 | dtype=np.float32),
146 | spaces.Box(0.0, float('inf'), [n_dims], dtype=np.float32),
147 | spaces.Box(0.0, 1.0, [self.env.action_space.n], dtype=np.float32)
148 | ])
149 | else:
150 | self.observation_space = spaces.Tuple([
151 | spaces.Box(0.0, float('inf'), [n_channels, resolution, resolution],
152 | dtype=np.float32),
153 | spaces.Box(0.0, float('inf'), [n_dims], dtype=np.float32)
154 | ])
155 | else:
156 | if isinstance(self.env.action_space, MaskDiscrete):
157 | self.observation_space = spaces.Tuple([
158 | spaces.Box(0.0, float('inf'), [n_dims], dtype=np.float32),
159 | spaces.Box(0.0, 1.0, [self.env.action_space.n], dtype=np.float32)
160 | ])
161 | else:
162 | self.observation_space = spaces.Box(0.0, float('inf'), [n_dims],
163 | dtype=np.float32)
164 |
165 | def step(self, action):
166 | self._action_seq_feature.push_action(action)
167 | observation, reward, done, info = self.env.step(action)
168 | self._dc.update(observation)
169 | return self._observation(observation), reward, done, info
170 |
171 | def reset(self, **kwargs):
172 | observation = self.env.reset()
173 | self._dc.reset(observation)
174 | self._action_seq_feature.reset()
175 | return self._observation(observation)
176 |
177 | @property
178 | def action_names(self):
179 | if not hasattr(self.env, 'action_names'):
180 | raise NotImplementedError
181 | return self.env.action_names
182 |
183 | @property
184 | def player_position(self):
185 | if not hasattr(self.env, 'player_position'):
186 | raise NotImplementedError
187 | return self.env.player_position
188 |
189 | def _observation(self, observation):
190 | need_flip = True if self.env.player_position == 0 else False
191 |
192 | # nonspatial features
193 | unit_type_feat = self._unit_count_feature.features(observation, need_flip)
194 | building_type_feat = self._building_count_feature.features(observation,
195 | need_flip)
196 | unit_stat_feat = self._unit_stat_count_feature.features(observation,
197 | need_flip)
198 | player_feat = self._player_feature.features(observation)
199 | score_feat = self._score_feature.features(observation)
200 | worker_feat = self._worker_feature.features(self._dc)
201 | if self._use_game_progress:
202 | game_progress_feat = self._game_progress_feature.features(observation)
203 | action_seq_feat = self._action_seq_feature.features()
204 | nonspatial_feat = np.concatenate([
205 | unit_type_feat,
206 | building_type_feat,
207 | unit_stat_feat,
208 | player_feat,
209 | score_feat,
210 | worker_feat,
211 | action_seq_feat,
212 | game_progress_feat if self._use_game_progress \
213 | else np.array([], dtype=np.float32),
214 | np.array(observation['action_mask'], dtype=np.float32) \
215 | if isinstance(self.env.action_space, MaskDiscrete) \
216 | else np.array([], dtype=np.float32)
217 | ])
218 |
219 | # spatial features
220 | if self._use_spatial_features:
221 | ally_map_feat = self._alliance_count_map_feature.features(
222 | observation, need_flip)
223 | type_map_feat = self._unit_type_count_map_feature.features(
224 | observation, need_flip)
225 | spatial_feat = np.concatenate([ally_map_feat, type_map_feat])
226 |
227 | # return features
228 | if self._use_spatial_features:
229 | if isinstance(self.env.action_space, MaskDiscrete):
230 | return (spatial_feat, nonspatial_feat, observation['action_mask'])
231 | else:
232 | return (spatial_feat, nonspatial_feat)
233 | else:
234 | if isinstance(self.env.action_space, MaskDiscrete):
235 | return (nonspatial_feat, observation['action_mask'])
236 | else:
237 | return nonspatial_feat
238 |
239 |
240 | class ZergPlayerObservationWrapper(ZergObservationWrapper):
241 |
242 | def __init__(self, player, **kwargs):
243 | self._warn_double_wrap = lambda *args: None
244 | self._player = player
245 | super(ZergPlayerObservationWrapper, self).__init__(**kwargs)
246 |
247 | def step(self, action):
248 | self._action_seq_feature.push_action(action[self._player])
249 | observation, reward, done, info = self.env.step(action)
250 | self._dc.update(observation[self._player])
251 | observation[self._player] = self._observation(observation[self._player])
252 | return observation, reward, done, info
253 |
254 | def reset(self, **kwargs):
255 | observation = self.env.reset()
256 | self._dc.reset(observation[self._player])
257 | self._action_seq_feature.reset()
258 | observation[self._player] = self._observation(observation[self._player])
259 | return observation
260 |
--------------------------------------------------------------------------------
/sc2learner/agents/dqn_agent.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | import os
6 | import sys
7 | import time
8 | import struct
9 | import random
10 | import math
11 | from copy import deepcopy
12 | import queue
13 | from threading import Thread
14 | from collections import deque
15 | import io
16 | import zmq
17 |
18 | import numpy as np
19 | import torch
20 | import torch.nn as nn
21 | import torch.nn.functional as F
22 | from torch.autograd import Variable
23 | import torch.optim as optim
24 | from gym.spaces import prng
25 | from gym.spaces.discrete import Discrete
26 | from gym import spaces
27 |
28 | from sc2learner.agents.replay_memory import Transition
29 | from sc2learner.agents.replay_memory import RemoteReplayMemory
30 | from sc2learner.utils.utils import tprint
31 |
32 |
33 | class DQNAgent(object):
34 |
35 | def __init__(self, network, action_space, init_model_path=None):
36 | assert type(action_space) == spaces.Discrete
37 | self._action_space = action_space
38 | self._network = network
39 | if init_model_path is not None:
40 | self.load_params(torch.load(init_model_path,
41 | map_location=lambda storage, loc: storage))
42 | if torch.cuda.device_count() > 1:
43 | self._network = nn.DataParallel(self._network)
44 | if torch.cuda.is_available(): self._network.cuda()
45 | self._optimizer = None
46 | self._target_network = None
47 | self._num_optim_steps = 0
48 |
49 | def act(self, observation, eps=0):
50 | self._network.eval()
51 | if random.uniform(0, 1) >= eps:
52 | observation = torch.from_numpy(np.expand_dims(observation, 0))
53 | if torch.cuda.is_available():
54 | observation = observation.pin_memory().cuda(non_blocking=True)
55 | with torch.no_grad():
56 | q = self._network(observation)
57 | action = q.data.max(1)[1].item()
58 | return action
59 | else:
60 | return self._action_space.sample()
61 |
62 | def optimize_step(self,
63 | obs_batch,
64 | next_obs_batch,
65 | action_batch,
66 | reward_batch,
67 | done_batch,
68 | mc_return_batch,
69 | discount,
70 | mmc_beta,
71 | gradient_clipping,
72 | adam_eps,
73 | learning_rate,
74 | target_update_interval):
75 | # create optimizer
76 | if self._optimizer is None:
77 | self._optimizer = optim.Adam(self._network.parameters(),
78 | eps=adam_eps,
79 | lr=learning_rate)
80 | # create target network
81 | if self._target_network is None:
82 | self._target_network = deepcopy(self._network)
83 | if torch.cuda.is_available(): self._target_network.cuda()
84 | self._target_network.eval()
85 |
86 | # update target network
87 | if self._num_optim_steps % target_update_interval == 0:
88 | self._target_network.load_state_dict(self._network.state_dict())
89 |
90 | # move to gpu
91 | if torch.cuda.is_available():
92 | obs_batch = obs_batch.cuda(non_blocking=True)
93 | next_obs_batch = next_obs_batch.cuda(non_blocking=True)
94 | action_batch = action_batch.cuda(non_blocking=True)
95 | reward_batch = reward_batch.cuda(non_blocking=True)
96 | mc_return_batch = mc_return_batch.cuda(non_blocking=True)
97 | done_batch = done_batch.cuda(non_blocking=True)
98 |
99 | # compute max-q target
100 | self._network.eval()
101 | with torch.no_grad():
102 | q_next_target = self._target_network(next_obs_batch)
103 | q_next = self._network(next_obs_batch)
104 | futures = q_next_target.gather(
105 | 1, q_next.max(dim=1)[1].view(-1, 1)).squeeze()
106 | futures = futures * (1 - done_batch)
107 | target_q = reward_batch + discount * futures
108 | target_q = target_q * mmc_beta + (1.0 - mmc_beta) * mc_return_batch
109 |
110 | # define loss
111 | self._network.train()
112 | q = self._network(obs_batch).gather(1, action_batch.view(-1, 1)).squeeze()
113 | loss = F.mse_loss(q, target_q.detach())
114 |
115 | # compute gradient and update parameters
116 | self._optimizer.zero_grad()
117 | loss.backward()
118 | for param in self._network.parameters():
119 | param.grad.data.clamp_(-gradient_clipping, gradient_clipping)
120 | self._optimizer.step()
121 | self._num_optim_steps += 1
122 | return loss.data.item()
123 |
124 | def reset(self):
125 | pass
126 |
127 | def load_params(self, state_dict):
128 | self._network.load_state_dict(state_dict)
129 |
130 | def read_params(self):
131 | if torch.cuda.device_count() > 1:
132 | return self._network.module.state_dict()
133 | else:
134 | return self._network.state_dict()
135 |
136 |
137 | class DQNActor(object):
138 |
139 | def __init__(self,
140 | memory_size,
141 | memory_warmup_size,
142 | env,
143 | network,
144 | discount,
145 | send_freq=4.0,
146 | ports=("5700", "5701", "5702"),
147 | learner_ip="localhost"):
148 | assert type(env.action_space) == spaces.Discrete
149 | assert len(ports) == 3
150 | self._env = env
151 | self._discount = discount
152 | self._epsilon = 1.0
153 |
154 | self._agent = DQNAgent(network, env.action_space)
155 | self._replay_memory = RemoteReplayMemory(
156 | is_server=False,
157 | memory_size=memory_size,
158 | memory_warmup_size=memory_warmup_size,
159 | send_freq=send_freq,
160 | ports=ports[:2],
161 | server_ip=learner_ip)
162 |
163 | self._zmq_context = zmq.Context()
164 | self._model_requestor = self._zmq_context.socket(zmq.REQ)
165 | self._model_requestor.connect("tcp://%s:%s" % (learner_ip, ports[2]))
166 |
167 | def run(self):
168 | while True:
169 | # fetch model
170 | t = time.time()
171 | self._update_model()
172 | tprint("Update model time: %f eps: %f" % (time.time() - t, self._epsilon))
173 | # rollout
174 | t = time.time()
175 | self._rollout()
176 | tprint("Rollout time: %f" % (time.time() - t))
177 |
178 | def _rollout(self):
179 | rollout, done = [], False
180 | observation = self._env.reset()
181 | while not done:
182 | action = self._agent.act(observation, eps=self._epsilon)
183 | next_observation, reward, done, info = self._env.step(action)
184 | rollout.append(
185 | (observation, action, reward, next_observation, done))
186 | observation = next_observation
187 |
188 | discounted_return = 0
189 | for transition in reversed(rollout):
190 | reward = transition[2]
191 | discounted_return = discounted_return * self._discount + reward
192 | self._replay_memory.push(*transition, discounted_return)
193 |
194 | def _update_model(self):
195 | self._model_requestor.send_string("request model")
196 | file_object = io.BytesIO(self._model_requestor.recv_pyobj())
197 | self._agent.load_params(
198 | torch.load(file_object, map_location=lambda storage, loc: storage))
199 | self._epsilon = self._model_requestor.recv_pyobj()
200 |
201 |
202 | class DQNLearner(object):
203 |
204 | def __init__(self,
205 | network,
206 | action_space,
207 | memory_size,
208 | memory_warmup_size,
209 | discount,
210 | eps_start,
211 | eps_end,
212 | eps_decay_steps,
213 | eps_decay_steps2,
214 | batch_size,
215 | mmc_beta,
216 | gradient_clipping,
217 | adam_eps,
218 | learning_rate,
219 | target_update_interval,
220 | checkpoint_dir,
221 | checkpoint_interval,
222 | print_interval,
223 | ports=("5700", "5701", "5702"),
224 | init_model_path=None):
225 | assert type(action_space) == spaces.Discrete
226 | self._agent = DQNAgent(network, action_space)
227 | self._replay_memory = RemoteReplayMemory(
228 | is_server=True,
229 | memory_size=memory_size,
230 | memory_warmup_size=memory_warmup_size,
231 | ports=ports[:2])
232 | if init_model_path is not None:
233 | self._agent.load_params(
234 | torch.load(init_model_path,
235 | map_location=lambda storage, loc: storage))
236 | self._model_params = self._agent.read_params()
237 |
238 | self._batch_size = batch_size
239 | self._mmc_beta = mmc_beta
240 | self._gradient_clipping = gradient_clipping
241 | self._adam_eps = adam_eps
242 | self._learning_rate = learning_rate
243 | self._target_update_interval = target_update_interval
244 | self._checkpoint_dir = checkpoint_dir
245 | self._checkpoint_interval = checkpoint_interval
246 | self._print_interval = print_interval
247 | self._discount = discount
248 | self._eps_start = eps_start
249 | self._eps_end = eps_end
250 | self._eps_decay_steps = eps_decay_steps
251 | self._eps_decay_steps2 = eps_decay_steps2
252 | self._epsilon = eps_start
253 |
254 | self._zmq_context = zmq.Context()
255 | self._reply_model_thread = Thread(
256 | target=self._reply_model, args=(self._zmq_context, ports[2]))
257 | self._reply_model_thread.start()
258 |
259 | def run(self):
260 | batch_queue = queue.Queue(8)
261 | batch_thread = Thread(target=self._prepare_batch,
262 | args=(batch_queue, self._batch_size,))
263 | batch_thread.start()
264 |
265 | updates, loss, total_rollout_frames = 0, [], 0
266 | time_start = time.time()
267 | while True:
268 | updates += 1
269 | observation, next_observation, action, reward, done, mc_return = \
270 | batch_queue.get()
271 | self._epsilon = self._schedule_epsilon(updates)
272 | loss.append(self._agent.optimize_step(
273 | obs_batch=observation,
274 | next_obs_batch=next_observation,
275 | action_batch=action,
276 | reward_batch=reward,
277 | done_batch=done,
278 | mc_return_batch=mc_return,
279 | discount=self._discount,
280 | mmc_beta=self._mmc_beta,
281 | gradient_clipping=self._gradient_clipping,
282 | adam_eps=self._adam_eps,
283 | learning_rate=self._learning_rate,
284 | target_update_interval=self._target_update_interval))
285 | self._model_params = self._agent.read_params()
286 | if updates % self._checkpoint_interval == 0:
287 | ckpt_path = os.path.join(self._checkpoint_dir,
288 | 'checkpoint-%d' % updates)
289 | self._save_checkpoint(ckpt_path)
290 | if updates % self._print_interval == 0:
291 | time_elapsed = time.time() - time_start
292 | train_fps = self._print_interval * self._batch_size / time_elapsed
293 | rollout_fps = (self._replay_memory.total - total_rollout_frames) \
294 | / time_elapsed
295 | loss_mean = np.mean(loss)
296 | tprint("Update: %d Train-fps: %.1f Rollout-fps: %.1f "
297 | "Loss: %.5f Epsilon: %.5f Time: %.1f" % (updates, train_fps,
298 | rollout_fps, loss_mean, self._epsilon, time_elapsed))
299 | time_start, loss = time.time(), []
300 | total_rollout_frames = self._replay_memory.total
301 |
302 | def _prepare_batch(self, batch_queue, batch_size):
303 | while True:
304 | transitions = self._replay_memory.sample(batch_size)
305 | batch = self._transitions_to_batch(transitions)
306 | batch_queue.put(batch)
307 |
308 | def _transitions_to_batch(self, transitions):
309 | batch = Transition(*zip(*transitions))
310 | observation = torch.from_numpy(np.stack(batch.observation))
311 | next_observation = torch.from_numpy(np.stack(batch.next_observation))
312 | reward = torch.FloatTensor(batch.reward)
313 | action = torch.LongTensor(batch.action)
314 | done = torch.Tensor(batch.done)
315 | mc_return = torch.FloatTensor(batch.mc_return)
316 |
317 | if torch.cuda.is_available():
318 | observation = observation.pin_memory()
319 | next_observation = next_observation.pin_memory()
320 | action = action.pin_memory()
321 | reward = reward.pin_memory()
322 | mc_return = mc_return.pin_memory()
323 | done = done.pin_memory()
324 |
325 | return observation, next_observation, action, reward, done, mc_return
326 |
327 | def _save_checkpoint(self, checkpoint_path):
328 | torch.save(self._model_params, checkpoint_path)
329 |
330 | def _schedule_epsilon(self, steps):
331 | if steps < self._eps_decay_steps:
332 | return self._eps_start - (self._eps_start - self._eps_end) * \
333 | steps / self._eps_decay_steps
334 | elif steps < self._eps_decay_steps2:
335 | return self._eps_end - (self._eps_end - 0.01) * \
336 | (steps - self._eps_decay_steps) / self._eps_decay_steps2
337 | else:
338 | return 0.01
339 |
340 | def _reply_model(self, zmq_context, port):
341 | receiver = zmq_context.socket(zmq.REP)
342 | receiver.bind("tcp://*:%s" % port)
343 | while True:
344 | assert receiver.recv_string() == "request model"
345 | f = io.BytesIO()
346 | torch.save(self._model_params, f)
347 | receiver.send_pyobj(f.getvalue(), zmq.SNDMORE)
348 | receiver.send_pyobj(self._epsilon)
349 |
--------------------------------------------------------------------------------
/sc2learner/envs/actions/combat.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | from collections import namedtuple
6 |
7 | from s2clientprotocol import sc2api_pb2 as sc_pb
8 | from pysc2.lib.typeenums import UNIT_TYPEID as UNIT_TYPE
9 | from pysc2.lib.typeenums import ABILITY_ID as ABILITY
10 | from pysc2.lib.typeenums import UPGRADE_ID as UPGRADE
11 |
12 | from sc2learner.envs.actions.function import Function
13 | import sc2learner.envs.common.utils as utils
14 | from sc2learner.envs.common.const import ATTACK_FORCE
15 | from sc2learner.envs.common.const import ALLY_TYPE
16 | from sc2learner.envs.common.const import PRIORITIZED_ATTACK
17 |
18 |
19 | Region = namedtuple('Region', ('ranges', 'rally_point_a', 'rally_point_b'))
20 |
21 |
22 | class CombatActions(object):
23 |
24 | def __init__(self):
25 | self._regions = [
26 | Region([(0, 0, 200, 176)], (161.5, 21.5), (38.5, 122.5)),
27 | Region([(0, 88, 80, 176)], (68, 108), (68, 108)),
28 | Region([(80, 88, 120, 176)], (100, 113.5), (100, 113.5)),
29 | Region([(120, 88, 200, 176)], (147.5, 113.5), (147.5, 113.5)),
30 | Region([(0, 55, 80, 88)], (36.5, 76.5), (36.5, 76.5)),
31 | Region([(80, 55, 120, 88)], (100, 71.5), (100, 71.5)),
32 | Region([(120, 55, 200, 88)], (163.5, 66.5), (163.5, 66.5)),
33 | Region([(0, 0, 80, 55)], (52.5, 30), (52.5, 30)),
34 | Region([(80, 0, 120, 55)], (100, 30), (100, 30)),
35 | Region([(120, 0, 200, 55)], (133, 36), (133, 36))
36 | ]
37 | self._flip_region = lambda r_id: 10 - r_id if r_id > 0 else r_id
38 |
39 | self._attack_tasks = {}
40 |
41 | def reset(self):
42 | self._attack_tasks.clear()
43 |
44 | def action(self, source_region_id, target_region_id):
45 | assert source_region_id < len(self._regions)
46 | assert target_region_id < len(self._regions)
47 | return Function(
48 | name=("combats_in_region_%d_attack_region_%d" %
49 | (source_region_id, target_region_id)),
50 | function=self._attack_region(source_region_id, target_region_id),
51 | is_valid=self._is_valid_attack_region(source_region_id, target_region_id)
52 | )
53 |
54 | @property
55 | def num_regions(self):
56 | return len(self._regions)
57 |
58 | @property
59 | def action_rally_new_combat_units(self):
60 | return Function(name="rally_new_combat_units",
61 | function=self._rally_new_combat_units,
62 | is_valid=self._is_valid_rally_new_combat_units)
63 |
64 | @property
65 | def action_framewise_rally_and_attack(self):
66 | return Function(name="framewise_rally_and_attack",
67 | function=self._framewise_rally_and_attack,
68 | is_valid=lambda dc: True)
69 |
70 | def _attack_region(self, source_region_id, target_region_id):
71 |
72 | def act(dc):
73 | flip = True if self._player_position(dc) == 0 else False
74 | src_id = self._flip_region(source_region_id) if flip else source_region_id
75 | tgt_id = self._flip_region(target_region_id) if flip else target_region_id
76 | combat_unit = [u for u in dc.combat_units if self._is_in_region(u, src_id)]
77 | self._set_attack_task(combat_unit, tgt_id)
78 | return []
79 |
80 | return act
81 |
82 | def _is_valid_attack_region(self, source_region_id, target_region_id):
83 |
84 | def is_valid(dc):
85 | flip = True if self._player_position(dc) == 0 else False
86 | src_id = self._flip_region(source_region_id) if flip else source_region_id
87 | combat_unit = [u for u in dc.combat_units if self._is_in_region(u, src_id)]
88 | return len(combat_unit) >= 3
89 |
90 | return is_valid
91 |
92 | def _rally_new_combat_units(self, dc):
93 | new_combat_units = [u for u in dc.combat_units if dc.is_new_unit(u)]
94 | if self._player_position(dc) == 0:
95 | self._set_attack_task(new_combat_units, 1)
96 | else:
97 | self._set_attack_task(new_combat_units, 9)
98 | return []
99 |
100 | def _is_valid_rally_new_combat_units(self, dc):
101 | new_combat_units = [u for u in dc.combat_units if dc.is_new_unit(u)]
102 | if len(new_combat_units) > 0: return True
103 | else: return False
104 |
105 | def _framewise_rally_and_attack(self, dc):
106 | actions = []
107 | for region_id in range(len(self._regions)):
108 | units_with_task = [u for u in dc.combat_units
109 | if (u.tag in self._attack_tasks and
110 | self._attack_tasks[u.tag] == region_id)]
111 | if len(units_with_task) > 0:
112 | target_enemies = [
113 | u for u in dc.units_of_alliance(ALLY_TYPE.ENEMY.value)
114 | if self._is_in_region(u, region_id)
115 | ]
116 | if len(target_enemies) > 0:
117 | actions.extend(
118 | self._micro_attack(units_with_task, target_enemies, dc))
119 | else:
120 | if self._player_position(dc) == 0:
121 | rally_point = self._regions[region_id].rally_point_a
122 | else:
123 | rally_point = self._regions[region_id].rally_point_b
124 | actions.extend(self._micro_rally(units_with_task, rally_point, dc))
125 | return actions
126 |
127 | def _micro_attack(self, combat_units, enemy_units, dc):
128 |
129 | def prioritized_attack(unit, target_units):
130 | assert len(target_units) > 0
131 | prioritized_target_units = [u for u in target_units
132 | if u.unit_type in PRIORITIZED_ATTACK]
133 | if len(prioritized_target_units) > 0:
134 | closest_target = utils.closest_unit(unit, prioritized_target_units)
135 | else:
136 | closest_target = utils.closest_unit(unit, target_units)
137 | target_pos = (closest_target.float_attr.pos_x,
138 | closest_target.float_attr.pos_y)
139 | return self._unit_attack(unit, target_pos, dc)
140 |
141 | def flee_or_fight(unit, target_units):
142 | assert len(target_units) > 0
143 | closest_target = utils.closest_unit(unit, target_units)
144 | closest_dist = utils.closest_distance(unit, enemy_units)
145 | strongest_health = utils.strongest_health(combat_units)
146 | if (closest_dist < 5.0 and
147 | unit.float_attr.health / unit.float_attr.health_max < 0.3 and
148 | strongest_health > 0.9):
149 | x = unit.float_attr.pos_x + (unit.float_attr.pos_x - \
150 | closest_target.float_attr.pos_x) * 0.2
151 | y = unit.float_attr.pos_y + (unit.float_attr.pos_y - \
152 | closest_target.float_attr.pos_y) * 0.2
153 | target_pos = (x, y)
154 | return self._unit_move(unit, target_pos, dc)
155 | else:
156 | target_pos = (closest_target.float_attr.pos_x,
157 | closest_target.float_attr.pos_y)
158 | return self._unit_attack(unit, target_pos, dc)
159 |
160 | air_combat_units = [
161 | u for u in combat_units
162 | if (ATTACK_FORCE[u.unit_type].can_attack_air and
163 | not ATTACK_FORCE[u.unit_type].can_attack_ground)
164 | ]
165 | ground_combat_units = [
166 | u for u in combat_units
167 | if (not ATTACK_FORCE[u.unit_type].can_attack_air and
168 | ATTACK_FORCE[u.unit_type].can_attack_ground)
169 | ]
170 | air_ground_combat_units = [
171 | u for u in combat_units
172 | if (ATTACK_FORCE[u.unit_type].can_attack_air and
173 | ATTACK_FORCE[u.unit_type].can_attack_ground)
174 | ]
175 | air_enemy_units = [u for u in enemy_units if u.bool_attr.is_flying]
176 | ground_enemy_units = [u for u in enemy_units if not u.bool_attr.is_flying]
177 | actions = []
178 | for unit in air_combat_units:
179 | if len(air_enemy_units) > 0:
180 | actions.extend(prioritized_attack(unit, air_enemy_units))
181 | for unit in ground_combat_units:
182 | if len(ground_enemy_units) > 0:
183 | actions.extend(prioritized_attack(unit, ground_enemy_units))
184 | for unit in air_ground_combat_units:
185 | if len(enemy_units) > 0:
186 | actions.extend(prioritized_attack(unit, enemy_units))
187 | return actions
188 |
189 | def _micro_rally(self, units, rally_point, dc):
190 | actions = []
191 | for unit in units:
192 | actions.extend(self._unit_attack(unit, rally_point, dc))
193 | return actions
194 |
195 | def _unit_attack(self, unit, target_pos, dc):
196 | # move with attack
197 | if unit.unit_type == UNIT_TYPE.ZERG_RAVAGER.value:
198 | return self._ravager_unit_attack(unit, target_pos, dc)
199 | #elif (unit.unit_type == UNIT_TYPE.ZERG_ROACH.value or
200 | #unit.unit_type == UNIT_TYPE.ZERG_ROACHBURROWED.value):
201 | #return self._roach_unit_attack(unit, target_pos, dc)
202 | elif (unit.unit_type == UNIT_TYPE.ZERG_LURKERMP.value or
203 | unit.unit_type == UNIT_TYPE.ZERG_LURKERMPBURROWED.value):
204 | return self._lurker_unit_attack(unit, target_pos, dc)
205 | else:
206 | return self._normal_unit_attack(unit, target_pos)
207 |
208 | def _unit_move(self, unit, target_pos, dc):
209 | # move without attack
210 | if unit.unit_type == UNIT_TYPE.ZERG_LURKERMPBURROWED.value:
211 | return self._lurker_unit_move(unit, target_pos)
212 | #elif unit.unit_type == UNIT_TYPE.ZERG_ROACHBURROWED.value:
213 | #return self._roach_unit_move(unit, target_pos, dc)
214 | else:
215 | return self._normal_unit_move(unit, target_pos)
216 |
217 | def _normal_unit_attack(self, unit, target_pos):
218 | action = sc_pb.Action()
219 | action.action_raw.unit_command.unit_tags.append(unit.tag)
220 | action.action_raw.unit_command.ability_id = ABILITY.ATTACK_ATTACK.value
221 | action.action_raw.unit_command.target_world_space_pos.x = target_pos[0]
222 | action.action_raw.unit_command.target_world_space_pos.y = target_pos[1]
223 | return [action]
224 |
225 | def _normal_unit_move(self, unit, target_pos):
226 | action = sc_pb.Action()
227 | action.action_raw.unit_command.unit_tags.append(unit.tag)
228 | action.action_raw.unit_command.ability_id = ABILITY.MOVE.value
229 | action.action_raw.unit_command.target_world_space_pos.x = target_pos[0]
230 | action.action_raw.unit_command.target_world_space_pos.y = target_pos[1]
231 | return [action]
232 |
233 | def _roach_unit_attack(self, unit, target_pos, dc):
234 | actions = []
235 | ground_enemies = [u for u in dc.units_of_alliance(ALLY_TYPE.ENEMY.value)
236 | if not u.bool_attr.is_flying]
237 | if len(utils.units_nearby(unit, ground_enemies, max_distance=4)) > 0:
238 | if unit.unit_type == UNIT_TYPE.ZERG_ROACHBURROWED.value:
239 | action = sc_pb.Action()
240 | action.action_raw.unit_command.unit_tags.append(unit.tag)
241 | action.action_raw.unit_command.ability_id = ABILITY.BURROWUP_ROACH.value
242 | actions.append(action)
243 | actions.extend(self._normal_unit_attack(unit, target_pos))
244 | else:
245 | actions.extend(self._roach_unit_move(unit, target_pos, dc))
246 | return actions
247 |
248 | def _roach_unit_move(self, unit, target_pos, dc):
249 | actions = []
250 | if (UPGRADE.TUNNELINGCLAWS.value in dc.upgraded_techs and
251 | UPGRADE.BURROW.value in dc.upgraded_techs and
252 | unit.unit_type == UNIT_TYPE.ZERG_ROACH.value):
253 | action = sc_pb.Action()
254 | action.action_raw.unit_command.unit_tags.append(unit.tag)
255 | action.action_raw.unit_command.ability_id = ABILITY.BURROWDOWN_ROACH.value
256 | actions.append(action)
257 | actions.extend(self._normal_unit_move(unit, target_pos))
258 | return actions
259 |
260 | def _lurker_unit_attack(self, unit, target_pos, dc):
261 | actions = []
262 | ground_enemies = [u for u in dc.units_of_alliance(ALLY_TYPE.ENEMY.value)
263 | if not u.bool_attr.is_flying]
264 | if len(utils.units_nearby(unit, ground_enemies, max_distance=8)) > 0:
265 | if unit.unit_type == UNIT_TYPE.ZERG_LURKERMP.value:
266 | action = sc_pb.Action()
267 | action.action_raw.unit_command.unit_tags.append(unit.tag)
268 | action.action_raw.unit_command.ability_id = \
269 | ABILITY.BURROWDOWN_LURKER.value
270 | actions.append(action)
271 | else:
272 | actions.extend(self._lurker_unit_move(unit, target_pos))
273 | return actions
274 |
275 | def _lurker_unit_move(self, unit, target_pos):
276 | actions = []
277 | if unit.unit_type == UNIT_TYPE.ZERG_LURKERMPBURROWED.value:
278 | action = sc_pb.Action()
279 | action.action_raw.unit_command.unit_tags.append(unit.tag)
280 | action.action_raw.unit_command.ability_id = ABILITY.BURROWUP_LURKER.value
281 | actions.append(action)
282 | actions.extend(self._normal_unit_move(unit, target_pos))
283 | return actions
284 |
285 | def _ravager_unit_attack(self, unit, target_pos, dc):
286 | actions = []
287 | ground_units = [u for u in dc.units_of_alliance(ALLY_TYPE.SELF.value)
288 | if not u.bool_attr.is_flying]
289 | if len(utils.units_nearby(target_pos, ground_units, max_distance=2)) == 0:
290 | action = sc_pb.Action()
291 | action.action_raw.unit_command.unit_tags.append(unit.tag)
292 | action.action_raw.unit_command.ability_id = \
293 | ABILITY.EFFECT_CORROSIVEBILE.value
294 | action.action_raw.unit_command.target_world_space_pos.x = target_pos[0]
295 | action.action_raw.unit_command.target_world_space_pos.y = target_pos[1]
296 | actions.append(action)
297 | actions.extend(self._normal_unit_attack(unit, target_pos))
298 | return actions
299 |
300 | def _set_attack_task(self, units, target_region_id):
301 | for u in units:
302 | self._attack_tasks[u.tag] = target_region_id
303 |
304 | def _is_in_region(self, unit, region_id):
305 | return any([(unit.float_attr.pos_x >= r[0] and
306 | unit.float_attr.pos_x < r[2] and
307 | unit.float_attr.pos_y >= r[1] and
308 | unit.float_attr.pos_y < r[3])
309 | for r in self._regions[region_id].ranges])
310 |
311 | def _player_position(self, dc):
312 | if dc.init_base_pos[0] < 100: return 0
313 | else: return 1
314 |
--------------------------------------------------------------------------------
/sc2learner/agents/ppo_agent.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | import os
6 | import time
7 | from collections import deque
8 | from queue import Queue
9 | import queue
10 | from threading import Thread
11 | import time
12 | import random
13 | import joblib
14 |
15 | import numpy as np
16 | import tensorflow as tf
17 | import zmq
18 | from gym import spaces
19 |
20 | from sc2learner.envs.spaces.mask_discrete import MaskDiscrete
21 | from sc2learner.agents.utils_tf import explained_variance
22 | from sc2learner.utils.utils import tprint
23 |
24 |
25 | class Model(object):
26 | def __init__(self, *, policy, ob_space, ac_space, nbatch_act, nbatch_train,
27 | unroll_length, ent_coef, vf_coef, max_grad_norm, scope_name,
28 | value_clip=False):
29 | sess = tf.get_default_session()
30 |
31 | act_model = policy(sess, scope_name, ob_space, ac_space, nbatch_act, 1,
32 | reuse=False)
33 | train_model = policy(sess, scope_name, ob_space, ac_space, nbatch_train,
34 | unroll_length, reuse=True)
35 |
36 | A = tf.placeholder(shape=(nbatch_train,), dtype=tf.int32)
37 | ADV = tf.placeholder(tf.float32, [None])
38 | R = tf.placeholder(tf.float32, [None])
39 | OLDNEGLOGPAC = tf.placeholder(tf.float32, [None])
40 | OLDVPRED = tf.placeholder(tf.float32, [None])
41 | LR = tf.placeholder(tf.float32, [])
42 | CLIPRANGE = tf.placeholder(tf.float32, [])
43 |
44 | neglogpac = train_model.pd.neglogp(A)
45 | entropy = tf.reduce_mean(train_model.pd.entropy())
46 |
47 | vpred = train_model.vf
48 | vpredclipped = OLDVPRED + tf.clip_by_value(train_model.vf - OLDVPRED,
49 | -CLIPRANGE, CLIPRANGE)
50 | vf_losses1 = tf.square(vpred - R)
51 | if value_clip:
52 | vf_losses2 = tf.square(vpredclipped - R)
53 | vf_loss = .5 * tf.reduce_mean(tf.maximum(vf_losses1, vf_losses2))
54 | else:
55 | vf_loss = .5 * tf.reduce_mean(vf_losses1)
56 | ratio = tf.exp(OLDNEGLOGPAC - neglogpac)
57 | pg_losses = -ADV * ratio
58 | pg_losses2 = -ADV * tf.clip_by_value(ratio, 1.0 - CLIPRANGE,
59 | 1.0 + CLIPRANGE)
60 | pg_loss = tf.reduce_mean(tf.maximum(pg_losses, pg_losses2))
61 | approxkl = .5 * tf.reduce_mean(tf.square(neglogpac - OLDNEGLOGPAC))
62 | clipfrac = tf.reduce_mean(
63 | tf.to_float(tf.greater(tf.abs(ratio - 1.0), CLIPRANGE)))
64 | loss = pg_loss - entropy * ent_coef + vf_loss * vf_coef
65 | params = tf.trainable_variables(scope=scope_name)
66 | grads = tf.gradients(loss, params)
67 | if max_grad_norm is not None:
68 | grads, _grad_norm = tf.clip_by_global_norm(grads, max_grad_norm)
69 | grads = list(zip(grads, params))
70 | trainer = tf.train.AdamOptimizer(learning_rate=LR, epsilon=1e-5)
71 | _train = trainer.apply_gradients(grads)
72 | new_params = [tf.placeholder(p.dtype, shape=p.get_shape()) for p in params]
73 | param_assign_ops = [p.assign(new_p) for p, new_p in zip(params, new_params)]
74 |
75 | def train(lr, cliprange, obs, returns, dones, actions, values, neglogpacs,
76 | states=None):
77 | advs = returns - values
78 | advs = (advs - advs.mean()) / (advs.std() + 1e-8)
79 | if isinstance(ac_space, MaskDiscrete):
80 | td_map = {train_model.X:obs[0], train_model.MASK:obs[-1], A:actions,
81 | ADV:advs, R:returns, LR:lr, CLIPRANGE:cliprange,
82 | OLDNEGLOGPAC:neglogpacs, OLDVPRED:values}
83 | else:
84 | td_map = {train_model.X:obs, A:actions, ADV:advs, R:returns, LR:lr,
85 | CLIPRANGE:cliprange, OLDNEGLOGPAC:neglogpacs, OLDVPRED:values}
86 | if states is not None:
87 | td_map[train_model.STATE] = states
88 | td_map[train_model.DONE] = dones
89 | return sess.run(
90 | [pg_loss, vf_loss, entropy, approxkl, clipfrac, _train],
91 | td_map
92 | )[:-1]
93 | self.loss_names = ['policy_loss', 'value_loss', 'policy_entropy',
94 | 'approxkl', 'clipfrac']
95 |
96 | def save(save_path):
97 | joblib.dump(read_params(), save_path)
98 |
99 | def load(load_path):
100 | loaded_params = joblib.load(load_path)
101 | load_params(loaded_params)
102 |
103 | def read_params():
104 | return sess.run(params)
105 |
106 | def load_params(loaded_params):
107 | sess.run(param_assign_ops,
108 | feed_dict={p : v for p, v in zip(new_params, loaded_params)})
109 |
110 | self.train = train
111 | self.train_model = train_model
112 | self.act_model = act_model
113 | self.step = act_model.step
114 | self.value = act_model.value
115 | self.initial_state = act_model.initial_state
116 | self.save = save
117 | self.load = load
118 | self.read_params = read_params
119 | self.load_params = load_params
120 |
121 | tf.global_variables_initializer().run(session=sess)
122 |
123 |
124 | class PPOActor(object):
125 |
126 | def __init__(self, env, policy, unroll_length, gamma, lam, queue_size=1,
127 | enable_push=True, learner_ip="localhost", port_A="5700",
128 | port_B="5701"):
129 | self._env = env
130 | self._unroll_length = unroll_length
131 | self._lam = lam
132 | self._gamma = gamma
133 | self._enable_push = enable_push
134 |
135 | self._model = Model(policy=policy,
136 | scope_name="model",
137 | ob_space=env.observation_space,
138 | ac_space=env.action_space,
139 | nbatch_act=1,
140 | nbatch_train=unroll_length,
141 | unroll_length=unroll_length,
142 | ent_coef=0.01,
143 | vf_coef=0.5,
144 | max_grad_norm=0.5)
145 | self._obs = env.reset()
146 | self._state = self._model.initial_state
147 | self._done = False
148 | self._cum_reward = 0
149 |
150 | self._zmq_context = zmq.Context()
151 | self._model_requestor = self._zmq_context.socket(zmq.REQ)
152 | self._model_requestor.connect("tcp://%s:%s" % (learner_ip, port_A))
153 | if enable_push:
154 | self._data_queue = Queue(queue_size)
155 | self._push_thread = Thread(target=self._push_data, args=(
156 | self._zmq_context, learner_ip, port_B, self._data_queue))
157 | self._push_thread.start()
158 |
159 | def run(self):
160 | while True:
161 | # fetch model
162 | t = time.time()
163 | self._update_model()
164 | tprint("Update model time: %f" % (time.time() - t))
165 | t = time.time()
166 | # rollout
167 | unroll = self._nstep_rollout()
168 | if self._enable_push:
169 | if self._data_queue.full(): tprint("[WARN]: Actor's queue is full.")
170 | self._data_queue.put(unroll)
171 | tprint("Rollout time: %f" % (time.time() - t))
172 |
173 | def _nstep_rollout(self):
174 | mb_obs, mb_rewards, mb_actions, mb_values, mb_dones, mb_neglogpacs = \
175 | [],[],[],[],[],[]
176 | mb_states, episode_infos = self._state, []
177 | for _ in range(self._unroll_length):
178 | action, value, self._state, neglogpac = self._model.step(
179 | transform_tuple(self._obs, lambda x: np.expand_dims(x, 0)),
180 | self._state,
181 | np.expand_dims(self._done, 0))
182 | mb_obs.append(transform_tuple(self._obs, lambda x: x.copy()))
183 | mb_actions.append(action[0])
184 | mb_values.append(value[0])
185 | mb_neglogpacs.append(neglogpac[0])
186 | mb_dones.append(self._done)
187 | self._obs, reward, self._done, info = self._env.step(action[0])
188 | self._cum_reward += reward
189 | if self._done:
190 | self._obs = self._env.reset()
191 | self._state = self._model.initial_state
192 | episode_infos.append({'r': self._cum_reward})
193 | self._cum_reward = 0
194 | mb_rewards.append(reward)
195 | if isinstance(self._obs, tuple):
196 | mb_obs = tuple(np.asarray(obs, dtype=self._obs[0].dtype)
197 | for obs in zip(*mb_obs))
198 | else:
199 | mb_obs = np.asarray(mb_obs, dtype=self._obs.dtype)
200 | mb_rewards = np.asarray(mb_rewards, dtype=np.float32)
201 | mb_actions = np.asarray(mb_actions)
202 | mb_values = np.asarray(mb_values, dtype=np.float32)
203 | mb_neglogpacs = np.asarray(mb_neglogpacs, dtype=np.float32)
204 | mb_dones = np.asarray(mb_dones, dtype=np.bool)
205 | last_values = self._model.value(
206 | transform_tuple(self._obs, lambda x: np.expand_dims(x, 0)),
207 | self._state,
208 | np.expand_dims(self._done, 0))
209 | mb_returns = np.zeros_like(mb_rewards)
210 | mb_advs = np.zeros_like(mb_rewards)
211 | last_gae_lam = 0
212 | for t in reversed(range(self._unroll_length)):
213 | if t == self._unroll_length - 1:
214 | next_nonterminal = 1.0 - self._done
215 | next_values = last_values[0]
216 | else:
217 | next_nonterminal = 1.0 - mb_dones[t + 1]
218 | next_values = mb_values[t + 1]
219 | delta = mb_rewards[t] + self._gamma * next_values * next_nonterminal - \
220 | mb_values[t]
221 | mb_advs[t] = last_gae_lam = delta + self._gamma * self._lam * \
222 | next_nonterminal * last_gae_lam
223 | mb_returns = mb_advs + mb_values
224 | return (mb_obs, mb_returns, mb_dones, mb_actions, mb_values, mb_neglogpacs,
225 | mb_states, episode_infos)
226 |
227 | def _push_data(self, zmq_context, learner_ip, port_B, data_queue):
228 | sender = zmq_context.socket(zmq.PUSH)
229 | sender.setsockopt(zmq.SNDHWM, 1)
230 | sender.setsockopt(zmq.RCVHWM, 1)
231 | sender.connect("tcp://%s:%s" % (learner_ip, port_B))
232 | while True:
233 | data = data_queue.get()
234 | sender.send_pyobj(data)
235 |
236 | def _update_model(self):
237 | self._model_requestor.send_string("request model")
238 | self._model.load_params(self._model_requestor.recv_pyobj())
239 |
240 |
241 | class PPOLearner(object):
242 |
243 | def __init__(self, env, policy, unroll_length, lr, clip_range, batch_size,
244 | ent_coef=0.01, vf_coef=0.5, max_grad_norm=0.5, queue_size=8,
245 | print_interval=100, save_interval=10000, learn_act_speed_ratio=0,
246 | unroll_split=8, save_dir=None, init_model_path=None,
247 | port_A="5700", port_B="5701"):
248 | assert isinstance(env.action_space, spaces.Discrete)
249 | if isinstance(lr, float): lr = constfn(lr)
250 | else: assert callable(lr)
251 | if isinstance(clip_range, float): clip_range = constfn(clip_range)
252 | else: assert callable(clip_range)
253 | self._lr = lr
254 | self._clip_range=clip_range
255 | self._batch_size = batch_size
256 | self._unroll_length = unroll_length
257 | self._print_interval = print_interval
258 | self._save_interval = save_interval
259 | self._learn_act_speed_ratio = learn_act_speed_ratio
260 | self._save_dir = save_dir
261 |
262 | self._model = Model(policy=policy,
263 | scope_name="model",
264 | ob_space=env.observation_space,
265 | ac_space=env.action_space,
266 | nbatch_act=1,
267 | nbatch_train=batch_size * unroll_length,
268 | unroll_length=unroll_length,
269 | ent_coef=ent_coef,
270 | vf_coef=vf_coef,
271 | max_grad_norm=max_grad_norm)
272 | if init_model_path is not None: self._model.load(init_model_path)
273 | self._model_params = self._model.read_params()
274 | self._unroll_split = unroll_split if self._model.initial_state is None else 1
275 | assert self._unroll_length % self._unroll_split == 0
276 | self._data_queue = deque(maxlen=queue_size * self._unroll_split)
277 | self._data_timesteps = deque(maxlen=200)
278 | self._episode_infos = deque(maxlen=5000)
279 | self._num_unrolls = 0
280 |
281 | self._zmq_context = zmq.Context()
282 | self._pull_data_thread = Thread(
283 | target=self._pull_data,
284 | args=(self._zmq_context, self._data_queue, self._episode_infos,
285 | self._unroll_split, port_B)
286 | )
287 | self._pull_data_thread.start()
288 | self._reply_model_thread = Thread(
289 | target=self._reply_model, args=(self._zmq_context, port_A))
290 | self._reply_model_thread.start()
291 |
292 | def run(self):
293 | #while len(self._data_queue) < self._data_queue.maxlen: time.sleep(1)
294 | while len(self._episode_infos) < self._episode_infos.maxlen / 2:
295 | time.sleep(1)
296 |
297 | batch_queue = Queue(4)
298 | batch_threads = [
299 | Thread(target=self._prepare_batch,
300 | args=(self._data_queue, batch_queue,
301 | self._batch_size * self._unroll_split))
302 | for _ in range(8)
303 | ]
304 | for thread in batch_threads:
305 | thread.start()
306 |
307 | updates, loss = 0, []
308 | time_start = time.time()
309 | while True:
310 | while (self._learn_act_speed_ratio > 0 and
311 | updates * self._batch_size >= \
312 | self._num_unrolls * self._learn_act_speed_ratio):
313 | time.sleep(0.001)
314 | updates += 1
315 | lr_now = self._lr(updates)
316 | clip_range_now = self._clip_range(updates)
317 |
318 | batch = batch_queue.get()
319 | obs, returns, dones, actions, values, neglogpacs, states = batch
320 | loss.append(self._model.train(lr_now, clip_range_now, obs, returns, dones,
321 | actions, values, neglogpacs, states))
322 | self._model_params = self._model.read_params()
323 |
324 | if updates % self._print_interval == 0:
325 | loss_mean = np.mean(loss, axis=0)
326 | batch_steps = self._batch_size * self._unroll_length
327 | time_elapsed = time.time() - time_start
328 | train_fps = self._print_interval * batch_steps / time_elapsed
329 | rollout_fps = len(self._data_timesteps) * self._unroll_length / \
330 | (time.time() - self._data_timesteps[0])
331 | var = explained_variance(values, returns)
332 | avg_reward = safemean([info['r'] for info in self._episode_infos])
333 | tprint("Update: %d Train-fps: %.1f Rollout-fps: %.1f "
334 | "Explained-var: %.5f Avg-reward %.2f Policy-loss: %.5f "
335 | "Value-loss: %.5f Policy-entropy: %.5f Approx-KL: %.5f "
336 | "Clip-frac: %.3f Time: %.1f" % (updates, train_fps, rollout_fps,
337 | var, avg_reward, *loss_mean[:5], time_elapsed))
338 | time_start, loss = time.time(), []
339 |
340 | if self._save_dir is not None and updates % self._save_interval == 0:
341 | os.makedirs(self._save_dir, exist_ok=True)
342 | save_path = os.path.join(self._save_dir, 'checkpoint-%d' % updates)
343 | self._model.save(save_path)
344 | tprint('Saved to %s.' % save_path)
345 |
346 | def _prepare_batch(self, data_queue, batch_queue, batch_size):
347 | while True:
348 | batch = random.sample(data_queue, batch_size)
349 | obs, returns, dones, actions, values, neglogpacs, states = zip(*batch)
350 | if isinstance(obs[0], tuple):
351 | obs = tuple(np.concatenate(ob) for ob in zip(*obs))
352 | else:
353 | obs = np.concatenate(obs)
354 | returns = np.concatenate(returns)
355 | dones = np.concatenate(dones)
356 | actions = np.concatenate(actions)
357 | values = np.concatenate(values)
358 | neglogpacs = np.concatenate(neglogpacs)
359 | states = np.concatenate(states) if states[0] is not None else None
360 | batch_queue.put((obs, returns, dones, actions, values, neglogpacs, states))
361 |
362 | def _pull_data(self, zmq_context, data_queue, episode_infos, unroll_split,
363 | port_B):
364 | receiver = zmq_context.socket(zmq.PULL)
365 | receiver.setsockopt(zmq.RCVHWM, 1)
366 | receiver.setsockopt(zmq.SNDHWM, 1)
367 | receiver.bind("tcp://*:%s" % port_B)
368 | while True:
369 | data = receiver.recv_pyobj()
370 | if unroll_split > 1:
371 | data_queue.extend(list(zip(*(
372 | [list(zip(*transform_tuple(
373 | data[0], lambda x: np.split(x, unroll_split))))] + \
374 | [np.split(arr, unroll_split) for arr in data[1:-2]] + \
375 | [[data[-2] for _ in range(unroll_split)]]
376 | ))))
377 | else:
378 | data_queue.append(data[:-1])
379 | episode_infos.extend(data[-1])
380 | self._data_timesteps.append(time.time())
381 | self._num_unrolls += 1
382 |
383 | def _reply_model(self, zmq_context, port_A):
384 | receiver = zmq_context.socket(zmq.REP)
385 | receiver.bind("tcp://*:%s" % port_A)
386 | while True:
387 | msg = receiver.recv_string()
388 | assert msg == "request model"
389 | receiver.send_pyobj(self._model_params)
390 |
391 |
392 | class PPOAgent(object):
393 |
394 | def __init__(self, env, policy, model_path=None):
395 | assert isinstance(env.action_space, spaces.Discrete)
396 | self._model = Model(policy=policy,
397 | scope_name="model",
398 | ob_space=env.observation_space,
399 | ac_space=env.action_space,
400 | nbatch_act=1,
401 | nbatch_train=1,
402 | unroll_length=1,
403 | ent_coef=0.01,
404 | vf_coef=0.5,
405 | max_grad_norm=0.5)
406 | if model_path is not None:
407 | self._model.load(model_path)
408 | self._state = self._model.initial_state
409 | self._done = False
410 |
411 | def act(self, observation):
412 | action, value, self._state, _ = self._model.step(
413 | transform_tuple(observation, lambda x: np.expand_dims(x, 0)),
414 | self._state,
415 | np.expand_dims(self._done, 0))
416 | return action[0]
417 |
418 | def reset(self):
419 | self._state = self._model.initial_state
420 |
421 |
422 | class PPOSelfplayActor(object):
423 |
424 | def __init__(self, env, policy, unroll_length, gamma, lam, model_cache_size,
425 | model_cache_prob, queue_size=1, prob_latest_opponent=0.0,
426 | init_opponent_pool_filelist=None, freeze_opponent_pool=False,
427 | enable_push=True, learner_ip="localhost", port_A="5700",
428 | port_B="5701"):
429 | assert isinstance(env.action_space, spaces.Discrete)
430 | self._env = env
431 | self._unroll_length = unroll_length
432 | self._lam = lam
433 | self._gamma = gamma
434 | self._prob_latest_opponent = prob_latest_opponent
435 | self._freeze_opponent_pool = freeze_opponent_pool
436 | self._enable_push = enable_push
437 | self._model_cache_prob = model_cache_prob
438 |
439 | self._model = Model(policy=policy,
440 | scope_name="model",
441 | ob_space=env.observation_space,
442 | ac_space=env.action_space,
443 | nbatch_act=1,
444 | nbatch_train=unroll_length,
445 | unroll_length=unroll_length,
446 | ent_coef=0.01,
447 | vf_coef=0.5,
448 | max_grad_norm=0.5)
449 | self._oppo_model = Model(policy=policy,
450 | scope_name="oppo_model",
451 | ob_space=env.observation_space,
452 | ac_space=env.action_space,
453 | nbatch_act=1,
454 | nbatch_train=unroll_length,
455 | unroll_length=unroll_length,
456 | ent_coef=0.01,
457 | vf_coef=0.5,
458 | max_grad_norm=0.5)
459 | self._obs, self._oppo_obs = env.reset()
460 | self._state = self._model.initial_state
461 | self._oppo_state = self._oppo_model.initial_state
462 | self._done = False
463 | self._cum_reward = 0
464 |
465 | self._model_cache = deque(maxlen=model_cache_size)
466 | if init_opponent_pool_filelist is not None:
467 | with open(init_opponent_pool_filelist, 'r') as f:
468 | for model_path in f.readlines():
469 | print(model_path)
470 | self._model_cache.append(joblib.load(model_path.strip()))
471 | self._latest_model = self._oppo_model.read_params()
472 | if len(self._model_cache) == 0:
473 | self._model_cache.append(self._latest_model)
474 | self._update_opponent()
475 |
476 | self._zmq_context = zmq.Context()
477 | self._model_requestor = self._zmq_context.socket(zmq.REQ)
478 | self._model_requestor.connect("tcp://%s:%s" % (learner_ip, port_A))
479 | if enable_push:
480 | self._data_queue = Queue(queue_size)
481 | self._push_thread = Thread(target=self._push_data, args=(
482 | self._zmq_context, learner_ip, port_B, self._data_queue))
483 | self._push_thread.start()
484 |
485 | def run(self):
486 | while True:
487 | t = time.time()
488 | self._update_model()
489 | tprint("Time update model: %f" % (time.time() - t))
490 | t = time.time()
491 | unroll = self._nstep_rollout()
492 | if self._enable_push:
493 | if self._data_queue.full(): tprint("[WARN]: Actor's queue is full.")
494 | self._data_queue.put(unroll)
495 | tprint("Time rollout: %f" % (time.time() - t))
496 |
497 | def _nstep_rollout(self):
498 | mb_obs, mb_rewards, mb_actions, mb_values, mb_dones, mb_neglogpacs = \
499 | [],[],[],[],[],[]
500 | mb_states, episode_infos = self._state, []
501 | for _ in range(self._unroll_length):
502 | action, value, self._state, neglogpac = self._model.step(
503 | transform_tuple(self._obs, lambda x: np.expand_dims(x, 0)),
504 | self._state,
505 | np.expand_dims(self._done, 0))
506 | oppo_action, _, self._oppo_state, _ = self._oppo_model.step(
507 | transform_tuple(self._oppo_obs, lambda x: np.expand_dims(x, 0)),
508 | self._oppo_state,
509 | np.expand_dims(self._done, 0))
510 | mb_obs.append(transform_tuple(self._obs, lambda x: x.copy()))
511 | mb_actions.append(action[0])
512 | mb_values.append(value[0])
513 | mb_neglogpacs.append(neglogpac[0])
514 | mb_dones.append(self._done)
515 | (self._obs, self._oppo_obs), reward, self._done, info = self._env.step(
516 | [action[0], oppo_action[0]])
517 | self._cum_reward += reward
518 | if self._done:
519 | self._obs, self._oppo_obs = self._env.reset()
520 | self._state = self._model.initial_state
521 | self._oppo_state = self._oppo_model.initial_state
522 | self._update_opponent()
523 | episode_infos.append({'r': self._cum_reward})
524 | self._cum_reward = 0
525 | mb_rewards.append(reward)
526 | if isinstance(self._obs, tuple):
527 | mb_obs = tuple(np.asarray(obs, dtype=self._obs[0].dtype)
528 | for obs in zip(*mb_obs))
529 | else:
530 | mb_obs = np.asarray(mb_obs, dtype=self._obs.dtype)
531 | mb_rewards = np.asarray(mb_rewards, dtype=np.float32)
532 | mb_actions = np.asarray(mb_actions)
533 | mb_values = np.asarray(mb_values, dtype=np.float32)
534 | mb_neglogpacs = np.asarray(mb_neglogpacs, dtype=np.float32)
535 | mb_dones = np.asarray(mb_dones, dtype=np.bool)
536 | last_values = self._model.value(
537 | transform_tuple(self._obs, lambda x: np.expand_dims(x, 0)),
538 | self._state,
539 | np.expand_dims(self._done, 0))
540 | mb_returns = np.zeros_like(mb_rewards)
541 | mb_advs = np.zeros_like(mb_rewards)
542 | last_gae_lam = 0
543 | for t in reversed(range(self._unroll_length)):
544 | if t == self._unroll_length - 1:
545 | next_nonterminal = 1.0 - self._done
546 | next_values = last_values[0]
547 | else:
548 | next_nonterminal = 1.0 - mb_dones[t + 1]
549 | next_values = mb_values[t + 1]
550 | delta = mb_rewards[t] + self._gamma * next_values * next_nonterminal - \
551 | mb_values[t]
552 | mb_advs[t] = last_gae_lam = delta + self._gamma * self._lam * \
553 | next_nonterminal * last_gae_lam
554 | mb_returns = mb_advs + mb_values
555 | return (mb_obs, mb_returns, mb_dones, mb_actions, mb_values, mb_neglogpacs,
556 | mb_states, episode_infos)
557 |
558 | def _push_data(self, zmq_context, learner_ip, port_B, data_queue):
559 | sender = zmq_context.socket(zmq.PUSH)
560 | sender.setsockopt(zmq.SNDHWM, 1)
561 | sender.setsockopt(zmq.RCVHWM, 1)
562 | sender.connect("tcp://%s:%s" % (learner_ip, port_B))
563 | while True:
564 | data = data_queue.get()
565 | sender.send_pyobj(data)
566 |
567 | def _update_model(self):
568 | self._model_requestor.send_string("request model")
569 | model_params = self._model_requestor.recv_pyobj()
570 | self._model.load_params(model_params)
571 | if (not self._freeze_opponent_pool and
572 | random.uniform(0, 1.0) < self._model_cache_prob):
573 | self._model_cache.append(model_params)
574 | self._latest_model = model_params
575 |
576 | def _update_opponent(self):
577 | if (random.uniform(0, 1.0) < self._prob_latest_opponent or
578 | len(self._model_cache) == 0):
579 | self._oppo_model.load_params(self._latest_model)
580 | tprint("Opponent updated with the current model.")
581 | else:
582 | model_params = random.choice(self._model_cache)
583 | self._oppo_model.load_params(model_params)
584 | tprint("Opponent updated with the previous model. %d models cached." %
585 | len(self._model_cache))
586 |
587 |
588 | def constfn(val):
589 | def f(_):
590 | return val
591 | return f
592 |
593 |
594 | def safemean(xs):
595 | return np.nan if len(xs) == 0 else np.mean(xs)
596 |
597 |
598 | def transform_tuple(x, transformer):
599 | if isinstance(x, tuple):
600 | return tuple(transformer(a) for a in x)
601 | else:
602 | return transformer(x)
603 |
--------------------------------------------------------------------------------