├── tests ├── __init__.py ├── test_dataset.py ├── test_behaviour_cloning_learner.py ├── test_common_layers.py ├── test_spatial_decoder.py ├── test_progress_logger.py ├── test_transformer.py ├── test_spatial_encoder.py ├── test_unit_group_encoder.py ├── test_policy.py └── test_sc2_environment.py ├── scripts ├── __init__.py ├── play_agent_vs_bot.py ├── evaluate.py ├── behaviour_cloning.py ├── play_agent_vs_human.py ├── download_replays.py └── build_dataset.py ├── sc2_imitation_learning ├── __init__.py ├── common │ ├── __init__.py │ ├── types.py │ ├── flags.py │ ├── progress_logger.py │ ├── mlp.py │ ├── layers.py │ ├── evaluator.py │ ├── replay_processor.py │ └── utils.py ├── dataset │ ├── __init__.py │ ├── sc2_dataset.py │ └── dataset.py ├── agents │ ├── common │ │ ├── __init__.py │ │ ├── scalar_encoder.py │ │ ├── unit_group_encoder.py │ │ ├── spatial_decoder.py │ │ └── spatial_encoder.py │ ├── __init__.py │ └── sc2_feature_layer_agent.py ├── environment │ ├── __init__.py │ └── environment.py └── behaviour_cloning │ └── __init__.py ├── docs └── sc2_feature_layer_agent_architecture.png ├── configs ├── mini_games │ ├── play.gin │ ├── evaluate.gin │ ├── build_dataset.gin │ ├── behaviour_cloning.gin │ ├── environment.gin │ └── agents │ │ └── sc2_feature_layer_agent.gin └── 1v1 │ ├── build_dataset.gin │ ├── play_agent_vs_bot.gin │ ├── play_agent_vs_human.gin │ ├── behaviour_cloning.gin │ ├── behaviour_cloning_small.gin │ ├── behaviour_cloning_single_gpu.gin │ ├── environment.gin │ └── evaluate.gin ├── requirements.txt ├── LICENSE ├── .gitignore └── README.md /tests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /scripts/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /sc2_imitation_learning/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /sc2_imitation_learning/common/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /sc2_imitation_learning/dataset/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /sc2_imitation_learning/agents/common/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /sc2_imitation_learning/environment/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /sc2_imitation_learning/behaviour_cloning/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /docs/sc2_feature_layer_agent_architecture.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chscheller/sc2_imitation_learning/HEAD/docs/sc2_feature_layer_agent_architecture.png -------------------------------------------------------------------------------- /configs/mini_games/play.gin: -------------------------------------------------------------------------------- 1 | include 'configs/mini_games/environment.gin' 2 | 3 | 4 | # Play config 5 | # ---------------------------------------------------------------------------- 6 | 7 | play.env_fn=@CollectMineralsAndGas/SC2SingleAgentEnv 8 | play.num_episodes = 1 9 | 10 | CollectMineralsAndGas/SC2SingleAgentEnv.map_name = 'CollectMineralsAndGas' 11 | -------------------------------------------------------------------------------- /configs/mini_games/evaluate.gin: -------------------------------------------------------------------------------- 1 | include 'configs/mini_games/environment.gin' 2 | 3 | evaluate.envs = [ 4 | @CollectMineralsAndGas/SC2SingleAgentEnv 5 | ] 6 | evaluate.num_episodes = 100 7 | evaluate.random_seed = 42 8 | evaluate.num_evaluators = 2 9 | 10 | SC2SingleAgentEnv.save_replay_episodes = 1 11 | 12 | CollectMineralsAndGas/SC2SingleAgentEnv.map_name = 'CollectMineralsAndGas' -------------------------------------------------------------------------------- /configs/mini_games/build_dataset.gin: -------------------------------------------------------------------------------- 1 | include 'configs/mini_games/environment.gin' 2 | 3 | ProcessReplay.interface_config = %INTERFACE_CONFIG 4 | ProcessReplay.action_space = %ACTION_SPACE 5 | ProcessReplay.observation_space = %OBSERVATION_SPACE 6 | ProcessReplay.sc2_version = '4.7.1' 7 | 8 | StoreReplay.action_space = %ACTION_SPACE 9 | StoreReplay.observation_space = %OBSERVATION_SPACE 10 | -------------------------------------------------------------------------------- /sc2_imitation_learning/common/types.py: -------------------------------------------------------------------------------- 1 | from typing import Union, Iterable, Mapping, Text 2 | 3 | from sonnet.src.types import ShapeLike 4 | from tensorflow import DType, TensorSpec 5 | 6 | ShapeNest = Union[ShapeLike, Iterable['ShapeNest'], Mapping[Text, 'ShapeNest'], ] 7 | DTypeNest = Union[DType, Iterable['DTypeNest'], Mapping[Text, 'DTypeNest'], ] 8 | TensorSpecNest = Union[TensorSpec, Iterable['TensorSpecNest'], Mapping[Text, 'TensorSpecNest'], ] 9 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==0.11.0 2 | dm-sonnet==2.0.0 3 | dm-tree==0.1.5 4 | git+git://github.com/google/gin-config.git@master#egg=gin-config 5 | gym==0.17.3 6 | h5py==2.10.0 7 | mpyq==0.2.5 8 | numpy==1.18.5 9 | pypeln==0.4.7 10 | git+git://github.com/metataro/pysc2.git@master#egg=PySC2 11 | PyYAML==5.4 12 | s2clientprotocol==5.0.2.81102.0 13 | s2protocol==5.0.3.81433.0 14 | sc2reader==1.6.0 15 | scipy==1.5.3 16 | tensorboard==2.3.0 17 | tensorflow==2.3.1 18 | tensorflow-addons==0.11.2 19 | tensorflow-probability==0.11.1 20 | tqdm==4.51.0 21 | wandb==0.10.12 22 | -------------------------------------------------------------------------------- /sc2_imitation_learning/common/flags.py: -------------------------------------------------------------------------------- 1 | from absl import flags 2 | from absl.flags import ListParser, CsvListSerializer 3 | 4 | 5 | class IntListParser(ListParser): 6 | def parse(self, argument): 7 | parsed_list = super().parse(argument) 8 | return [int(x) for x in parsed_list] 9 | 10 | 11 | # noinspection PyPep8Naming 12 | def DEFINE_int_list(name, default, help, flag_values=flags.FLAGS, **args): 13 | parser = IntListParser() 14 | serializer = CsvListSerializer(',') 15 | flags.DEFINE(parser, name, default, help, flag_values, serializer, **args) 16 | -------------------------------------------------------------------------------- /configs/1v1/build_dataset.gin: -------------------------------------------------------------------------------- 1 | include 'configs/1v1/environment.gin' 2 | 3 | FilterReplay.min_duration = 60. 4 | FilterReplay.min_mmr = 3500 5 | FilterReplay.min_apm = 60 6 | FilterReplay.observed_player_races = [1] # 1=Terran, 2=Zerg, 3=Protoss 7 | FilterReplay.opponent_player_races = [1] # 1=Terran, 2=Zerg, 3=Protoss 8 | FilterReplay.wins_only = True 9 | 10 | ProcessReplay.interface_config = %INTERFACE_CONFIG 11 | ProcessReplay.action_space = %ACTION_SPACE 12 | ProcessReplay.observation_space = %OBSERVATION_SPACE 13 | ProcessReplay.sc2_version = '4.7.1' 14 | 15 | StoreReplay.action_space = %ACTION_SPACE 16 | StoreReplay.observation_space = %OBSERVATION_SPACE 17 | -------------------------------------------------------------------------------- /tests/test_dataset.py: -------------------------------------------------------------------------------- 1 | from unittest import TestCase 2 | 3 | from sc2_imitation_learning.dataset.dataset import EpisodeIterator, EpisodeSlice 4 | 5 | 6 | class EpisodeIteratorTest(TestCase): 7 | 8 | def test_iterate(self): 9 | it = EpisodeIterator(episode_id=1, episode_path='test', episode_length=3, sequence_length=2) 10 | slices = [next(it) for _ in range(4)] 11 | self.assertTrue(slices == [ 12 | EpisodeSlice(episode_id=1, episode_path='test', start=0, length=2, wrap_at_end=True), 13 | EpisodeSlice(episode_id=1, episode_path='test', start=2, length=2, wrap_at_end=True), 14 | EpisodeSlice(episode_id=1, episode_path='test', start=1, length=2, wrap_at_end=True), 15 | EpisodeSlice(episode_id=1, episode_path='test', start=0, length=2, wrap_at_end=True), 16 | ]) 17 | -------------------------------------------------------------------------------- /configs/1v1/play_agent_vs_bot.gin: -------------------------------------------------------------------------------- 1 | include 'configs/1v1/environment.gin' 2 | 3 | 4 | # Play config 5 | # ---------------------------------------------------------------------------- 6 | 7 | play.env_fn=@KairosJunction_tvt_very_easy/SC2SingleAgentEnv 8 | play.num_episodes = 1 9 | 10 | KairosJunction_tvt_very_easy/SC2SingleAgentEnv.map_name = 'KairosJunction' 11 | KairosJunction_tvt_very_easy/SC2SingleAgentEnv.agent_race = 'terran' 12 | KairosJunction_tvt_very_easy/SC2SingleAgentEnv.bot_race = 'terran' 13 | KairosJunction_tvt_very_easy/SC2SingleAgentEnv.bot_difficulty = 'very_easy' 14 | KairosJunction_tvt_very_easy/SC2SingleAgentEnv.bot_build = 'random' 15 | KairosJunction_tvt_very_easy/SC2SingleAgentEnv.visualize = False 16 | KairosJunction_tvt_very_easy/SC2SingleAgentEnv.save_replay_episodes = 0 17 | KairosJunction_tvt_very_easy/SC2SingleAgentEnv.replay_dir = None 18 | KairosJunction_tvt_very_easy/SC2SingleAgentEnv.replay_prefix = None 19 | KairosJunction_tvt_very_easy/SC2SingleAgentEnv.game_steps_per_episode = 0 20 | KairosJunction_tvt_very_easy/SC2SingleAgentEnv.random_seed = None 21 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Christian Scheller 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /configs/1v1/play_agent_vs_human.gin: -------------------------------------------------------------------------------- 1 | include 'configs/1v1/environment.gin' 2 | 3 | 4 | # Host config 5 | # ---------------------------------------------------------------------------- 6 | 7 | 8 | # Human config 9 | # ---------------------------------------------------------------------------- 10 | human.map_name = 'KairosJunction' 11 | human.render = None 12 | human.remote = None 13 | human.host = '127.0.0.1' 14 | human.config_port = 14380 15 | human.realtime = False 16 | human.fps = 22.4 17 | human.rgb_screen_size = (64,64) 18 | human.rgb_minimap_size = (64,64) 19 | human.feature_screen_size = None 20 | human.feature_minimap_size = None 21 | human.race = 'terran' 22 | human.player_name = 'Human' 23 | 24 | # Agent config 25 | # ---------------------------------------------------------------------------- 26 | 27 | agent.env_fn=@SC2LanEnv 28 | 29 | SC2LanEnv.host = '127.0.0.1' 30 | SC2LanEnv.config_port = 14380 31 | SC2LanEnv.interface_config = %INTERFACE_CONFIG 32 | SC2LanEnv.observation_space = %OBSERVATION_SPACE 33 | SC2LanEnv.action_space = %ACTION_SPACE 34 | SC2LanEnv.agent_race = 'terran' 35 | SC2LanEnv.agent_name = 'Hambbe' 36 | SC2LanEnv.visualize = False 37 | SC2LanEnv.realtime = False 38 | SC2LanEnv.replay_dir = None 39 | SC2LanEnv.replay_prefix = None 40 | -------------------------------------------------------------------------------- /tests/test_behaviour_cloning_learner.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | 4 | from sc2_imitation_learning.behaviour_cloning.learner import compute_correct_predictions, compute_neg_log_probs 5 | 6 | 7 | class Test(tf.test.TestCase): 8 | def test_compute_correct_predictions(self): 9 | targets = np.asarray([0, -1, 1, -1]) 10 | predictions = np.asarray([0, -1, 0, 0]) 11 | correct_predictions, total_predictions = compute_correct_predictions(targets, predictions) 12 | self.assertEqual(correct_predictions, 1) 13 | self.assertEqual(total_predictions, 2) 14 | 15 | def test_compute_neg_log_probs(self): 16 | # test without masked labels 17 | labels = np.asarray([0, 1]) 18 | logits = np.asarray([[0.5, 1.5], [-1.0, 2.0]]) 19 | log_probs = tf.math.log_softmax(logits, axis=-1) 20 | label_mask_value = -1 21 | neg_log_probs = compute_neg_log_probs(labels, logits, label_mask_value) 22 | 23 | self.assertAllClose(neg_log_probs, [-log_probs[0, labels[0]], -log_probs[1, labels[1]]]) 24 | 25 | # test with masked labels 26 | labels = np.asarray([0, -1]) 27 | logits = np.asarray([[0.5, 1.5], [-1.0, 2.0]]) 28 | log_probs = tf.math.log_softmax(logits, axis=-1) 29 | label_mask_value = -1 30 | neg_log_probs = compute_neg_log_probs(labels, logits, label_mask_value) 31 | 32 | self.assertAllClose(neg_log_probs, [-log_probs[0, labels[0]], 0.]) 33 | -------------------------------------------------------------------------------- /configs/mini_games/behaviour_cloning.gin: -------------------------------------------------------------------------------- 1 | include 'configs/mini_games/environment.gin' 2 | include 'configs/mini_games/agents/sc2_feature_layer_agent.gin' 3 | 4 | import sc2_imitation_learning.dataset.sc2_dataset 5 | 6 | 7 | # Train config 8 | # ---------------------------------------------------------------------------- 9 | 10 | train.action_space = %ACTION_SPACE 11 | train.observation_space = %OBSERVATION_SPACE 12 | train.data_loader = @SC2DataLoader() 13 | train.batch_size = 8 14 | train.sequence_length = 64 15 | train.total_train_samples = 2e6 16 | train.l2_regularization = 1e-5 17 | train.update_frequency = 1 18 | train.agent_fn = @sc2_feature_layer_agent.SC2FeatureLayerAgent 19 | train.optimizer_fn = @tf.keras.optimizers.Adam 20 | train.eval_interval = 5e5 21 | train.max_to_keep_checkpoints = 10 22 | train.save_checkpoint_interval = 600 # 10 minutes 23 | train.tensorboard_log_interval = 10 24 | train.console_log_interval = 60 25 | 26 | tf.keras.optimizers.Adam.learning_rate = 1e-4 27 | 28 | SC2DataLoader.path = './data/datasets/mini_games/CollectMineralsAndGas_v2' 29 | SC2DataLoader.action_space = %ACTION_SPACE 30 | SC2DataLoader.observation_space = %OBSERVATION_SPACE 31 | 32 | 33 | 34 | # Eval config 35 | # ---------------------------------------------------------------------------- 36 | 37 | evaluate.envs=[@SC2SingleAgentEnv] 38 | evaluate.num_episodes = 20 39 | evaluate.random_seed = 21 40 | evaluate.num_evaluators = 2 41 | 42 | SC2SingleAgentEnv.map_name = 'CollectMineralsAndGas' 43 | -------------------------------------------------------------------------------- /tests/test_common_layers.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | from sc2_imitation_learning.common.layers import SparseEmbed, MaskedGlobalAveragePooling1D 4 | 5 | 6 | class TestMaskedGlobalAveragePooling2D(tf.test.TestCase): 7 | def test_masked_global_average_pooling_2d(self): 8 | pooling = MaskedGlobalAveragePooling1D(mask_value=0) 9 | 10 | inputs = tf.constant([[ 11 | [0, 0, 0], 12 | [1, 1, 1], 13 | [2, 2, 2], 14 | [3, 3, 3], 15 | ], [ 16 | [0, 0, 0], 17 | [1, 1, 1], 18 | [0, 0, 0], 19 | [0, 0, 0], 20 | ], [ 21 | [0, 0, 0], 22 | [1, 1, 1], 23 | [-1, -1, 0], 24 | [0, 0, 0], 25 | ]], dtype=tf.float32) 26 | 27 | output = pooling(inputs) 28 | 29 | self.assertAllClose(output, [[2, 2, 2], [1, 1, 1], [0, 0, 0.5]]) 30 | 31 | 32 | class TestSparseEmbed(tf.test.TestCase): 33 | def test_sparse_embed(self): 34 | sparse_embed = SparseEmbed([0, 2, 5], embed_dim=3) 35 | 36 | output_1 = sparse_embed([0, 0, 2]) # valid ids 37 | output_2 = sparse_embed([5, 5, 2]) # valid ids 38 | output_3 = sparse_embed([1, 3, 4]) # invalid ids (within bounds) 39 | output_4 = sparse_embed([6, 7, 8]) # invalid ids (out of bounds) 40 | 41 | self.assertEqual(output_1.shape, (3, 3)) 42 | self.assertNotAllEqual(output_1, output_2) 43 | self.assertNotAllEqual(output_1, output_3) 44 | self.assertNotAllEqual(output_2, output_3) 45 | self.assertAllEqual(output_3, output_4) 46 | 47 | with self.assertRaises(ValueError): 48 | sparse_embed2 = SparseEmbed([-1, 0, 1], embed_dim=3) 49 | -------------------------------------------------------------------------------- /sc2_imitation_learning/environment/environment.py: -------------------------------------------------------------------------------- 1 | import collections 2 | from abc import ABC, abstractmethod 3 | from typing import Tuple, Dict, Any 4 | 5 | import tree 6 | 7 | StepOutput = collections.namedtuple('StepOutput', ['reward', 'info', 'done', 'observation']) 8 | 9 | 10 | class Space(ABC): 11 | @property 12 | @abstractmethod 13 | def specs(self) -> Dict: 14 | pass 15 | 16 | def dtypes(self) -> Dict: 17 | return tree.map_structure(lambda s: s.dtype, self.specs) 18 | 19 | def shapes(self, as_tensor_shapes: bool = False) -> Dict: 20 | if as_tensor_shapes: 21 | import tensorflow as tf # load tensorflow lazily 22 | return tree.map_structure(lambda s: tf.TensorShape(list(s.shape)), self.specs) 23 | else: 24 | return tree.map_structure(lambda s: s.shape, self.specs) 25 | 26 | 27 | class ActionSpace(Space): 28 | @abstractmethod 29 | def no_op(self) -> Dict: 30 | pass 31 | 32 | @abstractmethod 33 | def transform(self, action: Dict) -> Tuple[Any, int]: 34 | pass 35 | 36 | @abstractmethod 37 | def transform_back(self, action: Any, step_mul: int) -> Dict: 38 | pass 39 | 40 | 41 | class ObservationSpace(Space): 42 | @abstractmethod 43 | def transform(self, observation: Dict) -> Dict: 44 | pass 45 | 46 | @abstractmethod 47 | def transform_back(self, observation: Dict) -> Dict: 48 | pass 49 | 50 | 51 | class EnvMeta(ABC): 52 | @abstractmethod 53 | def launch(self) -> None: 54 | pass 55 | 56 | @property 57 | @abstractmethod 58 | def level_name(self) -> str: 59 | pass 60 | 61 | @property 62 | @abstractmethod 63 | def action_space(self) -> ActionSpace: 64 | pass 65 | 66 | @property 67 | @abstractmethod 68 | def observation_space(self) -> ObservationSpace: 69 | pass 70 | 71 | -------------------------------------------------------------------------------- /tests/test_spatial_decoder.py: -------------------------------------------------------------------------------- 1 | from unittest import TestCase 2 | import numpy as np 3 | import tensorflow as tf 4 | 5 | from sc2_imitation_learning.agents.common.spatial_decoder import ResSpatialDecoder, FiLMedSpatialDecoder 6 | 7 | 8 | class Test(tf.test.TestCase): 9 | def test_res_spatial_decoder(self): 10 | dec = ResSpatialDecoder(out_channels=64, num_blocks=4) 11 | autoregressive_embedding = tf.constant(np.random.randn(1, 64), dtype=tf.float32) 12 | map_skip = [tf.constant(np.random.randn(1, 8, 8, 64), dtype=tf.float32)] 13 | conv_out = dec(autoregressive_embedding=autoregressive_embedding, map_skip=map_skip) 14 | 15 | self.assertEqual(conv_out.dtype, tf.float32) 16 | self.assertEqual(conv_out.shape.as_list(), [1, 8, 8, 64]) 17 | self.assertEqual(tf.reduce_any(tf.math.is_inf(conv_out)), False) 18 | self.assertEqual(tf.reduce_any(tf.math.is_nan(conv_out)), False) 19 | 20 | def test_filmed_spatial_decoder(self): 21 | dec = FiLMedSpatialDecoder(out_channels=64, num_blocks=4) 22 | autoregressive_embedding = tf.constant(np.random.randn(1, 512), dtype=tf.float32) 23 | map_skip = [ 24 | tf.constant(np.random.randn(1, 8, 8, 64), dtype=tf.float32), 25 | tf.constant(np.random.randn(1, 8, 8, 64), dtype=tf.float32), 26 | tf.constant(np.random.randn(1, 8, 8, 64), dtype=tf.float32), 27 | tf.constant(np.random.randn(1, 8, 8, 64), dtype=tf.float32), 28 | tf.constant(np.random.randn(1, 8, 8, 64), dtype=tf.float32) 29 | ] 30 | conv_out = dec(autoregressive_embedding=autoregressive_embedding, map_skip=map_skip) 31 | 32 | self.assertEqual(conv_out.dtype, tf.float32) 33 | self.assertEqual(conv_out.shape.as_list(), [1, 8, 8, 64]) 34 | self.assertEqual(tf.reduce_any(tf.math.is_inf(conv_out)), False) 35 | self.assertEqual(tf.reduce_any(tf.math.is_nan(conv_out)), False) 36 | -------------------------------------------------------------------------------- /tests/test_progress_logger.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os 3 | import random 4 | import tempfile 5 | import time 6 | from unittest import TestCase 7 | import tensorflow as tf 8 | 9 | from sc2_imitation_learning.common.progress_logger import ConsoleProgressLogger, TensorboardProgressLogger 10 | 11 | 12 | class Test(TestCase): 13 | def test_tensorboard_progress_logger(self): 14 | final_step = 100 15 | 16 | with tempfile.TemporaryDirectory() as log_dir: 17 | summary_writer = tf.summary.create_file_writer(log_dir) 18 | 19 | initial_size = sum(os.path.getsize(f) for f in glob.glob(os.path.join(log_dir, '*')) if os.path.isfile(f)) 20 | 21 | progress_logger = TensorboardProgressLogger( 22 | summary_writer=summary_writer, 23 | logging_interval=1.) 24 | progress_logger.start() 25 | 26 | for i in range(final_step): 27 | progress_logger.log_dict({ 28 | 'loss/loss': 10 * (final_step-i) / final_step, 29 | 'samples_per_second': 5.0 + 10.0 * random.random(), 30 | 'learning_rate': 1e-4 31 | }, tf.constant(i, dtype=tf.int32)) 32 | time.sleep(0.1 * random.random()) 33 | 34 | progress_logger.shutdown() 35 | 36 | final_size = sum(os.path.getsize(f) for f in glob.glob(os.path.join(log_dir, '*')) if os.path.isfile(f)) 37 | 38 | self.assertGreater(final_size, initial_size) 39 | 40 | def test_console_progress_logger(self): 41 | final_step = 200 42 | 43 | progress_logger = ConsoleProgressLogger( 44 | final_step=final_step, 45 | batch_samples=10, 46 | logging_interval=1.) 47 | progress_logger.start() 48 | 49 | for i in range(final_step): 50 | progress_logger.log_dict({ 51 | 'loss/loss': 10 * (final_step-i) / final_step, 52 | 'samples_per_second': 5.0 + 10.0 * random.random(), 53 | 'learning_rate': 1e-4 54 | }, i) 55 | time.sleep(0.1 * random.random()) 56 | 57 | progress_logger.shutdown() 58 | -------------------------------------------------------------------------------- /configs/mini_games/environment.gin: -------------------------------------------------------------------------------- 1 | import sc2_imitation_learning.environment.sc2_environment 2 | 3 | INTERFACE_CONFIG = @interface_config/singleton() 4 | interface_config/singleton.constructor = @SC2InterfaceConfig 5 | SC2InterfaceConfig.dimension_screen = (32, 32) 6 | SC2InterfaceConfig.dimension_minimap = (32, 32) 7 | SC2InterfaceConfig.screen_features = ('visibility_map', 'player_relative', 'unit_type', 'selected', 8 | 'unit_hit_points_ratio', 'unit_energy_ratio', 'unit_density_aa') 9 | SC2InterfaceConfig.minimap_features = ('camera', 'player_relative', 'alerts') 10 | SC2InterfaceConfig.scalar_features = ('player', 'available_actions') 11 | SC2InterfaceConfig.available_actions = None 12 | SC2InterfaceConfig.upgrade_set = None 13 | SC2InterfaceConfig.max_step_mul = 16 14 | SC2InterfaceConfig.max_multi_select = 64 15 | SC2InterfaceConfig.max_cargo = 8 16 | SC2InterfaceConfig.max_build_queue = 8 17 | SC2InterfaceConfig.max_production_queue = 16 18 | 19 | OBSERVATION_SPACE = @observation_space/singleton() 20 | observation_space/singleton.constructor = @SC2ObservationSpace 21 | SC2ObservationSpace.config = %INTERFACE_CONFIG 22 | 23 | ACTION_SPACE = @action_space/singleton() 24 | action_space/singleton.constructor = @SC2ActionSpace 25 | SC2ActionSpace.config = %INTERFACE_CONFIG 26 | 27 | SC2SingleAgentEnv.interface_config = %INTERFACE_CONFIG 28 | SC2SingleAgentEnv.observation_space = %OBSERVATION_SPACE 29 | SC2SingleAgentEnv.action_space = %ACTION_SPACE 30 | SC2SingleAgentEnv.map_name = 'CollectMineralsAndGas' 31 | SC2SingleAgentEnv.battle_net_map = False 32 | SC2SingleAgentEnv.agent_race = 'terran' 33 | SC2SingleAgentEnv.agent_name = 'Hambbe' 34 | SC2SingleAgentEnv.bot_race = 'zerg' 35 | SC2SingleAgentEnv.bot_difficulty = 'easy' 36 | SC2SingleAgentEnv.bot_build = 'random' 37 | SC2SingleAgentEnv.visualize = False 38 | SC2SingleAgentEnv.realtime = False 39 | SC2SingleAgentEnv.save_replay_episodes = 0 40 | SC2SingleAgentEnv.replay_dir = None 41 | SC2SingleAgentEnv.replay_prefix = None 42 | SC2SingleAgentEnv.game_steps_per_episode = None 43 | SC2SingleAgentEnv.score_index = None 44 | SC2SingleAgentEnv.score_multiplier = None 45 | SC2SingleAgentEnv.disable_fog = False 46 | SC2SingleAgentEnv.ensure_available_actions = True 47 | SC2SingleAgentEnv.version = '4.7.1' 48 | SC2SingleAgentEnv.random_seed = None 49 | -------------------------------------------------------------------------------- /configs/1v1/behaviour_cloning.gin: -------------------------------------------------------------------------------- 1 | include 'configs/1v1/environment.gin' 2 | include 'configs/1v1/agents/sc2_feature_layer_agent.gin' 3 | 4 | import sc2_imitation_learning.dataset.sc2_dataset 5 | 6 | 7 | # Train config 8 | # ---------------------------------------------------------------------------- 9 | 10 | train.action_space = %ACTION_SPACE 11 | train.observation_space = %OBSERVATION_SPACE 12 | train.data_loader = @SC2DataLoader() 13 | train.batch_size = 24 14 | train.sequence_length = 64 15 | train.total_train_samples = 1e9 16 | train.l2_regularization = 1e-5 17 | train.update_frequency = 1 18 | train.agent_fn = @sc2_feature_layer_agent.SC2FeatureLayerAgent 19 | train.optimizer_fn = @tf.keras.optimizers.Adam 20 | train.eval_interval = 1e8 21 | train.max_to_keep_checkpoints = 10 22 | train.save_checkpoint_interval = 14400 # 4 hours 23 | train.tensorboard_log_interval = 10 24 | train.console_log_interval = 60 25 | 26 | tf.keras.optimizers.Adam.learning_rate = 1e-4 27 | 28 | SC2DataLoader.path = './data/datasets/v1' 29 | SC2DataLoader.action_space = %ACTION_SPACE 30 | SC2DataLoader.observation_space = %OBSERVATION_SPACE 31 | SC2DataLoader.min_duration = 0. 32 | SC2DataLoader.min_mmr = 0 33 | SC2DataLoader.min_apm = 0 34 | SC2DataLoader.observed_player_races = [1] # 1=Terran, 2=Zerg, 3=Protoss 35 | SC2DataLoader.opponent_player_races = [1] # 1=Terran, 2=Zerg, 3=Protoss 36 | SC2DataLoader.map_names = None # ['Kairos Junction LE'] 37 | 38 | 39 | 40 | # Eval config 41 | # ---------------------------------------------------------------------------- 42 | 43 | evaluate.envs=[ 44 | @KairosJunction_tvt_very_easy/SC2SingleAgentEnv, 45 | @KairosJunction_tvt_medium/SC2SingleAgentEnv, 46 | ] 47 | evaluate.num_episodes = 20 48 | evaluate.random_seed = 21 49 | evaluate.num_evaluators = 4 50 | 51 | SC2SingleAgentEnv.bot_race = 'terran' 52 | SC2SingleAgentEnv.save_replay_episodes = 1 53 | 54 | KairosJunction_tvt_very_easy/SC2SingleAgentEnv.map_name = 'KairosJunction' 55 | KairosJunction_tvt_very_easy/SC2SingleAgentEnv.bot_difficulty = 'very_easy' 56 | 57 | KairosJunction_tvt_easy/SC2SingleAgentEnv.map_name = 'KairosJunction' 58 | KairosJunction_tvt_easy/SC2SingleAgentEnv.bot_difficulty = 'easy' 59 | 60 | KairosJunction_tvt_medium/SC2SingleAgentEnv.map_name = 'KairosJunction' 61 | KairosJunction_tvt_medium/SC2SingleAgentEnv.bot_difficulty = 'medium' 62 | 63 | KairosJunction_tvt_hard/SC2SingleAgentEnv.map_name = 'KairosJunction' 64 | KairosJunction_tvt_hard/SC2SingleAgentEnv.bot_difficulty = 'hard' 65 | -------------------------------------------------------------------------------- /configs/1v1/behaviour_cloning_small.gin: -------------------------------------------------------------------------------- 1 | include 'configs/1v1/environment.gin' 2 | include 'configs/1v1/agents/sc2_feature_layer_agent.gin' 3 | 4 | import sc2_imitation_learning.dataset.sc2_dataset 5 | 6 | 7 | # Train config 8 | # ---------------------------------------------------------------------------- 9 | 10 | train.action_space = %ACTION_SPACE 11 | train.observation_space = %OBSERVATION_SPACE 12 | train.data_loader = @SC2DataLoader() 13 | train.batch_size = 28 14 | train.sequence_length = 64 15 | train.total_train_samples = 1e9 16 | train.l2_regularization = 1e-5 17 | train.update_frequency = 1 18 | train.agent_fn = @sc2_feature_layer_agent.SC2FeatureLayerAgent 19 | train.optimizer_fn = @tf.keras.optimizers.Adam 20 | train.eval_interval = 1e8 21 | train.max_to_keep_checkpoints = 10 22 | train.save_checkpoint_interval = 14400 # 4 hours 23 | train.tensorboard_log_interval = 10 24 | train.console_log_interval = 60 25 | 26 | tf.keras.optimizers.Adam.learning_rate = 1e-4 27 | 28 | SC2DataLoader.path = './data/datasets/v1' 29 | SC2DataLoader.action_space = %ACTION_SPACE 30 | SC2DataLoader.observation_space = %OBSERVATION_SPACE 31 | SC2DataLoader.min_duration = 0. 32 | SC2DataLoader.min_mmr = 0 33 | SC2DataLoader.min_apm = 0 34 | SC2DataLoader.observed_player_races = [1] # 1=Terran, 2=Zerg, 3=Protoss 35 | SC2DataLoader.opponent_player_races = [1] # 1=Terran, 2=Zerg, 3=Protoss 36 | SC2DataLoader.map_names = None # ['Kairos Junction LE'] 37 | 38 | 39 | 40 | # Eval config 41 | # ---------------------------------------------------------------------------- 42 | 43 | evaluate.envs=[ 44 | @KairosJunction_tvt_very_easy/SC2SingleAgentEnv, 45 | @KairosJunction_tvt_medium/SC2SingleAgentEnv, 46 | ] 47 | evaluate.num_episodes = 20 48 | evaluate.random_seed = 21 49 | evaluate.num_evaluators = 1 50 | 51 | SC2SingleAgentEnv.bot_race = 'terran' 52 | SC2SingleAgentEnv.save_replay_episodes = 1 53 | 54 | KairosJunction_tvt_very_easy/SC2SingleAgentEnv.map_name = 'KairosJunction' 55 | KairosJunction_tvt_very_easy/SC2SingleAgentEnv.bot_difficulty = 'very_easy' 56 | 57 | KairosJunction_tvt_easy/SC2SingleAgentEnv.map_name = 'KairosJunction' 58 | KairosJunction_tvt_easy/SC2SingleAgentEnv.bot_difficulty = 'easy' 59 | 60 | KairosJunction_tvt_medium/SC2SingleAgentEnv.map_name = 'KairosJunction' 61 | KairosJunction_tvt_medium/SC2SingleAgentEnv.bot_difficulty = 'medium' 62 | 63 | KairosJunction_tvt_hard/SC2SingleAgentEnv.map_name = 'KairosJunction' 64 | KairosJunction_tvt_hard/SC2SingleAgentEnv.bot_difficulty = 'hard' 65 | -------------------------------------------------------------------------------- /configs/1v1/behaviour_cloning_single_gpu.gin: -------------------------------------------------------------------------------- 1 | include 'configs/1v1/environment.gin' 2 | include 'configs/1v1/agents/sc2_feature_layer_agent.gin' 3 | 4 | import sc2_imitation_learning.dataset.sc2_dataset 5 | 6 | 7 | # Train config 8 | # ---------------------------------------------------------------------------- 9 | 10 | train.action_space = %ACTION_SPACE 11 | train.observation_space = %OBSERVATION_SPACE 12 | train.data_loader = @SC2DataLoader() 13 | train.batch_size = 8 14 | train.sequence_length = 64 15 | train.total_train_samples = 1e9 16 | train.l2_regularization = 1e-5 17 | train.update_frequency = 4 18 | train.agent_fn = @sc2_feature_layer_agent.SC2FeatureLayerAgent 19 | train.optimizer_fn = @tf.keras.optimizers.Adam 20 | train.eval_interval = 1e8 21 | train.max_to_keep_checkpoints = 10 22 | train.save_checkpoint_interval = 14400 # 4 hours 23 | train.tensorboard_log_interval = 10 24 | train.console_log_interval = 60 25 | 26 | tf.keras.optimizers.Adam.learning_rate = 1e-4 27 | 28 | SC2DataLoader.path = './data/datasets/v1' 29 | SC2DataLoader.action_space = %ACTION_SPACE 30 | SC2DataLoader.observation_space = %OBSERVATION_SPACE 31 | SC2DataLoader.min_duration = 0. 32 | SC2DataLoader.min_mmr = 0 33 | SC2DataLoader.min_apm = 0 34 | SC2DataLoader.observed_player_races = [1] # 1=Terran, 2=Zerg, 3=Protoss 35 | SC2DataLoader.opponent_player_races = [1] # 1=Terran, 2=Zerg, 3=Protoss 36 | SC2DataLoader.map_names = None # ['Kairos Junction LE'] 37 | 38 | 39 | 40 | # Eval config 41 | # ---------------------------------------------------------------------------- 42 | 43 | evaluate.envs=[ 44 | @KairosJunction_tvt_very_easy/SC2SingleAgentEnv, 45 | @KairosJunction_tvt_medium/SC2SingleAgentEnv, 46 | ] 47 | evaluate.num_episodes = 20 48 | evaluate.random_seed = 21 49 | evaluate.num_evaluators = 1 50 | 51 | SC2SingleAgentEnv.bot_race = 'terran' 52 | SC2SingleAgentEnv.save_replay_episodes = 1 53 | 54 | KairosJunction_tvt_very_easy/SC2SingleAgentEnv.map_name = 'KairosJunction' 55 | KairosJunction_tvt_very_easy/SC2SingleAgentEnv.bot_difficulty = 'very_easy' 56 | 57 | KairosJunction_tvt_easy/SC2SingleAgentEnv.map_name = 'KairosJunction' 58 | KairosJunction_tvt_easy/SC2SingleAgentEnv.bot_difficulty = 'easy' 59 | 60 | KairosJunction_tvt_medium/SC2SingleAgentEnv.map_name = 'KairosJunction' 61 | KairosJunction_tvt_medium/SC2SingleAgentEnv.bot_difficulty = 'medium' 62 | 63 | KairosJunction_tvt_hard/SC2SingleAgentEnv.map_name = 'KairosJunction' 64 | KairosJunction_tvt_hard/SC2SingleAgentEnv.bot_difficulty = 'hard' 65 | -------------------------------------------------------------------------------- /configs/1v1/environment.gin: -------------------------------------------------------------------------------- 1 | import sc2_imitation_learning.environment.sc2_environment 2 | 3 | INTERFACE_CONFIG = @interface_config/singleton() 4 | interface_config/singleton.constructor = @SC2InterfaceConfig 5 | SC2InterfaceConfig.dimension_screen = (64, 64) 6 | SC2InterfaceConfig.dimension_minimap = (64, 64) 7 | SC2InterfaceConfig.screen_features = ('visibility_map', 'player_relative', 'unit_type', 'selected', 8 | 'unit_hit_points_ratio', 'unit_energy_ratio', 'unit_density_aa') 9 | SC2InterfaceConfig.minimap_features = ('camera', 'player_relative', 'alerts') 10 | SC2InterfaceConfig.scalar_features = ('player', 'home_race_requested', 'away_race_requested', 'upgrades', 11 | 'game_loop', 'available_actions', 'unit_counts', 'build_queue', 12 | 'build_queue_length', 'cargo', 'cargo_length', 'cargo_slots_available', 13 | 'control_groups', 'multi_select', 'multi_select_length', 14 | 'production_queue_length', 'production_queue') 15 | SC2InterfaceConfig.available_actions = None 16 | SC2InterfaceConfig.upgrade_set = None 17 | SC2InterfaceConfig.max_step_mul = 16 18 | SC2InterfaceConfig.max_multi_select = 64 19 | SC2InterfaceConfig.max_cargo = 8 20 | SC2InterfaceConfig.max_build_queue = 8 21 | SC2InterfaceConfig.max_production_queue = 16 22 | 23 | OBSERVATION_SPACE = @observation_space/singleton() 24 | observation_space/singleton.constructor = @SC2ObservationSpace 25 | SC2ObservationSpace.config = %INTERFACE_CONFIG 26 | 27 | ACTION_SPACE = @action_space/singleton() 28 | action_space/singleton.constructor = @SC2ActionSpace 29 | SC2ActionSpace.config = %INTERFACE_CONFIG 30 | 31 | SC2SingleAgentEnv.interface_config = %INTERFACE_CONFIG 32 | SC2SingleAgentEnv.observation_space = %OBSERVATION_SPACE 33 | SC2SingleAgentEnv.action_space = %ACTION_SPACE 34 | SC2SingleAgentEnv.map_name = 'KairosJunction' 35 | SC2SingleAgentEnv.battle_net_map = False 36 | SC2SingleAgentEnv.agent_race = 'terran' 37 | SC2SingleAgentEnv.agent_name = 'Hambbe' 38 | SC2SingleAgentEnv.bot_race = 'zerg' 39 | SC2SingleAgentEnv.bot_difficulty = 'easy' 40 | SC2SingleAgentEnv.bot_build = 'random' 41 | SC2SingleAgentEnv.visualize = False 42 | SC2SingleAgentEnv.realtime = False 43 | SC2SingleAgentEnv.save_replay_episodes = 0 44 | SC2SingleAgentEnv.replay_dir = None 45 | SC2SingleAgentEnv.replay_prefix = None 46 | SC2SingleAgentEnv.game_steps_per_episode = None 47 | SC2SingleAgentEnv.score_index = None 48 | SC2SingleAgentEnv.score_multiplier = None 49 | SC2SingleAgentEnv.disable_fog = False 50 | SC2SingleAgentEnv.ensure_available_actions = True 51 | SC2SingleAgentEnv.version = '4.7.1' 52 | SC2SingleAgentEnv.random_seed = None 53 | -------------------------------------------------------------------------------- /tests/test_transformer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | 4 | from sc2_imitation_learning.common.transformer import SC2EntityTransformerEncoder 5 | 6 | 7 | class TestSC2EntityTransformerEncoder(tf.test.TestCase): 8 | def test_forward(self): 9 | transformer = SC2EntityTransformerEncoder(num_layers=2, model_dim=2, num_heads=2, dff=4) 10 | 11 | entities = tf.constant(np.random.randn(2, 3, 4), dtype=tf.float32) 12 | 13 | embedded_entities = transformer(entities) 14 | 15 | self.assertEqual(embedded_entities.dtype, tf.float32) 16 | self.assertEqual(embedded_entities.shape.as_list(), [2, 3, 2]) 17 | self.assertFalse(tf.reduce_any(tf.math.is_inf(embedded_entities))) 18 | self.assertFalse(tf.reduce_any(tf.math.is_nan(embedded_entities))) 19 | self.assertNotAllClose(embedded_entities, tf.zeros_like(embedded_entities)) 20 | 21 | def test_mask(self): 22 | transformer = SC2EntityTransformerEncoder(num_layers=2, model_dim=2, num_heads=2, dff=4, mask_value=0) 23 | 24 | # no entity masked 25 | entities = tf.constant(np.random.randn(2, 3, 4), dtype=tf.float32) 26 | embedded_entities = transformer(entities) 27 | 28 | self.assertEqual(embedded_entities.dtype, tf.float32) 29 | self.assertEqual(embedded_entities.shape.as_list(), [2, 3, 2]) 30 | self.assertFalse(tf.reduce_any(tf.math.is_inf(embedded_entities))) 31 | self.assertFalse(tf.reduce_any(tf.math.is_nan(embedded_entities))) 32 | self.assertFalse(tf.reduce_any(embedded_entities == 0.)) 33 | 34 | # some entities masked 35 | entities = tf.constant(np.concatenate([np.random.randn(2, 3, 2), np.zeros((2, 3, 2))], axis=-1), dtype=tf.float32) 36 | embedded_entities = transformer(entities) 37 | 38 | self.assertEqual(embedded_entities.dtype, tf.float32) 39 | self.assertEqual(embedded_entities.shape.as_list(), [2, 3, 2]) 40 | self.assertFalse(tf.reduce_any(tf.math.is_inf(embedded_entities))) 41 | self.assertFalse(tf.reduce_any(tf.math.is_nan(embedded_entities))) 42 | self.assertFalse(tf.reduce_any(embedded_entities[:, :, :2] == 0.)) 43 | self.assertTrue(tf.reduce_all(embedded_entities[:, :, 2:] == 0.)) 44 | 45 | # all entities masked 46 | entities = tf.constant(np.zeros((2, 3, 4)), dtype=tf.float32) 47 | embedded_entities = transformer(entities) 48 | 49 | self.assertEqual(embedded_entities.dtype, tf.float32) 50 | self.assertEqual(embedded_entities.shape.as_list(), [2, 3, 2]) 51 | self.assertFalse(tf.reduce_any(tf.math.is_inf(embedded_entities))) 52 | self.assertFalse(tf.reduce_any(tf.math.is_nan(embedded_entities))) 53 | self.assertTrue(tf.reduce_all(embedded_entities == 0.)) 54 | -------------------------------------------------------------------------------- /scripts/play_agent_vs_bot.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import traceback 4 | from typing import Type 5 | 6 | import gin 7 | import tensorflow as tf 8 | from absl import app 9 | from absl import flags 10 | 11 | from sc2_imitation_learning.common.utils import gin_register_external_configurables, make_dummy_action 12 | from sc2_imitation_learning.environment.sc2_environment import SC2SingleAgentEnv 13 | 14 | logging.basicConfig(level=logging.WARNING) 15 | logger = logging.getLogger(__name__) 16 | 17 | flags.DEFINE_string('agent_dir', default=None, help='Path to the directory where the agent is stored.') 18 | flags.DEFINE_multi_string('gin_file', ['configs/1v1/play_agent_vs_bot.gin'], 'List of paths to Gin config files.') 19 | flags.DEFINE_multi_string('gin_param', None, 'List of Gin parameter bindings.') 20 | 21 | FLAGS = flags.FLAGS 22 | 23 | 24 | gin_register_external_configurables() 25 | 26 | 27 | @gin.configurable 28 | def play(env_fn: Type[SC2SingleAgentEnv] = gin.REQUIRED, num_episodes: int = gin.REQUIRED) -> None: 29 | 30 | agent = tf.saved_model.load(FLAGS.agent_dir) 31 | agent_state = agent.initial_state(1) 32 | 33 | env = env_fn() 34 | env.launch() 35 | for episode in range(num_episodes): 36 | episode_reward = 0. 37 | episode_frames = 0 38 | episode_steps = 0 39 | try: 40 | reward, done, observation = 0., False, env.reset() 41 | action = make_dummy_action(env.action_space, num_batch_dims=1) 42 | while not done: 43 | env_outputs = ( 44 | tf.constant([reward], dtype=tf.float32), 45 | tf.constant([episode_steps == 0], dtype=tf.bool), 46 | tf.nest.map_structure(lambda o: tf.constant([o], dtype=tf.dtypes.as_dtype(o.dtype)), observation)) 47 | agent_output, agent_state = agent(action, env_outputs, agent_state) 48 | action = tf.nest.map_structure(lambda t: t.numpy(), agent_output.actions) 49 | reward, _, done, observation = env.step(action) 50 | episode_reward += reward 51 | episode_frames += action['step_mul'] + 1 52 | episode_steps += 1 53 | except Exception as e: 54 | logger.error(f"Failed to play episode {episode} (stacktrace below).") 55 | traceback.print_exc() 56 | finally: 57 | logger.info(f"Episode completed: total reward={episode_reward}, frames={episode_frames}, " 58 | f"steps={episode_steps}") 59 | env.close() 60 | 61 | 62 | def main(_): 63 | gin.parse_config_files_and_bindings(FLAGS.gin_file, FLAGS.gin_param) 64 | play() 65 | 66 | 67 | if __name__ == '__main__': 68 | os.environ['TF_FORCE_GPU_ALLOW_GROWTH'] = 'true' 69 | app.run(main) 70 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | /.idea 2 | /data 3 | /experiments 4 | /venv 5 | /plots 6 | /wandb 7 | 8 | # Byte-compiled / optimized / DLL files 9 | __pycache__/ 10 | *.py[cod] 11 | *$py.class 12 | 13 | # C extensions 14 | *.so 15 | 16 | # Distribution / packaging 17 | .Python 18 | build/ 19 | develop-eggs/ 20 | dist/ 21 | downloads/ 22 | eggs/ 23 | .eggs/ 24 | lib/ 25 | lib64/ 26 | parts/ 27 | sdist/ 28 | var/ 29 | wheels/ 30 | share/python-wheels/ 31 | *.egg-info/ 32 | .installed.cfg 33 | *.egg 34 | MANIFEST 35 | 36 | # PyInstaller 37 | # Usually these files are written by a python script from a template 38 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 39 | *.manifest 40 | *.spec 41 | 42 | # Installer logs 43 | pip-log.txt 44 | pip-delete-this-directory.txt 45 | 46 | # Unit test / coverage reports 47 | htmlcov/ 48 | .tox/ 49 | .nox/ 50 | .coverage 51 | .coverage.* 52 | .cache 53 | nosetests.xml 54 | coverage.xml 55 | *.cover 56 | *.py,cover 57 | .hypothesis/ 58 | .pytest_cache/ 59 | cover/ 60 | 61 | # Translations 62 | *.mo 63 | *.pot 64 | 65 | # Django stuff: 66 | *.log 67 | local_settings.py 68 | db.sqlite3 69 | db.sqlite3-journal 70 | 71 | # Flask stuff: 72 | instance/ 73 | .webassets-cache 74 | 75 | # Scrapy stuff: 76 | .scrapy 77 | 78 | # Sphinx documentation 79 | docs/_build/ 80 | 81 | # PyBuilder 82 | .pybuilder/ 83 | target/ 84 | 85 | # Jupyter Notebook 86 | .ipynb_checkpoints 87 | 88 | # IPython 89 | profile_default/ 90 | ipython_config.py 91 | 92 | # pyenv 93 | # For a library or package, you might want to ignore these files since the code is 94 | # intended to run in multiple environments; otherwise, check them in: 95 | # .python-version 96 | 97 | # pipenv 98 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 99 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 100 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 101 | # install all needed dependencies. 102 | #Pipfile.lock 103 | 104 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 105 | __pypackages__/ 106 | 107 | # Celery stuff 108 | celerybeat-schedule 109 | celerybeat.pid 110 | 111 | # SageMath parsed files 112 | *.sage.py 113 | 114 | # Environments 115 | .env 116 | .venv 117 | env/ 118 | venv/ 119 | ENV/ 120 | env.bak/ 121 | venv.bak/ 122 | 123 | # Spyder project settings 124 | .spyderproject 125 | .spyproject 126 | 127 | # Rope project settings 128 | .ropeproject 129 | 130 | # mkdocs documentation 131 | /site 132 | 133 | # mypy 134 | .mypy_cache/ 135 | .dmypy.json 136 | dmypy.json 137 | 138 | # Pyre type checker 139 | .pyre/ 140 | 141 | # pytype static type analyzer 142 | .pytype/ 143 | 144 | # Cython debug symbols 145 | cython_debug/ 146 | -------------------------------------------------------------------------------- /sc2_imitation_learning/agents/__init__.py: -------------------------------------------------------------------------------- 1 | import collections 2 | from abc import ABC, abstractmethod 3 | from typing import Tuple, Optional, Text 4 | 5 | import sonnet as snt 6 | import tensorflow as tf 7 | import tree 8 | from sonnet.src import types 9 | 10 | from sc2_imitation_learning.environment.environment import ActionSpace, ObservationSpace 11 | 12 | AgentOutput = collections.namedtuple('AgentOutput', ['logits', 'actions', 'values']) 13 | 14 | 15 | class Agent(snt.RNNCore, ABC): 16 | def __call__(self, prev_actions, env_outputs, core_state, unroll=False, teacher_actions=None) -> Tuple[AgentOutput, Tuple]: 17 | if not unroll: 18 | # Add time dimension. 19 | prev_actions, env_outputs = tf.nest.map_structure( 20 | lambda t: tf.expand_dims(t, 0), (prev_actions, env_outputs)) 21 | 22 | outputs, core_state = self._unroll(prev_actions, env_outputs, core_state, teacher_actions) 23 | 24 | if not unroll: 25 | # Remove time dimension. 26 | outputs = tf.nest.map_structure(lambda t: None if t is None else tf.squeeze(t, 0), outputs) 27 | 28 | return outputs, core_state 29 | 30 | @abstractmethod 31 | def _unroll(self, prev_actions, env_outputs, core_state, teacher_actions=None) -> Tuple[AgentOutput, Tuple]: 32 | pass 33 | 34 | 35 | def build_saved_agent(agent: Agent, observation_space: ObservationSpace, action_space: ActionSpace) -> tf.Module: 36 | call_input_signature = [ 37 | tree.map_structure_with_path( 38 | lambda path, s: tf.TensorSpec((None,) + s.shape, s.dtype, name='action/' + '/'.join(path)), 39 | action_space.specs), 40 | ( 41 | tf.TensorSpec((None,), dtype=tf.float32, name='reward'), 42 | tf.TensorSpec((None,), dtype=tf.bool, name='done'), 43 | tree.map_structure_with_path( 44 | lambda path, s: tf.TensorSpec((None,) + s.shape, s.dtype, name='observation/' + '/'.join(path)), 45 | observation_space.specs) 46 | ), 47 | tree.map_structure( 48 | lambda t: tf.TensorSpec((None,) + t.shape[1:], t.dtype, name='agent_state'), agent.initial_state(1)) 49 | ] 50 | 51 | initial_state_input_signature = [ 52 | tf.TensorSpec(shape=(), dtype=tf.int32, name='batch_size'), 53 | ] 54 | 55 | class SavedAgent(tf.Module): 56 | def __init__(self, agent: Agent, name=None): 57 | super().__init__(name) 58 | self._agent = agent 59 | 60 | @tf.function(input_signature=call_input_signature) 61 | def __call__(self, prev_action, env_outputs, agent_state): 62 | return self._agent(prev_action, env_outputs, agent_state) 63 | 64 | @tf.function(input_signature=initial_state_input_signature) 65 | def initial_state(self, batch_size): 66 | return self._agent.initial_state(batch_size) 67 | 68 | return SavedAgent(agent) 69 | -------------------------------------------------------------------------------- /tests/test_spatial_encoder.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | 4 | from sc2_imitation_learning.agents.common.feature_encoder import OneHotEncoder 5 | from sc2_imitation_learning.agents.common.spatial_encoder import ImpalaCNNSpatialEncoder, AlphaStarSpatialEncoder 6 | from sc2_imitation_learning.common.conv import ConvNet2D 7 | 8 | 9 | class Test(tf.test.TestCase): 10 | def test_impala_cnnspatial_encoder(self): 11 | enc = ImpalaCNNSpatialEncoder( 12 | feature_layer_encoders={ 13 | 'player_relative': lambda: OneHotEncoder(depth=5) 14 | }, 15 | input_projection_dim=32, 16 | num_blocks=[2, 2, 2], 17 | output_channels=[32, 64, 64], 18 | max_pool_padding='SAME', 19 | spatial_embedding_size=256 20 | ) 21 | 22 | raw_features = { 23 | 'player_relative': tf.constant(np.random.randint(0, 5, (1, 64, 64), dtype=np.uint16), dtype=tf.uint16) 24 | } 25 | 26 | embedded_spatial, map_skip = enc(raw_features) 27 | 28 | self.assertEqual(embedded_spatial.dtype, tf.float32) 29 | self.assertEqual(embedded_spatial.shape.as_list(), [1, 256]) 30 | self.assertEqual(tf.reduce_any(tf.math.is_inf(embedded_spatial)), False) 31 | self.assertEqual(tf.reduce_any(tf.math.is_nan(embedded_spatial)), False) 32 | 33 | self.assertEqual(len(map_skip), 1) 34 | self.assertEqual(map_skip[0].dtype, tf.float32) 35 | self.assertEqual(map_skip[0].shape.as_list(), [1, 8, 8, 64]) 36 | self.assertEqual(tf.reduce_any(tf.math.is_inf(map_skip[0])), False) 37 | self.assertEqual(tf.reduce_any(tf.math.is_nan(map_skip[0])), False) 38 | 39 | 40 | def test_alpha_star_spatial_encoder(self): 41 | 42 | enc = AlphaStarSpatialEncoder( 43 | feature_layer_encoders={ 44 | 'player_relative': lambda: OneHotEncoder(depth=5) 45 | }, 46 | input_projection_dim=16, 47 | downscale_conv_net=ConvNet2D( 48 | output_channels=[16, 32], 49 | kernel_shapes=[4, 4], 50 | strides=[2, 2], 51 | paddings=['SAME', 'SAME'], 52 | activate_final=True, 53 | ), 54 | res_out_channels=32, 55 | res_num_blocks=4, 56 | res_stride=1, 57 | spatial_embedding_size=256 58 | ) 59 | 60 | raw_features = { 61 | 'player_relative': tf.constant(np.random.randint(0, 5, (1, 64, 64), dtype=np.uint16), dtype=tf.uint16) 62 | } 63 | 64 | embedded_spatial, map_skip = enc(raw_features) 65 | 66 | self.assertEqual(embedded_spatial.dtype, tf.float32) 67 | self.assertEqual(embedded_spatial.shape.as_list(), [1, 256]) 68 | self.assertEqual(tf.reduce_any(tf.math.is_inf(embedded_spatial)), False) 69 | self.assertEqual(tf.reduce_any(tf.math.is_nan(embedded_spatial)), False) 70 | 71 | self.assertEqual(len(map_skip), 5) 72 | for i in range(len(map_skip)): 73 | self.assertEqual(map_skip[i].dtype, tf.float32) 74 | self.assertEqual(map_skip[i].shape.as_list(), [1, 16, 16, 32]) 75 | self.assertEqual(tf.reduce_any(tf.math.is_inf(map_skip[i])), False) 76 | self.assertEqual(tf.reduce_any(tf.math.is_nan(map_skip[i])), False) 77 | -------------------------------------------------------------------------------- /sc2_imitation_learning/agents/common/scalar_encoder.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod, ABC 2 | from typing import Dict, Sequence, NamedTuple, Type 3 | 4 | import gin 5 | import sonnet as snt 6 | import tensorflow as tf 7 | 8 | from sc2_imitation_learning.agents.common.feature_encoder import ActionEncoder, FeatureEncoder 9 | from sc2_imitation_learning.environment.sc2_environment import SC2ActionSpace 10 | 11 | 12 | class ScalarEncoderOutputs(NamedTuple): 13 | embedded_scalar: tf.Tensor 14 | scalar_context: tf.Tensor 15 | 16 | 17 | @gin.register 18 | class ScalarEncoder(snt.Module, ABC): 19 | """ Encoder module for scalar features. """ 20 | 21 | @abstractmethod 22 | def __call__(self, features: Dict[str, tf.Tensor], prev_actions: Dict[str, tf.Tensor]) -> ScalarEncoderOutputs: 23 | """ Applies the specified encodings on features and prev_actions, constructs scalar embedding and 24 | scalar context vectors. 25 | 26 | Args: 27 | features: A Dict with raw scalar features. 28 | prev_actions: A Dict containing the actions of the previous time step. 29 | 30 | Returns: 31 | A namedtuple with: 32 | - embedded_scalar: A scalar embedding vector. 33 | - scalar_context: A scalar context vector. 34 | """ 35 | pass 36 | 37 | 38 | @gin.register 39 | class ConcatScalarEncoder(ScalarEncoder): 40 | """ Concat Encoder module for scalar features. Produces an encoding and a context vector through concatenation. """ 41 | 42 | def __init__(self, 43 | action_space: SC2ActionSpace = gin.REQUIRED, 44 | feature_encoders: Dict[str, Type[FeatureEncoder]] = gin.REQUIRED, 45 | prev_action_encoders: Dict[str, Type[ActionEncoder]] = gin.REQUIRED, 46 | context_feature_names: Sequence[str] = ('home_race_requested', 'away_race_requested', 47 | 'available_actions')): 48 | """ Constructs the encoder module 49 | 50 | Args: 51 | action_space: The action space with the environment. 52 | feature_encoders: A Dict with feature encoders. Keys must correspond to inputs. 53 | prev_action_encoders: A Dict with action encoders. Keys must correspond to action names. 54 | context_feature_names: A List with feature names that should be included in the context vector. 55 | """ 56 | super().__init__() 57 | self._action_space = action_space 58 | self._context_feature_names = context_feature_names 59 | self._embed_features = {key: enc() for key, enc in feature_encoders.items()} 60 | self._embed_actions = {key: enc(action_space.specs[key].n) for key, enc in prev_action_encoders.items()} 61 | 62 | def __call__(self, features: Dict[str, tf.Tensor], prev_actions: Dict[str, tf.Tensor]) -> ScalarEncoderOutputs: 63 | embedded_features = {key: embed(features[key]) for key, embed in self._embed_features.items()} 64 | embedded_actions = {key: embed(prev_actions[key]) for key, embed in self._embed_actions.items()} 65 | scalar_context = tf.concat([embedded_features[key] for key in self._context_feature_names], axis=-1) 66 | embedded_scalar = tf.concat(tf.nest.flatten(embedded_features) + tf.nest.flatten(embedded_actions), axis=-1) 67 | return ScalarEncoderOutputs(embedded_scalar, scalar_context) 68 | -------------------------------------------------------------------------------- /tests/test_unit_group_encoder.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from pysc2.lib.static_data import UNIT_TYPES 3 | 4 | from sc2_imitation_learning.agents.common.feature_encoder import UnitSelectionEncoder 5 | from sc2_imitation_learning.agents.common.unit_group_encoder import mask_unit_group, ConcatAverageUnitGroupEncoder 6 | from sc2_imitation_learning.common.conv import ConvNet1D 7 | 8 | 9 | class UnitGroupEncoderTest(tf.test.TestCase): 10 | def test_mask_unit_group(self): 11 | unit_group = tf.constant([[[0], [1], [2]]]) 12 | unit_group_length = tf.constant([2]) 13 | unit_group_masked = mask_unit_group(unit_group, unit_group_length) 14 | self.assertAllClose(unit_group_masked, [[[0], [1], [0]]]) 15 | 16 | def test_concat_average_unit_group_encoder(self): 17 | enc = ConcatAverageUnitGroupEncoder(embedding_size=16, feature_encoders={ 18 | 'multi_select': lambda: UnitSelectionEncoder(encoder=ConvNet1D( 19 | output_channels=[16], kernel_shapes=[1], strides=[1], paddings=['SAME'], activate_final=True)) 20 | }) 21 | 22 | raw_multi_select = tf.constant([[[ 23 | UNIT_TYPES[i], # unit_type 24 | 0, # player_relative 25 | 100, # health 26 | 0, # shields 27 | 0, # energy 28 | 0, # transport_slots_taken 29 | 0, # build_progress 30 | ] for i in range(3)]], dtype=tf.uint16) 31 | 32 | raw_multi_select_length = tf.constant([2], dtype=tf.uint16) 33 | 34 | embedded_unit_group, unit_group_embeddings = enc({ 35 | 'multi_select': raw_multi_select, 36 | 'multi_select_length': raw_multi_select_length, 37 | }) 38 | 39 | self.assertEqual(embedded_unit_group.dtype, tf.float32) 40 | self.assertEqual(embedded_unit_group.shape.as_list(), [1, 16]) 41 | self.assertEqual(tf.reduce_any(tf.math.is_inf(embedded_unit_group)), False) 42 | self.assertEqual(tf.reduce_any(tf.math.is_nan(embedded_unit_group)), False) 43 | 44 | self.assertEqual(unit_group_embeddings['multi_select'].dtype, tf.float32) 45 | self.assertEqual(unit_group_embeddings['multi_select'].shape.as_list(), [1, 3, 16]) 46 | self.assertEqual(tf.reduce_any(tf.math.is_inf(unit_group_embeddings['multi_select'])), False) 47 | self.assertEqual(tf.reduce_any(tf.math.is_nan(unit_group_embeddings['multi_select'])), False) 48 | 49 | raw_multi_select = tf.constant([[[ 50 | UNIT_TYPES[i], # unit_type 51 | 0, # player_relative 52 | 100, # health 53 | 0, # shields 54 | 0, # energy 55 | 0, # transport_slots_taken 56 | 0, # build_progress 57 | ] for i in range(3)]], dtype=tf.uint16) 58 | 59 | raw_multi_select_length = tf.constant([0], dtype=tf.uint16) 60 | 61 | embedded_unit_group, unit_group_embeddings = enc({ 62 | 'multi_select': raw_multi_select, 63 | 'multi_select_length': raw_multi_select_length, 64 | }) 65 | 66 | self.assertEqual(embedded_unit_group.dtype, tf.float32) 67 | self.assertEqual(embedded_unit_group.shape.as_list(), [1, 16]) 68 | self.assertEqual(tf.reduce_any(tf.math.is_inf(embedded_unit_group)), False) 69 | self.assertEqual(tf.reduce_any(tf.math.is_nan(embedded_unit_group)), False) 70 | self.assertAllClose(embedded_unit_group, tf.zeros_like(embedded_unit_group)) 71 | 72 | self.assertEqual(unit_group_embeddings['multi_select'].dtype, tf.float32) 73 | self.assertEqual(unit_group_embeddings['multi_select'].shape.as_list(), [1, 3, 16]) 74 | self.assertEqual(tf.reduce_any(tf.math.is_inf(unit_group_embeddings['multi_select'])), False) 75 | self.assertEqual(tf.reduce_any(tf.math.is_nan(unit_group_embeddings['multi_select'])), False) -------------------------------------------------------------------------------- /sc2_imitation_learning/agents/common/unit_group_encoder.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from typing import Dict, NamedTuple, Type 3 | 4 | import gin 5 | import sonnet as snt 6 | import tensorflow as tf 7 | 8 | from sc2_imitation_learning.agents.common.feature_encoder import FeatureEncoder 9 | from sc2_imitation_learning.common.layers import MaskedGlobalAveragePooling1D 10 | 11 | 12 | class UnitGroupsEncoderOutputs(NamedTuple): 13 | embedded_unit_group: tf.Tensor 14 | unit_group_embeddings: Dict[str, tf.Tensor] 15 | 16 | 17 | def mask_unit_group(unit_group: tf.Tensor, unit_group_length: tf.Tensor, mask_value=0) -> tf.Tensor: 18 | """ Masks unit groups according to their length. 19 | 20 | Args: 21 | unit_group: A tensor of rank 3 with a sequence of unit feature vectors. 22 | unit_group_length: The length of the unit group (assumes all unit feature vectors upfront). 23 | mask_value: The mask value. 24 | 25 | Returns: 26 | A tensor of rank 3 where indices beyond unit_group_length are zero-masked. 27 | 28 | """ 29 | if unit_group_length is not None: 30 | # get rid of last dimensions with size 1 31 | if unit_group.shape.rank - unit_group_length.shape.rank < 2: 32 | unit_group_length = tf.squeeze(unit_group_length, axis=-1) # B 33 | 34 | # mask with mask_value 35 | unit_group_mask = tf.sequence_mask( 36 | tf.cast(unit_group_length, tf.int32), maxlen=unit_group.shape[1], dtype=unit_group.dtype) # B x T 37 | unit_group_mask = tf.expand_dims(unit_group_mask, axis=-1) 38 | unit_group *= unit_group_mask 39 | if mask_value != 0: 40 | mask_value = tf.convert_to_tensor(mask_value) 41 | unit_group = tf.cast(unit_group, mask_value.dtype) 42 | unit_group_mask = tf.cast(unit_group_mask, mask_value.dtype) 43 | unit_group += (1 - unit_group_mask) * mask_value 44 | return unit_group 45 | 46 | 47 | class UnitGroupEncoder(snt.Module, ABC): 48 | """ Encoder module for unit group features. """ 49 | 50 | @abstractmethod 51 | def __call__(self, features: Dict[str, tf.Tensor]) -> UnitGroupsEncoderOutputs: 52 | """ Encodes the unit group features 53 | 54 | Args: 55 | features: A Dict with raw scalar features. 56 | 57 | Returns: 58 | A namedtuple with: 59 | - embedded_unit_group: An embedded unit group vector 60 | - unit_group_embeddings: A Dict of unit group embeddings. 61 | """ 62 | pass 63 | 64 | 65 | @gin.register 66 | class ConcatAverageUnitGroupEncoder(UnitGroupEncoder): 67 | """ Unit group encoder module that encodes unit groups by concatenating their average embedding vectors """ 68 | def __init__(self, 69 | embedding_size: int = gin.REQUIRED, 70 | feature_encoders: Dict[str, Type[FeatureEncoder]] = gin.REQUIRED): 71 | super().__init__() 72 | self._feature_encoders = {key: enc() for key, enc in feature_encoders.items()} 73 | self._unit_group_embed = { 74 | key: snt.Sequential([ 75 | 76 | MaskedGlobalAveragePooling1D(mask_value=0), # assume encoded unit group are zero masked before. 77 | snt.Linear(output_size=embedding_size), 78 | tf.nn.relu 79 | ]) 80 | for key in self._feature_encoders.keys() 81 | } 82 | 83 | def __call__(self, features: Dict[str, tf.Tensor]) -> UnitGroupsEncoderOutputs: 84 | unit_group_embeddings = { 85 | key: enc(mask_unit_group(features[key], features.get(f'{key}_length', None), -1)) 86 | for key, enc in self._feature_encoders.items()} 87 | 88 | embedded_unit_groups = { 89 | key: emb(unit_group_embeddings[key]) 90 | for key, emb in self._unit_group_embed.items() 91 | } 92 | embedded_unit_groups = tf.concat(tf.nest.flatten(embedded_unit_groups), axis=-1) 93 | 94 | return UnitGroupsEncoderOutputs( 95 | embedded_unit_group=embedded_unit_groups, unit_group_embeddings=unit_group_embeddings) 96 | -------------------------------------------------------------------------------- /scripts/evaluate.py: -------------------------------------------------------------------------------- 1 | import concurrent.futures 2 | import datetime 3 | import functools 4 | import json 5 | import logging 6 | import multiprocessing 7 | import os 8 | from typing import Type, List 9 | 10 | import gin 11 | import tensorflow as tf 12 | import wandb 13 | from absl import app, flags 14 | 15 | from sc2_imitation_learning.agents import Agent 16 | from sc2_imitation_learning.common.evaluator import evaluate_on_multiple_envs 17 | from sc2_imitation_learning.common.utils import compute_stats_dict, gin_register_external_configurables 18 | from sc2_imitation_learning.environment.sc2_environment import SC2SingleAgentEnv 19 | 20 | logging.basicConfig(level=logging.INFO) 21 | logger = logging.getLogger(__file__) 22 | 23 | flags.DEFINE_string('logdir', './experiments/2021-03-21_03-00-19', 'Experiment directory.') 24 | flags.DEFINE_multi_string('gin_file', ['./configs/1v1/evaluate.gin'], 'List of paths to Gin config files.') 25 | flags.DEFINE_multi_string('gin_param', None, 'List of Gin parameter bindings.') 26 | 27 | # logger config 28 | flags.DEFINE_bool('wandb_logging_enabled', False, 'If wandb logging should be enabled.') 29 | flags.DEFINE_string('wandb_project', 'sc2-il', 'Name of the wandb project.') 30 | flags.DEFINE_string('wandb_entity', None, 'Name of the wandb entity.') 31 | flags.DEFINE_list('wandb_tags', ['behaviour_cloning'], 'List of wandb tags.') 32 | 33 | FLAGS = flags.FLAGS 34 | 35 | 36 | gin_register_external_configurables() 37 | 38 | 39 | def agent_fn(experiment_path, *args, **kwargs) -> Agent: 40 | return tf.saved_model.load(os.path.join(experiment_path, 'saved_model')) 41 | 42 | 43 | @gin.configurable 44 | def evaluate(experiment_path: str, 45 | envs: List[Type[SC2SingleAgentEnv]] = gin.REQUIRED, 46 | num_episodes: int = gin.REQUIRED, 47 | random_seed: int = gin.REQUIRED, 48 | num_evaluators: int = gin.REQUIRED): 49 | 50 | with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor: 51 | available_gpus = executor.submit(tf.config.list_physical_devices, 'GPU').result() 52 | 53 | episode_stats = evaluate_on_multiple_envs( 54 | agent_fn=functools.partial(agent_fn, experiment_path), 55 | envs=envs, 56 | num_episodes=num_episodes, 57 | num_evaluators=num_evaluators, 58 | random_seed=random_seed, 59 | replay_dir=os.path.abspath(os.path.join(FLAGS.logdir, 'replays', 'eval')), 60 | available_gpus=available_gpus) 61 | 62 | return { 63 | matchup: { 64 | 'num_episodes': len(stats), 65 | 'episode_frames': compute_stats_dict([s.num_frames for s in stats]), 66 | 'episode_steps': compute_stats_dict([s.num_steps for s in stats]), 67 | 'episode_reward': compute_stats_dict([s.reward for s in stats]), 68 | } 69 | for matchup, stats in episode_stats.items() 70 | } 71 | 72 | 73 | def main(argv): 74 | gin.parse_config_files_and_bindings(FLAGS.gin_file, FLAGS.gin_param) 75 | 76 | if not os.path.exists(FLAGS.logdir): 77 | raise ValueError(f"Logdir '{FLAGS.logdir}' does not exist exists.") 78 | 79 | if FLAGS.wandb_logging_enabled: 80 | experiment_name = os.path.basename(FLAGS.logdir.rstrip("/")) 81 | job_type = 'test' 82 | wandb.init( 83 | id=f"{experiment_name}-{job_type}", 84 | name=f"{experiment_name}-{job_type}", 85 | group=experiment_name, 86 | job_type=job_type, 87 | project=FLAGS.wandb_project, 88 | entity=FLAGS.wandb_entity, 89 | tags=FLAGS.wandb_tags, 90 | resume="allow") 91 | wandb.tensorboard.patch(save=False, tensorboardX=False) 92 | 93 | eval_outcome = evaluate(FLAGS.logdir) 94 | 95 | with open(os.path.join(FLAGS.logdir, f'eval_outcome_{datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")}.json'), 'w') as f: 96 | json.dump(eval_outcome, f, indent=4) 97 | 98 | if FLAGS.wandb_logging_enabled: 99 | wandb.log(eval_outcome) 100 | 101 | 102 | if __name__ == '__main__': 103 | os.environ['TF_FORCE_GPU_ALLOW_GROWTH'] = 'true' 104 | multiprocessing.set_start_method('spawn') 105 | app.run(main) 106 | -------------------------------------------------------------------------------- /sc2_imitation_learning/common/progress_logger.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import threading 3 | import timeit 4 | from abc import abstractmethod 5 | from collections import defaultdict 6 | from typing import Optional 7 | 8 | import tensorflow as tf 9 | import numpy as np 10 | 11 | from sc2_imitation_learning.common.utils import flatten_nested_dicts 12 | 13 | 14 | class ProgressLogger(object): 15 | def __init__(self, 16 | logging_interval: float = 10., 17 | initial_step: int = 0) -> None: 18 | super().__init__() 19 | self._logging_interval = logging_interval 20 | self._initial_step = initial_step 21 | self._step = self._initial_step 22 | self._lock = threading.Lock() 23 | self._terminated = threading.Event() 24 | self._logger_thread = threading.Thread(target=self._run) 25 | self._logs = defaultdict(list) 26 | 27 | def start(self): 28 | self._logger_thread.start() 29 | 30 | def shutdown(self, block: bool = True): 31 | assert self._logger_thread.is_alive() 32 | self._terminated.set() 33 | if block: 34 | self._logger_thread.join() 35 | 36 | def log_dict(self, values: dict, step: int): 37 | assert step >= self._step 38 | flattened = flatten_nested_dicts(values) 39 | with self._lock: 40 | for key, value in flattened.items(): 41 | self._logs[key].append(np.copy(value)) 42 | self._step = step 43 | 44 | @abstractmethod 45 | def _log(self, values: dict, step: int): 46 | pass 47 | 48 | def _run(self): 49 | last_log, last_step = timeit.default_timer(), self._initial_step 50 | while not self._terminated.isSet(): 51 | with self._lock: 52 | # only lock critical section 53 | if self._step != last_step: 54 | logs = {k: np.mean(v) for k, v in self._logs.items() if len(v) > 0} 55 | step = self._step 56 | self._logs = defaultdict(list) 57 | else: 58 | logs = None 59 | if logs is not None: 60 | self._log(logs, step) 61 | last_step = step 62 | now = timeit.default_timer() 63 | elapsed = now - last_log 64 | self._terminated.wait(max(0., self._logging_interval - elapsed)) 65 | last_log = timeit.default_timer() 66 | 67 | 68 | class TensorboardProgressLogger(ProgressLogger): 69 | def __init__(self, 70 | summary_writer: tf.summary.SummaryWriter, 71 | logging_interval: float = 10., 72 | initial_step: int = 0) -> None: 73 | super().__init__(logging_interval, initial_step) 74 | self._summary_writer = summary_writer 75 | 76 | def _log(self, values: dict, step: int): 77 | with self._summary_writer.as_default(): 78 | for key, value in values.items(): 79 | tf.summary.scalar(key, value, step=np.int64(step)) 80 | 81 | 82 | class ConsoleProgressLogger(ProgressLogger): 83 | def __init__(self, 84 | final_step: int, 85 | batch_samples: int, 86 | logging_interval: float = 10., 87 | initial_step: int = 0, 88 | start_time: Optional[float] = None) -> None: 89 | super().__init__(logging_interval, initial_step) 90 | self._final_step = final_step 91 | self._batch_samples = batch_samples 92 | self._start_time = timeit.default_timer() if start_time is None else start_time 93 | self._last_log_time = timeit.default_timer() 94 | 95 | def _log(self, values: dict, step: int): 96 | print(f"Train | " 97 | f"step={step} | " 98 | f"samples={self._batch_samples * step} | " 99 | f"progress={round(100 * step / float(self._final_step), 1):5.1f}% | " 100 | f"time={datetime.timedelta(seconds=round(timeit.default_timer() - self._start_time))} | " 101 | f"loss={values['loss/loss']:.3f} | " 102 | f"samples/sec={values['samples_per_second']:.2f} | " 103 | f"lr={values['learning_rate']:.3e}", 104 | flush=True) 105 | -------------------------------------------------------------------------------- /sc2_imitation_learning/common/mlp.py: -------------------------------------------------------------------------------- 1 | from typing import Sequence, Union, Text, Callable, Optional 2 | 3 | import gin 4 | import sonnet as snt 5 | import tensorflow as tf 6 | 7 | 8 | @gin.register() 9 | class MLP(snt.Module): 10 | def __init__(self, 11 | output_sizes: Sequence[int], 12 | activation: Union[Callable[[tf.Tensor], tf.Tensor], Text] = tf.nn.relu, 13 | with_layer_norm: bool = False, 14 | activate_final: bool = False, 15 | name: Optional[Text] = None): 16 | super().__init__(name) 17 | self._output_sizes = output_sizes 18 | self._activate_final = activate_final 19 | self._with_layer_norm = with_layer_norm 20 | if isinstance(activation, str): 21 | self._activation = tf.keras.activations.deserialize(activation) 22 | else: 23 | self._activation = activation 24 | self._layers = [] 25 | if self._with_layer_norm: 26 | self._layer_norms = [] 27 | for output_size in self._output_sizes: 28 | self._layers.append(snt.Linear(output_size=output_size)) 29 | if self._with_layer_norm: 30 | self._layer_norms.append(snt.LayerNorm(axis=-1, create_scale=True, create_offset=True)) 31 | 32 | def __call__(self, inputs: tf.Tensor) -> tf.Tensor: 33 | num_layers = len(self._layers) 34 | mlp_out = inputs 35 | for i, layer in enumerate(self._layers): 36 | mlp_out = layer(mlp_out) 37 | if i < (num_layers - 1) or self._activate_final: 38 | if self._with_layer_norm: 39 | mlp_out = self._layer_norms[i](mlp_out) 40 | mlp_out = self._activation(mlp_out) 41 | return mlp_out 42 | 43 | 44 | class ResMLPBlock(snt.Module): 45 | def __init__(self, 46 | output_size: int, 47 | with_projection: bool = False, 48 | with_layer_norm: bool = False, 49 | name: Optional[Text] = None): 50 | super().__init__(name) 51 | if with_projection: 52 | self._linear_proj = snt.Linear(output_size=output_size) 53 | else: 54 | self._linear_proj = None 55 | self._linear = snt.Linear(output_size=output_size) 56 | if with_layer_norm: 57 | self._layer_norm = snt.LayerNorm(axis=-1, create_scale=True, create_offset=True) 58 | 59 | def __call__(self, inputs: tf.Tensor) -> tf.Tensor: 60 | block_out = inputs 61 | if self._layer_norm is not None: 62 | block_out = self._layer_norm(block_out) 63 | block_out = tf.nn.relu(block_out) 64 | if self._linear_proj is not None: 65 | shortcut = self._linear_proj(block_out) 66 | else: 67 | shortcut = inputs 68 | block_out = self._linear(block_out) 69 | block_out = block_out + shortcut 70 | return block_out 71 | 72 | 73 | @gin.register() 74 | class ResMLP(snt.Module): 75 | def __init__(self, 76 | output_size: int, 77 | num_blocks: int, 78 | with_projection: bool = False, 79 | with_layer_norm: bool = False, 80 | activate_final: bool = False, 81 | name: Optional[Text] = None): 82 | super().__init__(name) 83 | self._output_size = output_size 84 | self._num_blocks = num_blocks 85 | self._with_projection = with_projection 86 | self._with_layer_norm = with_layer_norm 87 | self._activate_final = activate_final 88 | layers = [] 89 | for i in range(self._num_blocks): 90 | layers.append(ResMLPBlock( 91 | output_size=self._output_size, with_projection=self._with_projection and i == 0, 92 | with_layer_norm=self._with_layer_norm)) 93 | if self._activate_final: 94 | if self._with_layer_norm: 95 | layers.append(snt.LayerNorm(axis=-1, create_scale=True, create_offset=True)) 96 | layers.append(tf.nn.relu) 97 | self._net = snt.Sequential(layers) 98 | 99 | def __call__(self, inputs: tf.Tensor) -> tf.Tensor: 100 | mlp_out = self._net(inputs) 101 | return mlp_out 102 | -------------------------------------------------------------------------------- /sc2_imitation_learning/common/layers.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Text, Union, Sequence 2 | 3 | import numpy as np 4 | import sonnet as snt 5 | import tensorflow as tf 6 | 7 | 8 | class MaskedGlobalAveragePooling1D(snt.Module): 9 | """ Global average pooling operation for masked temporal inputs. """ 10 | 11 | def __init__(self, mask_value=0, name: Optional[Text] = None): 12 | """ Initializes the sparse average pooling module 13 | 14 | Args: 15 | mask_value: input value that will be masked. 16 | name: An optional string name for the module. 17 | """ 18 | super().__init__(name) 19 | self._mask_value = mask_value 20 | 21 | def __call__(self, inputs: tf.Tensor) -> tf.Tensor: 22 | """ Applies the defined pooling operation on the inputs. 23 | 24 | Args: 25 | inputs: A tf.Tensor with shape (batch_dim, sequence_length, channels) 26 | 27 | Returns: A tf.Tensor with shape (batch_dim, channels) 28 | """ 29 | mask = tf.reduce_any(tf.not_equal(inputs, self._mask_value), axis=-1) # B x T 30 | mask = tf.cast(mask, inputs.dtype) 31 | mask = tf.expand_dims(mask, axis=2) 32 | inputs *= mask 33 | return tf.math.divide_no_nan(tf.reduce_sum(inputs, axis=1), tf.reduce_sum(mask, axis=1)) 34 | 35 | 36 | class SparseOneHot(snt.Module): 37 | """ Embedding module for sparse vocabulary. Supports unknown tokens that lay between 0 and max(vocab). """ 38 | 39 | def __init__(self, 40 | vocab: Sequence[int], 41 | dtype: tf.DType = tf.float32, 42 | name: Optional[Text] = None): 43 | """ Initializes the sparse one hot module. 44 | 45 | Args: 46 | vocab: A list of non-negative integer vocabulary tokens. 47 | embed_dim: Embedding dimension. 48 | name: An optional string name for the module. 49 | """ 50 | super().__init__(name) 51 | if not all(i >= 0 for i in vocab): 52 | raise ValueError("Negative vocabulary tokens are not supported.") 53 | self._dtype = dtype 54 | self._vocab_size = len(vocab) + 1 55 | vocab_lookup = np.zeros((np.max(vocab) + 1,), dtype=np.int32) 56 | # start lookup range from one to keep zeros for unknowns 57 | vocab_lookup[vocab] = np.arange(1, len(vocab) + 1, dtype=np.int32) 58 | self._vocab_lookup = tf.constant(vocab_lookup, dtype=tf.int32) 59 | 60 | def __call__(self, inputs: tf.Tensor) -> tf.Tensor: 61 | """ Embeds the given inputs. 62 | 63 | Args: 64 | inputs: Input tensor with vocabulary indices. 65 | 66 | Returns: 67 | The resulting embedding tensor. 68 | """ 69 | inputs = tf.expand_dims(inputs, axis=-1) 70 | indices = tf.gather_nd(params=self._vocab_lookup, indices=tf.cast(inputs, dtype=tf.int32)) 71 | return tf.one_hot(indices=indices, depth=self._vocab_size, dtype=self._dtype) 72 | 73 | 74 | class SparseEmbed(snt.Module): 75 | """ Embedding module for sparse vocabulary. Supports unknown tokens that lay between 0 and max(vocab). """ 76 | 77 | def __init__(self, vocab: Union[Sequence[int], tf.Tensor], embed_dim: int, densify_gradients: bool = False, 78 | name: Optional[Text] = None): 79 | """ Initializes the sparse embedding module 80 | 81 | Args: 82 | vocab: A list of non-negative vocabulary integer tokens. 83 | embed_dim: Embedding dimension. 84 | name: An optional string name for the module. 85 | """ 86 | super().__init__(name) 87 | if not all(i >= 0 for i in vocab): 88 | raise ValueError("Negative vocabulary tokens are not supported.") 89 | vocab_lookup = np.zeros((np.max(vocab) + 1,), dtype=np.int32) 90 | # start lookup range from one to keep zeros for unknowns 91 | vocab_lookup[vocab] = np.arange(1, len(vocab) + 1, dtype=np.int32) 92 | self._vocab_lookup = tf.constant(vocab_lookup, dtype=tf.int32) 93 | self._embed = snt.Embed(vocab_size=len(vocab) + 1, embed_dim=embed_dim, 94 | densify_gradients=densify_gradients, dtype=tf.float32) 95 | 96 | def __call__(self, inputs: tf.Tensor) -> tf.Tensor: 97 | """ Embeds the given inputs. 98 | 99 | Args: 100 | inputs: Input tensor with vocabulary indices. 101 | 102 | Returns: 103 | The resulting embedding tensor. 104 | """ 105 | inputs = tf.expand_dims(inputs, axis=-1) 106 | indices = tf.gather_nd(params=self._vocab_lookup, indices=tf.cast(inputs, dtype=tf.int32)) 107 | return self._embed(indices) 108 | -------------------------------------------------------------------------------- /sc2_imitation_learning/dataset/sc2_dataset.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import glob 3 | import logging 4 | import multiprocessing 5 | import os 6 | from typing import Optional 7 | from typing import Set, AbstractSet 8 | 9 | import gin 10 | import tensorflow as tf 11 | from pysc2 import run_configs 12 | from pysc2.env.sc2_env import Race 13 | 14 | from sc2_imitation_learning.common.utils import load_json 15 | from sc2_imitation_learning.dataset.dataset import DataLoader, get_dataset_specs, load_dataset_from_hdf5 16 | from sc2_imitation_learning.environment.sc2_environment import SC2ActionSpace, SC2ObservationSpace, SC2Maps 17 | 18 | logger = logging.getLogger(__name__) 19 | 20 | 21 | SC2REPLAY_RACES = { 22 | 'Prot': Race.protoss, 23 | 'Terr': Race.terran, 24 | 'Zerg': Race.zerg 25 | } 26 | 27 | 28 | @gin.register 29 | class SC2DataLoader(DataLoader): 30 | def __init__(self, 31 | path: str, 32 | action_space: SC2ActionSpace, 33 | observation_space: SC2ObservationSpace, 34 | min_duration: float = 0., 35 | min_mmr: int = 0, 36 | min_apm: int = 0, 37 | observed_player_races: AbstractSet[Race] = frozenset((Race.protoss, Race.terran, Race.zerg)), 38 | opponent_player_races: AbstractSet[Race] = frozenset((Race.protoss, Race.terran, Race.zerg)), 39 | map_names: Optional[Set[str]] = None) -> None: 40 | super().__init__() 41 | assert os.path.isdir(path), f"Not a valid dataset path: '{path}'" 42 | 43 | sc2_maps = SC2Maps(run_configs.get().data_dir) 44 | 45 | def filter_replay_info(episode_info): 46 | replay_info = episode_info['replay_info'] 47 | observed_player_info = next( 48 | filter(lambda p: p['PlayerID'] == episode_info['observed_player_id'], replay_info['Players'])) 49 | if len(replay_info['Players']) > 1: 50 | opponent_player_info = next( 51 | filter(lambda p: p['PlayerID'] != episode_info['observed_player_id'], replay_info['Players'])) 52 | return (replay_info['Duration'] >= min_duration 53 | and observed_player_info.get('MMR', 0) >= min_mmr 54 | and observed_player_info['APM'] >= min_apm 55 | and SC2REPLAY_RACES[observed_player_info['AssignedRace']] in observed_player_races 56 | and (len(replay_info['Players']) == 1 or 57 | SC2REPLAY_RACES[opponent_player_info['AssignedRace']] in opponent_player_races) 58 | and (map_names is None or sc2_maps.normalize_map_name(replay_info['Title']) in map_names)) 59 | 60 | with multiprocessing.Pool() as p: 61 | meta_infos = p.map(load_json, glob.glob(os.path.join(path, '*.meta'))) 62 | 63 | logger.info(f"Found {len(meta_infos)} episodes.") 64 | 65 | meta_infos = [meta_info for meta_info in meta_infos if filter_replay_info(meta_info['episode_info'])] 66 | 67 | logger.info(f"Filtered {len(meta_infos)} episodes (Filter: " 68 | f"min_duration={min_duration}, " 69 | f"min_mmr={min_mmr}, " 70 | f"min_apm={min_apm}, " 71 | f"observed_player_races={list(observed_player_races)}, " 72 | f"opponent_player_races={list(opponent_player_races)}, " 73 | f"map_names={map_names if map_names is None else list(map_names)}).") 74 | 75 | self._file_paths = [os.path.join(path, meta_info['data_file']) for meta_info in meta_infos] 76 | self._num_samples = sum([meta_info['episode_length'] for meta_info in meta_infos]) 77 | self._num_episodes = len(self._file_paths) 78 | 79 | assert self._num_episodes > 0, "Empty dataset" 80 | 81 | logger.info(f"Loaded dataset with {self._num_episodes} episodes ({self._num_samples} samples).") 82 | 83 | self._specs = get_dataset_specs(action_space=action_space, observation_space=observation_space) 84 | 85 | @property 86 | def num_samples(self) -> int: 87 | return self._num_samples 88 | 89 | @property 90 | def num_episodes(self) -> int: 91 | return self._num_episodes 92 | 93 | def load(self, 94 | batch_size: int, 95 | sequence_length: int, 96 | offset_episodes: int = 0, 97 | num_episodes: int = 0, 98 | num_workers: int = os.cpu_count(), 99 | chunk_size: int = 4, 100 | seed: Optional[int] = None) -> tf.data.Dataset: 101 | if num_episodes > 0: 102 | file_paths = self._file_paths[offset_episodes:offset_episodes+num_episodes] 103 | else: 104 | file_paths = self._file_paths[offset_episodes:] 105 | return load_dataset_from_hdf5( 106 | file_paths, self._specs, batch_size, sequence_length, num_workers, chunk_size, seed) 107 | -------------------------------------------------------------------------------- /sc2_imitation_learning/agents/common/spatial_decoder.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from typing import Sequence, Text, Optional 3 | 4 | import gin 5 | import sonnet as snt 6 | import tensorflow as tf 7 | 8 | from sc2_imitation_learning.common.conv import FiLMedResBlock, ResBlock, broadcast_nc_to_nhwc 9 | 10 | 11 | class SpatialDecoder(snt.Module, ABC): 12 | @abstractmethod 13 | def __call__(self, autoregressive_embedding: tf.Tensor, map_skip: Sequence[tf.Tensor]) -> tf.Tensor: 14 | pass 15 | 16 | 17 | @gin.register 18 | class ResSpatialDecoder(SpatialDecoder): 19 | def __init__(self, 20 | out_channels: int = gin.REQUIRED, 21 | num_blocks: int = gin.REQUIRED, 22 | name: Optional[Text] = None): 23 | super().__init__(name) 24 | self._out_channels = out_channels 25 | self._num_blocks = num_blocks 26 | self._input_transform = snt.Conv2D(output_channels=self._out_channels, kernel_shape=1, stride=1) 27 | self._layers = [] 28 | for i in range(self._num_blocks): 29 | self._layers.append(ResBlock(out_channels=self._out_channels, stride=1)) 30 | 31 | def __call__(self, autoregressive_embedding: tf.Tensor, map_skip: Sequence[tf.Tensor]) -> tf.Tensor: 32 | map_skip = list(reversed(map_skip)) 33 | inputs, map_skip = map_skip[0], map_skip[1:] 34 | 35 | broadcast_embedding = broadcast_nc_to_nhwc(autoregressive_embedding, inputs.shape[1], inputs.shape[2]) 36 | inputs = tf.concat([broadcast_embedding, inputs], axis=-1) 37 | inputs = tf.nn.relu(inputs) 38 | inputs = self._input_transform(inputs) 39 | inputs = tf.nn.relu(inputs) 40 | 41 | conv_out = inputs 42 | if len(map_skip) == 0: 43 | for layer in self._layers: 44 | conv_out = layer(conv_out) 45 | else: 46 | assert self._num_blocks == len(map_skip), \ 47 | f"'num_blocks' must be equal to the lengths of 'map_skip' but got: {self._num_blocks}, {len(map_skip)}" 48 | for layer, skip_connection in zip(self._layers, reversed(map_skip)): 49 | conv_out = layer(conv_out) 50 | conv_out += skip_connection 51 | 52 | conv_out = tf.nn.relu(conv_out) 53 | 54 | return conv_out 55 | 56 | 57 | @gin.register 58 | class FiLMedSpatialDecoder(SpatialDecoder): 59 | def __init__(self, 60 | out_channels: int = gin.REQUIRED, 61 | num_blocks: int = gin.REQUIRED, 62 | name: Optional[Text] = None): 63 | super().__init__(name) 64 | self._out_channels = out_channels 65 | self._num_blocks = num_blocks 66 | self._input_transform = snt.Conv2D(output_channels=self._out_channels, kernel_shape=1, stride=1) 67 | self._layers = [] 68 | for i in range(self._num_blocks): 69 | self._layers.append(FiLMedResBlock(out_channels=self._out_channels, stride=1)) 70 | 71 | @snt.once 72 | def _initialize(self, inputs: tf.Tensor, autoregressive_embedding: tf.Tensor): 73 | assert self._out_channels * self._num_blocks * 2 == autoregressive_embedding.shape[-1], \ 74 | f"output_channels={self._out_channels} and num_blocks={self._num_blocks} are not " \ 75 | f"compatible with autoregressive_embedding of size {autoregressive_embedding.shape[-1]}" 76 | 77 | self._reshape_non_spatial = snt.Reshape(output_shape=(inputs.shape[1], inputs.shape[2], -1)) 78 | 79 | def __call__(self, autoregressive_embedding: tf.Tensor, map_skip: Sequence[tf.Tensor]) -> tf.Tensor: 80 | map_skip = list(reversed(map_skip)) 81 | inputs, map_skip = map_skip[0], map_skip[1:] 82 | 83 | self._initialize(inputs, autoregressive_embedding) 84 | 85 | inputs = tf.concat([self._reshape_non_spatial(autoregressive_embedding), inputs], axis=-1) 86 | inputs = tf.nn.relu(inputs) 87 | inputs = self._input_transform(inputs) 88 | inputs = tf.nn.relu(inputs) 89 | 90 | gammas, betas = [], [] 91 | for i in range(0, autoregressive_embedding.shape[-1], self._out_channels * 2): 92 | gammas.append(autoregressive_embedding[:, i:i+self._out_channels]) 93 | betas.append(autoregressive_embedding[:, i+self._out_channels:i+2*self._out_channels]) 94 | 95 | assert self._num_blocks == len(map_skip) == len(gammas) == len(betas),\ 96 | f"'num_blocks' must be equal to the lengths of 'map_skip', 'gammas' and 'betas' but got: " \ 97 | f"{self._num_blocks}, {len(map_skip)}, {len(gammas)} and {len(betas)}" 98 | 99 | conv_out = inputs 100 | for layer, gamma, beta, skip_connection in zip(self._layers, gammas, betas, map_skip): 101 | conv_out = layer(conv_out, gamma, beta) 102 | conv_out += skip_connection 103 | conv_out = tf.nn.relu(conv_out) 104 | 105 | return conv_out 106 | -------------------------------------------------------------------------------- /configs/1v1/evaluate.gin: -------------------------------------------------------------------------------- 1 | include 'configs/1v1/environment.gin' 2 | 3 | evaluate.envs = [ 4 | @Automaton_tvt_very_easy/SC2SingleAgentEnv, 5 | @Automaton_tvt_easy/SC2SingleAgentEnv, 6 | @Automaton_tvt_medium/SC2SingleAgentEnv, 7 | @Automaton_tvt_hard/SC2SingleAgentEnv, 8 | 9 | @Blueshift_tvt_very_easy/SC2SingleAgentEnv, 10 | @Blueshift_tvt_easy/SC2SingleAgentEnv, 11 | @Blueshift_tvt_medium/SC2SingleAgentEnv, 12 | @Blueshift_tvt_hard/SC2SingleAgentEnv, 13 | 14 | @CeruleanFall_tvt_very_easy/SC2SingleAgentEnv, 15 | @CeruleanFall_tvt_easy/SC2SingleAgentEnv, 16 | @CeruleanFall_tvt_medium/SC2SingleAgentEnv, 17 | @CeruleanFall_tvt_hard/SC2SingleAgentEnv, 18 | 19 | @KairosJunction_tvt_very_easy/SC2SingleAgentEnv, 20 | @KairosJunction_tvt_easy/SC2SingleAgentEnv, 21 | @KairosJunction_tvt_medium/SC2SingleAgentEnv, 22 | @KairosJunction_tvt_hard/SC2SingleAgentEnv, 23 | 24 | @ParaSite_tvt_very_easy/SC2SingleAgentEnv, 25 | @ParaSite_tvt_easy/SC2SingleAgentEnv, 26 | @ParaSite_tvt_medium/SC2SingleAgentEnv, 27 | @ParaSite_tvt_hard/SC2SingleAgentEnv, 28 | 29 | @PortAleksander_tvt_very_easy/SC2SingleAgentEnv, 30 | @PortAleksander_tvt_easy/SC2SingleAgentEnv, 31 | @PortAleksander_tvt_medium/SC2SingleAgentEnv, 32 | @PortAleksander_tvt_hard/SC2SingleAgentEnv, 33 | 34 | @Stasis_tvt_very_easy/SC2SingleAgentEnv, 35 | @Stasis_tvt_easy/SC2SingleAgentEnv, 36 | @Stasis_tvt_medium/SC2SingleAgentEnv, 37 | @Stasis_tvt_hard/SC2SingleAgentEnv 38 | ] 39 | evaluate.num_episodes = 100 40 | evaluate.random_seed = 42 41 | evaluate.num_evaluators = 20 42 | 43 | SC2SingleAgentEnv.bot_race = 'terran' 44 | SC2SingleAgentEnv.game_steps_per_episode = 0 45 | SC2SingleAgentEnv.save_replay_episodes = 1 46 | 47 | Automaton_tvt_very_easy/SC2SingleAgentEnv.map_name = 'Automaton' 48 | Automaton_tvt_very_easy/SC2SingleAgentEnv.bot_difficulty = 'very_easy' 49 | Automaton_tvt_easy/SC2SingleAgentEnv.map_name = 'Automaton' 50 | Automaton_tvt_easy/SC2SingleAgentEnv.bot_difficulty = 'easy' 51 | Automaton_tvt_medium/SC2SingleAgentEnv.map_name = 'Automaton' 52 | Automaton_tvt_medium/SC2SingleAgentEnv.bot_difficulty = 'medium' 53 | Automaton_tvt_hard/SC2SingleAgentEnv.map_name = 'Automaton' 54 | Automaton_tvt_hard/SC2SingleAgentEnv.bot_difficulty = 'hard' 55 | 56 | Blueshift_tvt_very_easy/SC2SingleAgentEnv.map_name = 'Blueshift' 57 | Blueshift_tvt_very_easy/SC2SingleAgentEnv.bot_difficulty = 'very_easy' 58 | Blueshift_tvt_easy/SC2SingleAgentEnv.map_name = 'Blueshift' 59 | Blueshift_tvt_easy/SC2SingleAgentEnv.bot_difficulty = 'easy' 60 | Blueshift_tvt_medium/SC2SingleAgentEnv.map_name = 'Blueshift' 61 | Blueshift_tvt_medium/SC2SingleAgentEnv.bot_difficulty = 'medium' 62 | Blueshift_tvt_hard/SC2SingleAgentEnv.map_name = 'Blueshift' 63 | Blueshift_tvt_hard/SC2SingleAgentEnv.bot_difficulty = 'hard' 64 | 65 | CeruleanFall_tvt_very_easy/SC2SingleAgentEnv.map_name = 'CeruleanFall' 66 | CeruleanFall_tvt_very_easy/SC2SingleAgentEnv.bot_difficulty = 'very_easy' 67 | CeruleanFall_tvt_easy/SC2SingleAgentEnv.map_name = 'CeruleanFall' 68 | CeruleanFall_tvt_easy/SC2SingleAgentEnv.bot_difficulty = 'easy' 69 | CeruleanFall_tvt_medium/SC2SingleAgentEnv.map_name = 'CeruleanFall' 70 | CeruleanFall_tvt_medium/SC2SingleAgentEnv.bot_difficulty = 'medium' 71 | CeruleanFall_tvt_hard/SC2SingleAgentEnv.map_name = 'CeruleanFall' 72 | CeruleanFall_tvt_hard/SC2SingleAgentEnv.bot_difficulty = 'hard' 73 | 74 | KairosJunction_tvt_very_easy/SC2SingleAgentEnv.map_name = 'KairosJunction' 75 | KairosJunction_tvt_very_easy/SC2SingleAgentEnv.bot_difficulty = 'very_easy' 76 | KairosJunction_tvt_easy/SC2SingleAgentEnv.map_name = 'KairosJunction' 77 | KairosJunction_tvt_easy/SC2SingleAgentEnv.bot_difficulty = 'easy' 78 | KairosJunction_tvt_medium/SC2SingleAgentEnv.map_name = 'KairosJunction' 79 | KairosJunction_tvt_medium/SC2SingleAgentEnv.bot_difficulty = 'medium' 80 | KairosJunction_tvt_hard/SC2SingleAgentEnv.map_name = 'KairosJunction' 81 | KairosJunction_tvt_hard/SC2SingleAgentEnv.bot_difficulty = 'hard' 82 | 83 | ParaSite_tvt_very_easy/SC2SingleAgentEnv.map_name = 'ParaSite' 84 | ParaSite_tvt_very_easy/SC2SingleAgentEnv.bot_difficulty = 'very_easy' 85 | ParaSite_tvt_easy/SC2SingleAgentEnv.map_name = 'ParaSite' 86 | ParaSite_tvt_easy/SC2SingleAgentEnv.bot_difficulty = 'easy' 87 | ParaSite_tvt_medium/SC2SingleAgentEnv.map_name = 'ParaSite' 88 | ParaSite_tvt_medium/SC2SingleAgentEnv.bot_difficulty = 'medium' 89 | ParaSite_tvt_hard/SC2SingleAgentEnv.map_name = 'ParaSite' 90 | ParaSite_tvt_hard/SC2SingleAgentEnv.bot_difficulty = 'hard' 91 | 92 | PortAleksander_tvt_very_easy/SC2SingleAgentEnv.map_name = 'PortAleksander' 93 | PortAleksander_tvt_very_easy/SC2SingleAgentEnv.bot_difficulty = 'very_easy' 94 | PortAleksander_tvt_easy/SC2SingleAgentEnv.map_name = 'PortAleksander' 95 | PortAleksander_tvt_easy/SC2SingleAgentEnv.bot_difficulty = 'easy' 96 | PortAleksander_tvt_medium/SC2SingleAgentEnv.map_name = 'PortAleksander' 97 | PortAleksander_tvt_medium/SC2SingleAgentEnv.bot_difficulty = 'medium' 98 | PortAleksander_tvt_hard/SC2SingleAgentEnv.map_name = 'PortAleksander' 99 | PortAleksander_tvt_hard/SC2SingleAgentEnv.bot_difficulty = 'hard' 100 | 101 | Stasis_tvt_very_easy/SC2SingleAgentEnv.map_name = 'Stasis' 102 | Stasis_tvt_very_easy/SC2SingleAgentEnv.bot_difficulty = 'very_easy' 103 | Stasis_tvt_easy/SC2SingleAgentEnv.map_name = 'Stasis' 104 | Stasis_tvt_easy/SC2SingleAgentEnv.bot_difficulty = 'easy' 105 | Stasis_tvt_medium/SC2SingleAgentEnv.map_name = 'Stasis' 106 | Stasis_tvt_medium/SC2SingleAgentEnv.bot_difficulty = 'medium' 107 | Stasis_tvt_hard/SC2SingleAgentEnv.map_name = 'Stasis' 108 | Stasis_tvt_hard/SC2SingleAgentEnv.bot_difficulty = 'hard' 109 | -------------------------------------------------------------------------------- /sc2_imitation_learning/common/evaluator.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import multiprocessing 3 | import sys 4 | import traceback 5 | from collections import defaultdict 6 | from queue import Queue 7 | from typing import Type, Optional, Callable, NamedTuple, List, Dict 8 | 9 | import numpy as np 10 | import tensorflow as tf 11 | from absl.flags import FLAGS 12 | 13 | from sc2_imitation_learning.agents import Agent 14 | from sc2_imitation_learning.common.utils import make_dummy_action 15 | from sc2_imitation_learning.environment.sc2_environment import SC2SingleAgentEnv 16 | 17 | logger = logging.getLogger(__file__) 18 | 19 | 20 | class EvalEpisode(NamedTuple): 21 | matchup: str 22 | stats: 'EpisodeStats' 23 | 24 | 25 | class EpisodeStats(NamedTuple): 26 | num_frames: int 27 | num_steps: int 28 | reward: float 29 | 30 | 31 | class Evaluator(multiprocessing.Process): 32 | 33 | def __init__(self, 34 | agent_fn: Callable[[], Agent], 35 | env: SC2SingleAgentEnv, 36 | queue_in: Queue, 37 | queue_out: Queue, 38 | device: Optional = None) -> None: 39 | super().__init__() 40 | self._agent_fn = agent_fn 41 | self._env = env 42 | self._queue_in = queue_in 43 | self._queue_out = queue_out 44 | self._device = device 45 | 46 | def run(self) -> None: 47 | FLAGS(sys.argv) 48 | 49 | if self._device is not None: 50 | tf.config.set_visible_devices(self._device, 'GPU') 51 | 52 | agent = self._agent_fn() 53 | agent_state = agent.initial_state(1) 54 | 55 | seed = self._queue_in.get() 56 | while seed is not None: 57 | env: SC2SingleAgentEnv = self._env 58 | env.restart(seed) 59 | 60 | try: 61 | total_reward, frame, step = 0., 0, 0 62 | reward, done, observation = 0., False, env.reset() 63 | action = make_dummy_action(env.action_space, num_batch_dims=1) 64 | while not done: 65 | env_outputs = ( 66 | tf.constant([reward], dtype=tf.float32), 67 | tf.constant([step == 0], dtype=tf.bool), 68 | tf.nest.map_structure(lambda o: tf.constant([o], dtype=tf.dtypes.as_dtype(o.dtype)), observation)) 69 | agent_output, agent_state = agent(action, env_outputs, agent_state) 70 | action = tf.nest.map_structure(lambda t: t.numpy(), agent_output.actions) 71 | reward, _, done, observation = env.step(action) 72 | total_reward += reward 73 | frame += action['step_mul'] + 1 74 | step += 1 75 | 76 | self._queue_out.put( 77 | EvalEpisode(matchup=env.level_name, 78 | stats=EpisodeStats(num_frames=frame, num_steps=step, reward=total_reward))) 79 | 80 | except Exception as e: 81 | logger.error(f"Failed to evaluate episode (stacktrace below). Restart env.") 82 | traceback.print_exc() 83 | continue 84 | finally: 85 | env.close() 86 | 87 | seed = self._queue_in.get() 88 | 89 | 90 | def evaluate_on_single_env(agent_fn: Callable[[], Agent], 91 | env_cls: Type[SC2SingleAgentEnv], 92 | num_episodes: int, 93 | num_evaluators: int, 94 | random_seed: int = None, 95 | replay_dir: Optional[str] = None, 96 | available_gpus: Optional[List] = None) -> Dict[str, EpisodeStats]: 97 | logger.info(f"Start evaluation for {num_episodes} episodes using {num_evaluators} evaluator threads.") 98 | 99 | queue_in = multiprocessing.Queue() 100 | queue_out = multiprocessing.Queue() 101 | 102 | rngesus: np.random.Generator = np.random.default_rng(seed=random_seed) 103 | random_seeds = rngesus.integers(low=0, high=np.iinfo(np.int32).max, size=num_episodes) 104 | for random_seed in random_seeds: 105 | queue_in.put(int(random_seed)) 106 | 107 | for _ in range(num_evaluators): 108 | queue_in.put(None) 109 | 110 | env = env_cls(replay_dir=replay_dir) 111 | 112 | evaluators = [ 113 | Evaluator(agent_fn=agent_fn, env=env, queue_in=queue_in, queue_out=queue_out, 114 | device=None if available_gpus is None else available_gpus[i % len(available_gpus)]) 115 | for i in range(num_evaluators)] 116 | 117 | for evaluator in evaluators: 118 | evaluator.start() 119 | 120 | logger.info("All evaluator threads started.") 121 | 122 | all_episode_stats = defaultdict(list) 123 | total_episodes = 0 124 | while total_episodes < num_episodes: 125 | eval_episode: EvalEpisode = queue_out.get() 126 | all_episode_stats[eval_episode.matchup].append(eval_episode.stats) 127 | total_episodes += 1 128 | logger.info(f"Episode {total_episodes} completed: " 129 | f"matchup={eval_episode.matchup}, " 130 | f"reward={eval_episode.stats.reward:.2f}, " 131 | f"frames={eval_episode.stats.num_frames}, " 132 | f"matchup_rewards_mean={np.mean([s.reward for s in all_episode_stats[eval_episode.matchup]]):.2f}") 133 | 134 | logger.info(f"Evaluation completed:\n\t" + "\n\t".join([ 135 | f"matchup={matchup}, " 136 | f"mean_reward={np.mean([s.reward for s in stats]):.2f} (std={np.std([s.reward for s in stats]):.2f})" 137 | for matchup, stats in all_episode_stats.items() 138 | ])) 139 | 140 | for evaluator in evaluators: 141 | evaluator.join() 142 | 143 | logger.info("All evaluator threads stopped.") 144 | 145 | return dict(all_episode_stats) 146 | 147 | 148 | def evaluate_on_multiple_envs(agent_fn: Callable[[], Agent], 149 | envs: List[Type[SC2SingleAgentEnv]], 150 | num_episodes: int, 151 | num_evaluators: int, 152 | random_seed: int = None, 153 | replay_dir: Optional[str] = None, 154 | available_gpus: Optional[List] = None) -> Dict[str, EpisodeStats]: 155 | all_episode_stats = defaultdict(list) 156 | 157 | for env_cls in envs: 158 | eval_episodes = evaluate_on_single_env( 159 | agent_fn=agent_fn, 160 | env_cls=env_cls, 161 | num_episodes=num_episodes, 162 | num_evaluators=num_evaluators, 163 | random_seed=random_seed, 164 | replay_dir=replay_dir, 165 | available_gpus=available_gpus) 166 | 167 | for matchup, stats in eval_episodes.items(): 168 | all_episode_stats[matchup].extend(stats) 169 | 170 | return dict(all_episode_stats) 171 | -------------------------------------------------------------------------------- /sc2_imitation_learning/agents/sc2_feature_layer_agent.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple, List, Text, Union, Optional, Dict 2 | 3 | import gin 4 | import sonnet as snt 5 | import tensorflow as tf 6 | 7 | from sc2_imitation_learning.agents import Agent, AgentOutput 8 | from sc2_imitation_learning.agents.common.policy_head import AutoregressivePolicyHead, PolicyContextFeatures 9 | from sc2_imitation_learning.agents.common.scalar_encoder import ScalarEncoder 10 | from sc2_imitation_learning.agents.common.spatial_encoder import SpatialEncoder 11 | from sc2_imitation_learning.agents.common.unit_group_encoder import UnitGroupEncoder 12 | 13 | 14 | class SC2FeatureLayerAgentHead(snt.Module): 15 | def __init__(self, 16 | autoregressive_embed_dim: int, 17 | policy_heads: List[AutoregressivePolicyHead], 18 | name: Optional[Text] = None): 19 | super().__init__(name) 20 | self._core_outputs_embed = snt.Linear(autoregressive_embed_dim) # todo glu! 21 | self._policy_heads = policy_heads 22 | 23 | def __call__(self, 24 | core_outputs: tf.Tensor, 25 | scalar_context: tf.Tensor, 26 | unit_groups: Dict[str, tf.Tensor], 27 | available_actions: tf.Tensor, 28 | screen_skip: List[tf.Tensor], 29 | minimap_skip: List[tf.Tensor], 30 | teacher_actions: Optional[Dict[str, tf.Tensor]] = None) -> AgentOutput: 31 | 32 | context = PolicyContextFeatures( 33 | scalar_context=scalar_context, 34 | unit_groups=unit_groups, 35 | available_actions=available_actions, 36 | map_skip={ 37 | 'screen': screen_skip, 38 | 'minimap': minimap_skip 39 | }) 40 | 41 | autoregressive_embedding = self._core_outputs_embed(core_outputs) 42 | 43 | action, logits = {}, {} 44 | for policy_head in self._policy_heads: 45 | poliy_outputs, autoregressive_embedding = policy_head( 46 | autoregressive_embedding=autoregressive_embedding, 47 | context=context, 48 | partial_action=action if teacher_actions is None else teacher_actions, 49 | teacher_action=None if teacher_actions is None else teacher_actions[policy_head.action_name]) 50 | action[policy_head.action_name] = poliy_outputs.action 51 | logits[policy_head.action_name] = poliy_outputs.logits 52 | 53 | return AgentOutput(logits=logits, actions=action, values=None) 54 | 55 | 56 | @gin.register 57 | class SC2FeatureLayerAgent(Agent): 58 | """ An agent that operates on scalar features and spatial feature layers as sensory inputs. """ 59 | 60 | def __init__(self, 61 | scalar_encoder: ScalarEncoder = gin.REQUIRED, 62 | unit_group_encoder: Optional[UnitGroupEncoder] = None, 63 | screen_encoder: SpatialEncoder = gin.REQUIRED, 64 | minimap_encoder: SpatialEncoder = gin.REQUIRED, 65 | core: Union[tf.keras.layers.LSTMCell, tf.keras.layers.StackedRNNCells] = gin.REQUIRED, 66 | autoregressive_embed_dim: int = gin.REQUIRED, 67 | policy_heads: List[AutoregressivePolicyHead] = gin.REQUIRED, 68 | ): 69 | """ Constructs a feature layer agent. 70 | 71 | Args: 72 | scalar_encoder: A scalar encoder module. 73 | unit_group_encoder: A unit group encoder module. 74 | screen_encoder: A screen encoder module. 75 | minimap_encoder: A minimap encoder module. 76 | core: An LSTM core module. Both single layer (LSTMCell) and deep (StackedRNNCells) LSTMs are supported. 77 | autoregressive_embed_dim: The size of the autoregressive embedding vector used during action decoding. 78 | policy_heads: A list of autoregressive policy heads. IMPORTANT: ordering matters! The actions are decoded 79 | exactly in the provided order. 80 | """ 81 | super().__init__() 82 | self._scalar_encoder = snt.BatchApply(scalar_encoder) 83 | if unit_group_encoder is not None: 84 | self._unit_group_encoder = snt.BatchApply(unit_group_encoder) 85 | else: 86 | self._unit_group_encoder = None 87 | self._screen_encoder = snt.BatchApply(screen_encoder) 88 | self._minimap_encoder = snt.BatchApply(minimap_encoder) 89 | self._core = core 90 | self._head = snt.BatchApply(SC2FeatureLayerAgentHead(autoregressive_embed_dim, policy_heads)) 91 | 92 | def initial_state(self, batch_size: int, **kwargs): 93 | if isinstance(self._core, snt.RNNCore): 94 | return self._core.initial_state(batch_size=batch_size, **kwargs) 95 | else: 96 | return self._core.get_initial_state(batch_size=batch_size, dtype=tf.float32) 97 | 98 | def _unroll(self, prev_actions, env_outputs, core_state, teacher_actions=None) -> Tuple[AgentOutput, Tuple]: 99 | rewards, done, observations = env_outputs 100 | 101 | embedded_scalar, scalar_context = self._scalar_encoder( 102 | features=observations['scalar_features'], 103 | prev_actions=prev_actions) 104 | 105 | if self._unit_group_encoder is not None: 106 | _, unit_groups = self._unit_group_encoder(features=observations['scalar_features']) 107 | else: 108 | unit_groups = dict() 109 | 110 | embedded_screen, screen_skip = self._screen_encoder(features=observations['screen_features']) 111 | 112 | embedded_minimap, minimap_skip = self._minimap_encoder(features=observations['minimap_features']) 113 | 114 | embedded_observations = tf.concat(values=[embedded_scalar, embedded_screen, embedded_minimap], axis=-1) 115 | 116 | initial_core_state = self.initial_state(batch_size=tf.shape(done)[1]) 117 | 118 | core_output_list = [] 119 | for input_, d in zip(tf.unstack(embedded_observations), tf.unstack(done)): 120 | # If the episode ended, the core state should be reset before the next. 121 | core_state = tf.nest.map_structure( 122 | lambda x, y, d=d: tf.where(tf.reshape(d, [tf.shape(d)[0]] + [1] * (x.shape.rank - 1)), x, y), 123 | initial_core_state, 124 | core_state) 125 | core_output, core_state = self._core(input_, core_state) 126 | core_output_list.append(core_output) 127 | 128 | core_outputs = tf.stack(core_output_list) 129 | 130 | agent_outputs: AgentOutput = self._head( 131 | core_outputs=core_outputs, 132 | scalar_context=scalar_context, 133 | unit_groups=unit_groups, 134 | available_actions=observations['scalar_features']['available_actions'], 135 | screen_skip=screen_skip, 136 | minimap_skip=minimap_skip, 137 | teacher_actions=teacher_actions) 138 | 139 | return agent_outputs, core_state 140 | -------------------------------------------------------------------------------- /scripts/behaviour_cloning.py: -------------------------------------------------------------------------------- 1 | import concurrent.futures 2 | import datetime 3 | import functools 4 | import logging 5 | import multiprocessing 6 | import os 7 | from typing import Type, Callable, List, Optional, Union, Dict 8 | 9 | import gin 10 | import numpy as np 11 | import tensorflow as tf 12 | from absl import app, flags 13 | 14 | from sc2_imitation_learning.agents import Agent 15 | from sc2_imitation_learning.behaviour_cloning.learner import learner_loop 16 | from sc2_imitation_learning.common.evaluator import evaluate_on_multiple_envs 17 | from sc2_imitation_learning.common.utils import gin_register_external_configurables, gin_config_str_to_dict 18 | from sc2_imitation_learning.dataset.dataset import DataLoader 19 | from sc2_imitation_learning.environment.environment import ObservationSpace, ActionSpace 20 | from sc2_imitation_learning.environment.sc2_environment import SC2SingleAgentEnv 21 | 22 | logging.basicConfig(level=logging.INFO) 23 | logger = logging.getLogger(__file__) 24 | 25 | flags.DEFINE_string('logdir', f"./experiments/{datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}", 26 | 'Experiment logging directory.') 27 | flags.DEFINE_multi_string('gin_file', ['./configs/1v1/behaviour_cloning.gin'], 'List of paths to Gin config files.') 28 | flags.DEFINE_multi_string('gin_param', None, 'List of Gin parameter bindings.') 29 | 30 | # logger config 31 | flags.DEFINE_bool('wandb_logging_enabled', False, 'If wandb logging should be enabled.') 32 | flags.DEFINE_string('wandb_project', 'sc2-il', 'Name of the wandb project.') 33 | flags.DEFINE_string('wandb_entity', None, 'Name of the wandb entity.') 34 | flags.DEFINE_list('wandb_tags', ['behaviour_cloning'], 'List of wandb tags.') 35 | 36 | FLAGS = flags.FLAGS 37 | 38 | gin_register_external_configurables() 39 | 40 | 41 | def agent_fn(saved_model_path, *args, **kwargs) -> Agent: 42 | return tf.saved_model.load(saved_model_path) 43 | 44 | 45 | @gin.configurable 46 | def evaluate(saved_model_path: str, 47 | envs: List[Type[SC2SingleAgentEnv]] = gin.REQUIRED, 48 | num_episodes: int = gin.REQUIRED, 49 | random_seed: int = gin.REQUIRED, 50 | num_evaluators: int = gin.REQUIRED) -> Dict[str, Union[int, float, np.ndarray]]: 51 | with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor: 52 | available_gpus = executor.submit(tf.config.list_physical_devices, 'GPU').result() 53 | 54 | episode_stats = evaluate_on_multiple_envs( 55 | agent_fn=functools.partial(agent_fn, saved_model_path), 56 | envs=envs, 57 | num_episodes=num_episodes, 58 | num_evaluators=num_evaluators, 59 | random_seed=random_seed, 60 | replay_dir=os.path.abspath(os.path.join(FLAGS.logdir, 'replays')), 61 | available_gpus=available_gpus) 62 | 63 | return { 64 | matchup: { 65 | 'num_episodes': len(stats), 66 | 'episode_frames/mean': np.mean([s.num_frames for s in stats]), 67 | 'episode_steps/mean': np.mean([s.num_steps for s in stats]), 68 | 'episode_reward/mean': np.mean([s.reward for s in stats]), 69 | 'episode_frames/min': np.min([s.num_frames for s in stats]), 70 | 'episode_steps/min': np.min([s.num_steps for s in stats]), 71 | 'episode_reward/min': np.min([s.reward for s in stats]), 72 | 'episode_frames/max': np.max([s.num_frames for s in stats]), 73 | 'episode_steps/max': np.max([s.num_steps for s in stats]), 74 | 'episode_reward/max': np.max([s.reward for s in stats]), 75 | } 76 | for matchup, stats in episode_stats.items() 77 | } 78 | 79 | 80 | @gin.configurable 81 | def train(action_space: ActionSpace = gin.REQUIRED, 82 | observation_space: ObservationSpace = gin.REQUIRED, 83 | data_loader: DataLoader = gin.REQUIRED, 84 | batch_size: int = gin.REQUIRED, 85 | sequence_length: int = gin.REQUIRED, 86 | total_train_samples: int = gin.REQUIRED, 87 | l2_regularization: float = gin.REQUIRED, 88 | update_frequency: int = gin.REQUIRED, 89 | agent_fn: Callable[[], Agent] = gin.REQUIRED, 90 | optimizer_fn: Callable[[], tf.keras.optimizers.Optimizer] = gin.REQUIRED, 91 | eval_interval: int = gin.REQUIRED, 92 | max_to_keep_checkpoints: Optional[int] = None, 93 | save_checkpoint_interval: float = 1800, 94 | tensorboard_log_interval: float = 10, 95 | console_log_interval: float = 60): 96 | 97 | def dataset_fn(ctx: tf.distribute.InputContext) -> tf.data.Dataset: 98 | num_episodes = data_loader.num_episodes // ctx.num_input_pipelines 99 | start_index = ctx.input_pipeline_id * num_episodes 100 | dataset = data_loader.load( 101 | batch_size=ctx.get_per_replica_batch_size(global_batch_size=batch_size), 102 | sequence_length=sequence_length, 103 | offset_episodes=start_index, 104 | num_episodes=num_episodes, 105 | num_workers=min(num_episodes, os.cpu_count())) 106 | return dataset.prefetch(buffer_size=ctx.num_replicas_in_sync) 107 | 108 | training_strategy = tf.distribute.MirroredStrategy([]) 109 | 110 | learner_loop(log_dir=FLAGS.logdir, 111 | observation_space=observation_space, 112 | action_space=action_space, 113 | training_strategy=training_strategy, 114 | dataset_fn=dataset_fn, 115 | agent_fn=agent_fn, 116 | optimizer_fn=optimizer_fn, 117 | total_train_samples=total_train_samples, 118 | batch_size=batch_size, 119 | sequence_size=sequence_length, 120 | l2_regularization=l2_regularization, 121 | update_frequency=update_frequency, 122 | num_episodes=data_loader.num_episodes, 123 | eval_fn=evaluate, 124 | eval_interval=eval_interval, 125 | max_to_keep_checkpoints=max_to_keep_checkpoints, 126 | save_checkpoint_interval=save_checkpoint_interval, 127 | tensorboard_log_interval=tensorboard_log_interval, 128 | console_log_interval=console_log_interval) 129 | 130 | 131 | def main(argv): 132 | gin.parse_config_files_and_bindings(FLAGS.gin_file, FLAGS.gin_param) 133 | 134 | os.makedirs(FLAGS.logdir, exist_ok=True) 135 | 136 | gin_config_str = gin.config_str(max_line_length=120) 137 | 138 | print("Loaded configuration:") 139 | print(gin_config_str) 140 | 141 | with open(os.path.join(FLAGS.logdir, 'config.gin'), mode='w') as f: 142 | f.write(gin_config_str) 143 | 144 | if FLAGS.wandb_logging_enabled: 145 | import wandb 146 | experiment_name = os.path.basename(FLAGS.logdir.rstrip("/")) 147 | job_type = 'train' 148 | wandb.init( 149 | id=f"{experiment_name}-{job_type}", 150 | name=f"{experiment_name}-{job_type}", 151 | group=experiment_name, 152 | job_type=job_type, 153 | project=FLAGS.wandb_project, 154 | entity=FLAGS.wandb_entity, 155 | tags=FLAGS.wandb_tags, 156 | resume="allow", 157 | config=gin_config_str_to_dict(gin_config_str)) 158 | wandb.tensorboard.patch(save=False, tensorboardX=False) 159 | 160 | train() 161 | 162 | 163 | if __name__ == '__main__': 164 | os.environ['TF_FORCE_GPU_ALLOW_GROWTH'] = 'true' 165 | multiprocessing.set_start_method('spawn') 166 | app.run(main) 167 | -------------------------------------------------------------------------------- /sc2_imitation_learning/agents/common/spatial_encoder.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from typing import Sequence, Union, Text, Optional, List, Dict, NamedTuple, Type 3 | 4 | import gin 5 | import sonnet as snt 6 | import tensorflow as tf 7 | 8 | from sc2_imitation_learning.agents.common.feature_encoder import FeatureEncoder 9 | from sc2_imitation_learning.common.conv import ResBlock, ImpalaResBlock, ConvNet2D 10 | 11 | 12 | class SpatialEncoderOutputs(NamedTuple): 13 | embedded_spatial: tf.Tensor 14 | map_skip: List[tf.Tensor] 15 | 16 | 17 | class SpatialEncoder(snt.Module, ABC): 18 | @abstractmethod 19 | def __call__(self, features: Dict[str, tf.Tensor]) -> SpatialEncoderOutputs: 20 | """ Applies the spatial encoder transformation. 21 | 22 | Args: 23 | features: A Dict of spatial feature layers (tf.Tensor). 24 | 25 | Returns: 26 | A NamedTuple: 27 | - embedded_spatial: A 1D tensor that represents the spatial embedding. 28 | - map_skip: A List of 2D tensors that represent intermediate spatial representations before the 1D 29 | bottleneck (can be used by spatial policies that rely on spatial information). 30 | """ 31 | pass 32 | 33 | 34 | @gin.register 35 | class ImpalaCNNSpatialEncoder(SpatialEncoder): 36 | """ Spatial encoder based on the residual CNN architecture described in `IMPALA: Scalable Distributed Deep-RL 37 | with Importance Weighted Actor-Learner Architectures` (https://arxiv.org/abs/1802.01561). 38 | 39 | This architecture consists of a number of `ImpalaResModule` that transform spatial feature layers into a spatial 40 | representation. The output tensor of the last `ImpalaResModule` is stored in `map_skip` and embedded into a 1D 41 | tensor by a linear layer followed by ReLU activations.""" 42 | 43 | def __init__(self, 44 | feature_layer_encoders: Dict[str, Type[FeatureEncoder]] = gin.REQUIRED, 45 | input_projection_dim: int = gin.REQUIRED, 46 | num_blocks: Sequence[int] = gin.REQUIRED, 47 | output_channels: Sequence[int] = gin.REQUIRED, 48 | max_pool_padding: str = 'SAME', 49 | spatial_embedding_size: int = gin.REQUIRED, 50 | name: Optional[Text] = None): 51 | """ Constructs the Impala CNN module 52 | 53 | Args: 54 | feature_layer_encoders: A Dict of `FeatureEncoder`s that specify the feature layers and their encoding. 55 | input_projection_dim: A scalar that defines the channel dim of the input projection. 56 | num_blocks: A sequence of scalars that define the number residual conv blocks in each `ImpalaResModule`. 57 | output_channels: A sequence of scalars that defines the channel dimension in each `ImpalaResModule`. 58 | max_pool_padding: The padding applied to the inputs of the max-pooling layer (either 'SAME' or 'Valid'). 59 | spatial_embedding_size: The size of the 1D embedding. 60 | name: An optional module name. 61 | """ 62 | super().__init__(name) 63 | self._feature_layer_encoders = {key: enc() for key, enc in feature_layer_encoders.items()} 64 | self._input_projection = snt.Conv2D( 65 | output_channels=input_projection_dim, kernel_shape=1, stride=1, padding='SAME') 66 | self._cnn = snt.Sequential([ 67 | ImpalaResBlock(num_blocks=b, out_channels=c, max_pool_padding=max_pool_padding, name=f'impala_block_{i}') 68 | for i, (b, c) in enumerate(zip(num_blocks, output_channels))]) 69 | self._flatten = snt.Flatten() 70 | self._final_linear = snt.Linear(output_size=spatial_embedding_size) 71 | 72 | def __call__(self, features: Dict[str, tf.Tensor]) -> SpatialEncoderOutputs: 73 | embedded_feature_layers = {key: enc(features[key]) for key, enc in self._feature_layer_encoders.items()} 74 | embedded_feature_layers = tf.concat(tf.nest.flatten(embedded_feature_layers), axis=-1) 75 | embedded_feature_layers = self._input_projection(embedded_feature_layers) 76 | 77 | conv_out = self._cnn(embedded_feature_layers) 78 | conv_out = tf.nn.relu(conv_out) 79 | 80 | embedded_spatial = self._flatten(conv_out) 81 | embedded_spatial = self._final_linear(embedded_spatial) 82 | embedded_spatial = tf.nn.relu(embedded_spatial) 83 | 84 | return SpatialEncoderOutputs(embedded_spatial, [conv_out]) 85 | 86 | 87 | @gin.register 88 | class AlphaStarSpatialEncoder(SpatialEncoder): 89 | """ Spatial encoder based on the spatial encoder architecture described in `Grandmaster level in StarCraft II using 90 | multi-agent reinforcement learning` (https://www.nature.com/articles/s41586-019-1724-z).""" 91 | 92 | def __init__(self, 93 | feature_layer_encoders: Dict[str, Type[FeatureEncoder]] = gin.REQUIRED, 94 | input_projection_dim: int = gin.REQUIRED, 95 | downscale_conv_net: ConvNet2D = gin.REQUIRED, 96 | res_out_channels: int = gin.REQUIRED, 97 | res_num_blocks: int = gin.REQUIRED, 98 | res_stride: Union[int, Sequence[int]] = gin.REQUIRED, 99 | spatial_embedding_size: int = gin.REQUIRED, 100 | name: Optional[Text] = None): 101 | """ Constructs the AlphaStar spatial encoder module. 102 | 103 | Args: 104 | feature_layer_encoders: A Dict of `FeatureEncoder`s that specify the feature layers and their encoding. 105 | input_projection_dim: A scalar that defines the channel dim of the input projection. 106 | downscale_conv_net: A ConvNet2D that initially downscales spatial inputs. 107 | res_out_channels: A scalar that defines the channel dimension of the `ResBlock`s that are applied after the 108 | downscale convolutions. 109 | res_num_blocks: A scalar that defines the number of the `ResBlock`s that are applied after the downscale 110 | convolutions. 111 | res_stride: A kernel stride (either scalar or sequence of scalars) that define the stride of the 112 | `ResBlock`s that are applied after the downscale convolutions. 113 | spatial_embedding_size: The size of the 1D embedding. 114 | name: An optional module name. 115 | """ 116 | super().__init__(name) 117 | self._feature_layer_encoders = {key: enc() for key, enc in feature_layer_encoders.items()} 118 | self._input_projection = snt.Conv2D( 119 | output_channels=input_projection_dim, kernel_shape=1, stride=1, padding='SAME') 120 | self._downscale = downscale_conv_net 121 | self._res_blocks = [ 122 | ResBlock(out_channels=res_out_channels, stride=res_stride) for _ in range(res_num_blocks)] 123 | self._spatial_embed = snt.Sequential([ 124 | snt.Flatten(), 125 | snt.Linear(output_size=spatial_embedding_size), 126 | tf.nn.relu 127 | ]) 128 | 129 | def __call__(self, features: Dict[str, tf.Tensor]) -> SpatialEncoderOutputs: 130 | embedded_feature_layers = {key: enc(features[key]) for key, enc in self._feature_layer_encoders.items()} 131 | embedded_feature_layers = tf.concat(tf.nest.flatten(embedded_feature_layers), axis=-1) 132 | embedded_feature_layers = self._input_projection(embedded_feature_layers) 133 | 134 | conv_out = self._downscale(embedded_feature_layers) 135 | map_skip = [conv_out] 136 | 137 | for layer in self._res_blocks: 138 | conv_out = layer(conv_out) 139 | map_skip.append(conv_out) 140 | conv_out = tf.nn.relu(conv_out) 141 | 142 | embedded_spatial = self._spatial_embed(conv_out) 143 | 144 | return SpatialEncoderOutputs(embedded_spatial, map_skip) 145 | -------------------------------------------------------------------------------- /sc2_imitation_learning/common/replay_processor.py: -------------------------------------------------------------------------------- 1 | import json 2 | import re 3 | from typing import Tuple, Iterator, Dict, Text, Optional 4 | 5 | import mpyq 6 | import six 7 | from absl import logging 8 | from pysc2 import run_configs, maps 9 | from pysc2.env import sc2_env, environment 10 | from pysc2.env.sc2_env import possible_results 11 | from pysc2.lib import features 12 | from pysc2.lib.actions import FunctionCall, FUNCTIONS 13 | from s2clientprotocol import sc2api_pb2 as sc_pb 14 | 15 | from sc2_imitation_learning.environment.sc2_environment import SC2ObservationSpace, SC2ActionSpace, SC2InterfaceConfig 16 | 17 | 18 | def get_replay_info(replay_path: Text) -> Dict: 19 | with open(replay_path, 'rb') as f: 20 | archive = mpyq.MPQArchive(f).extract() 21 | metadata = json.loads(archive[b"replay.gamemetadata.json"].decode("utf-8")) 22 | return metadata 23 | 24 | 25 | def get_game_version(replay_data: bytes) -> str: 26 | replay_io = six.BytesIO() 27 | replay_io.write(replay_data) 28 | replay_io.seek(0) 29 | archive = mpyq.MPQArchive(replay_io).extract() 30 | metadata = json.loads(archive[b"replay.gamemetadata.json"].decode("utf-8")) 31 | version = metadata["GameVersion"] 32 | return ".".join(version.split(".")[:-1]) 33 | 34 | 35 | class ReplayProcessor(object): 36 | 37 | def __init__( 38 | self, 39 | replay_path: str, 40 | interface_config: SC2InterfaceConfig, 41 | observation_space: SC2ObservationSpace, 42 | action_space: SC2ActionSpace, 43 | discount: float = 1., 44 | score_index: Optional[int] = None, 45 | score_multiplier: Optional[float] = None, 46 | disable_fog: bool = False, 47 | map_path: str = None, 48 | observed_player_id: int = 1, 49 | version: Optional[str] = None) -> None: 50 | super().__init__() 51 | self._replay_path = replay_path 52 | 53 | self.observation_space = observation_space 54 | self.action_space = action_space 55 | 56 | self._discount = discount 57 | self._disable_fog = disable_fog 58 | 59 | self._default_score_index = score_index or 0 60 | self._default_score_multiplier = score_multiplier 61 | self._default_episode_length = None 62 | 63 | self._run_config = run_configs.get(version=version) 64 | 65 | agent_interface_format = sc2_env.AgentInterfaceFormat( 66 | feature_dimensions=sc2_env.Dimensions( 67 | screen=interface_config.dimension_screen, 68 | minimap=interface_config.dimension_minimap 69 | ), 70 | use_unit_counts=True, 71 | hide_specific_actions=True, 72 | ) 73 | interface = sc2_env.SC2Env._get_interface(agent_interface_format, False) 74 | 75 | self._launch(replay_path, interface, map_path, observed_player_id) 76 | 77 | self._finalize(agent_interface_format) 78 | 79 | def _launch(self, replay_path, interface, map_path, observed_player_id): 80 | replay_data = self._run_config.replay_data(replay_path) 81 | 82 | version = get_game_version(replay_data) 83 | logging.info(f"Start SC2 process (game version={version})...") 84 | 85 | self._sc2_proc = self._run_config.start( 86 | # version=version, 87 | full_screen=False 88 | ) 89 | self._controller = self._sc2_proc.controller 90 | self.replay_info = self._controller.replay_info(replay_data) 91 | 92 | map_name = re.sub(r"[ '-]|[LTRS]E$", "", self.replay_info.map_name) 93 | if map_name == 'MacroEconomy': 94 | map_name = 'CollectMineralsAndGas' 95 | map_inst = maps.get(map_name) 96 | 97 | start_replay = sc_pb.RequestStartReplay( 98 | replay_data=replay_data, 99 | options=interface, 100 | disable_fog=False, 101 | observed_player_id=observed_player_id 102 | ) 103 | 104 | def _default_if_none(value, default): 105 | return default if value is None else value 106 | 107 | self._score_index = _default_if_none(self._default_score_index, map_inst.score_index) 108 | self._score_multiplier = _default_if_none(self._default_score_multiplier, map_inst.score_multiplier) 109 | self._episode_length = _default_if_none(self._default_episode_length, map_inst.game_steps_per_episode) 110 | 111 | map_path = map_path or self.replay_info.local_map_path 112 | if map_path: 113 | start_replay.map_data = self._run_config.map_data(map_path) 114 | 115 | self._controller.start_replay(start_replay) 116 | 117 | def _finalize( 118 | self, 119 | agent_interface_format: sc2_env.AgentInterfaceFormat, 120 | ) -> None: 121 | self._features = features.features_from_game_info( 122 | game_info=self._controller.game_info(), 123 | agent_interface_format=agent_interface_format 124 | ) 125 | 126 | self._state = environment.StepType.FIRST 127 | 128 | self._episode_steps = 0 129 | 130 | logging.info('Replay environment is ready for replay: %s', self._replay_path) 131 | 132 | def iterator(self) -> Iterator[Tuple[environment.TimeStep, dict]]: 133 | # returns ((s, r, d, \gamma), a) samples 134 | curr_time_step, _ = self.observe() 135 | while curr_time_step.step_type != environment.StepType.LAST: 136 | next_time_step, action = self.next(1) 137 | yield curr_time_step, action 138 | curr_time_step = next_time_step 139 | yield curr_time_step, self.action_space.no_op() 140 | 141 | def next(self, step_mul: int) -> Tuple[environment.TimeStep, dict]: 142 | if step_mul <= 0: 143 | raise ValueError(f"expect step_mul > 0, got {step_mul}") 144 | 145 | if self._state == environment.StepType.LAST: 146 | raise RuntimeError("Replay already ended.") 147 | 148 | self._state = environment.StepType.MID 149 | 150 | self._controller.step(step_mul) 151 | 152 | observation = self._observe(step_mul) 153 | 154 | return observation 155 | 156 | def observe(self) -> Tuple[environment.TimeStep, dict]: 157 | return self._observe() 158 | 159 | def _observe(self, step_mul=1) -> Tuple[environment.TimeStep, dict]: 160 | raw_observation = self._controller.observe() 161 | 162 | actions = [] 163 | try: 164 | actions = [self._features.reverse_action(action) for action in raw_observation.actions] 165 | except ValueError as e: 166 | logging.warning(f"Failed to reverse_action: {e}") 167 | actions = actions if len(actions) > 0 else [FunctionCall(FUNCTIONS["no_op"].id, [])] 168 | 169 | observation = self._features.transform_obs(raw_observation) 170 | 171 | self._episode_steps = observation['game_loop'][0] 172 | 173 | outcome = 0 174 | discount = self._discount 175 | episode_complete = raw_observation.player_result 176 | if episode_complete: 177 | self._state = environment.StepType.LAST 178 | discount = 0 179 | player_id = raw_observation.observation.player_common.player_id 180 | for result in raw_observation.player_result: 181 | if result.player_id == player_id: 182 | outcome = possible_results.get(result.result, 0) 183 | 184 | if self._score_index >= 0: # Game score, not win/loss reward. 185 | cur_score = observation["score_cumulative"][self._score_index] 186 | if self._episode_steps == 0: # First reward is always 0. 187 | reward = 0 188 | else: 189 | reward = max(0, cur_score - self._last_score) 190 | self._last_score = cur_score 191 | else: 192 | reward = outcome 193 | 194 | observation = self.observation_space.transform_back(observation) 195 | action = self.action_space.transform_back(actions[0], step_mul) 196 | 197 | time_step = environment.TimeStep( 198 | step_type=self._state, reward=reward * self._score_multiplier, discount=discount, observation=observation 199 | ) 200 | 201 | return time_step, action 202 | 203 | @property 204 | def state(self): 205 | return self._state 206 | 207 | def close(self) -> None: 208 | if self._controller is not None: 209 | self._controller.quit() 210 | self._controller = None 211 | 212 | if self._sc2_proc is not None: 213 | self._sc2_proc.close() 214 | self._sc2_proc = None 215 | 216 | def __enter__(self): 217 | return self 218 | 219 | def __exit__(self, exc_type, exc_val, exc_tb): 220 | self.close() 221 | -------------------------------------------------------------------------------- /scripts/play_agent_vs_human.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import platform 4 | import sys 5 | import time 6 | import traceback 7 | from typing import Type, Tuple, Optional 8 | 9 | import gin 10 | import portpicker 11 | import tensorflow as tf 12 | from absl import app 13 | from absl import flags 14 | from pysc2 import run_configs, maps 15 | from pysc2.env import sc2_env, lan_sc2_env 16 | from pysc2.lib import renderer_human 17 | from pysc2.lib.protocol import Status 18 | from s2clientprotocol import sc2api_pb2 as sc_pb 19 | 20 | from sc2_imitation_learning.common.utils import gin_register_external_configurables, make_dummy_action 21 | from sc2_imitation_learning.environment.sc2_environment import SC2LanEnv 22 | 23 | logging.basicConfig(level=logging.WARNING) 24 | logger = logging.getLogger(__name__) 25 | 26 | flags.DEFINE_string('agent_dir', default=None, help='Path to the directory where the agent is stored.') 27 | flags.DEFINE_multi_string('gin_file', ['configs/1v1/play_agent_vs_human.gin'], 'List of paths to Gin config files.') 28 | flags.DEFINE_multi_string('gin_param', None, 'List of Gin parameter bindings.') 29 | flags.DEFINE_bool('human', False, 'Human.') 30 | 31 | FLAGS = flags.FLAGS 32 | 33 | gin_register_external_configurables() 34 | 35 | 36 | @gin.configurable 37 | def human(map_name: str = gin.REQUIRED, 38 | render: Optional[bool] = None, 39 | host: str = '127.0.0.1', 40 | config_port: int = 14380, 41 | remote: Optional[str] = None, 42 | realtime: bool = False, 43 | fps: float = 22.4, 44 | rgb_screen_size: Optional[Tuple[int, int]] = None, 45 | rgb_minimap_size: Optional[Tuple[int, int]] = None, 46 | feature_screen_size: Optional[Tuple[int, int]] = None, 47 | feature_minimap_size: Optional[Tuple[int, int]] = None, 48 | race: str = 'zerg', 49 | player_name: str = ''): 50 | if render is None: 51 | render = platform.system() == "Linux" 52 | 53 | run_config = run_configs.get() 54 | map_inst = maps.get(map_name) 55 | 56 | ports = [config_port + p for p in range(5)] # tcp + 2 * num_players 57 | if not all(portpicker.is_port_free(p) for p in ports): 58 | sys.exit("Need 5 free ports after the config port.") 59 | 60 | proc = None 61 | ssh_proc = None 62 | tcp_conn = None 63 | udp_sock = None 64 | try: 65 | proc = run_config.start(extra_ports=ports[1:], timeout_seconds=300, host=host, window_loc=(50, 50)) 66 | 67 | tcp_port = ports[0] 68 | settings = { 69 | "remote": remote, 70 | "game_version": proc.version.game_version, 71 | "realtime": realtime, 72 | "map_name": map_inst.name, 73 | "map_path": map_inst.path, 74 | "map_data": map_inst.data(run_config), 75 | "ports": { 76 | "server": {"game": ports[1], "base": ports[2]}, 77 | "client": {"game": ports[3], "base": ports[4]}, 78 | } 79 | } 80 | 81 | create = sc_pb.RequestCreateGame( 82 | realtime=settings["realtime"], 83 | local_map=sc_pb.LocalMap(map_path=settings["map_path"])) 84 | create.player_setup.add(type=sc_pb.Participant) 85 | create.player_setup.add(type=sc_pb.Participant) 86 | 87 | controller = proc.controller 88 | controller.save_map(settings["map_path"], settings["map_data"]) 89 | controller.create_game(create) 90 | 91 | if remote is not None: 92 | ssh_proc = lan_sc2_env.forward_ports( 93 | remote, proc.host, [settings["ports"]["client"]["base"]], 94 | [tcp_port, settings["ports"]["server"]["base"]]) 95 | 96 | tcp_conn = lan_sc2_env.tcp_server(lan_sc2_env.Addr(proc.host, tcp_port), settings) 97 | 98 | if remote is not None: 99 | udp_sock = lan_sc2_env.udp_server( 100 | lan_sc2_env.Addr(proc.host, settings["ports"]["client"]["game"])) 101 | 102 | lan_sc2_env.daemon_thread( 103 | lan_sc2_env.tcp_to_udp, 104 | (tcp_conn, udp_sock, lan_sc2_env.Addr(proc.host, settings["ports"]["server"]["game"]))) 105 | 106 | lan_sc2_env.daemon_thread(lan_sc2_env.udp_to_tcp, (udp_sock, tcp_conn)) 107 | 108 | join = sc_pb.RequestJoinGame() 109 | join.shared_port = 0 # unused 110 | join.server_ports.game_port = settings["ports"]["server"]["game"] 111 | join.server_ports.base_port = settings["ports"]["server"]["base"] 112 | join.client_ports.add(game_port=settings["ports"]["client"]["game"], 113 | base_port=settings["ports"]["client"]["base"]) 114 | 115 | # join.observed_player_id = 2 116 | join.race = sc2_env.Race[race] 117 | join.player_name = player_name 118 | if render: 119 | join.options.raw = True 120 | join.options.score = True 121 | join.options.raw_affects_selection = True 122 | join.options.raw_crop_to_playable_area = True 123 | join.options.show_cloaked = True 124 | join.options.show_burrowed_shadows = True 125 | join.options.show_placeholders = True 126 | if feature_screen_size and feature_minimap_size: 127 | fl = join.options.feature_layer 128 | fl.width = 24 129 | fl.resolution.x = feature_screen_size[0] 130 | fl.resolution.y = feature_screen_size[1] 131 | fl.minimap_resolution.x = feature_minimap_size[0] 132 | fl.minimap_resolution.y = feature_minimap_size[1] 133 | if rgb_screen_size and rgb_minimap_size: 134 | join.options.render.resolution.x = rgb_screen_size[0] 135 | join.options.render.resolution.y = rgb_screen_size[1] 136 | join.options.render.minimap_resolution.x = rgb_minimap_size[0] 137 | join.options.render.minimap_resolution.y = rgb_minimap_size[1] 138 | controller.join_game(join) 139 | 140 | if render: 141 | renderer = renderer_human.RendererHuman(fps=fps, render_feature_grid=False) 142 | while controller.status == Status.init_game: 143 | print("Waiting in status = Status.init_game...") 144 | time.sleep(1) 145 | renderer.run(run_configs.get(), controller, max_episodes=1) 146 | else: # Still step forward so the Mac/Windows renderer works. 147 | while True: 148 | frame_start_time = time.time() 149 | if not realtime: 150 | controller.step() 151 | obs = controller.observe() 152 | if obs.player_result: 153 | break 154 | time.sleep(max(0, frame_start_time - time.time() + 1 / fps)) 155 | except KeyboardInterrupt: 156 | pass 157 | finally: 158 | if tcp_conn: 159 | tcp_conn.close() 160 | if proc: 161 | proc.close() 162 | if udp_sock: 163 | udp_sock.close() 164 | if ssh_proc: 165 | ssh_proc.terminate() 166 | for _ in range(5): 167 | if ssh_proc.poll() is not None: 168 | break 169 | time.sleep(1) 170 | if ssh_proc.poll() is None: 171 | ssh_proc.kill() 172 | ssh_proc.wait() 173 | 174 | 175 | @gin.configurable 176 | def agent(env_fn: Type[SC2LanEnv] = gin.REQUIRED) -> None: 177 | agent = tf.saved_model.load(FLAGS.agent_dir) 178 | agent_state = agent.initial_state(1) 179 | 180 | env = env_fn() 181 | env.launch() 182 | episode_reward = 0. 183 | episode_frames = 0 184 | episode_steps = 0 185 | try: 186 | reward, done, observation = 0., False, env.reset() 187 | action = make_dummy_action(env.action_space, num_batch_dims=1) 188 | while not done: 189 | env_outputs = ( 190 | tf.constant([reward], dtype=tf.float32), 191 | tf.constant([episode_steps == 0], dtype=tf.bool), 192 | tf.nest.map_structure(lambda o: tf.constant([o], dtype=tf.dtypes.as_dtype(o.dtype)), observation)) 193 | agent_output, agent_state = agent(action, env_outputs, agent_state) 194 | action = tf.nest.map_structure(lambda t: t.numpy(), agent_output.actions) 195 | reward, _, done, observation = env.step(action) 196 | episode_reward += reward 197 | episode_frames += action['step_mul'] + 1 198 | episode_steps += 1 199 | except Exception as e: 200 | logger.error(f"Failed to play episode(stacktrace below).") 201 | traceback.print_exc() 202 | finally: 203 | logger.info(f"Episode completed: total reward={episode_reward}, frames={episode_frames}, " 204 | f"steps={episode_steps}") 205 | env.close() 206 | 207 | 208 | def main(_): 209 | gin.parse_config_files_and_bindings(FLAGS.gin_file, FLAGS.gin_param) 210 | if FLAGS.human: 211 | human() 212 | else: 213 | agent() 214 | 215 | 216 | if __name__ == '__main__': 217 | os.environ['TF_FORCE_GPU_ALLOW_GROWTH'] = 'true' 218 | app.run(main) 219 | -------------------------------------------------------------------------------- /tests/test_policy.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | 4 | from sc2_imitation_learning.agents.common.policy_head import ActionTypePolicyHead, PolicyContextFeatures, ScalarPolicyHead, \ 5 | SpatialPolicyHead, UnitGroupPointerPolicyHead, ActionEmbedding, ActionArgumentMask, AutoregressivePolicyHead 6 | from sc2_imitation_learning.agents.common.spatial_decoder import ResSpatialDecoder 7 | from sc2_imitation_learning.common.conv import ConvNet2DTranspose 8 | from sc2_imitation_learning.common.mlp import ResMLP, MLP 9 | from sc2_imitation_learning.environment.sc2_environment import SC2ActionSpace, SC2InterfaceConfig 10 | 11 | 12 | class Test(tf.test.TestCase): 13 | def test_action_type_policy_head(self): 14 | policy = ActionTypePolicyHead( 15 | num_actions=16, 16 | decoder=ResMLP( 17 | output_size=256, num_blocks=16, with_projection=True, with_layer_norm=True, activate_final=True)) 18 | inputs = tf.constant(np.random.rand(1, 256)) 19 | context = PolicyContextFeatures( 20 | scalar_context=tf.constant(np.random.rand(1, 256)), 21 | unit_groups={}, # unused 22 | available_actions=tf.constant([([False]*15) + [True]], dtype=tf.bool), # only action index 15 available 23 | map_skip={} # unused 24 | ) 25 | action, logits = policy(inputs, context) 26 | 27 | self.assertTrue(tf.squeeze(action) == 15) 28 | self.assertEqual(action.shape.as_list(), [1]) 29 | 30 | self.assertIn(logits.dtype, [tf.float32, tf.float64]) 31 | self.assertEqual(logits.shape.as_list(), [1, 16]) 32 | self.assertEqual(tf.reduce_any(tf.math.is_inf(logits)), False) 33 | self.assertEqual(tf.reduce_any(tf.math.is_nan(logits)), False) 34 | 35 | def test_scalar_policy_head(self): 36 | policy = ScalarPolicyHead( 37 | num_actions=16, 38 | decoder=MLP(output_sizes=[256, 256], with_layer_norm=False, activate_final=True)) 39 | inputs = tf.constant(np.random.rand(1, 256)) 40 | context = PolicyContextFeatures( 41 | scalar_context=tf.constant(np.random.rand(1, 256)), 42 | unit_groups={}, # unused 43 | available_actions=tf.constant([[False]*16], dtype=tf.bool), # unused 44 | map_skip={} # unused 45 | ) 46 | action, logits = policy(inputs, context) 47 | 48 | self.assertIn(action.dtype, [tf.int32, tf.int64]) 49 | self.assertEqual(action.shape.as_list(), [1]) 50 | self.assertTrue(0 <= tf.squeeze(action) < 16) 51 | 52 | self.assertIn(logits.dtype, [tf.float32, tf.float64]) 53 | self.assertEqual(logits.shape.as_list(), [1, 16]) 54 | self.assertEqual(tf.reduce_any(tf.math.is_inf(logits)), False) 55 | self.assertEqual(tf.reduce_any(tf.math.is_nan(logits)), False) 56 | 57 | def test_spatial_policy_head(self): 58 | policy = SpatialPolicyHead( 59 | num_actions=64*64, 60 | decoder=ResSpatialDecoder(out_channels=64, num_blocks=4), 61 | upsample_conv_net=ConvNet2DTranspose( 62 | output_channels=[32, 16, 16], 63 | kernel_shapes=[4, 4, 4], 64 | strides=[2, 2, 2], 65 | paddings=['SAME', 'SAME', 'SAME'] 66 | ), 67 | map_skip='screen') 68 | inputs = tf.constant(np.random.rand(1, 64)) 69 | context = PolicyContextFeatures( 70 | scalar_context=tf.constant(np.random.rand(1, 256)), 71 | unit_groups={}, # unused 72 | available_actions=tf.constant([[True]], dtype=tf.bool), # unused 73 | map_skip={ 74 | 'screen': [tf.constant(np.random.rand(1, 8, 8, 64))] 75 | } 76 | ) 77 | action, logits = policy(inputs, context) 78 | 79 | self.assertIn(action.dtype, [tf.int32, tf.int64]) 80 | self.assertEqual(action.shape.as_list(), [1]) 81 | self.assertTrue(0 <= tf.squeeze(action) < 64*64) 82 | 83 | self.assertIn(logits.dtype, [tf.float32, tf.float64]) 84 | self.assertEqual(logits.shape.as_list(), [1, 64*64]) 85 | self.assertEqual(tf.reduce_any(tf.math.is_inf(logits)), False) 86 | self.assertEqual(tf.reduce_any(tf.math.is_nan(logits)), False) 87 | 88 | def test_unit_group_pointer_policy_head(self): 89 | policy = UnitGroupPointerPolicyHead( 90 | num_actions=64*64, 91 | query_embedding_output_sizes=[256, 16], 92 | key_embedding_output_sizes=[16], 93 | target_group='multi_select', 94 | mask_zeros=True) 95 | inputs = tf.constant(np.random.randn(1, 64)) 96 | context = PolicyContextFeatures( 97 | scalar_context=tf.constant(np.random.randn(1, 256)), 98 | unit_groups={ 99 | 'multi_select': tf.constant(np.concatenate([ 100 | np.random.randn(1, 3, 32), 101 | np.zeros(shape=(1, 1, 32), dtype=np.float32), 102 | ], axis=1)) 103 | }, 104 | available_actions=tf.constant([[True]], dtype=tf.bool), # unused 105 | map_skip={} # unused 106 | ) 107 | action, logits = policy(inputs, context) 108 | 109 | self.assertIn(action.dtype, [tf.int32, tf.int64]) 110 | self.assertEqual(action.shape.as_list(), [1]) 111 | self.assertTrue(0 <= tf.squeeze(action) < 3) 112 | 113 | self.assertIn(logits.dtype, [tf.float32, tf.float64]) 114 | self.assertEqual(logits.shape.as_list(), [1, 4]) 115 | self.assertAllClose(logits[:, -1], [logits.dtype.min]) # all-zero entities should have masked logits 116 | self.assertEqual(tf.reduce_any(tf.math.is_inf(logits)), False) 117 | self.assertEqual(tf.reduce_any(tf.math.is_nan(logits)), False) 118 | 119 | def test_action_embedding(self): 120 | embed = ActionEmbedding(num_actions=16, output_sizes=[256, 64], with_layer_norm=False) 121 | embedded_action = embed(tf.constant([-1, 0, 12, 15])) 122 | 123 | self.assertIn(embedded_action.dtype, [tf.float32, tf.float64]) 124 | self.assertEqual(embedded_action.shape.as_list(), [4, 64]) 125 | self.assertEqual(tf.reduce_any(tf.math.is_inf(embedded_action)), False) 126 | self.assertEqual(tf.reduce_any(tf.math.is_nan(embedded_action)), False) 127 | 128 | def test_action_argument_mask(self): 129 | mask = ActionArgumentMask(argument_name='minimap', action_mask_value=-1) 130 | masked_action = mask({'action_type': tf.constant([0, 13])}, tf.constant([1, 2])) 131 | 132 | self.assertIn(masked_action.dtype, [tf.int32, tf.int64]) 133 | self.assertEqual(masked_action.shape.as_list(), [2]) 134 | self.assertAllEqual(masked_action, [-1, 2]) 135 | 136 | 137 | def test_autoregressive_policy_head(self): 138 | ar_policy = AutoregressivePolicyHead( 139 | action_space=SC2ActionSpace(SC2InterfaceConfig()), 140 | action_name='select_add', 141 | policy_head=lambda a: ScalarPolicyHead( 142 | num_actions=a, decoder=MLP(output_sizes=[256, 256], with_layer_norm=False, activate_final=True)), 143 | action_embed=lambda a: ActionEmbedding( 144 | num_actions=a, output_sizes=[256, 64], with_layer_norm=False), 145 | action_mask=lambda a, b: ActionArgumentMask(argument_name=a, action_mask_value=b), 146 | action_mask_value=-1 147 | ) 148 | ar_embedding = tf.constant(np.random.randn(2, 64), dtype=tf.float32) 149 | (action, logits), updated_ar_embedding = ar_policy( 150 | autoregressive_embedding=ar_embedding, 151 | context=PolicyContextFeatures( 152 | scalar_context=tf.constant(np.random.randn(2, 256), dtype=tf.float32), 153 | unit_groups={}, # unused 154 | available_actions=tf.constant([[True], [True]], dtype=tf.bool), # unused 155 | map_skip={} # unused 156 | ), 157 | partial_action={ 158 | 'action_type': tf.constant([0, 3]) 159 | } 160 | ) 161 | 162 | self.assertIn(action.dtype, [tf.int32, tf.int64]) 163 | self.assertEqual(action.shape.as_list(), [2]) 164 | self.assertEqual(action[0], -1) # action_type = 0 (no_op) does NOT require a select_add argument 165 | self.assertNotEqual(action[1], -1) # action_type = 3 (select_rect) does require a select_add argument 166 | 167 | self.assertIn(logits.dtype, [tf.float32, tf.float64]) 168 | self.assertEqual(logits.shape.as_list(), [2, 2]) 169 | self.assertEqual(tf.reduce_any(tf.math.is_inf(logits)), False) 170 | self.assertEqual(tf.reduce_any(tf.math.is_nan(logits)), False) 171 | 172 | self.assertIn(updated_ar_embedding.dtype, [tf.float32, tf.float64]) 173 | self.assertEqual(updated_ar_embedding.shape.as_list(), [2, 64]) 174 | self.assertEqual(tf.reduce_any(tf.math.is_inf(updated_ar_embedding)), False) 175 | self.assertEqual(tf.reduce_any(tf.math.is_nan(updated_ar_embedding)), False) 176 | self.assertAllClose(ar_embedding[0], updated_ar_embedding[0]) # masked actions should not update the embedding 177 | self.assertNotAllClose(ar_embedding[1], updated_ar_embedding[1]) # unmasked actions should update the embedding 178 | 179 | 180 | -------------------------------------------------------------------------------- /tests/test_sc2_environment.py: -------------------------------------------------------------------------------- 1 | from unittest import TestCase 2 | 3 | import gym 4 | import numpy as np 5 | from pysc2.lib.static_data import UNIT_TYPES, UPGRADES 6 | 7 | from sc2_imitation_learning.environment.sc2_environment import get_scalar_feature, SC2InterfaceConfig, Feature, \ 8 | get_screen_feature, get_minimap_feature, LookupFeature, AvailableActionsFeature, PaddedSequenceFeature, \ 9 | SequenceLengthFeature, UnitCountsFeature, UpgradesFeature 10 | 11 | 12 | class SC2EnvironmentTest(TestCase): 13 | def _sc2_interface_config(self): 14 | return SC2InterfaceConfig( 15 | dimension_screen=(64, 64), 16 | dimension_minimap=(64, 64), 17 | screen_features=('visibility_map', 'player_relative', 'unit_type', 'selected', 'unit_hit_points_ratio', 18 | 'unit_energy_ratio', 'unit_density_aa'), 19 | minimap_features=('camera', 'player_relative', 'alerts'), 20 | scalar_features=('game_loop', 'available_actions', 'player'), 21 | available_actions=None, 22 | upgrade_set=None, 23 | max_step_mul=16, 24 | max_multi_select=64, 25 | max_cargo=8, 26 | max_build_queue=8, 27 | max_production_queue=16, 28 | ) 29 | 30 | def test_get_scalar_feature(self): 31 | interface_config = self._sc2_interface_config() 32 | player_feature = get_scalar_feature('player', interface_config) 33 | self.assertIsInstance(player_feature, Feature) 34 | with self.assertRaises(Exception): 35 | _ = get_scalar_feature('non_existing_feature', interface_config) 36 | 37 | def test_get_screen_feature(self): 38 | interface_config = self._sc2_interface_config() 39 | player_feature = get_screen_feature('visibility_map', interface_config) 40 | self.assertIsInstance(player_feature, Feature) 41 | with self.assertRaises(Exception): 42 | _ = get_screen_feature('non_existing_feature', interface_config) 43 | 44 | def test_get_minimap_feature(self): 45 | interface_config = self._sc2_interface_config() 46 | player_feature = get_minimap_feature('camera', interface_config) 47 | self.assertIsInstance(player_feature, Feature) 48 | with self.assertRaises(Exception): 49 | _ = get_minimap_feature('non_existing_feature', interface_config) 50 | 51 | def test_lookup_feature(self): 52 | obs = { 53 | 'a': np.random.uniform(1.0, 2.0, (2, 2)).astype(np.float64), 54 | 'b': np.random.uniform(0.0, 1.0, (2, 2)).astype(np.float64) 55 | } 56 | feature = LookupFeature(obs_key='b', low=0.0, high=1.0, shape=(2, 2), dtype=np.float32) 57 | 58 | self.assertEqual(feature.spec(), gym.spaces.Box(low=0.0, high=1.0, shape=(2, 2), dtype=np.float32)) 59 | 60 | extracted = feature.extract(obs) 61 | 62 | self.assertTrue(extracted.dtype == np.float32) 63 | 64 | def test_available_actions_feature(self): 65 | feature = AvailableActionsFeature(max_num_actions=3) 66 | self.assertEqual(feature.spec(), gym.spaces.Box(low=0, high=1, shape=(3,), dtype=np.uint16)) 67 | 68 | obs = {'available_actions': np.asarray([1, 2])} 69 | extracted = feature.extract(obs) 70 | self.assertTrue(extracted.dtype == np.uint16) 71 | self.assertTrue(np.allclose(extracted, [0, 1, 1])) 72 | 73 | obs = {'available_actions': np.asarray([])} 74 | with self.assertRaises(IndexError): 75 | feature.extract(obs) 76 | 77 | obs = {'available_actions': np.asarray([0, 1, 2])} 78 | extracted = feature.extract(obs) 79 | self.assertTrue(extracted.dtype == np.uint16) 80 | self.assertTrue(np.allclose(extracted, [1, 1, 1])) 81 | 82 | obs = {'available_actions': np.asarray([4])} 83 | with self.assertRaises(IndexError): 84 | feature.extract(obs) 85 | 86 | def test_padded_sequence_feature(self): 87 | feature = PaddedSequenceFeature( 88 | obs_key='a', max_length=2, feature_shape=(), low=0, high=np.iinfo(np.uint16).max, dtype=np.uint16) 89 | self.assertEqual(feature.spec(), gym.spaces.Box( 90 | low=0, high=np.iinfo(np.uint16).max, shape=(2,), dtype=np.uint16)) 91 | 92 | obs = {'a': np.arange(2)} 93 | extracted = feature.extract(obs) 94 | self.assertEqual(extracted.dtype, np.uint16) 95 | self.assertTrue(np.allclose(extracted, np.arange(2))) 96 | 97 | obs = {'a': np.arange(1)} 98 | extracted = feature.extract(obs) 99 | self.assertEqual(extracted.dtype, np.uint16) 100 | self.assertTrue(np.allclose(extracted, np.zeros((2,)))) 101 | 102 | obs = {'a': np.arange(3)} 103 | extracted = feature.extract(obs) 104 | self.assertEqual(extracted.dtype, np.uint16) 105 | self.assertTrue(np.allclose(extracted, np.arange(2))) 106 | 107 | def test_sequence_length_feature(self): 108 | feature = SequenceLengthFeature(obs_key='a', max_length=2) 109 | self.assertEqual(feature.spec(), gym.spaces.Box(low=0, high=2, shape=(1,), dtype=np.uint16)) 110 | 111 | obs = {'a': np.arange(0)} 112 | extracted = feature.extract(obs) 113 | self.assertEqual(extracted.dtype, np.uint16) 114 | self.assertEqual(extracted.shape, (1,)) 115 | self.assertEqual(extracted.squeeze(), 0) 116 | 117 | obs = {'a': np.arange(1)} 118 | extracted = feature.extract(obs) 119 | self.assertEqual(extracted.dtype, np.uint16) 120 | self.assertEqual(extracted.shape, (1,)) 121 | self.assertEqual(extracted.squeeze(), 1) 122 | 123 | obs = {'a': np.arange(2)} 124 | extracted = feature.extract(obs) 125 | self.assertEqual(extracted.dtype, np.uint16) 126 | self.assertEqual(extracted.shape, (1,)) 127 | self.assertEqual(extracted.squeeze(), 2) 128 | 129 | obs = {'a': np.arange(3)} 130 | extracted = feature.extract(obs) 131 | self.assertEqual(extracted.dtype, np.uint16) 132 | self.assertEqual(extracted.shape, (1,)) 133 | self.assertEqual(extracted.squeeze(), 2) 134 | 135 | def test_unit_counts_feature(self): 136 | feature = UnitCountsFeature() 137 | self.assertEqual( 138 | feature.spec(), 139 | gym.spaces.Box(low=0, high=np.iinfo(np.uint16).max, shape=(len(UNIT_TYPES) + 1,), dtype=np.uint16)) 140 | 141 | # default case 142 | obs = {'unit_counts': np.stack([ 143 | np.asarray([UNIT_TYPES[0], UNIT_TYPES[3]]), 144 | np.asarray([1, 1]) 145 | ], axis=-1)} 146 | extracted = feature.extract(obs) 147 | expected = np.zeros((len(UNIT_TYPES) + 1,), dtype=np.uint16) 148 | expected[1] = 1 149 | expected[4] = 1 150 | self.assertEqual(extracted.dtype, np.uint16) 151 | self.assertEqual(extracted.shape, (len(UNIT_TYPES) + 1,)) 152 | self.assertTrue(np.allclose(extracted, expected)) 153 | 154 | # unknown unit id, < max(UNIT_TYPES) 155 | obs = {'unit_counts': np.stack([ 156 | np.asarray([UNIT_TYPES[0], next(i for i in range(len(UNIT_TYPES)) if i not in UNIT_TYPES)]), 157 | np.asarray([1, 1]) 158 | ], axis=-1)} 159 | extracted = feature.extract(obs) 160 | expected = np.zeros((len(UNIT_TYPES) + 1,), dtype=np.uint16) 161 | expected[1] = 1 162 | expected[0] = 1 163 | self.assertEqual(extracted.dtype, np.uint16) 164 | self.assertEqual(extracted.shape, (len(UNIT_TYPES) + 1,)) 165 | self.assertTrue(np.allclose(extracted, expected)) 166 | 167 | # unknown unit id, > max(UNIT_TYPES) 168 | obs = {'unit_counts': np.stack([ 169 | np.asarray([UNIT_TYPES[0], max(UNIT_TYPES) + 1]), 170 | np.asarray([1, 1]) 171 | ], axis=-1)} 172 | with self.assertRaises(IndexError): 173 | feature.extract(obs) 174 | 175 | def test_upgrades_feature(self): 176 | feature = UpgradesFeature() 177 | self.assertEqual(feature.spec(), gym.spaces.Box(low=False, high=True, shape=(len(UPGRADES),), dtype=np.bool)) 178 | 179 | # default case 180 | obs = {'upgrades': np.asarray([feature._upgrade_set[0], feature._upgrade_set[4]])} 181 | extracted = feature.extract(obs) 182 | expected = np.zeros((len(UPGRADES),), dtype=np.bool) 183 | expected[0] = True 184 | expected[4] = True 185 | self.assertEqual(extracted.dtype, np.bool) 186 | self.assertEqual(extracted.shape, (len(UPGRADES),)) 187 | self.assertTrue(np.allclose(extracted, expected)) 188 | 189 | # unknown upgrade 190 | obs = {'upgrades': np.asarray([feature._upgrade_set[0], max(UPGRADES) + 1])} 191 | extracted = feature.extract(obs) 192 | expected = np.zeros((len(UPGRADES),), dtype=np.bool) 193 | expected[0] = True 194 | self.assertEqual(extracted.dtype, np.bool) 195 | self.assertEqual(extracted.shape, (len(UPGRADES),)) 196 | self.assertTrue(np.allclose(extracted, expected)) 197 | 198 | feature = UpgradesFeature([UPGRADES[0], UPGRADES[4]]) 199 | 200 | # custom upgrade set, unknown upgrade 201 | obs = {'upgrades': np.asarray([UPGRADES[0], UPGRADES[1]])} 202 | extracted = feature.extract(obs) 203 | expected = np.zeros((2,), dtype=np.bool) 204 | expected[0] = True 205 | self.assertEqual(extracted.dtype, np.bool) 206 | self.assertEqual(extracted.shape, (2,)) 207 | self.assertTrue(np.allclose(extracted, expected)) 208 | -------------------------------------------------------------------------------- /sc2_imitation_learning/common/utils.py: -------------------------------------------------------------------------------- 1 | import collections 2 | import functools 3 | import json 4 | import logging 5 | import re 6 | import traceback 7 | from typing import Iterable, Callable, Type, Union 8 | 9 | import gin 10 | import numpy as np 11 | import scipy.stats 12 | import tensorflow as tf 13 | import tensorflow_addons as tfa 14 | import yaml 15 | 16 | from sc2_imitation_learning.common.types import ShapeLike 17 | from sc2_imitation_learning.environment.environment import ObservationSpace, ActionSpace, Space 18 | 19 | 20 | def swap_leading_axes(tensor: tf.Tensor) -> tf.Tensor: 21 | return tf.transpose(tensor, perm=[1, 0] + list(range(2, len(tensor.get_shape())))) 22 | 23 | 24 | def prepend_leading_dims(shape: ShapeLike, leading_dims: ShapeLike) -> tf.TensorShape: 25 | return tf.TensorShape(leading_dims).concatenate(shape) 26 | 27 | 28 | def flatten_nested_dicts(d, parent_key='', sep='/'): 29 | items = [] 30 | for k, v in d.items(): 31 | new_key = parent_key + sep + k if parent_key else k 32 | if isinstance(v, collections.MutableMapping): 33 | items.extend(flatten_nested_dicts(v, new_key, sep=sep).items()) 34 | else: 35 | items.append((new_key, v)) 36 | return dict(items) 37 | 38 | 39 | def unflatten_nested_dicts(d, sep='/'): 40 | out_dict = {} 41 | for k, v in d.items(): 42 | dict_pointer = out_dict 43 | key_path = list(k.split(sep)) 44 | if len(key_path) > 1: 45 | for sub_k in key_path[:-1]: 46 | if sub_k not in dict_pointer: 47 | dict_pointer[sub_k] = {} 48 | dict_pointer = dict_pointer[sub_k] 49 | dict_pointer[key_path[-1]] = v 50 | return out_dict 51 | 52 | 53 | class Aggregator(tf.Module): 54 | """Utility module for keeping state and statistics for individual actors. 55 | Copied from: 56 | https://github.com/google-research/seed_rl/blob/f53c5be4ea083783fb10bdf26f11c3a80974fa03/common/utils.py""" 57 | 58 | def __init__(self, num_actors, specs, name='Aggregator'): 59 | """Inits an Aggregator. 60 | 61 | Args: 62 | num_actors: int, number of actors. 63 | specs: Structure (as defined by tf.nest) of tf.TensorSpecs that will be 64 | stored for each actor. 65 | name: Name of the scope for the operations. 66 | """ 67 | super(Aggregator, self).__init__(name=name) 68 | 69 | def create_variable(spec): 70 | z = tf.zeros([num_actors] + spec.shape.dims, dtype=spec.dtype) 71 | return tf.Variable(z, trainable=False, name=spec.name) 72 | 73 | self._state = tf.nest.map_structure(create_variable, specs) 74 | 75 | @tf.Module.with_name_scope 76 | def reset(self, actor_ids): 77 | """Fills the tensors for the given actors with zeros.""" 78 | with tf.name_scope('Aggregator_reset'): 79 | for s in tf.nest.flatten(self._state): 80 | s.scatter_update(tf.IndexedSlices(0, actor_ids)) 81 | 82 | @tf.Module.with_name_scope 83 | def add(self, actor_ids, values): 84 | """In-place adds values to the state associated to the given actors. 85 | 86 | Args: 87 | actor_ids: 1D tensor with the list of actor IDs we want to add values to. 88 | values: A structure of tensors following the input spec, with an added 89 | first dimension that must either have the same size as 'actor_ids', or 90 | should not exist (in which case, the value is broadcasted to all actor 91 | ids). 92 | """ 93 | tf.nest.assert_same_structure(values, self._state) 94 | for s, v in zip(tf.nest.flatten(self._state), tf.nest.flatten(values)): 95 | s.scatter_add(tf.IndexedSlices(v, actor_ids)) 96 | 97 | @tf.Module.with_name_scope 98 | def read(self, actor_ids): 99 | """Reads the values corresponding to a list of actors. 100 | 101 | Args: 102 | actor_ids: 1D tensor with the list of actor IDs we want to read. 103 | 104 | Returns: 105 | A structure of tensors with the same shapes as the input specs. A 106 | dimension is added in front of each tensor, with size equal to the number 107 | of actor_ids provided. 108 | """ 109 | return tf.nest.map_structure(lambda s: s.sparse_read(actor_ids), 110 | self._state) 111 | 112 | @tf.Module.with_name_scope 113 | def replace(self, actor_ids, values): 114 | """Replaces the state associated to the given actors. 115 | 116 | Args: 117 | actor_ids: 1D tensor with the list of actor IDs. 118 | values: A structure of tensors following the input spec, with an added 119 | first dimension that must either have the same size as 'actor_ids', or 120 | should not exist (in which case, the value is broadcasted to all actor 121 | ids). 122 | """ 123 | tf.nest.assert_same_structure(values, self._state) 124 | for s, v in zip(tf.nest.flatten(self._state), tf.nest.flatten(values)): 125 | s.scatter_update(tf.IndexedSlices(v, actor_ids)) 126 | 127 | 128 | def retry(max_tries: int, exceptions: Iterable[Type[Exception]] = (Exception,), exception_on_failure: bool = False): 129 | def wrapped(fn: Callable): 130 | @functools.wraps(fn) 131 | def _retry(*args, **kwargs): 132 | num_tries = 0 133 | while num_tries < max_tries: 134 | try: 135 | return fn(*args, **kwargs) 136 | except tuple(exceptions) as e: 137 | logging.warning(f"Failed to call '{fn.__name__}': {e}\n{traceback.format_exc()}") 138 | num_tries += 1 139 | logging.error(f"Failed to call '{fn.__name__}', retried {num_tries} of {max_tries} times.") 140 | if exception_on_failure: 141 | raise RuntimeError(f"Failed to call '{fn.__name__}', retried {num_tries} of {max_tries} times.") 142 | return None 143 | return _retry 144 | return wrapped 145 | 146 | 147 | def positional_encoding(max_position, embedding_size, add_batch_dim=False): 148 | positions = np.arange(max_position) 149 | angle_rates = 1 / np.power(10000, (2 * (np.arange(embedding_size)//2)) / np.float32(embedding_size)) 150 | angle_rads = positions[:, np.newaxis] * angle_rates[np.newaxis, :] 151 | 152 | angle_rads[:, 0::2] = np.sin(angle_rads[:, 0::2]) 153 | angle_rads[:, 1::2] = np.cos(angle_rads[:, 1::2]) 154 | 155 | if add_batch_dim: 156 | angle_rads = angle_rads[np.newaxis, ...] 157 | 158 | return tf.cast(angle_rads, dtype=tf.float32) 159 | 160 | 161 | def make_dummy(space: Space, num_batch_dims: int = 1): 162 | return tf.nest.map_structure(lambda s: tf.zeros((1,)*num_batch_dims + s.shape, s.dtype), space.specs) 163 | 164 | 165 | def make_dummy_action(action_space: ActionSpace, num_batch_dims: int = 1): 166 | return make_dummy(action_space, num_batch_dims) 167 | 168 | 169 | def make_dummy_observation(observation_space: ObservationSpace, num_batch_dims: int = 1): 170 | return make_dummy(observation_space, num_batch_dims) 171 | 172 | 173 | def make_dummy_batch(observation_space: ObservationSpace, action_space: ActionSpace, num_batch_dims: int = 2): 174 | prev_actions = make_dummy_action(action_space, num_batch_dims=num_batch_dims) 175 | rewards = tf.zeros((1,)*num_batch_dims, dtype=tf.float32) 176 | dones = tf.zeros((1,)*num_batch_dims, dtype=tf.bool) 177 | observations = make_dummy_observation(observation_space, num_batch_dims=num_batch_dims) 178 | return prev_actions, (rewards, dones, observations) 179 | 180 | 181 | def compute_stats_dict(samples: Union[list, np.ndarray]): 182 | conf_int = scipy.stats.t.interval(0.95, len(samples)-1, loc=np.mean(samples), scale=scipy.stats.sem(samples)) 183 | return { 184 | 'samples': [float(x) for x in samples], 185 | 'mean': float(np.mean(samples)), 186 | 'median': float(np.median(samples)), 187 | 'std': float(np.std(samples)), 188 | 'min': float(min(samples)), 189 | 'max': float(max(samples)), 190 | 'mean_ci_95': [float(x) for x in conf_int], 191 | } 192 | 193 | 194 | def gin_register_external_configurables(): 195 | gin.external_configurable(tf.nn.relu, 'tf.nn.relu') 196 | gin.external_configurable(tf.keras.layers.LSTMCell, 'tf.keras.layers.LSTMCell') 197 | gin.external_configurable(tf.keras.layers.StackedRNNCells, 'tf.keras.layers.StackedRNNCells') 198 | gin.external_configurable(tfa.rnn.LayerNormLSTMCell, 'tfa.rnn.LayerNormLSTMCell') 199 | gin.external_configurable(tf.keras.optimizers.Adam, 'tf.keras.optimizers.Adam') 200 | gin.external_configurable(tf.keras.optimizers.schedules.PolynomialDecay, 201 | 'tf.keras.optimizers.schedules.PolynomialDecay') 202 | gin.external_configurable(tf.keras.optimizers.schedules.ExponentialDecay, 203 | 'tf.keras.optimizers.schedules.ExponentialDecay') 204 | 205 | 206 | def gin_config_str_to_dict(gin_config_str: str) -> dict: 207 | gin_config_str = "\n".join([x for x in gin_config_str.split("\n") if not x.startswith("import")]) 208 | gin_config_str = re.compile(r"\\\n[^\S\r\n]+").sub(' ', gin_config_str) # collapse indented newlines to single line 209 | gin_config_str = gin_config_str.replace("@", "").replace(" = %", ": ").replace(" = ", ": ") 210 | gin_config_dict = yaml.safe_load(gin_config_str) 211 | return unflatten_nested_dicts(gin_config_dict) 212 | 213 | 214 | def load_json(file_path: str): 215 | with open(file_path) as f: 216 | return json.load(f) 217 | -------------------------------------------------------------------------------- /scripts/download_replays.py: -------------------------------------------------------------------------------- 1 | # Source: https://github.com/Blizzard/s2client-proto/blob/5b42959a40a45cca290ce427b5522a35c8a59178/samples/replay-api/download_replays.py 2 | # Lint as: python3 3 | """Download replay packs via Blizzard Game Data APIs.""" 4 | 5 | # pylint: disable=bad-indentation, line-too-long 6 | 7 | from __future__ import absolute_import 8 | from __future__ import division 9 | from __future__ import print_function 10 | 11 | import argparse 12 | import collections 13 | import itertools 14 | import json 15 | import logging 16 | import os 17 | import requests 18 | import shutil 19 | import subprocess 20 | import sys 21 | 22 | try: 23 | import mpyq 24 | except ImportError: 25 | logging.warning( 26 | 'Failed to import mpyq; version and corruption detection is disabled.') 27 | mpyq = None 28 | from six import print_ as print # To get access to `flush` in python 2. 29 | 30 | API_BASE_URL = 'https://us.api.blizzard.com' 31 | API_NAMESPACE = 's2-client-replays' 32 | 33 | 34 | class RequestError(Exception): 35 | pass 36 | 37 | 38 | def mkdirs(path): 39 | if not os.path.exists(path): 40 | os.makedirs(path) 41 | 42 | 43 | def print_part(*args): 44 | print(*args, end='', flush=True) 45 | 46 | 47 | class BnetAPI(object): 48 | """Represents a handle to the battle.net api.""" 49 | 50 | def __init__(self, key, secret): 51 | headers = {'Content-Type': 'application/json'} 52 | params = {'grant_type': 'client_credentials'} 53 | response = requests.post('https://us.battle.net/oauth/token', 54 | headers=headers, params=params, 55 | auth=requests.auth.HTTPBasicAuth(key, secret)) 56 | if response.status_code != requests.codes.ok: 57 | raise RequestError( 58 | 'Failed to get oauth access token. response={}'.format(response)) 59 | response = json.loads(response.text) 60 | if 'access_token' in response: 61 | self._token = response['access_token'] 62 | else: 63 | raise RequestError( 64 | 'Failed to get oauth access token. response={}'.format(response)) 65 | 66 | def get(self, url, params=None): 67 | """Make an autorized get request to the api by url.""" 68 | params = params or {} 69 | params['namespace'] = API_NAMESPACE, 70 | headers = {'Authorization': 'Bearer ' + self._token} 71 | response = requests.get(url, headers=headers, params=params) 72 | if response.status_code != requests.codes.ok: 73 | raise RequestError( 74 | 'Request to "{}" failed. response={}'.format(url, response)) 75 | response_json = json.loads(response.text) 76 | if response_json.get('status') == 'nok': 77 | raise RequestError( 78 | 'Request to "{}" failed. response={}'.format( 79 | url, response_json.get('reason'))) 80 | return response_json 81 | 82 | def url(self, path): 83 | return requests.compat.urljoin(API_BASE_URL, path) 84 | 85 | def get_base_url(self): 86 | return self.get(self.url('/data/sc2/archive_url/base_url'))['base_url'] 87 | 88 | def search_by_client_version(self, client_version): 89 | """Search for replay archives by version.""" 90 | meta_urls = [] 91 | for page in itertools.count(1): 92 | params = { 93 | 'client_version': client_version, 94 | '_pageSize': 100, 95 | '_page': page, 96 | } 97 | response = self.get(self.url('/data/sc2/search/archive'), params) 98 | for result in response['results']: 99 | assert result['data']['client_version'] == client_version 100 | meta_urls.append(result['key']['href']) 101 | if response['pageCount'] <= page: 102 | break 103 | return meta_urls 104 | 105 | 106 | def download(key, secret, version, replays_dir, download_dir, extract=False, 107 | remove=False, filter_version='keep'): 108 | """Download the replays for a specific vesion. Check help below.""" 109 | # Get OAuth token from us region 110 | api = BnetAPI(key, secret) 111 | 112 | # Get meta file infos for the give client version 113 | print('Searching replay packs with client version:', version) 114 | meta_file_urls = api.search_by_client_version(version) 115 | if len(meta_file_urls) == 0: 116 | sys.exit('No matching replay packs found for the client version!') 117 | 118 | # Download replay packs. 119 | download_base_url = api.get_base_url() 120 | print('Found {} replay packs'.format(len(meta_file_urls))) 121 | print('Downloading to:', download_dir) 122 | print('Extracting to:', replays_dir) 123 | mkdirs(download_dir) 124 | for i, meta_file_url in enumerate(sorted(meta_file_urls), 1): 125 | # Construct full url to download replay packs 126 | meta_file_info = api.get(meta_file_url) 127 | archive_url = requests.compat.urljoin(download_base_url, 128 | meta_file_info['path']) 129 | 130 | print_part('{}/{}: {} ... '.format(i, len(meta_file_urls), archive_url)) 131 | 132 | file_name = archive_url.split('/')[-1] 133 | file_path = os.path.join(download_dir, file_name) 134 | 135 | with requests.get(archive_url, stream=True) as response: 136 | content_length = int(response.headers['Content-Length']) 137 | print_part(content_length // 1024**2, 'Mb ... ') 138 | if (not os.path.exists(file_path) or 139 | os.path.getsize(file_path) != content_length): 140 | with open(file_path, 'wb') as f: 141 | shutil.copyfileobj(response.raw, f) 142 | print_part('downloaded') 143 | else: 144 | print_part('found') 145 | 146 | if extract: 147 | print_part(' ... extracting') 148 | if os.path.getsize(file_path) <= 22: # Size of an empty zip file. 149 | print_part(' ... zip file is empty') 150 | else: 151 | subprocess.call(['unzip', '-P', 'iagreetotheeula', '-u', '-o', 152 | '-q', '-d', replays_dir, file_path]) 153 | if remove: 154 | os.remove(file_path) 155 | print() 156 | 157 | if mpyq is not None and filter_version != 'keep': 158 | print('Filtering replays.') 159 | found_versions = collections.defaultdict(int) 160 | found_str = lambda: ', '.join('{}: {}'.format(v, c) 161 | for v, c in sorted(found_versions.items())) 162 | all_replays = [f for f in os.listdir(replays_dir) if f.endswith('.SC2Replay')] 163 | for i, file_name in enumerate(all_replays): 164 | if i % 100 == 0: 165 | print_part('\r{}/{}: {:.1f}%, found: {}'.format( 166 | i, len(all_replays), 100 * i / len(all_replays), found_str())) 167 | file_path = os.path.join(replays_dir, file_name) 168 | with open(file_path, 'rb') as fd: 169 | try: 170 | archive = mpyq.MPQArchive(fd).extract() 171 | metadata = json.loads( 172 | archive[b'replay.gamemetadata.json'].decode('utf-8')) 173 | except KeyboardInterrupt: 174 | raise 175 | except: # pylint: disable=bare-except 176 | found_versions['corrupt'] += 1 177 | os.remove(file_path) 178 | continue 179 | game_version = '.'.join(metadata['GameVersion'].split('.')[:-1]) 180 | found_versions[game_version] += 1 181 | if filter_version == 'sort': 182 | version_dir = os.path.join(replays_dir, game_version) 183 | if found_versions[game_version] == 1: # First one of this version. 184 | mkdirs(version_dir) 185 | os.rename(file_path, os.path.join(version_dir, file_name)) 186 | elif filter_version == 'delete': 187 | if game_version != version: 188 | os.remove(file_path) 189 | print('\nFound replays:', found_str()) 190 | 191 | 192 | def main(): 193 | parser = argparse.ArgumentParser() 194 | parser.add_argument('--key', required=True, help='Battle.net API key.') 195 | parser.add_argument('--secret', required=True, help='Battle.net API secret.') 196 | parser.add_argument('--version', required=True, 197 | help=('Download all replays from this StarCraft 2 game' 198 | 'version, eg: "4.8.3".')) 199 | parser.add_argument('--replays_dir', default='./replays', 200 | help='Where to save the replays.') 201 | parser.add_argument('--download_dir', default='./download', 202 | help='Where to save the zip files.') 203 | parser.add_argument('--extract', action='store_true', 204 | help='Whether to extract the zip files.') 205 | parser.add_argument('--remove', action='store_true', 206 | help='Whether to delete the zip files after extraction.') 207 | parser.add_argument('--filter_version', default='keep', 208 | choices=['keep', 'delete', 'sort'], 209 | help=("What to do with replays that don't match the " 210 | "requested version. Keep is fast, but does no " 211 | "filtering. Delete deletes any that don't match. " 212 | "Sort puts them in sub-directories based on " 213 | "their version.")) 214 | args = parser.parse_args() 215 | args_dict = dict(vars(args).items()) 216 | download(**args_dict) 217 | 218 | 219 | if __name__ == '__main__': 220 | main() 221 | -------------------------------------------------------------------------------- /sc2_imitation_learning/dataset/dataset.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import json 3 | import logging 4 | import multiprocessing 5 | import os 6 | import pickle 7 | import random 8 | from abc import ABC, abstractmethod 9 | from typing import List, Tuple, Iterator 10 | from typing import NamedTuple, Dict, Optional, Sequence 11 | 12 | import gym 13 | import h5py 14 | import numpy as np 15 | import tensorflow as tf 16 | import tqdm 17 | 18 | from sc2_imitation_learning.common.utils import unflatten_nested_dicts, flatten_nested_dicts, load_json 19 | from sc2_imitation_learning.environment.environment import ObservationSpace, ActionSpace 20 | 21 | logger = logging.getLogger(__name__) 22 | 23 | 24 | class ActionTimeStep(NamedTuple): 25 | observation: Dict 26 | action: Dict 27 | reward: float 28 | done: bool 29 | 30 | 31 | class EpisodeSlice(NamedTuple): 32 | episode_id: int 33 | episode_path: str 34 | start: int 35 | length: int 36 | wrap_at_end: bool 37 | 38 | 39 | class EpisodeIterator(Iterator): 40 | def __init__(self, episode_id: int, episode_path: str, episode_length: int, sequence_length: int, 41 | start_index: int = 0) -> None: 42 | super().__init__() 43 | self._episode_id = episode_id 44 | self._episode_path = episode_path 45 | self._episode_length = episode_length 46 | self._sequence_length = sequence_length 47 | self._index = start_index 48 | 49 | def __next__(self) -> EpisodeSlice: 50 | episode_slice = EpisodeSlice( 51 | episode_id=self._episode_id, episode_path=self._episode_path, start=self._index, 52 | length=self._sequence_length, wrap_at_end=True) 53 | self._index = (self._index + self._sequence_length) % self._episode_length 54 | return episode_slice 55 | 56 | 57 | class Batcher(object): 58 | def __init__(self, batch_size: int, sequence_length: int, max_queue_size: int, seed: Optional[int] = None) -> None: 59 | super().__init__() 60 | self._batch_size = batch_size 61 | self._sequence_length = sequence_length 62 | self._queue_out = multiprocessing.Queue(maxsize=max_queue_size) 63 | self._seed = seed 64 | 65 | def __call__(self, file_paths: Sequence[str]) -> Iterator[List[EpisodeSlice]]: 66 | process = multiprocessing.Process(target=self._run, args=(file_paths,), daemon=True) 67 | try: 68 | process.start() 69 | while True: 70 | yield self._queue_out.get() 71 | finally: 72 | process.terminate() 73 | 74 | def _run(self, file_paths: Sequence[str]): 75 | rng = random.Random(self._seed) 76 | episode_iterators = [] 77 | for i, path in enumerate(tqdm.tqdm(file_paths, total=len(file_paths))): 78 | # with open(path.replace('.hdf5', '.pkl'), mode='rb') as f: 79 | # meta = pickle.load(f) 80 | with open(path.replace('.hdf5', '.meta'), mode='r') as f: 81 | meta = json.load(f) 82 | if meta['episode_length'] >= self._sequence_length: 83 | episode_iterators.append(EpisodeIterator( 84 | i, path, meta['episode_length'], self._sequence_length, rng.randint(0, meta['episode_length'] - 1))) 85 | while True: 86 | batch_episodes = rng.sample(episode_iterators, k=self._batch_size) 87 | batch = [next(it) for it in batch_episodes] 88 | self._queue_out.put(batch) 89 | 90 | 91 | def h5py_dataset_iterator(g, prefix=None): 92 | for key in g.keys(): 93 | item = g[key] 94 | path = key 95 | if prefix is not None: 96 | path = f'{prefix}/{path}' 97 | if isinstance(item, h5py.Dataset): # test for dataset 98 | yield path, item 99 | elif isinstance(item, h5py.Group): # test for group (go down) 100 | yield from h5py_dataset_iterator(item, path) 101 | 102 | 103 | def load_episode_slice(episode_slice: EpisodeSlice) -> Tuple[int, Dict]: 104 | with h5py.File(episode_slice.episode_path, 'r') as f: 105 | episode_length = f['reward'].shape[0] 106 | if episode_slice.wrap_at_end and episode_slice.start + episode_slice.length > episode_length: 107 | sequence = { 108 | key: np.concatenate([ 109 | dataset[episode_slice.start:], 110 | dataset[:(episode_slice.start + episode_slice.length) % episode_length] 111 | ], axis=0) 112 | for key, dataset in h5py_dataset_iterator(f)} 113 | else: 114 | sequence = { 115 | key: dataset[episode_slice.start:episode_slice.start+episode_slice.length] 116 | for key, dataset in h5py_dataset_iterator(f)} 117 | for k, v in sequence.items(): 118 | assert v.shape[0] == episode_slice.length, f"{k}, {v.shape}, {episode_length}, {episode_slice}" 119 | return episode_slice.episode_id, unflatten_nested_dicts(sequence) 120 | 121 | 122 | def load_batch(semaphore, batch: List[EpisodeSlice]) -> Dict: 123 | semaphore.acquire() 124 | batch = map(load_episode_slice, batch) 125 | return tf.nest.map_structure(lambda *x: np.stack(x), *batch) 126 | 127 | 128 | def load_dataset_from_hdf5(file_paths: Sequence[str], 129 | specs: Dict[str, gym.spaces.Space], 130 | batch_size: int, 131 | sequence_length: int, 132 | num_workers: int = os.cpu_count(), 133 | chunk_size: int = 4, 134 | seed: Optional[int] = None) -> tf.data.Dataset: 135 | 136 | output_types = (tf.int32, unflatten_nested_dicts(tf.nest.map_structure(lambda s: s.dtype, specs))) 137 | output_shapes = ((batch_size,), unflatten_nested_dicts(tf.nest.map_structure( 138 | lambda s: tf.TensorShape([batch_size, sequence_length]).concatenate(s.shape), specs))) 139 | 140 | def _gen(): 141 | batcher = Batcher(batch_size, sequence_length=sequence_length, max_queue_size=num_workers, seed=seed) 142 | manager = multiprocessing.Manager() 143 | load_batch_semaphore = manager.Semaphore(2*chunk_size*num_workers) 144 | load_batch_ = functools.partial(load_batch, load_batch_semaphore) 145 | with multiprocessing.Pool(processes=num_workers) as pool: 146 | for batch in pool.imap(load_batch_, batcher(file_paths=file_paths), chunksize=chunk_size): 147 | yield batch 148 | load_batch_semaphore.release() 149 | return tf.data.Dataset.from_generator(_gen, args=[], output_types=output_types, output_shapes=output_shapes) 150 | 151 | 152 | def get_dataset_specs(action_space: ActionSpace, observation_space: ObservationSpace): 153 | return { 154 | **{ 155 | f'observation/{key}': space 156 | for key, space in flatten_nested_dicts(observation_space.specs).items() 157 | }, 158 | **{ 159 | f'action/{key}': space 160 | for key, space in flatten_nested_dicts(action_space.specs).items() 161 | }, 162 | **{ 163 | f'prev_action/{key}': space 164 | for key, space in flatten_nested_dicts(action_space.specs).items() 165 | }, 166 | 'reward': gym.spaces.Box(-np.inf, np.inf, shape=(), dtype=np.float32), 167 | 'prev_reward': gym.spaces.Box(-np.inf, np.inf, shape=(), dtype=np.float32), 168 | 'done': gym.spaces.Box(-np.inf, np.inf, shape=(), dtype=np.bool), 169 | } 170 | 171 | 172 | def store_episode_to_hdf5(path: str, 173 | name: str, 174 | episode: List[ActionTimeStep], 175 | episode_info: Dict, 176 | specs: Dict[str, gym.spaces.Space]) -> str: 177 | os.makedirs(path, exist_ok=True) 178 | 179 | assert os.path.exists(os.path.join(path, f'{name}.hdf5')) is False, \ 180 | f"'{name}.hdf5' already exists in '{path}'." 181 | 182 | with h5py.File(os.path.join(path, f'{name}.hdf5'), mode='w') as f: 183 | datasets = { 184 | key: f.create_dataset(name=key, shape=(len(episode),) + space.shape, dtype=space.dtype) 185 | for key, space in specs.items() 186 | } 187 | for i, time_step in enumerate(episode): 188 | for key, value in flatten_nested_dicts(time_step.observation).items(): 189 | datasets[f'observation/{key}'][i] = np.asarray(value) 190 | for key, value in flatten_nested_dicts(time_step.action).items(): 191 | datasets[f'action/{key}'][i] = np.asarray(value) 192 | for key, value in flatten_nested_dicts(time_step.action).items(): 193 | datasets[f'prev_action/{key}'][i] = -1 if i == 0 else datasets[f'action/{key}'][i - 1] 194 | datasets['reward'][i] = time_step.reward 195 | datasets['prev_reward'][i] = 0. if i == 0 else datasets['reward'][i - 1] 196 | datasets['done'][i] = time_step.done 197 | 198 | with open(os.path.join(path, f'{name}.meta'), mode='w') as f: 199 | json.dump({ 200 | 'data_file': f'{name}.hdf5', 201 | 'episode_return': sum([float(time_step.reward) for time_step in episode]), 202 | 'episode_length': len(episode), 203 | 'episode_info': episode_info 204 | }, f, indent=4) 205 | 206 | return os.path.join(path, f'{name}') 207 | 208 | 209 | class DataLoader(ABC): 210 | @property 211 | @abstractmethod 212 | def num_samples(self) -> int: 213 | pass 214 | 215 | @property 216 | @abstractmethod 217 | def num_episodes(self) -> int: 218 | pass 219 | 220 | @abstractmethod 221 | def load(self, 222 | batch_size: int, 223 | sequence_length: int, 224 | offset_episodes: int = 0, 225 | num_episodes: int = 0, 226 | num_workers: int = os.cpu_count(), 227 | chunk_size: int = 4, 228 | seed: Optional[int] = None) -> tf.data.Dataset: 229 | pass 230 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # StarCraft II Imitation Learning 2 | 3 | This repository provides code to train neural network based StarCraft II agents from human demonstrations. 4 | It emerged as a side-product of my Master's thesis, where I looked at representation learning from demonstrations for 5 | task transfer in reinforcement learning. 6 | 7 | The main features are: 8 | - Behaviour cloning from StarCraft II replays 9 | - Modular and extensible agents, inspired by the architecture of [AlphaStar](https://deepmind.com/blog/article/AlphaStar-Grandmaster-level-in-StarCraft-II-using-multi-agent-reinforcement-learning) but using the feature-layer interface instead of the raw game interface 10 | - Hierarchical configurations using [Gin Config](https://github.com/google/gin-config) that provide great degree of flexibility and configurability 11 | - Pre-processing of large-scale replay datasets 12 | - Multi-GPU training 13 | - Playing against trained agents (Windows / Mac) 14 | - [Pretrained agents](#download-pre-trained-agents) for the Terran vs Terran match-up 15 | 16 | ## Table of Contents 17 | [Installation](#installation) 18 | [Train your own agent](#train-your-own-agent) 19 | [Play against trained agents](#play-against-trained-agents) 20 | [Download pre-trained agents](#download-pre-trained-agents) 21 | 22 | ## Installation 23 | 24 | ### Requirements 25 | 26 | - Python >= 3.6 27 | - StarCraft II >= 3.16.1 (**4.7.1 strongly recommended**) 28 | 29 | To install StarCraft II, you can follow the instructions at https://github.com/deepmind/pysc2#get-starcraft-ii.
30 | 31 | On Linux: From the available versions, version 4.7.1 is strongly recommended. 32 | Other versions are not tested and might run into compatibility issues with this code or the PySC2 library. 33 | Also, replays are tied to the StarCraft II version in which they were recorded, and of all the binaries available, version 4.7.1 has the largest number of replays currently available through the Blizzard Game Data APIs. 34 | 35 | On Windows/MacOS: The binaries for a certain game version will be downloaded automatically when opening a replay of that version via the game client. 36 | 37 | 38 | ### Get the StarCraft II Maps 39 | 40 | Download the [ladder maps](https://github.com/Blizzard/s2client-proto#map-packs) and extract them to the `StarCraftII/Maps/` directory. 41 | 42 | ### Get the Code 43 | 44 | ```shell script 45 | git clone https://github.com/metataro/sc2_imitation_learning.git 46 | ``` 47 | 48 | ### Install the Python Libraries 49 | 50 | ```shell script 51 | pip install -r requirements.txt 52 | ``` 53 | 54 | ## Train Your Own Agent 55 | 56 | ### Download Replay Packs 57 | 58 | There are replay packs available for [direct download](https://github.com/Blizzard/s2client-proto#replay-packs), 59 | however, a much larger number of replays can be downloaded via the [Blizzard Game Data APIs](https://develop.battle.net/documentation/guides/game-data-apis). 60 | 61 | The download of StarCraft II replays from the Blizzard Game Data APIs is described [here](https://github.com/Blizzard/s2client-proto/tree/master/samples/replay-api). 62 | For example, the following command will download all available replays of game version 4.7.1: 63 | 64 | ```shell script 65 | python -m scripts.download_replays \ 66 | --key \ 67 | --secret \ 68 | --version 4.7.1 \ 69 | --extract \ 70 | --filter_version sort 71 | ``` 72 | 73 | ### Prepare the Dataset 74 | 75 | Having downloaded the replay packs, you can preprocess and combine them into a dataset as follows: 76 | 77 | ```shell script 78 | python -m scripts.build_dataset \ 79 | --gin_file ./configs/1v1/build_dataset.gin \ 80 | --replays_path ./data/replays/4.7.1/ \ 81 | --dataset_path ./data/datasets/v1 82 | ``` 83 | 84 | Note that depending on the configuration, the resulting dataset may require large amounts of disk space (> 1TB). 85 | For example, the configuration defined in `./configs/1v1/build_dataset.gin` results in a dataset with the size of about 4.5TB, 86 | although only less than 5% of the 4.7.1 replays are used. 87 | 88 | 89 | ### Run the Training 90 | 91 | After preparing the dataset, you can run behaviour cloning training as follows: 92 | 93 | ```shell script 94 | python -m scripts.behaviour_cloning --gin_file ./configs/1v1/behaviour_cloning.gin 95 | ``` 96 | 97 | By default, the training will be parallelized across all available GPUs. 98 | You can limit the number of used GPUs by setting the environment variable `CUDA_VISIBLE_DEVICES`. 99 | 100 | The parameters in `configs/1v1/behaviour_cloning.gin` are optimized for a hardware setup with four Nvidia GTX 1080Ti GPUs 101 | and 20 physical CPUs (40 logical CPUs), where the training takes around one week to complete. 102 | You may need to adjust these configurations to fit your hardware specifications. 103 | 104 | Logs are written to a tensoboard log file inside the experiment directory. 105 | You can additionally enable logging to [Weights & Biases](https://wandb.ai/) by setting the `--wandb_logging_enabled` flag. 106 | 107 | 108 | ### Run the Evaluation 109 | 110 | You can evaluate trained agents against built-in A.I. as follows: 111 | 112 | ```shell script 113 | python -m scripts.evaluate --gin_file configs/1v1/evaluate.gin --logdir 114 | ``` 115 | 116 | Replace `` with the path to the experiment folder of the agent. 117 | This will run the evaluation as configured in `configs/1v1/evaluate.gin`. 118 | Again, you may need to adjust these configurations to fit your hardware specifications. 119 | 120 | By default, all available GPUs will be considered and evaluators will be split evenly across them. 121 | You can limit the number of used GPUs by setting the environment variable `CUDA_VISIBLE_DEVICES`. 122 | 123 | ## Play Against Trained Agents 124 | 125 | You can challenge yourself to play against trained agents. 126 | 127 | First, start a game as human player: 128 | 129 | ```shell script 130 | python -m scripts.play_agent_vs_human --human 131 | ``` 132 | 133 | Then, in a second console, let the agent join the game: 134 | 135 | ```shell script 136 | python -m scripts.play_agent_vs_human --agent_dir 137 | ``` 138 | 139 | Replace `` with the path to the where the model is stored (e.g. `/path/to/experiment/saved_model`). 140 | 141 | ## Download Pre-Trained Agents 142 | 143 | There are pre-trained agents available for download: 144 | 145 | https://drive.google.com/drive/folders/1PNhOYeA4AkxhTzexQc-urikN4RDhWEUO?usp=sharing 146 | 147 | ### Agent 1v1/tvt_all_maps 148 | 149 | #### Evaluation Results 150 | 151 | The table below shows the win rates of the agent when evaluated in TvT against built-in AI with randomly selected builds. 152 | Win rate for each map and difficulty level were determined by 100 evaluation matches. 153 | 154 | 155 | | Map | Very Easy | Easy | Medium | Hard | 156 | |:----|---------:|----:|------:|----:| 157 | | KairosJunction | 0.86 | 0.27 | 0.07 | 0.00 | 158 | | Automaton | 0.82 | 0.33 | 0.07 | 0.00 | 159 | | Blueshift | 0.84 | 0.41 | 0.03 | 0.00 | 160 | | CeruleanFall | 0.72 | 0.28 | 0.03 | 0.00 | 161 | | ParaSite | 0.75 | 0.41 | 0.02 | 0.01 | 162 | | PortAleksander | 0.72 | 0.34 | 0.05 | 0.00 | 163 | | Stasis | 0.73 | 0.44 | 0.08 | 0.00 | 164 | | **Overall** | **0.78** | **0.35** | **0.05** | ~ **0.00** | 165 | 166 | 167 | #### Recordings 168 | 169 | Video recordings of cherry-picked evaluation games: 170 | 171 | 172 | 173 | 179 | 185 | 186 | 187 | 193 | 194 | 195 |
174 | Midgame win vs easy A.I.
175 | 176 | Midgame win vs easy A.I. 177 | 178 |
180 | Marine rush win vs easy A.I.
181 | 182 | Marine rush win vs easy A.I. 183 | 184 |
188 | Basetrade win vs hard A.I.
189 | 190 | Basetrade win vs hard A.I. 191 | 192 |
196 | 197 | #### Training Data 198 | 199 | 200 | 201 | 202 | 203 | 204 | 205 | 206 | 207 | 208 | 209 | 210 | 211 | 212 | 213 | 214 | 215 | 216 | 217 | 218 | 219 | 220 | 221 | 222 | 223 | 224 |
MatchupsTvT
Minimum MMR3500
Minimum APM60
Minimum duration30
MapsKairosJunction, Automaton, Blueshift, CeruleanFall, ParaSite, PortAleksander, Stasis
Episodes35'051 (102'792'317 timesteps)
225 | 226 | #### Interface 227 | 228 | 229 | 230 | 231 | 232 | 233 | 234 | 235 | 236 | 237 | 238 | 239 | 240 | 241 | 242 | 243 | 244 | 245 | 246 | 247 | 248 | 249 |
Interface typeFeature layers
Dimensions64 x 64 (screen), 64 x 64 (minimap)
Screen featuresvisibility_map, player_relative, unit_type, selected, unit_hit_points_ratio, unit_energy_ratio, unit_density_aa
Minimum featurescamera, player_relative, alerts
Scalar featuresplayer, home_race_requested, away_race_requested, upgrades, game_loop, available_actions, unit_counts, build_queue, cargo, cargo_slots_available, control_groups, multi_select, production_queue
250 | 251 | #### Agent Architecture 252 | 253 | ![SC2 Featuer Layer Agent Architecture](docs/sc2_feature_layer_agent_architecture.png) 254 | 255 | -------------------------------------------------------------------------------- /scripts/build_dataset.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import sys 4 | from typing import Iterator, NamedTuple, List, Dict, Text, AbstractSet, Optional, Set 5 | 6 | import gin 7 | import pypeln as pl 8 | from absl import app 9 | from absl import flags 10 | from pypeln.process import IterableQueue 11 | from pypeln.process.api.filter import FilterFn 12 | from pypeln.process.api.map import MapFn 13 | from pysc2 import run_configs 14 | from pysc2.env.environment import StepType 15 | from pysc2.env.sc2_env import Race 16 | from tqdm import tqdm 17 | 18 | from sc2_imitation_learning.common.replay_processor import ReplayProcessor, get_replay_info 19 | from sc2_imitation_learning.common.utils import retry 20 | from sc2_imitation_learning.dataset.dataset import ActionTimeStep, store_episode_to_hdf5, get_dataset_specs 21 | from sc2_imitation_learning.dataset.sc2_dataset import SC2REPLAY_RACES 22 | from sc2_imitation_learning.environment.environment import ActionSpace, ObservationSpace 23 | from sc2_imitation_learning.environment.sc2_environment import SC2ActionSpace, SC2ObservationSpace, SC2InterfaceConfig, \ 24 | SC2Maps 25 | 26 | logging.basicConfig(level=logging.WARNING) 27 | logger = logging.getLogger(__name__) 28 | 29 | flags.DEFINE_string('replays_path', default='./data/replays/4.7.1/', 30 | help='Path to the directory where the replays are stored.') 31 | flags.DEFINE_string('dataset_path', default='./data/datasets/1v1/v1', 32 | help='Path to the directory where the dataset will be stored.') 33 | flags.DEFINE_integer('num_workers', os.cpu_count(), help='Number of parallel workers.') 34 | flags.DEFINE_multi_string('gin_file', ['./configs/1v1/build_dataset.gin'], help='List of paths to Gin config files.') 35 | flags.DEFINE_multi_string('gin_param', None, help='List of Gin parameter bindings.') 36 | 37 | FLAGS = flags.FLAGS 38 | 39 | 40 | class ReplayMeta(NamedTuple): 41 | observed_player_id: int 42 | replay_info: Dict 43 | replay_path: Text 44 | 45 | 46 | class Replay(NamedTuple): 47 | time_steps: List[ActionTimeStep] 48 | replay_meta: ReplayMeta 49 | 50 | 51 | def is_not_none(x): return x is not None 52 | 53 | 54 | def find_replays(replay_path: Text) -> Iterator[str]: 55 | for entry in os.scandir(os.path.abspath(replay_path)): 56 | if entry.name.endswith('.SC2Replay'): 57 | yield entry.path 58 | 59 | 60 | def load_replay_meta(replay_path: Text) -> List[ReplayMeta]: 61 | replay_info = get_replay_info(replay_path) 62 | return [ReplayMeta(player['PlayerID'], replay_info, replay_path) for player in replay_info['Players']] 63 | 64 | 65 | @gin.register 66 | class FilterReplay(FilterFn): 67 | def __init__(self, 68 | min_duration: float = 0., 69 | min_mmr: int = 0, 70 | min_apm: int = 0, 71 | observed_player_races: AbstractSet[Race] = frozenset((Race.protoss, Race.terran, Race.zerg)), 72 | opponent_player_races: AbstractSet[Race] = frozenset((Race.protoss, Race.terran, Race.zerg)), 73 | wins_only: bool = False, 74 | map_names: Optional[Set[str]] = None) -> None: 75 | super().__init__() 76 | self.min_duration = min_duration 77 | self.min_mmr = min_mmr 78 | self.min_apm = min_apm 79 | self.observed_player_races = observed_player_races 80 | self.opponent_player_races = opponent_player_races 81 | self.wins_only = wins_only 82 | self.map_names = map_names 83 | 84 | def __call__(self, replay_meta: ReplayMeta, **kwargs) -> bool: 85 | if not FLAGS.is_parsed(): 86 | FLAGS(sys.argv) 87 | observed_player_info = next( 88 | filter(lambda p: p['PlayerID'] == replay_meta.observed_player_id, replay_meta.replay_info['Players'])) 89 | if len(replay_meta.replay_info['Players']) > 1: 90 | opponent_player_info = next( 91 | filter(lambda p: p['PlayerID'] != replay_meta.observed_player_id, replay_meta.replay_info['Players'])) 92 | else: 93 | opponent_player_info = None 94 | sc2_maps = SC2Maps(run_configs.get().data_dir) 95 | return (replay_meta.replay_info['Duration'] >= self.min_duration 96 | and observed_player_info.get('MMR', 0) >= self.min_mmr 97 | and observed_player_info['APM'] >= self.min_apm 98 | and SC2REPLAY_RACES[observed_player_info['AssignedRace']] in self.observed_player_races 99 | and (opponent_player_info is None or 100 | SC2REPLAY_RACES[opponent_player_info['AssignedRace']] in self.opponent_player_races) 101 | and (not self.wins_only or observed_player_info['Result'] == 'Win') 102 | and (self.map_names is None or sc2_maps.normalize_map_name(replay_meta.replay_info['Title']) in self.map_names)) 103 | 104 | 105 | @gin.register 106 | class ProcessReplay(MapFn): 107 | def __init__(self, 108 | interface_config: SC2InterfaceConfig = gin.REQUIRED, 109 | action_space: SC2ActionSpace = gin.REQUIRED, 110 | observation_space: SC2ObservationSpace = gin.REQUIRED, 111 | sc2_version: str = gin.REQUIRED) -> None: 112 | super().__init__() 113 | self.interface_config = interface_config 114 | self.action_space = action_space 115 | self.observation_space = observation_space 116 | self.sc2_version = sc2_version 117 | 118 | @retry(max_tries=3) 119 | def __call__(self, replay_meta: ReplayMeta, **kwargs) -> Replay: 120 | if not FLAGS.is_parsed(): 121 | FLAGS(sys.argv) 122 | 123 | def _valid_or_fallback_action(o: dict, a: Dict): 124 | if o['scalar_features']['available_actions'][a['action_type']] == 0: 125 | return self.action_space.no_op() # action_type not available 126 | elif 'build_queue_length' in o['scalar_features'] and \ 127 | o['scalar_features']['build_queue_length'] <= a['build_queue_id']: 128 | return self.action_space.no_op() # build_queue_id not available 129 | elif 'multi_select_length' in o['scalar_features'] and \ 130 | o['scalar_features']['multi_select_length'] <= a['select_unit_id']: 131 | return self.action_space.no_op() # select_unit_id not available 132 | elif 'cargo_length' in o['scalar_features'] and \ 133 | o['scalar_features']['cargo_length'] <= a['unload_id']: 134 | return self.action_space.no_op() # unload_id not available 135 | else: 136 | return a 137 | 138 | with ReplayProcessor( 139 | replay_path=replay_meta.replay_path, 140 | interface_config=self.interface_config, 141 | observation_space=self.observation_space, 142 | action_space=self.action_space, 143 | observed_player_id=replay_meta.observed_player_id, 144 | version=self.sc2_version) as replay_processor: 145 | sampled_replay: List[ActionTimeStep] = [] 146 | reward = 0. 147 | for curr_ts, curr_act in replay_processor.iterator(): 148 | action = _valid_or_fallback_action(curr_ts.observation, curr_act) 149 | reward += curr_ts.reward 150 | if ( # add timestep to replay if: 151 | len(sampled_replay) == 0 # a) it is the first timestep of an episode, 152 | or curr_ts.step_type == StepType.LAST # b) it is the last timestep of an episode, 153 | or action['action_type'] != 0 # c) an action other than noop is executed or 154 | or sampled_replay[-1].action['step_mul'] == self.interface_config.max_step_mul - 1 # d) max_step_mul is reached 155 | ): 156 | sampled_replay.append(ActionTimeStep(observation=curr_ts.observation, action=action, reward=reward, 157 | done=len(sampled_replay) == 0)) 158 | reward = 0. 159 | else: # if timestep is skipped, increment step_mul of most recent action 160 | sampled_replay[-1].action['step_mul'] += 1 161 | 162 | return Replay(time_steps=sampled_replay, replay_meta=replay_meta) 163 | 164 | 165 | @gin.register 166 | class StoreReplay(MapFn): 167 | 168 | def __init__(self, 169 | dataset_path: str, 170 | action_space: ActionSpace = gin.REQUIRED, 171 | observation_space: ObservationSpace = gin.REQUIRED) -> None: 172 | super().__init__() 173 | self.dataset_path = dataset_path 174 | self.action_space = action_space 175 | self.observation_space = observation_space 176 | 177 | def __call__(self, replay: Replay, **kwargs) -> str: 178 | replay_name = os.path.splitext(os.path.basename(replay.replay_meta.replay_path))[0] 179 | replay_name = f"{replay_name}_{replay.replay_meta.observed_player_id}" 180 | specs = get_dataset_specs(self.action_space, self.observation_space) 181 | file_name = store_episode_to_hdf5( 182 | path=self.dataset_path, 183 | name=replay_name, 184 | episode=replay.time_steps, 185 | episode_info={ 186 | 'observed_player_id': replay.replay_meta.observed_player_id, 187 | 'replay_path': replay.replay_meta.replay_path, 188 | 'replay_info': replay.replay_meta.replay_info 189 | }, 190 | specs=specs) 191 | return file_name 192 | 193 | 194 | def patch_iterable_queue(): 195 | """ Patches __getstate__ and __setstate__ of IterableQueues such that namespace and exception_queue attributes get 196 | exported/restored. See PR: https://github.com/cgarciae/pypeln/pull/74 """ 197 | orig_getstate = IterableQueue.__getstate__ 198 | orig_setstate = IterableQueue.__setstate__ 199 | 200 | def new_getstate(self): 201 | return orig_getstate(self) + (self.namespace, self.exception_queue) 202 | 203 | def new_setstate(self, state): 204 | orig_setstate(self, state[:-2]) 205 | self.namespace, self.exception_queue = state[-2:] 206 | 207 | IterableQueue.__getstate__ = new_getstate 208 | IterableQueue.__setstate__ = new_setstate 209 | 210 | logger.info("Pickle patch for IterableQueue applied.") 211 | 212 | 213 | patch_iterable_queue() 214 | 215 | 216 | def main(_): 217 | gin.parse_config_files_and_bindings(FLAGS.gin_file, FLAGS.gin_param) 218 | 219 | os.makedirs(FLAGS.dataset_path, exist_ok=True) 220 | assert len(os.listdir(FLAGS.dataset_path)) == 0, f'dataset_path directory ({FLAGS.dataset_path}) must be empty.' 221 | 222 | gin_config_str = gin.config_str(max_line_length=120) 223 | 224 | print("Loaded configuration:") 225 | print(gin_config_str) 226 | 227 | with open(os.path.join(FLAGS.dataset_path, 'config.gin'), mode='w') as f: 228 | f.write(gin_config_str) 229 | 230 | filter_replay = gin.get_configurable(FilterReplay)() 231 | process_replay = gin.get_configurable(ProcessReplay)() 232 | store_replay = gin.get_configurable(StoreReplay)(dataset_path=FLAGS.dataset_path) 233 | 234 | dataset_files = [] 235 | for dataset_file in tqdm( 236 | find_replays(FLAGS.replays_path) 237 | | pl.process.flat_map(load_replay_meta, workers=FLAGS.num_workers, maxsize=0) 238 | | pl.process.filter(filter_replay, workers=1, maxsize=0) 239 | | pl.process.map(process_replay, workers=FLAGS.num_workers, maxsize=FLAGS.num_workers) 240 | | pl.process.filter(is_not_none, workers=1, maxsize=0) 241 | | pl.process.map(store_replay, workers=FLAGS.num_workers, maxsize=0) 242 | ): 243 | dataset_files.append(dataset_file) 244 | 245 | 246 | if __name__ == '__main__': 247 | app.run(main) 248 | -------------------------------------------------------------------------------- /configs/mini_games/agents/sc2_feature_layer_agent.gin: -------------------------------------------------------------------------------- 1 | import sc2_imitation_learning.agents.sc2_feature_layer_agent 2 | import sc2_imitation_learning.common.conv 3 | import sc2_imitation_learning.common.mlp 4 | import sc2_imitation_learning.common.transformer 5 | 6 | 7 | # SC2FeatureLayerAgent 8 | # ============================================================================== 9 | 10 | sc2_feature_layer_agent.SC2FeatureLayerAgent.scalar_encoder = @scalar_encoder/ConcatScalarEncoder() 11 | sc2_feature_layer_agent.SC2FeatureLayerAgent.screen_encoder = @screen_encoder/ImpalaCNNSpatialEncoder() 12 | sc2_feature_layer_agent.SC2FeatureLayerAgent.minimap_encoder = @minimap_encoder/ImpalaCNNSpatialEncoder() 13 | sc2_feature_layer_agent.SC2FeatureLayerAgent.core = @core/tfa.rnn.LayerNormLSTMCell() 14 | sc2_feature_layer_agent.SC2FeatureLayerAgent.autoregressive_embed_dim = 32 15 | sc2_feature_layer_agent.SC2FeatureLayerAgent.policy_heads = [ 16 | @action_type/AutoregressivePolicyHead(), 17 | @step_mul/AutoregressivePolicyHead(), 18 | @queued/AutoregressivePolicyHead(), 19 | @control_group_act/AutoregressivePolicyHead(), 20 | @control_group_id/AutoregressivePolicyHead(), 21 | @select_point_act/AutoregressivePolicyHead(), 22 | @select_add/AutoregressivePolicyHead(), 23 | @select_unit_act/AutoregressivePolicyHead(), 24 | @select_unit_id/AutoregressivePolicyHead(), 25 | @select_worker/AutoregressivePolicyHead(), 26 | @build_queue_id/AutoregressivePolicyHead(), 27 | @unload_id/AutoregressivePolicyHead(), 28 | @screen/AutoregressivePolicyHead(), 29 | @screen2/AutoregressivePolicyHead(), 30 | @minimap/AutoregressivePolicyHead(), 31 | ] 32 | 33 | # ScalarEncoder 34 | # ============================================================================== 35 | 36 | scalar_encoder/ConcatScalarEncoder.action_space = %ACTION_SPACE 37 | scalar_encoder/ConcatScalarEncoder.feature_encoders = { 38 | 'player': @PlayerEncoder, 39 | 'available_actions': @AvailableActionsEncoder, 40 | } 41 | scalar_encoder/ConcatScalarEncoder.prev_action_encoders = { 42 | 'action_type': @action_type/ActionEncoder, 43 | 'step_mul': @step_mul/ActionEncoder 44 | } 45 | scalar_encoder/ConcatScalarEncoder.context_feature_names = ['available_actions'] 46 | 47 | PlayerEncoder.embedding_size = 64 48 | AvailableActionsEncoder.embedding_size = 64 49 | action_type/ActionEncoder.embedding_size = 128 50 | step_mul/ActionEncoder.embedding_size = 32 51 | 52 | 53 | # Screen Encoder 54 | # ============================================================================== 55 | 56 | screen_encoder/ImpalaCNNSpatialEncoder.feature_layer_encoders = { 57 | 'visibility_map': @screen_visibility_map/OneHotEncoder, 58 | 'player_relative': @screen_player_relative/OneHotEncoder, 59 | 'unit_type': @screen_unit_type/UnitTypeEncoder, 60 | 'selected': @screen_selected/OneHotEncoder, 61 | 'unit_hit_points_ratio': @screen_unit_hit_points_ratio/ScaleEncoder, 62 | 'unit_energy_ratio': @screen_unit_energy_ratio/ScaleEncoder, 63 | 'unit_density_aa': @screen_unit_density_aa/ScaleEncoder 64 | } 65 | screen_encoder/ImpalaCNNSpatialEncoder.input_projection_dim = 16 66 | screen_encoder/ImpalaCNNSpatialEncoder.num_blocks = [2, 2] 67 | screen_encoder/ImpalaCNNSpatialEncoder.output_channels = [16, 32] 68 | screen_encoder/ImpalaCNNSpatialEncoder.max_pool_padding = 'SAME' 69 | screen_encoder/ImpalaCNNSpatialEncoder.spatial_embedding_size = 256 70 | 71 | screen_visibility_map/OneHotEncoder.depth = 4 72 | screen_player_relative/OneHotEncoder.depth = 5 73 | screen_selected/OneHotEncoder.depth = 2 74 | screen_unit_hit_points_ratio/ScaleEncoder.factor = 0.003921569 # 1/255 75 | screen_unit_energy_ratio/ScaleEncoder.factor = 0.003921569 # 1/255 76 | screen_unit_density_aa/ScaleEncoder.factor = 0.003921569 # 1/255 77 | 78 | screen_unit_type/UnitTypeEncoder.embed_dim = 10 79 | screen_unit_type/UnitTypeEncoder.max_unit_count = 512 80 | screen_unit_type/UnitTypeEncoder.encoder = @screen_unit_type_encoder/ConvNet1D() 81 | screen_unit_type_encoder/ConvNet1D.output_channels = [10] 82 | screen_unit_type_encoder/ConvNet1D.kernel_shapes = [1] 83 | screen_unit_type_encoder/ConvNet1D.strides = [1] 84 | screen_unit_type_encoder/ConvNet1D.paddings = ['SAME'] 85 | screen_unit_type_encoder/ConvNet1D.activate_final = True 86 | 87 | 88 | # Minimap Encoder 89 | # ============================================================================== 90 | 91 | minimap_encoder/ImpalaCNNSpatialEncoder.feature_layer_encoders = { 92 | 'camera': @minimap_camera/OneHotEncoder, 93 | 'player_relative': @minimap_player_relative/OneHotEncoder, 94 | 'alerts': @minimap_alerts/OneHotEncoder, 95 | } 96 | minimap_encoder/ImpalaCNNSpatialEncoder.input_projection_dim = 16 97 | minimap_encoder/ImpalaCNNSpatialEncoder.num_blocks = [2, 2] 98 | minimap_encoder/ImpalaCNNSpatialEncoder.output_channels = [16, 32] 99 | minimap_encoder/ImpalaCNNSpatialEncoder.max_pool_padding = 'SAME' 100 | minimap_encoder/ImpalaCNNSpatialEncoder.spatial_embedding_size = 256 101 | 102 | minimap_camera/OneHotEncoder.depth = 2 103 | minimap_player_relative/OneHotEncoder.depth = 5 104 | minimap_alerts/OneHotEncoder.depth = 2 105 | 106 | 107 | # LSTM core 108 | # ============================================================================== 109 | 110 | core/tfa.rnn.LayerNormLSTMCell.units = 1024 111 | 112 | 113 | # Policy heads 114 | # ============================================================================== 115 | 116 | # Policy heads: general setup 117 | # ------------------------------------------------------------------------------ 118 | 119 | AutoregressivePolicyHead.action_space = %ACTION_SPACE 120 | 121 | ActionEmbedding.output_sizes = [32] 122 | ActionEmbedding.with_layer_norm = False 123 | 124 | ActionTypePolicyHead.decoder = @action_type_policy_head/MLP() 125 | action_type_policy_head/MLP.output_sizes = [256, 256] 126 | action_type_policy_head/MLP.with_layer_norm = False 127 | action_type_policy_head/MLP.activate_final = True 128 | 129 | ScalarPolicyHead.decoder = @scalar_policy_head/MLP() 130 | scalar_policy_head/MLP.output_sizes = [256] 131 | scalar_policy_head/MLP.with_layer_norm = False 132 | scalar_policy_head/MLP.activate_final = True 133 | 134 | 135 | ResSpatialDecoder.out_channels = 32 136 | ResSpatialDecoder.num_blocks = 6 137 | 138 | SpatialPolicyHead.upsample_conv_net = @spatial_policy/ConvNet2DTranspose() 139 | spatial_policy/ConvNet2DTranspose.output_channels = [32, 16] 140 | spatial_policy/ConvNet2DTranspose.kernel_shapes = [4, 4] 141 | spatial_policy/ConvNet2DTranspose.strides = [2, 2] 142 | spatial_policy/ConvNet2DTranspose.paddings = ['SAME', 'SAME'] 143 | spatial_policy/ConvNet2DTranspose.activate_final = True 144 | 145 | 146 | # Policy heads: head specific setup 147 | # ------------------------------------------------------------------------------ 148 | 149 | action_type/AutoregressivePolicyHead.action_name = 'action_type' 150 | action_type/AutoregressivePolicyHead.policy_head = @action_type/ActionTypePolicyHead 151 | action_type/AutoregressivePolicyHead.action_embed = @action_type/ActionEmbedding 152 | action_type/AutoregressivePolicyHead.action_mask = None 153 | 154 | step_mul/AutoregressivePolicyHead.action_name = 'step_mul' 155 | step_mul/AutoregressivePolicyHead.policy_head = @step_mul/ScalarPolicyHead 156 | step_mul/AutoregressivePolicyHead.action_embed = @step_mul/ActionEmbedding 157 | step_mul/AutoregressivePolicyHead.action_mask = None 158 | 159 | queued/AutoregressivePolicyHead.action_name = 'queued' 160 | queued/AutoregressivePolicyHead.policy_head = @queued/ScalarPolicyHead 161 | queued/AutoregressivePolicyHead.action_embed = @queued/ActionEmbedding 162 | queued/AutoregressivePolicyHead.action_mask = @queued/ActionArgumentMask 163 | 164 | control_group_act/AutoregressivePolicyHead.action_name = 'control_group_act' 165 | control_group_act/AutoregressivePolicyHead.policy_head = @control_group_act/ScalarPolicyHead 166 | control_group_act/AutoregressivePolicyHead.action_embed = @control_group_act/ActionEmbedding 167 | control_group_act/AutoregressivePolicyHead.action_mask = @control_group_act/ActionArgumentMask 168 | 169 | control_group_id/AutoregressivePolicyHead.action_name = 'control_group_id' 170 | control_group_id/AutoregressivePolicyHead.policy_head = @control_group_id/ScalarPolicyHead 171 | control_group_id/AutoregressivePolicyHead.action_embed = None # final argument, no need to embed 172 | control_group_id/AutoregressivePolicyHead.action_mask = @control_group_id/ActionArgumentMask 173 | 174 | select_point_act/AutoregressivePolicyHead.action_name = 'select_point_act' 175 | select_point_act/AutoregressivePolicyHead.policy_head = @select_point_act/ScalarPolicyHead 176 | select_point_act/AutoregressivePolicyHead.action_embed = @select_point_act/ActionEmbedding 177 | select_point_act/AutoregressivePolicyHead.action_mask = @select_point_act/ActionArgumentMask 178 | 179 | select_add/AutoregressivePolicyHead.action_name = 'select_add' 180 | select_add/AutoregressivePolicyHead.policy_head = @select_add/ScalarPolicyHead 181 | select_add/AutoregressivePolicyHead.action_embed = @select_add/ActionEmbedding 182 | select_add/AutoregressivePolicyHead.action_mask = @select_add/ActionArgumentMask 183 | 184 | select_point_act/AutoregressivePolicyHead.action_name = 'select_point_act' 185 | select_point_act/AutoregressivePolicyHead.policy_head = @select_point_act/ScalarPolicyHead 186 | select_point_act/AutoregressivePolicyHead.action_embed = @select_point_act/ActionEmbedding 187 | select_point_act/AutoregressivePolicyHead.action_mask = @select_point_act/ActionArgumentMask 188 | 189 | select_add/AutoregressivePolicyHead.action_name = 'select_add' 190 | select_add/AutoregressivePolicyHead.policy_head = @select_add/ScalarPolicyHead 191 | select_add/AutoregressivePolicyHead.action_embed = @select_add/ActionEmbedding 192 | select_add/AutoregressivePolicyHead.action_mask = @select_add/ActionArgumentMask 193 | 194 | select_unit_act/AutoregressivePolicyHead.action_name = 'select_unit_act' 195 | select_unit_act/AutoregressivePolicyHead.policy_head = @select_unit_act/ScalarPolicyHead 196 | select_unit_act/AutoregressivePolicyHead.action_embed = @select_unit_act/ActionEmbedding 197 | select_unit_act/AutoregressivePolicyHead.action_mask = @select_unit_act/ActionArgumentMask 198 | 199 | select_unit_id/AutoregressivePolicyHead.action_name = 'select_unit_id' 200 | select_unit_id/AutoregressivePolicyHead.policy_head = @select_unit_id/ScalarPolicyHead 201 | select_unit_id/AutoregressivePolicyHead.action_embed = None # final argument, no need to embed 202 | select_unit_id/AutoregressivePolicyHead.action_mask = @select_unit_id/ActionArgumentMask 203 | 204 | select_worker/AutoregressivePolicyHead.action_name = 'select_worker' 205 | select_worker/AutoregressivePolicyHead.policy_head = @select_worker/ScalarPolicyHead 206 | select_worker/AutoregressivePolicyHead.action_embed = None # final argument, no need to embed 207 | select_worker/AutoregressivePolicyHead.action_mask = @select_worker/ActionArgumentMask 208 | 209 | build_queue_id/AutoregressivePolicyHead.action_name = 'build_queue_id' 210 | build_queue_id/AutoregressivePolicyHead.policy_head = @build_queue_id/ScalarPolicyHead 211 | build_queue_id/AutoregressivePolicyHead.action_embed = None # final argument, no need to embed 212 | build_queue_id/AutoregressivePolicyHead.action_mask = @build_queue_id/ActionArgumentMask 213 | 214 | unload_id/AutoregressivePolicyHead.action_name = 'unload_id' 215 | unload_id/AutoregressivePolicyHead.policy_head = @unload_id/ScalarPolicyHead 216 | unload_id/AutoregressivePolicyHead.action_embed = None # final argument, no need to embed 217 | unload_id/AutoregressivePolicyHead.action_mask = @unload_id/ActionArgumentMask 218 | 219 | screen/AutoregressivePolicyHead.action_name = 'screen' 220 | screen/AutoregressivePolicyHead.policy_head = @screen/SpatialPolicyHead 221 | screen/AutoregressivePolicyHead.action_embed = @screen/ActionEmbedding 222 | screen/AutoregressivePolicyHead.action_mask = @screen/ActionArgumentMask 223 | screen/SpatialPolicyHead.decoder = @ResSpatialDecoder() 224 | screen/SpatialPolicyHead.map_skip = 'screen' 225 | 226 | screen2/AutoregressivePolicyHead.action_name = 'screen2' 227 | screen2/AutoregressivePolicyHead.policy_head = @screen2/SpatialPolicyHead 228 | screen2/AutoregressivePolicyHead.action_embed = None # final argument, no need to embed 229 | screen2/AutoregressivePolicyHead.action_mask = @screen2/ActionArgumentMask 230 | screen2/SpatialPolicyHead.decoder = @ResSpatialDecoder() 231 | screen2/SpatialPolicyHead.map_skip = 'screen' 232 | 233 | minimap/AutoregressivePolicyHead.action_name = 'minimap' 234 | minimap/AutoregressivePolicyHead.policy_head = @minimap/SpatialPolicyHead 235 | minimap/AutoregressivePolicyHead.action_embed = None # final argument, no need to embed 236 | minimap/AutoregressivePolicyHead.action_mask = @minimap/ActionArgumentMask 237 | minimap/SpatialPolicyHead.decoder = @ResSpatialDecoder() 238 | minimap/SpatialPolicyHead.map_skip = 'minimap' --------------------------------------------------------------------------------