├── .gitignore ├── .pre-commit-config.yaml ├── LICENSE ├── README.md ├── docs ├── smac-official.png └── smac.md ├── pyproject.toml ├── setup.py └── smac ├── __init__.py ├── bin ├── __init__.py └── map_list.py ├── env ├── __init__.py ├── multiagentenv.py ├── pettingzoo │ ├── StarCraft2PZEnv.py │ ├── __init__.py │ └── test │ │ ├── __init__.py │ │ ├── all_test.py │ │ └── smac_pettingzoo_test.py └── starcraft2 │ ├── __init__.py │ ├── maps │ ├── SMAC_Maps │ │ ├── 10m_vs_11m.SC2Map │ │ ├── 1c3s5z.SC2Map │ │ ├── 25m.SC2Map │ │ ├── 27m_vs_30m.SC2Map │ │ ├── 2c_vs_64zg.SC2Map │ │ ├── 2m_vs_1z.SC2Map │ │ ├── 2s3z.SC2Map │ │ ├── 2s_vs_1sc.SC2Map │ │ ├── 3m.SC2Map │ │ ├── 3s5z.SC2Map │ │ ├── 3s5z_vs_3s6z.SC2Map │ │ ├── 3s_vs_3z.SC2Map │ │ ├── 3s_vs_4z.SC2Map │ │ ├── 3s_vs_5z.SC2Map │ │ ├── 5m_vs_6m.SC2Map │ │ ├── 6h_vs_8z.SC2Map │ │ ├── 8m.SC2Map │ │ ├── 8m_vs_9m.SC2Map │ │ ├── MMM.SC2Map │ │ ├── MMM2.SC2Map │ │ ├── bane_vs_bane.SC2Map │ │ ├── corridor.SC2Map │ │ └── so_many_baneling.SC2Map │ ├── __init__.py │ └── smac_maps.py │ ├── render.py │ └── starcraft2.py └── examples ├── __init__.py ├── pettingzoo ├── README.rst ├── __init__.py └── pettingzoo_demo.py ├── random_agents.py └── rllib ├── README.rst ├── __init__.py ├── env.py ├── model.py ├── run_ppo.py └── run_qmix.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 98 | __pypackages__/ 99 | 100 | # Celery stuff 101 | celerybeat-schedule 102 | celerybeat.pid 103 | 104 | # SageMath parsed files 105 | *.sage.py 106 | 107 | # Environments 108 | .env 109 | .venv 110 | # env/ 111 | venv/ 112 | ENV/ 113 | env.bak/ 114 | venv.bak/ 115 | 116 | # Spyder project settings 117 | .spyderproject 118 | .spyproject 119 | 120 | # Rope project settings 121 | .ropeproject 122 | 123 | # mkdocs documentation 124 | /site 125 | 126 | # mypy 127 | .mypy_cache/ 128 | .dmypy.json 129 | dmypy.json 130 | 131 | # Pyre type checker 132 | .pyre/ 133 | 134 | # pytype static type analyzer 135 | .pytype/ 136 | 137 | # Cython debug symbols 138 | cython_debug/ 139 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/ambv/black 3 | rev: stable 4 | hooks: 5 | - id: black 6 | language_version: python3.8 7 | - repo: https://gitlab.com/pycqa/flake8 8 | rev: '3.8.4' 9 | hooks: 10 | - id: flake8 11 | additional_dependencies: [flake8-bugbear] 12 | args: ["--show-source"] 13 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 whirl 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

2 | 3 |

4 | 5 | > __Note__ 6 | > SMACv2 is out! Check it out [here](https://github.com/oxwhirl/smacv2). 7 | 8 | > __Warning__ 9 | > **Please pay attention to the version of SC2 used for your experiments.** Performance is **not** always comparable between versions. The results in the [SMAC paper](https://arxiv.org/abs/1902.04043) use `SC2.4.6.2.69232` not `SC2.4.10`. 10 | 11 | # SMAC - StarCraft Multi-Agent Challenge 12 | 13 | [SMAC](https://github.com/oxwhirl/smac) is [WhiRL](http://whirl.cs.ox.ac.uk)'s environment for research in the field of cooperative multi-agent reinforcement learning (MARL) based on [Blizzard](http://blizzard.com)'s [StarCraft II](https://en.wikipedia.org/wiki/StarCraft_II:_Wings_of_Liberty) RTS game. SMAC makes use of Blizzard's [StarCraft II Machine Learning API](https://github.com/Blizzard/s2client-proto) and [DeepMind](https://deepmind.com)'s [PySC2](https://github.com/deepmind/pysc2) to provide a convenient interface for autonomous agents to interact with StarCraft II, getting observations and performing actions. Unlike the [PySC2](https://github.com/deepmind/pysc2), SMAC concentrates on *decentralised micromanagement* scenarios, where each unit of the game is controlled by an individual RL agent. 14 | 15 | Please refer to the accompanying [paper](https://arxiv.org/abs/1902.04043) and [blogpost](https://blog.ucldark.com/2019/02/12/smac.html) for the outline of our motivation for using SMAC as a testbed for MARL research and the initial experimental results. 16 | 17 | ## About 18 | 19 | Together with SMAC we also release [PyMARL](https://github.com/oxwhirl/pymarl) - our [PyTorch](https://github.com/pytorch/pytorch) framework for MARL research, which includes implementations of several state-of-the-art algorithms, such as [QMIX](https://arxiv.org/abs/1803.11485) and [COMA](https://arxiv.org/abs/1705.08926). 20 | 21 | Data from the runs used in the paper is included [here](https://github.com/oxwhirl/smac/releases/download/v1/smac_run_data.json). **These runs are outdated based on recent changes in StarCraft II. If you ran your experiments using the current version of SMAC, you mustn't compare your results with the ones provided here.** 22 | 23 | # Quick Start 24 | 25 | ## Installing SMAC 26 | 27 | You can install SMAC by using the following command: 28 | 29 | ```shell 30 | pip install git+https://github.com/oxwhirl/smac.git 31 | ``` 32 | 33 | Alternatively, you can clone the SMAC repository and then install `smac` with its dependencies: 34 | 35 | ```shell 36 | git clone https://github.com/oxwhirl/smac.git 37 | pip install -e smac/ 38 | ``` 39 | 40 | *NOTE*: If you want to extend SMAC, please install the package as follows: 41 | 42 | ```shell 43 | git clone https://github.com/oxwhirl/smac.git 44 | cd smac 45 | pip install -e ".[dev]" 46 | pre-commit install 47 | ``` 48 | 49 | You may also need to upgrade pip: `pip install --upgrade pip` for the install to work. 50 | 51 | ## Installing StarCraft II 52 | 53 | SMAC is based on the full game of StarCraft II (versions >= 3.16.1). To install the game, follow the commands bellow. 54 | 55 | ### Linux 56 | 57 | Please use the Blizzard's [repository](https://github.com/Blizzard/s2client-proto#downloads) to download the Linux version of StarCraft II. By default, the game is expected to be in `~/StarCraftII/` directory. This can be changed by setting the environment variable `SC2PATH`. 58 | 59 | ### MacOS/Windows 60 | 61 | Please install StarCraft II from [Battle.net](https://battle.net). The free [Starter Edition](http://battle.net/sc2/en/legacy-of-the-void/) also works. PySC2 will find the latest binary should you use the default install location. Otherwise, similar to the Linux version, you would need to set the `SC2PATH` environment variable with the correct location of the game. 62 | 63 | ## SMAC maps 64 | 65 | SMAC is composed of many combat scenarios with pre-configured maps. Before SMAC can be used, these maps need to be downloaded into the `Maps` directory of StarCraft II. 66 | 67 | Download the [SMAC Maps](https://github.com/oxwhirl/smac/releases/download/v0.1-beta1/SMAC_Maps.zip) and extract them to your `$SC2PATH/Maps` directory. If you installed SMAC via git, simply copy the `SMAC_Maps` directory from `smac/env/starcraft2/maps/` into `$SC2PATH/Maps` directory. 68 | 69 | ### List the maps 70 | 71 | To see the list of SMAC maps, together with the number of ally and enemy units and episode limit, run: 72 | 73 | ```shell 74 | python -m smac.bin.map_list 75 | ``` 76 | 77 | ### Creating new maps 78 | 79 | Users can extend SMAC by adding new maps/scenarios. To this end, one needs to: 80 | 81 | - Design a new map/scenario using StarCraft II Editor: 82 | - Please take a close look at the existing maps to understand the basics that we use (e.g. Triggers, Units, etc), 83 | - We make use of special RL units which never automatically start attacking the enemy. [Here](https://docs.google.com/document/d/1BfAM_AtZWBRhUiOBcMkb_uK4DAZW3CpvO79-vnEOKxA/edit?usp=sharing) is the step-by-step guide on how to create new RL units based on existing SC2 units, 84 | - Add the map information in [smac_maps.py](https://github.com/oxwhirl/smac/blob/master/smac/env/starcraft2/maps/smac_maps.py), 85 | - The newly designed RL units have new ids which need to be handled in [starcraft2.py](https://github.com/oxwhirl/smac/blob/master/smac/env/starcraft2/starcraft2.py). Specifically, for heterogenious maps containing more than one unit types, one needs to manually set the unit ids in the `_init_ally_unit_types()` function. 86 | 87 | ## Testing SMAC 88 | 89 | Please run the following command to make sure that `smac` and its maps are properly installed. 90 | 91 | ```bash 92 | python -m smac.examples.random_agents 93 | ``` 94 | 95 | ## Saving and Watching StarCraft II Replays 96 | 97 | ### Saving a replay 98 | 99 | If you’ve using our [PyMARL](https://github.com/oxwhirl/pymarl) framework for multi-agent RL, here’s what needs to be done: 100 | 1. **Saving models**: We run experiments on *Linux* servers with `save_model = True` (also `save_model_interval` is relevant) setting so that we have training checkpoints (parameters of neural networks) saved (click [here](https://github.com/oxwhirl/pymarl#saving-and-loading-learnt-models) for more details). 101 | 2. **Loading models**: Learnt models can be loaded using the `checkpoint_path` parameter. If you run PyMARL on *MacOS* (or *Windows*) while also setting `save_replay=True`, this will save a .SC2Replay file for `test_nepisode` episodes on the test mode (no exploration) in the Replay directory of StarCraft II. (click [here](https://github.com/oxwhirl/pymarl#watching-starcraft-ii-replays) for more details). 102 | 103 | If you want to save replays without using PyMARL, simply call the `save_replay()` function of SMAC's StarCraft2Env in your training/testing code. This will save a replay of all epsidoes since the launch of the StarCraft II client. 104 | 105 | The easiest way to save and later watch a replay on Linux is to use [Wine](https://www.winehq.org/). 106 | 107 | ### Watching a replay 108 | 109 | You can watch the saved replay directly within the StarCraft II client on MacOS/Windows by *clicking on the corresponding Replay file*. 110 | 111 | You can also watch saved replays by running: 112 | 113 | ```shell 114 | python -m pysc2.bin.play --norender --replay 115 | ``` 116 | 117 | This works for any replay as long as the map can be found by the game. 118 | 119 | For more information, please refer to [PySC2](https://github.com/deepmind/pysc2) documentation. 120 | 121 | # Documentation 122 | 123 | For the detailed description of the environment, read the [SMAC documentation](docs/smac.md). 124 | 125 | The initial results of our experiments using SMAC can be found in the [accompanying paper](https://arxiv.org/abs/1902.04043). 126 | 127 | # Citing SMAC 128 | 129 | If you use SMAC in your research, please cite the [SMAC paper](https://arxiv.org/abs/1902.04043). 130 | 131 | *M. Samvelyan, T. Rashid, C. Schroeder de Witt, G. Farquhar, N. Nardelli, T.G.J. Rudner, C.-M. Hung, P.H.S. Torr, J. Foerster, S. Whiteson. The StarCraft Multi-Agent Challenge, CoRR abs/1902.04043, 2019.* 132 | 133 | In BibTeX format: 134 | 135 | ```tex 136 | @article{samvelyan19smac, 137 | title = {{The} {StarCraft} {Multi}-{Agent} {Challenge}}, 138 | author = {Mikayel Samvelyan and Tabish Rashid and Christian Schroeder de Witt and Gregory Farquhar and Nantas Nardelli and Tim G. J. Rudner and Chia-Man Hung and Philiph H. S. Torr and Jakob Foerster and Shimon Whiteson}, 139 | journal = {CoRR}, 140 | volume = {abs/1902.04043}, 141 | year = {2019}, 142 | } 143 | ``` 144 | 145 | # Code Examples 146 | 147 | Below is a small code example which illustrates how SMAC can be used. Here, individual agents execute random policies after receiving the observations and global state from the environment. 148 | 149 | If you want to try the state-of-the-art algorithms (such as [QMIX](https://arxiv.org/abs/1803.11485) and [COMA](https://arxiv.org/abs/1705.08926)) on SMAC, make use of [PyMARL](https://github.com/oxwhirl/pymarl) - our framework for MARL research. 150 | 151 | ```python 152 | from smac.env import StarCraft2Env 153 | import numpy as np 154 | 155 | 156 | def main(): 157 | env = StarCraft2Env(map_name="8m") 158 | env_info = env.get_env_info() 159 | 160 | n_actions = env_info["n_actions"] 161 | n_agents = env_info["n_agents"] 162 | 163 | n_episodes = 10 164 | 165 | for e in range(n_episodes): 166 | env.reset() 167 | terminated = False 168 | episode_reward = 0 169 | 170 | while not terminated: 171 | obs = env.get_obs() 172 | state = env.get_state() 173 | # env.render() # Uncomment for rendering 174 | 175 | actions = [] 176 | for agent_id in range(n_agents): 177 | avail_actions = env.get_avail_agent_actions(agent_id) 178 | avail_actions_ind = np.nonzero(avail_actions)[0] 179 | action = np.random.choice(avail_actions_ind) 180 | actions.append(action) 181 | 182 | reward, terminated, _ = env.step(actions) 183 | episode_reward += reward 184 | 185 | print("Total reward in episode {} = {}".format(e, episode_reward)) 186 | 187 | env.close() 188 | 189 | ``` 190 | 191 | ## RLlib Examples 192 | 193 | You can also run SMAC environments in [RLlib](https://rllib.io), which includes scalable algorithms such as [PPO](https://ray.readthedocs.io/en/latest/rllib-algorithms.html#proximal-policy-optimization-ppo) and [IMPALA](https://ray.readthedocs.io/en/latest/rllib-algorithms.html#importance-weighted-actor-learner-architecture-impala). Check out the example code [here](https://github.com/oxwhirl/smac/tree/master/smac/examples/rllib). 194 | 195 | ## PettingZoo Example 196 | 197 | Thanks to [Rodrigo de Lazcano](https://github.com/rodrigodelazcano), SMAC now supports [PettingZoo API](https://github.com/PettingZoo-Team/PettingZoo) and PyGame environment rendering. Check out the example code [here](https://github.com/oxwhirl/smac/tree/master/smac/examples/pettingzoo). 198 | 199 | -------------------------------------------------------------------------------- /docs/smac-official.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/oxwhirl/smac/d6aab33f76abc3849c50463a8592a84f59a5ef84/docs/smac-official.png -------------------------------------------------------------------------------- /docs/smac.md: -------------------------------------------------------------------------------- 1 | ## Table of Contents 2 | 3 | - [StarCraft II](#starcraft-ii) 4 | - [Micromanagement](#micromanagement) 5 | - [SMAC](#smac) 6 | - [Scenarios](#scenarios) 7 | - [State and Observations](#state-and-observations) 8 | - [Action Space](#action-space) 9 | - [Rewards](#rewards) 10 | - [Environment Settings](#environment-settings) 11 | 12 | ## StarCraft II 13 | 14 | SMAC is based on the popular real-time strategy (RTS) game [StarCraft II](http://us.battle.net/sc2/en/game/guide/whats-sc2) written by [Blizzard](http://blizzard.com/). 15 | In a regular full game of StarCraft II, one or more humans compete against each other or against a built-in game AI to gather resources, construct buildings, and build armies of units to defeat their opponents. 16 | 17 | Akin to most RTSs, StarCraft has two main gameplay components: macromanagement and micromanagement. 18 | - _Macromanagement_ (macro) refers to high-level strategic considerations, such as economy and resource management. 19 | - _Micromanagement_ (micro) refers to fine-grained control of individual units. 20 | 21 | ### Micromanagement 22 | 23 | StarCraft has been used as a research platform for AI, and more recently, RL. Typically, the game is framed as a competitive problem: an agent takes the role of a human player, making macromanagement decisions and performing micromanagement as a puppeteer that issues orders to individual units from a centralised controller. 24 | 25 | In order to build a rich multi-agent testbed, we instead focus solely on micromanagement. 26 | Micro is a vital aspect of StarCraft gameplay with a high skill ceiling, and is practiced in isolation by amateur and professional players. 27 | For SMAC, we leverage the natural multi-agent structure of micromanagement by proposing a modified version of the problem designed specifically for decentralised control. 28 | In particular, we require that each unit be controlled by an independent agent that conditions only on local observations restricted to a limited field of view centred on that unit. 29 | Groups of these agents must be trained to solve challenging combat scenarios, battling an opposing army under the centralised control of the game's built-in scripted AI. 30 | 31 | Proper micro of units during battles will maximise the damage dealt to enemy units while minimising damage received, and requires a range of skills. 32 | For example, one important technique is _focus fire_, i.e., ordering units to jointly attack and kill enemy units one after another. When focusing fire, it is important to avoid _overkill_: inflicting more damage to units than is necessary to kill them. 33 | 34 | Other common micromanagement techniques include: assembling units into formations based on their armour types, making enemy units give chase while maintaining enough distance so that little or no damage is incurred (_kiting_), coordinating the positioning of units to attack from different directions or taking advantage of the terrain to defeat the enemy. 35 | 36 | Learning these rich cooperative behaviours under partial observability is challenging task, which can be used to evaluate the effectiveness of multi-agent reinforcement learning (MARL) algorithms. 37 | 38 | ## SMAC 39 | 40 | SMAC uses the [StarCraft II Learning Environment](https://github.com/deepmind/pysc2) to introduce a cooperative MARL environment. 41 | 42 | ### Scenarios 43 | 44 | SMAC consists of a set of StarCraft II micro scenarios which aim to evaluate how well independent agents are able to learn coordination to solve complex tasks. 45 | These scenarios are carefully designed to necessitate the learning of one or more micromanagement techniques to defeat the enemy. 46 | Each scenario is a confrontation between two armies of units. 47 | The initial position, number, and type of units in each army varies from scenario to scenario, as does the presence or absence of elevated or impassable terrain. 48 | 49 | The first army is controlled by the learned allied agents. 50 | The second army consists of enemy units controlled by the built-in game AI, which uses carefully handcrafted non-learned heuristics. 51 | At the beginning of each episode, the game AI instructs its units to attack the allied agents using its scripted strategies. 52 | An episode ends when all units of either army have died or when a pre-specified time limit is reached (in which case the game is counted as a defeat for the allied agents). 53 | The goal for each scenario is to maximise the win rate of the learned policies, i.e., the expected ratio of games won to games played. 54 | To speed up learning, the enemy AI units are ordered to attack the agents' spawning point in the beginning of each episode. 55 | 56 | Perhaps the simplest scenarios are _symmetric_ battle scenarios. 57 | The most straightforward of these scenarios are _homogeneous_, i.e., each army is composed of only a single unit type (e.g., Marines). 58 | A winning strategy in this setting would be to focus fire, ideally without overkill. 59 | _Heterogeneous_ symmetric scenarios, in which there is more than a single unit type on each side (e.g., Stalkers and Zealots), are more difficult to solve. 60 | Such challenges are particularly interesting when some of the units are extremely effective against others (this is known as _countering_), for example, by dealing bonus damage to a particular armour type. 61 | In such a setting, allied agents must deduce this property of the game and design an intelligent strategy to protect teammates vulnerable to certain enemy attacks. 62 | 63 | SMAC also includes more challenging scenarios, for example, in which the enemy army outnumbers the allied army by one or more units. In such _asymmetric_ scenarios it is essential to consider the health of enemy units in order to effectively target the desired opponent. 64 | 65 | Lastly, SMAC offers a set of interesting _micro-trick_ challenges that require a higher-level of cooperation and a specific micromanagement trick to defeat the enemy. 66 | An example of a challenge scenario is _2m_vs_1z_ (aka Marine Double Team), where two Terran Marines need to defeat an enemy Zealot. In this setting, the Marines must design a strategy which does not allow the Zealot to reach them, otherwise they will die almost immediately. 67 | Another example is _so_many_banelings_ where 7 allied Zealots face 32 enemy Baneling units. Banelings attack by running against a target and exploding when reaching them, causing damage to a certain area around the target. Hence, if a large number of Banelings attack a handful of Zealots located close to each other, the Zealots will be defeated instantly. The optimal strategy, therefore, is to cooperatively spread out around the map far from each other so that the Banelings' damage is distributed as thinly as possible. 68 | The _corridor_ scenario, in which 6 friendly Zealots face 24 enemy Zerglings, requires agents to make effective use of the terrain features. Specifically, agents should collectively wall off the choke point (the narrow region of the map) to block enemy attacks from different directions. Some of the micro-trick challenges are inspired by [StarCraft Master](http://us.battle.net/sc2/en/blog/4544189/new-blizzard-custom-game-starcraft-master-3-1-2012) challenge missions released by Blizzard. 69 | 70 | The complete list of challenges is presented bellow. The difficulty of the game AI is set to _very difficult_ (7). Our experiments, however, suggest that this setting does significantly impact the unit micromanagement of the built-in heuristics. 71 | 72 | | Name | Ally Units | Enemy Units | Type | 73 | | :---: | :---: | :---: | :---:| 74 | | 3m | 3 Marines | 3 Marines | homogeneous & symmetric | 75 | | 8m | 8 Marines | 8 Marines | homogeneous & symmetric | 76 | | 25m | 25 Marines | 25 Marines | homogeneous & symmetric | 77 | | 2s3z | 2 Stalkers & 3 Zealots | 2 Stalkers & 3 Zealots | heterogeneous & symmetric | 78 | | 3s5z | 3 Stalkers & 5 Zealots | 3 Stalkers & 5 Zealots | heterogeneous & symmetric | 79 | | MMM | 1 Medivac, 2 Marauders & 7 Marines | 1 Medivac, 2 Marauders & 7 Marines | heterogeneous & symmetric | 80 | | 5m_vs_6m | 5 Marines | 6 Marines | homogeneous & asymmetric | 81 | | 8m_vs_9m | 8 Marines | 9 Marines | homogeneous & asymmetric | 82 | | 10m_vs_11m | 10 Marines | 11 Marines | homogeneous & asymmetric | 83 | | 27m_vs_30m | 27 Marines | 30 Marines | homogeneous & asymmetric | 84 | | 3s5z_vs_3s6z | 3 Stalkers & 5 Zealots | 3 Stalkers & 6 Zealots | heterogeneous & asymmetric | 85 | | MMM2 | 1 Medivac, 2 Marauders & 7 Marines | 1 Medivac, 3 Marauders & 8 Marines | heterogeneous & asymmetric | 86 | | 2m_vs_1z | 2 Marines | 1 Zealot | micro-trick: alternating fire | 87 | | 2s_vs_1sc| 2 Stalkers | 1 Spine Crawler | micro-trick: alternating fire | 88 | | 3s_vs_3z | 3 Stalkers | 3 Zealots | micro-trick: kiting | 89 | | 3s_vs_4z | 3 Stalkers | 4 Zealots | micro-trick: kiting | 90 | | 3s_vs_5z | 3 Stalkers | 5 Zealots | micro-trick: kiting | 91 | | 6h_vs_8z | 6 Hydralisks | 8 Zealots | micro-trick: focus fire | 92 | | corridor | 6 Zealots | 24 Zerglings | micro-trick: wall off | 93 | | bane_vs_bane | 20 Zerglings & 4 Banelings | 20 Zerglings & 4 Banelings | micro-trick: positioning | 94 | | so_many_banelings| 7 Zealots | 32 Banelings | micro-trick: positioning | 95 | | 2c_vs_64zg| 2 Colossi | 64 Zerglings | micro-trick: positioning | 96 | | 1c3s5z | 1 Colossi & 3 Stalkers & 5 Zealots | 1 Colossi & 3 Stalkers & 5 Zealots | heterogeneous & symmetric | 97 | 98 | ### State and Observations 99 | 100 | At each timestep, agents receive local observations drawn within their field of view. This encompasses information about the map within a circular area around each unit and with a radius equal to the _sight range_. The sight range makes the environment partially observable from the standpoint of each agent. Agents can only observe other agents if they are both alive and located within the sight range. Hence, there is no way for agents to determine whether their teammates are far away or dead. 101 | 102 | The feature vector observed by each agent contains the following attributes for both allied and enemy units within the sight range: _distance_, _relative x_, _relative y_, _health_, _shield_, and _unit\_type_ [1](#myfootnote1). Shields serve as an additional source of protection that needs to be removed before any damage can be done to the health of units. 103 | All Protos units have shields, which can regenerate if no new damage is dealt 104 | (units of the other two races do not have this attribute). 105 | In addition, agents have access to the last actions of allied units that are in the field of view. Lastly, agents can observe the terrain features surrounding them; particularly, the values of eight points at a fixed radius indicating height and walkability. 106 | 107 | The global state, which is only available to agents during centralised training, contains information about all units on the map. Specifically, the state vector includes the coordinates of all agents relative to the centre of the map, together with unit features present in the observations. Additionally, the state stores the _energy_ of Medivacs and _cooldown_ of the rest of allied units, which represents the minimum delay between attacks. Finally, the last actions of all agents are attached to the central state. 108 | 109 | All features, both in the state as well as in the observations of individual agents, are normalised by their maximum values. The sight range is set to 9 for all agents. 110 | 111 | ### Action Space 112 | 113 | The discrete set of actions which agents are allowed to take consists of _move[direction]_ (four directions: north, south, east, or west), _attack[enemy_id]_, _stop_ and _no-op_. Dead agents can only take _no-op_ action while live agents cannot. 114 | As healer units, Medivacs must use _heal[agent\_id]_ actions instead of _attack[enemy\_id]_. The maximum number of actions an agent can take ranges between 7 and 70, depending on the scenario. 115 | 116 | To ensure decentralisation of the task, agents are restricted to use the _attack[enemy\_id]_ action only towards enemies in their _shooting range_. 117 | This additionally constrains the unit's ability to use the built-in _attack-move_ macro-actions on the enemies far away. We set the shooting range equal to 6. Having a larger sight than shooting range forces agents to make use of move commands before starting to fire. 118 | 119 | ### Rewards 120 | 121 | The overall goal is to have the highest win rate for each battle scenario. 122 | We provide a corresponding option for _sparse rewards_, which will cause the environment to return only a reward of +1 for winning and -1 for losing an episode. 123 | However, we also provide a default setting for a shaped reward signal calculated from the hit-point damage dealt and received by agents, some positive (negative) reward after having enemy (allied) units killed and/or a positive (negative) bonus for winning (losing) the battle. 124 | The exact values and scales of this shaped reward can be configured using a range of flags, but we strongly discourage disingenuous engineering of the reward function (e.g. tuning different reward functions for different scenarios). 125 | 126 | ### Environment Settings 127 | 128 | SMAC makes use of the [StarCraft II Learning Environment](https://arxiv.org/abs/1708.04782) (SC2LE) to communicate with the StarCraft II engine. SC2LE provides full control of the game by allowing to send commands and receive observations from the game. However, SMAC is conceptually different from the RL environment of SC2LE. The goal of SC2LE is to learn to play the full game of StarCraft II. This is a competitive task where a centralised RL agent receives RGB pixels as input and performs both macro and micro with the player-level control similar to human players. SMAC, on the other hand, represents a set of cooperative multi-agent micro challenges where each learning agent controls a single military unit. 129 | 130 | SMAC uses the _raw API_ of SC2LE. Raw API observations do not have any graphical component and include information about the units on the map such as health, location coordinates, etc. The raw API also allows sending action commands to individual units using their unit IDs. This setting differs from how humans play the actual game, but is convenient for designing decentralised multi-agent learning tasks. 131 | 132 | Since our micro-scenarios are shorter than actual StarCraft II games, restarting the game after each episode presents a computational bottleneck. To overcome this issue, we make use of the API's debug commands. Specifically, when all units of either army have been killed, we kill all remaining units by sending a debug action. Having no units left launches a trigger programmed with the StarCraft II Editor that re-spawns all units in their original location with full health, thereby restarting the scenario quickly and efficiently. 133 | 134 | Furthermore, to encourage agents to explore interesting micro-strategies themselves, we limit the influence of the StarCraft AI on our agents. Specifically we disable the automatic unit attack against enemies that attack agents or that are located nearby. 135 | To do so, we make use of new units created with the StarCraft II Editor that are exact copies of existing units with two attributes modified: _Combat: Default Acquire Level_ is set to _Passive_ (default _Offensive_) and _Behaviour: Response_ is set to _No Response_ (default _Acquire_). These fields are only modified for allied units; enemy units are unchanged. 136 | 137 | The sight and shooting range values might differ from the built-in _sight_ or _range_ attribute of some StarCraft II units. Our goal is not to master the original full StarCraft game, but rather to benchmark MARL methods for decentralised control. 138 | 139 | 1: _health_, _shield_ and _unit\_type_ of the unit the agent controls is also included in observations 140 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.black] 2 | line-length = 79 3 | include = '\.pyi?$' 4 | exclude = ''' 5 | /( 6 | \.git 7 | | \.hg 8 | | \.mypy_cache 9 | | \.tox 10 | | \.venv 11 | | _build 12 | | buck-out 13 | | build 14 | | dist 15 | )/ 16 | ''' 17 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | from setuptools import setup 6 | 7 | description = """SMAC - StarCraft Multi-Agent Challenge 8 | 9 | SMAC offers a diverse set of decentralised micromanagement challenges based on 10 | StarCraft II game. In these challenges, each of the units is controlled by an 11 | independent, learning agent that has to act based only on local observations, 12 | while the opponent's units are controlled by the built-in StarCraft II AI. 13 | 14 | The accompanying paper which outlines the motivation for using SMAC as well as 15 | results using the state-of-the-art deep multi-agent reinforcement learning 16 | algorithms can be found at https://www.arxiv.link 17 | 18 | Read the README at https://github.com/oxwhirl/smac for more information. 19 | """ 20 | 21 | extras_deps = { 22 | "dev": [ 23 | "pre-commit>=2.0.1", 24 | "black>=19.10b0", 25 | "flake8>=3.7", 26 | "flake8-bugbear>=20.1", 27 | ], 28 | } 29 | 30 | 31 | setup( 32 | name="SMAC", 33 | version="1.0.0", 34 | description="SMAC - StarCraft Multi-Agent Challenge.", 35 | long_description=description, 36 | author="WhiRL", 37 | author_email="mikayel@samvelyan.com", 38 | license="MIT License", 39 | keywords="StarCraft, Multi-Agent Reinforcement Learning", 40 | url="https://github.com/oxwhirl/smac", 41 | packages=[ 42 | "smac", 43 | "smac.env", 44 | "smac.env.starcraft2", 45 | "smac.env.starcraft2.maps", 46 | "smac.env.pettingzoo", 47 | "smac.bin", 48 | "smac.examples", 49 | "smac.examples.rllib", 50 | "smac.examples.pettingzoo", 51 | ], 52 | extras_require=extras_deps, 53 | install_requires=[ 54 | "protobuf<3.21", 55 | "pysc2>=3.0.0", 56 | "s2clientprotocol>=4.10.1.75800.0", 57 | "absl-py>=0.1.0", 58 | "numpy>=1.10", 59 | "pygame>=2.0.0", 60 | ], 61 | ) 62 | -------------------------------------------------------------------------------- /smac/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/oxwhirl/smac/d6aab33f76abc3849c50463a8592a84f59a5ef84/smac/__init__.py -------------------------------------------------------------------------------- /smac/bin/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/oxwhirl/smac/d6aab33f76abc3849c50463a8592a84f59a5ef84/smac/bin/__init__.py -------------------------------------------------------------------------------- /smac/bin/map_list.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | from smac.env.starcraft2.maps import smac_maps 6 | 7 | from pysc2 import maps as pysc2_maps 8 | 9 | 10 | def main(): 11 | smac_map_registry = smac_maps.get_smac_map_registry() 12 | all_maps = pysc2_maps.get_maps() 13 | print("{:<15} {:7} {:7} {:7}".format("Name", "Agents", "Enemies", "Limit")) 14 | for map_name, map_params in smac_map_registry.items(): 15 | map_class = all_maps[map_name] 16 | if map_class.path: 17 | print( 18 | "{:<15} {:<7} {:<7} {:<7}".format( 19 | map_name, 20 | map_params["n_agents"], 21 | map_params["n_enemies"], 22 | map_params["limit"], 23 | ) 24 | ) 25 | 26 | 27 | if __name__ == "__main__": 28 | main() 29 | -------------------------------------------------------------------------------- /smac/env/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | from smac.env.multiagentenv import MultiAgentEnv 6 | from smac.env.starcraft2.starcraft2 import StarCraft2Env 7 | 8 | __all__ = ["MultiAgentEnv", "StarCraft2Env"] 9 | -------------------------------------------------------------------------------- /smac/env/multiagentenv.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | 6 | class MultiAgentEnv(object): 7 | def step(self, actions): 8 | """Returns reward, terminated, info.""" 9 | raise NotImplementedError 10 | 11 | def get_obs(self): 12 | """Returns all agent observations in a list.""" 13 | raise NotImplementedError 14 | 15 | def get_obs_agent(self, agent_id): 16 | """Returns observation for agent_id.""" 17 | raise NotImplementedError 18 | 19 | def get_obs_size(self): 20 | """Returns the size of the observation.""" 21 | raise NotImplementedError 22 | 23 | def get_state(self): 24 | """Returns the global state.""" 25 | raise NotImplementedError 26 | 27 | def get_state_size(self): 28 | """Returns the size of the global state.""" 29 | raise NotImplementedError 30 | 31 | def get_avail_actions(self): 32 | """Returns the available actions of all agents in a list.""" 33 | raise NotImplementedError 34 | 35 | def get_avail_agent_actions(self, agent_id): 36 | """Returns the available actions for agent_id.""" 37 | raise NotImplementedError 38 | 39 | def get_total_actions(self): 40 | """Returns the total number of actions an agent could ever take.""" 41 | raise NotImplementedError 42 | 43 | def reset(self): 44 | """Returns initial observations and states.""" 45 | raise NotImplementedError 46 | 47 | def render(self): 48 | raise NotImplementedError 49 | 50 | def close(self): 51 | raise NotImplementedError 52 | 53 | def seed(self): 54 | raise NotImplementedError 55 | 56 | def save_replay(self): 57 | """Save a replay.""" 58 | raise NotImplementedError 59 | 60 | def get_env_info(self): 61 | env_info = { 62 | "state_shape": self.get_state_size(), 63 | "obs_shape": self.get_obs_size(), 64 | "n_actions": self.get_total_actions(), 65 | "n_agents": self.n_agents, 66 | "episode_limit": self.episode_limit, 67 | } 68 | return env_info 69 | -------------------------------------------------------------------------------- /smac/env/pettingzoo/StarCraft2PZEnv.py: -------------------------------------------------------------------------------- 1 | from smac.env import StarCraft2Env 2 | from smac.env.starcraft2.maps import get_map_params 3 | from gymnasium.utils import EzPickle 4 | from gymnasium.utils import seeding 5 | from gymnasium import spaces 6 | from pettingzoo.utils.env import ParallelEnv 7 | from pettingzoo.utils.conversions import ( 8 | parallel_to_aec as from_parallel_wrapper, 9 | ) 10 | from pettingzoo.utils import wrappers 11 | import numpy as np 12 | 13 | 14 | def parallel_env(max_cycles=None, **smac_args): 15 | if max_cycles is None: 16 | map_name = smac_args.get("map_name", "8m") 17 | max_cycles = get_map_params(map_name)["limit"] 18 | return _parallel_env(max_cycles, **smac_args) 19 | 20 | 21 | def raw_env(max_cycles=None, **smac_args): 22 | return from_parallel_wrapper(parallel_env(max_cycles, **smac_args)) 23 | 24 | 25 | def make_env(raw_env): 26 | def env_fn(**kwargs): 27 | env = raw_env(**kwargs) 28 | # env = wrappers.TerminateIllegalWrapper(env, illegal_reward=-1) 29 | env = wrappers.AssertOutOfBoundsWrapper(env) 30 | env = wrappers.OrderEnforcingWrapper(env) 31 | return env 32 | 33 | return env_fn 34 | 35 | 36 | class smac_parallel_env(ParallelEnv): 37 | def __init__(self, env, max_cycles): 38 | self.max_cycles = max_cycles 39 | self.env = env 40 | self.env.reset() 41 | self.reset_flag = 0 42 | self.agents, self.action_spaces = self._init_agents() 43 | self.possible_agents = self.agents[:] 44 | 45 | observation_size = env.get_obs_size() 46 | self.observation_spaces = { 47 | name: spaces.Dict( 48 | { 49 | "observation": spaces.Box( 50 | low=-1, 51 | high=1, 52 | shape=(observation_size,), 53 | dtype="float32", 54 | ), 55 | "action_mask": spaces.Box( 56 | low=0, 57 | high=1, 58 | shape=(self.action_spaces[name].n,), 59 | dtype=np.int8, 60 | ), 61 | } 62 | ) 63 | for name in self.agents 64 | } 65 | state_size = env.get_state_size() 66 | self.state_space = spaces.Box(low=-1, high=1, shape=(state_size,), dtype="float32") 67 | self._reward = 0 68 | 69 | def observation_space(self, agent): 70 | return self.observation_spaces[agent] 71 | 72 | def action_space(self, agent): 73 | return self.action_spaces[agent] 74 | 75 | def _init_agents(self): 76 | last_type = "" 77 | agents = [] 78 | action_spaces = {} 79 | self.agents_id = {} 80 | i = 0 81 | for agent_id, agent_info in self.env.agents.items(): 82 | unit_action_space = spaces.Discrete( 83 | self.env.get_total_actions() - 1 84 | ) # no-op in dead units is not an action 85 | if agent_info.unit_type == self.env.marine_id: 86 | agent_type = "marine" 87 | elif agent_info.unit_type == self.env.marauder_id: 88 | agent_type = "marauder" 89 | elif agent_info.unit_type == self.env.medivac_id: 90 | agent_type = "medivac" 91 | elif agent_info.unit_type == self.env.hydralisk_id: 92 | agent_type = "hydralisk" 93 | elif agent_info.unit_type == self.env.zergling_id: 94 | agent_type = "zergling" 95 | elif agent_info.unit_type == self.env.baneling_id: 96 | agent_type = "baneling" 97 | elif agent_info.unit_type == self.env.stalker_id: 98 | agent_type = "stalker" 99 | elif agent_info.unit_type == self.env.colossus_id: 100 | agent_type = "colossus" 101 | elif agent_info.unit_type == self.env.zealot_id: 102 | agent_type = "zealot" 103 | else: 104 | raise AssertionError(f"agent type {agent_type} not supported") 105 | 106 | if agent_type == last_type: 107 | i += 1 108 | else: 109 | i = 0 110 | 111 | agents.append(f"{agent_type}_{i}") 112 | self.agents_id[agents[-1]] = agent_id 113 | action_spaces[agents[-1]] = unit_action_space 114 | last_type = agent_type 115 | 116 | return agents, action_spaces 117 | 118 | def seed(self, seed=None): 119 | if seed is None: 120 | self.env._seed = seeding.create_seed(seed, max_bytes=4) 121 | else: 122 | self.env._seed = seed 123 | self.env.full_restart() 124 | 125 | def render(self, mode="human"): 126 | self.env.render(mode) 127 | 128 | def close(self): 129 | self.env.close() 130 | 131 | def reset(self, seed=None, options=None): 132 | self.env._episode_count = 1 133 | self.env.reset() 134 | 135 | self.agents = self.possible_agents[:] 136 | self.frames = 0 137 | self.terminations = {agent: False for agent in self.possible_agents} 138 | self.truncations = {agent: False for agent in self.possible_agents} 139 | return self._observe_all() 140 | 141 | def get_agent_smac_id(self, agent): 142 | return self.agents_id[agent] 143 | 144 | def _all_rewards(self, reward): 145 | all_rewards = [reward] * len(self.agents) 146 | return { 147 | agent: reward for agent, reward in zip(self.agents, all_rewards) 148 | } 149 | 150 | def _observe_all(self): 151 | all_obs = [] 152 | for agent in self.agents: 153 | agent_id = self.get_agent_smac_id(agent) 154 | obs = self.env.get_obs_agent(agent_id) 155 | action_mask = self.env.get_avail_agent_actions(agent_id) 156 | action_mask = action_mask[1:] 157 | action_mask = np.array(action_mask).astype(np.int8) 158 | obs = np.asarray(obs, dtype=np.float32) 159 | all_obs.append({"observation": obs, "action_mask": action_mask}) 160 | return {agent: obs for agent, obs in zip(self.agents, all_obs)} 161 | 162 | def _all_terms_truncs(self, terminated=False, truncated=False): 163 | terminations = [True] * len(self.agents) 164 | 165 | if not terminated: 166 | for i, agent in enumerate(self.agents): 167 | agent_done = False 168 | agent_id = self.get_agent_smac_id(agent) 169 | agent_info = self.env.get_unit_by_id(agent_id) 170 | if agent_info.health == 0: 171 | agent_done = True 172 | terminations[i] = agent_done 173 | 174 | terminations = {a: bool(t) for a, t in zip(self.agents, terminations)} 175 | truncations = {a: truncated for a in self.agents} 176 | 177 | return terminations, truncations 178 | 179 | def step(self, all_actions): 180 | action_list = [0] * self.env.n_agents 181 | for agent in self.agents: 182 | agent_id = self.get_agent_smac_id(agent) 183 | if agent in all_actions: 184 | if all_actions[agent] is None: 185 | action_list[agent_id] = 0 186 | else: 187 | action_list[agent_id] = all_actions[agent] + 1 188 | self._reward, terminated, smac_info = self.env.step(action_list) 189 | self.frames += 1 190 | 191 | all_infos = {agent: {} for agent in self.agents} 192 | # all_infos.update(smac_info) 193 | all_terms, all_truncs = self._all_terms_truncs( 194 | terminated=terminated, truncated=(self.frames >= self.max_cycles) 195 | ) 196 | all_rewards = self._all_rewards(self._reward) 197 | all_observes = self._observe_all() 198 | 199 | self.agents = [agent for agent in self.agents if not all_terms[agent]] 200 | self.agents = [agent for agent in self.agents if not all_truncs[agent]] 201 | 202 | return all_observes, all_rewards, all_terms, all_truncs, all_infos 203 | 204 | def state(self): 205 | return self.env.get_state() 206 | 207 | def __del__(self): 208 | self.env.close() 209 | 210 | 211 | env = make_env(raw_env) 212 | 213 | 214 | class _parallel_env(smac_parallel_env, EzPickle): 215 | metadata = {"render.modes": ["human"], "name": "sc2"} 216 | 217 | def __init__(self, max_cycles, **smac_args): 218 | EzPickle.__init__(self, max_cycles, **smac_args) 219 | env = StarCraft2Env(**smac_args) 220 | super().__init__(env, max_cycles) 221 | -------------------------------------------------------------------------------- /smac/env/pettingzoo/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/oxwhirl/smac/d6aab33f76abc3849c50463a8592a84f59a5ef84/smac/env/pettingzoo/__init__.py -------------------------------------------------------------------------------- /smac/env/pettingzoo/test/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/oxwhirl/smac/d6aab33f76abc3849c50463a8592a84f59a5ef84/smac/env/pettingzoo/test/__init__.py -------------------------------------------------------------------------------- /smac/env/pettingzoo/test/all_test.py: -------------------------------------------------------------------------------- 1 | from smac.env.starcraft2.maps import smac_maps 2 | from pysc2 import maps as pysc2_maps 3 | from smac.env.pettingzoo import StarCraft2PZEnv as sc2 4 | import pytest 5 | from pettingzoo import test 6 | import pickle 7 | 8 | smac_map_registry = smac_maps.get_smac_map_registry() 9 | all_maps = pysc2_maps.get_maps() 10 | map_names = [] 11 | for map_name in smac_map_registry.keys(): 12 | map_class = all_maps[map_name] 13 | if map_class.path: 14 | map_names.append(map_name) 15 | 16 | 17 | @pytest.mark.parametrize(("map_name"), map_names) 18 | def test_env(map_name): 19 | env = sc2.env(map_name=map_name) 20 | test.api_test(env) 21 | # test.parallel_api_test(sc2_v0.parallel_env()) # does not pass it due to 22 | # illegal actions test.seed_test(sc2.env, 50) # not required, sc2 env only 23 | # allows reseeding at initialization 24 | test.render_test(env) 25 | 26 | recreated_env = pickle.loads(pickle.dumps(env)) 27 | test.api_test(recreated_env) 28 | -------------------------------------------------------------------------------- /smac/env/pettingzoo/test/smac_pettingzoo_test.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import inspect 4 | from pettingzoo import test 5 | from smac.env.pettingzoo import StarCraft2PZEnv as sc2 6 | import pickle 7 | 8 | current_dir = os.path.dirname( 9 | os.path.abspath(inspect.getfile(inspect.currentframe())) 10 | ) 11 | parent_dir = os.path.dirname(current_dir) 12 | sys.path.insert(0, parent_dir) 13 | 14 | 15 | if __name__ == "__main__": 16 | env = sc2.env(map_name="corridor") 17 | test.api_test(env) 18 | # test.parallel_api_test(sc2_v0.parallel_env()) # does not pass it due to 19 | # illegal actions test.seed_test(sc2_v0.env, 50) # not required, sc2 env 20 | # only allows reseeding at initialization 21 | 22 | recreated_env = pickle.loads(pickle.dumps(env)) 23 | test.api_test(recreated_env) 24 | -------------------------------------------------------------------------------- /smac/env/starcraft2/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | from absl import flags 6 | 7 | FLAGS = flags.FLAGS 8 | FLAGS(["main.py"]) 9 | -------------------------------------------------------------------------------- /smac/env/starcraft2/maps/SMAC_Maps/10m_vs_11m.SC2Map: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/oxwhirl/smac/d6aab33f76abc3849c50463a8592a84f59a5ef84/smac/env/starcraft2/maps/SMAC_Maps/10m_vs_11m.SC2Map -------------------------------------------------------------------------------- /smac/env/starcraft2/maps/SMAC_Maps/1c3s5z.SC2Map: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/oxwhirl/smac/d6aab33f76abc3849c50463a8592a84f59a5ef84/smac/env/starcraft2/maps/SMAC_Maps/1c3s5z.SC2Map -------------------------------------------------------------------------------- /smac/env/starcraft2/maps/SMAC_Maps/25m.SC2Map: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/oxwhirl/smac/d6aab33f76abc3849c50463a8592a84f59a5ef84/smac/env/starcraft2/maps/SMAC_Maps/25m.SC2Map -------------------------------------------------------------------------------- /smac/env/starcraft2/maps/SMAC_Maps/27m_vs_30m.SC2Map: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/oxwhirl/smac/d6aab33f76abc3849c50463a8592a84f59a5ef84/smac/env/starcraft2/maps/SMAC_Maps/27m_vs_30m.SC2Map -------------------------------------------------------------------------------- /smac/env/starcraft2/maps/SMAC_Maps/2c_vs_64zg.SC2Map: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/oxwhirl/smac/d6aab33f76abc3849c50463a8592a84f59a5ef84/smac/env/starcraft2/maps/SMAC_Maps/2c_vs_64zg.SC2Map -------------------------------------------------------------------------------- /smac/env/starcraft2/maps/SMAC_Maps/2m_vs_1z.SC2Map: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/oxwhirl/smac/d6aab33f76abc3849c50463a8592a84f59a5ef84/smac/env/starcraft2/maps/SMAC_Maps/2m_vs_1z.SC2Map -------------------------------------------------------------------------------- /smac/env/starcraft2/maps/SMAC_Maps/2s3z.SC2Map: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/oxwhirl/smac/d6aab33f76abc3849c50463a8592a84f59a5ef84/smac/env/starcraft2/maps/SMAC_Maps/2s3z.SC2Map -------------------------------------------------------------------------------- /smac/env/starcraft2/maps/SMAC_Maps/2s_vs_1sc.SC2Map: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/oxwhirl/smac/d6aab33f76abc3849c50463a8592a84f59a5ef84/smac/env/starcraft2/maps/SMAC_Maps/2s_vs_1sc.SC2Map -------------------------------------------------------------------------------- /smac/env/starcraft2/maps/SMAC_Maps/3m.SC2Map: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/oxwhirl/smac/d6aab33f76abc3849c50463a8592a84f59a5ef84/smac/env/starcraft2/maps/SMAC_Maps/3m.SC2Map -------------------------------------------------------------------------------- /smac/env/starcraft2/maps/SMAC_Maps/3s5z.SC2Map: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/oxwhirl/smac/d6aab33f76abc3849c50463a8592a84f59a5ef84/smac/env/starcraft2/maps/SMAC_Maps/3s5z.SC2Map -------------------------------------------------------------------------------- /smac/env/starcraft2/maps/SMAC_Maps/3s5z_vs_3s6z.SC2Map: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/oxwhirl/smac/d6aab33f76abc3849c50463a8592a84f59a5ef84/smac/env/starcraft2/maps/SMAC_Maps/3s5z_vs_3s6z.SC2Map -------------------------------------------------------------------------------- /smac/env/starcraft2/maps/SMAC_Maps/3s_vs_3z.SC2Map: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/oxwhirl/smac/d6aab33f76abc3849c50463a8592a84f59a5ef84/smac/env/starcraft2/maps/SMAC_Maps/3s_vs_3z.SC2Map -------------------------------------------------------------------------------- /smac/env/starcraft2/maps/SMAC_Maps/3s_vs_4z.SC2Map: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/oxwhirl/smac/d6aab33f76abc3849c50463a8592a84f59a5ef84/smac/env/starcraft2/maps/SMAC_Maps/3s_vs_4z.SC2Map -------------------------------------------------------------------------------- /smac/env/starcraft2/maps/SMAC_Maps/3s_vs_5z.SC2Map: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/oxwhirl/smac/d6aab33f76abc3849c50463a8592a84f59a5ef84/smac/env/starcraft2/maps/SMAC_Maps/3s_vs_5z.SC2Map -------------------------------------------------------------------------------- /smac/env/starcraft2/maps/SMAC_Maps/5m_vs_6m.SC2Map: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/oxwhirl/smac/d6aab33f76abc3849c50463a8592a84f59a5ef84/smac/env/starcraft2/maps/SMAC_Maps/5m_vs_6m.SC2Map -------------------------------------------------------------------------------- /smac/env/starcraft2/maps/SMAC_Maps/6h_vs_8z.SC2Map: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/oxwhirl/smac/d6aab33f76abc3849c50463a8592a84f59a5ef84/smac/env/starcraft2/maps/SMAC_Maps/6h_vs_8z.SC2Map -------------------------------------------------------------------------------- /smac/env/starcraft2/maps/SMAC_Maps/8m.SC2Map: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/oxwhirl/smac/d6aab33f76abc3849c50463a8592a84f59a5ef84/smac/env/starcraft2/maps/SMAC_Maps/8m.SC2Map -------------------------------------------------------------------------------- /smac/env/starcraft2/maps/SMAC_Maps/8m_vs_9m.SC2Map: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/oxwhirl/smac/d6aab33f76abc3849c50463a8592a84f59a5ef84/smac/env/starcraft2/maps/SMAC_Maps/8m_vs_9m.SC2Map -------------------------------------------------------------------------------- /smac/env/starcraft2/maps/SMAC_Maps/MMM.SC2Map: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/oxwhirl/smac/d6aab33f76abc3849c50463a8592a84f59a5ef84/smac/env/starcraft2/maps/SMAC_Maps/MMM.SC2Map -------------------------------------------------------------------------------- /smac/env/starcraft2/maps/SMAC_Maps/MMM2.SC2Map: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/oxwhirl/smac/d6aab33f76abc3849c50463a8592a84f59a5ef84/smac/env/starcraft2/maps/SMAC_Maps/MMM2.SC2Map -------------------------------------------------------------------------------- /smac/env/starcraft2/maps/SMAC_Maps/bane_vs_bane.SC2Map: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/oxwhirl/smac/d6aab33f76abc3849c50463a8592a84f59a5ef84/smac/env/starcraft2/maps/SMAC_Maps/bane_vs_bane.SC2Map -------------------------------------------------------------------------------- /smac/env/starcraft2/maps/SMAC_Maps/corridor.SC2Map: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/oxwhirl/smac/d6aab33f76abc3849c50463a8592a84f59a5ef84/smac/env/starcraft2/maps/SMAC_Maps/corridor.SC2Map -------------------------------------------------------------------------------- /smac/env/starcraft2/maps/SMAC_Maps/so_many_baneling.SC2Map: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/oxwhirl/smac/d6aab33f76abc3849c50463a8592a84f59a5ef84/smac/env/starcraft2/maps/SMAC_Maps/so_many_baneling.SC2Map -------------------------------------------------------------------------------- /smac/env/starcraft2/maps/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | from smac.env.starcraft2.maps import smac_maps 6 | 7 | 8 | def get_map_params(map_name): 9 | map_param_registry = smac_maps.get_smac_map_registry() 10 | return map_param_registry[map_name] 11 | -------------------------------------------------------------------------------- /smac/env/starcraft2/maps/smac_maps.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | from pysc2.maps import lib 6 | 7 | 8 | class SMACMap(lib.Map): 9 | directory = "SMAC_Maps" 10 | download = "https://github.com/oxwhirl/smac#smac-maps" 11 | players = 2 12 | step_mul = 8 13 | game_steps_per_episode = 0 14 | 15 | 16 | map_param_registry = { 17 | "3m": { 18 | "n_agents": 3, 19 | "n_enemies": 3, 20 | "limit": 60, 21 | "a_race": "T", 22 | "b_race": "T", 23 | "unit_type_bits": 0, 24 | "map_type": "marines", 25 | }, 26 | "8m": { 27 | "n_agents": 8, 28 | "n_enemies": 8, 29 | "limit": 120, 30 | "a_race": "T", 31 | "b_race": "T", 32 | "unit_type_bits": 0, 33 | "map_type": "marines", 34 | }, 35 | "25m": { 36 | "n_agents": 25, 37 | "n_enemies": 25, 38 | "limit": 150, 39 | "a_race": "T", 40 | "b_race": "T", 41 | "unit_type_bits": 0, 42 | "map_type": "marines", 43 | }, 44 | "5m_vs_6m": { 45 | "n_agents": 5, 46 | "n_enemies": 6, 47 | "limit": 70, 48 | "a_race": "T", 49 | "b_race": "T", 50 | "unit_type_bits": 0, 51 | "map_type": "marines", 52 | }, 53 | "8m_vs_9m": { 54 | "n_agents": 8, 55 | "n_enemies": 9, 56 | "limit": 120, 57 | "a_race": "T", 58 | "b_race": "T", 59 | "unit_type_bits": 0, 60 | "map_type": "marines", 61 | }, 62 | "10m_vs_11m": { 63 | "n_agents": 10, 64 | "n_enemies": 11, 65 | "limit": 150, 66 | "a_race": "T", 67 | "b_race": "T", 68 | "unit_type_bits": 0, 69 | "map_type": "marines", 70 | }, 71 | "27m_vs_30m": { 72 | "n_agents": 27, 73 | "n_enemies": 30, 74 | "limit": 180, 75 | "a_race": "T", 76 | "b_race": "T", 77 | "unit_type_bits": 0, 78 | "map_type": "marines", 79 | }, 80 | "MMM": { 81 | "n_agents": 10, 82 | "n_enemies": 10, 83 | "limit": 150, 84 | "a_race": "T", 85 | "b_race": "T", 86 | "unit_type_bits": 3, 87 | "map_type": "MMM", 88 | }, 89 | "MMM2": { 90 | "n_agents": 10, 91 | "n_enemies": 12, 92 | "limit": 180, 93 | "a_race": "T", 94 | "b_race": "T", 95 | "unit_type_bits": 3, 96 | "map_type": "MMM", 97 | }, 98 | "2s3z": { 99 | "n_agents": 5, 100 | "n_enemies": 5, 101 | "limit": 120, 102 | "a_race": "P", 103 | "b_race": "P", 104 | "unit_type_bits": 2, 105 | "map_type": "stalkers_and_zealots", 106 | }, 107 | "3s5z": { 108 | "n_agents": 8, 109 | "n_enemies": 8, 110 | "limit": 150, 111 | "a_race": "P", 112 | "b_race": "P", 113 | "unit_type_bits": 2, 114 | "map_type": "stalkers_and_zealots", 115 | }, 116 | "3s5z_vs_3s6z": { 117 | "n_agents": 8, 118 | "n_enemies": 9, 119 | "limit": 170, 120 | "a_race": "P", 121 | "b_race": "P", 122 | "unit_type_bits": 2, 123 | "map_type": "stalkers_and_zealots", 124 | }, 125 | "3s_vs_3z": { 126 | "n_agents": 3, 127 | "n_enemies": 3, 128 | "limit": 150, 129 | "a_race": "P", 130 | "b_race": "P", 131 | "unit_type_bits": 0, 132 | "map_type": "stalkers", 133 | }, 134 | "3s_vs_4z": { 135 | "n_agents": 3, 136 | "n_enemies": 4, 137 | "limit": 200, 138 | "a_race": "P", 139 | "b_race": "P", 140 | "unit_type_bits": 0, 141 | "map_type": "stalkers", 142 | }, 143 | "3s_vs_5z": { 144 | "n_agents": 3, 145 | "n_enemies": 5, 146 | "limit": 250, 147 | "a_race": "P", 148 | "b_race": "P", 149 | "unit_type_bits": 0, 150 | "map_type": "stalkers", 151 | }, 152 | "1c3s5z": { 153 | "n_agents": 9, 154 | "n_enemies": 9, 155 | "limit": 180, 156 | "a_race": "P", 157 | "b_race": "P", 158 | "unit_type_bits": 3, 159 | "map_type": "colossi_stalkers_zealots", 160 | }, 161 | "2m_vs_1z": { 162 | "n_agents": 2, 163 | "n_enemies": 1, 164 | "limit": 150, 165 | "a_race": "T", 166 | "b_race": "P", 167 | "unit_type_bits": 0, 168 | "map_type": "marines", 169 | }, 170 | "corridor": { 171 | "n_agents": 6, 172 | "n_enemies": 24, 173 | "limit": 400, 174 | "a_race": "P", 175 | "b_race": "Z", 176 | "unit_type_bits": 0, 177 | "map_type": "zealots", 178 | }, 179 | "6h_vs_8z": { 180 | "n_agents": 6, 181 | "n_enemies": 8, 182 | "limit": 150, 183 | "a_race": "Z", 184 | "b_race": "P", 185 | "unit_type_bits": 0, 186 | "map_type": "hydralisks", 187 | }, 188 | "2s_vs_1sc": { 189 | "n_agents": 2, 190 | "n_enemies": 1, 191 | "limit": 300, 192 | "a_race": "P", 193 | "b_race": "Z", 194 | "unit_type_bits": 0, 195 | "map_type": "stalkers", 196 | }, 197 | "so_many_baneling": { 198 | "n_agents": 7, 199 | "n_enemies": 32, 200 | "limit": 100, 201 | "a_race": "P", 202 | "b_race": "Z", 203 | "unit_type_bits": 0, 204 | "map_type": "zealots", 205 | }, 206 | "bane_vs_bane": { 207 | "n_agents": 24, 208 | "n_enemies": 24, 209 | "limit": 200, 210 | "a_race": "Z", 211 | "b_race": "Z", 212 | "unit_type_bits": 2, 213 | "map_type": "bane", 214 | }, 215 | "2c_vs_64zg": { 216 | "n_agents": 2, 217 | "n_enemies": 64, 218 | "limit": 400, 219 | "a_race": "P", 220 | "b_race": "Z", 221 | "unit_type_bits": 0, 222 | "map_type": "colossus", 223 | }, 224 | } 225 | 226 | 227 | def get_smac_map_registry(): 228 | return map_param_registry 229 | 230 | 231 | for name in map_param_registry.keys(): 232 | globals()[name] = type(name, (SMACMap,), dict(filename=name)) 233 | -------------------------------------------------------------------------------- /smac/env/starcraft2/render.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import re 3 | import subprocess 4 | import platform 5 | from absl import logging 6 | import math 7 | import time 8 | import collections 9 | import os 10 | import pygame 11 | import queue 12 | 13 | from pysc2.lib import colors 14 | from pysc2.lib import point 15 | from pysc2.lib.renderer_human import _Surface 16 | from pysc2.lib import transform 17 | from pysc2.lib import features 18 | 19 | 20 | def clamp(n, smallest, largest): 21 | return max(smallest, min(n, largest)) 22 | 23 | 24 | def _get_desktop_size(): 25 | """Get the desktop size.""" 26 | if platform.system() == "Linux": 27 | try: 28 | xrandr_query = subprocess.check_output(["xrandr", "--query"]) 29 | sizes = re.findall( 30 | r"\bconnected primary (\d+)x(\d+)", str(xrandr_query) 31 | ) 32 | if sizes[0]: 33 | return point.Point(int(sizes[0][0]), int(sizes[0][1])) 34 | except ValueError: 35 | logging.error("Failed to get the resolution from xrandr.") 36 | 37 | # Most general, but doesn't understand multiple monitors. 38 | display_info = pygame.display.Info() 39 | return point.Point(display_info.current_w, display_info.current_h) 40 | 41 | 42 | class StarCraft2Renderer: 43 | def __init__(self, env, mode): 44 | os.environ["PYGAME_HIDE_SUPPORT_PROMPT"] = "hide" 45 | 46 | self.env = env 47 | self.mode = mode 48 | self.obs = None 49 | self._window_scale = 0.75 50 | self.game_info = game_info = self.env._controller.game_info() 51 | self.static_data = self.env._controller.data() 52 | 53 | self._obs_queue = queue.Queue() 54 | self._game_times = collections.deque( 55 | maxlen=100 56 | ) # Avg FPS over 100 frames. # pytype: disable=wrong-keyword-args 57 | self._render_times = collections.deque( 58 | maxlen=100 59 | ) # pytype: disable=wrong-keyword-args 60 | self._last_time = time.time() 61 | self._last_game_loop = 0 62 | self._name_lengths = {} 63 | 64 | self._map_size = point.Point.build(game_info.start_raw.map_size) 65 | self._playable = point.Rect( 66 | point.Point.build(game_info.start_raw.playable_area.p0), 67 | point.Point.build(game_info.start_raw.playable_area.p1), 68 | ) 69 | 70 | window_size_px = point.Point( 71 | self.env.window_size[0], self.env.window_size[1] 72 | ) 73 | window_size_px = self._map_size.scale_max_size( 74 | window_size_px * self._window_scale 75 | ).ceil() 76 | self._scale = window_size_px.y // 32 77 | 78 | self.display = pygame.Surface(window_size_px) 79 | 80 | if mode == "human": 81 | self.display = pygame.display.set_mode(window_size_px, 0, 32) 82 | pygame.display.init() 83 | 84 | pygame.display.set_caption("Starcraft Viewer") 85 | pygame.font.init() 86 | self._world_to_world_tl = transform.Linear( 87 | point.Point(1, -1), point.Point(0, self._map_size.y) 88 | ) 89 | self._world_tl_to_screen = transform.Linear(scale=window_size_px / 32) 90 | self.screen_transform = transform.Chain( 91 | self._world_to_world_tl, self._world_tl_to_screen 92 | ) 93 | 94 | surf_loc = point.Rect(point.origin, window_size_px) 95 | sub_surf = self.display.subsurface( 96 | pygame.Rect(surf_loc.tl, surf_loc.size) 97 | ) 98 | self._surf = _Surface( 99 | sub_surf, 100 | None, 101 | surf_loc, 102 | self.screen_transform, 103 | None, 104 | self.draw_screen, 105 | ) 106 | 107 | self._font_small = pygame.font.Font(None, int(self._scale * 0.5)) 108 | self._font_large = pygame.font.Font(None, self._scale) 109 | 110 | def close(self): 111 | pygame.display.quit() 112 | pygame.quit() 113 | 114 | def _get_units(self): 115 | for u in sorted( 116 | self.obs.observation.raw_data.units, 117 | key=lambda u: (u.pos.z, u.owner != 16, -u.radius, u.tag), 118 | ): 119 | yield u, point.Point.build(u.pos) 120 | 121 | def get_unit_name(self, surf, name, radius): 122 | """Get a length limited unit name for drawing units.""" 123 | key = (name, radius) 124 | if key not in self._name_lengths: 125 | max_len = surf.world_to_surf.fwd_dist(radius * 1.6) 126 | for i in range(len(name)): 127 | if self._font_small.size(name[: i + 1])[0] > max_len: 128 | self._name_lengths[key] = name[:i] 129 | break 130 | else: 131 | self._name_lengths[key] = name 132 | return self._name_lengths[key] 133 | 134 | def render(self, mode): 135 | self.obs = self.env._obs 136 | self.score = self.env.reward 137 | self.step = self.env._episode_steps 138 | 139 | now = time.time() 140 | self._game_times.append( 141 | ( 142 | now - self._last_time, 143 | max( 144 | 1, 145 | self.obs.observation.game_loop 146 | - self.obs.observation.game_loop, 147 | ), 148 | ) 149 | ) 150 | 151 | if mode == "human": 152 | pygame.event.pump() 153 | 154 | self._surf.draw(self._surf) 155 | 156 | observation = np.array(pygame.surfarray.pixels3d(self.display)) 157 | 158 | if mode == "human": 159 | pygame.display.flip() 160 | 161 | self._last_time = now 162 | self._last_game_loop = self.obs.observation.game_loop 163 | # self._obs_queue.put(self.obs) 164 | return ( 165 | np.transpose(observation, axes=(1, 0, 2)) 166 | if mode == "rgb_array" 167 | else None 168 | ) 169 | 170 | def draw_base_map(self, surf): 171 | """Draw the base map.""" 172 | hmap_feature = features.SCREEN_FEATURES.height_map 173 | hmap = self.env.terrain_height * 255 174 | hmap = hmap.astype(np.uint8) 175 | if ( 176 | self.env.map_name == "corridor" 177 | or self.env.map_name == "so_many_baneling" 178 | or self.env.map_name == "2s_vs_1sc" 179 | ): 180 | hmap = np.flip(hmap) 181 | else: 182 | hmap = np.rot90(hmap, axes=(1, 0)) 183 | if not hmap.any(): 184 | hmap = hmap + 100 # pylint: disable=g-no-augmented-assignment 185 | hmap_color = hmap_feature.color(hmap) 186 | out = hmap_color * 0.6 187 | 188 | surf.blit_np_array(out) 189 | 190 | def draw_units(self, surf): 191 | """Draw the units.""" 192 | unit_dict = None # Cache the units {tag: unit_proto} for orders. 193 | tau = 2 * math.pi 194 | for u, p in self._get_units(): 195 | fraction_damage = clamp( 196 | (u.health_max - u.health) / (u.health_max or 1), 0, 1 197 | ) 198 | surf.draw_circle( 199 | colors.PLAYER_ABSOLUTE_PALETTE[u.owner], p, u.radius 200 | ) 201 | 202 | if fraction_damage > 0: 203 | surf.draw_circle( 204 | colors.PLAYER_ABSOLUTE_PALETTE[u.owner] // 2, 205 | p, 206 | u.radius * fraction_damage, 207 | ) 208 | surf.draw_circle(colors.black, p, u.radius, thickness=1) 209 | 210 | if self.static_data.unit_stats[u.unit_type].movement_speed > 0: 211 | surf.draw_arc( 212 | colors.white, 213 | p, 214 | u.radius, 215 | u.facing - 0.1, 216 | u.facing + 0.1, 217 | thickness=1, 218 | ) 219 | 220 | def draw_arc_ratio( 221 | color, world_loc, radius, start, end, thickness=1 222 | ): 223 | surf.draw_arc( 224 | color, world_loc, radius, start * tau, end * tau, thickness 225 | ) 226 | 227 | if u.shield and u.shield_max: 228 | draw_arc_ratio( 229 | colors.blue, p, u.radius - 0.05, 0, u.shield / u.shield_max 230 | ) 231 | 232 | if u.energy and u.energy_max: 233 | draw_arc_ratio( 234 | colors.purple * 0.9, 235 | p, 236 | u.radius - 0.1, 237 | 0, 238 | u.energy / u.energy_max, 239 | ) 240 | elif u.orders and 0 < u.orders[0].progress < 1: 241 | draw_arc_ratio( 242 | colors.cyan, p, u.radius - 0.15, 0, u.orders[0].progress 243 | ) 244 | if u.buff_duration_remain and u.buff_duration_max: 245 | draw_arc_ratio( 246 | colors.white, 247 | p, 248 | u.radius - 0.2, 249 | 0, 250 | u.buff_duration_remain / u.buff_duration_max, 251 | ) 252 | if u.attack_upgrade_level: 253 | draw_arc_ratio( 254 | self.upgrade_colors[u.attack_upgrade_level], 255 | p, 256 | u.radius - 0.25, 257 | 0.18, 258 | 0.22, 259 | thickness=3, 260 | ) 261 | if u.armor_upgrade_level: 262 | draw_arc_ratio( 263 | self.upgrade_colors[u.armor_upgrade_level], 264 | p, 265 | u.radius - 0.25, 266 | 0.23, 267 | 0.27, 268 | thickness=3, 269 | ) 270 | if u.shield_upgrade_level: 271 | draw_arc_ratio( 272 | self.upgrade_colors[u.shield_upgrade_level], 273 | p, 274 | u.radius - 0.25, 275 | 0.28, 276 | 0.32, 277 | thickness=3, 278 | ) 279 | 280 | def write_small(loc, s): 281 | surf.write_world(self._font_small, colors.white, loc, str(s)) 282 | 283 | name = self.get_unit_name( 284 | surf, 285 | self.static_data.units.get(u.unit_type, ""), 286 | u.radius, 287 | ) 288 | 289 | if name: 290 | write_small(p, name) 291 | 292 | start_point = p 293 | for o in u.orders: 294 | target_point = None 295 | if o.HasField("target_unit_tag"): 296 | if unit_dict is None: 297 | unit_dict = { 298 | t.tag: t 299 | for t in self.obs.observation.raw_data.units 300 | } 301 | target_unit = unit_dict.get(o.target_unit_tag) 302 | if target_unit: 303 | target_point = point.Point.build(target_unit.pos) 304 | if target_point: 305 | surf.draw_line(colors.cyan, start_point, target_point) 306 | start_point = target_point 307 | else: 308 | break 309 | 310 | def draw_overlay(self, surf): 311 | """Draw the overlay describing resources.""" 312 | obs = self.obs.observation 313 | times, steps = zip(*self._game_times) 314 | sec = obs.game_loop // 22.4 315 | surf.write_screen( 316 | self._font_large, 317 | colors.green, 318 | (-0.2, 0.2), 319 | "Score: %s, Step: %s, %.1f/s, Time: %d:%02d" 320 | % ( 321 | self.score, 322 | self.step, 323 | sum(steps) / (sum(times) or 1), 324 | sec // 60, 325 | sec % 60, 326 | ), 327 | align="right", 328 | ) 329 | surf.write_screen( 330 | self._font_large, 331 | colors.green * 0.8, 332 | (-0.2, 1.2), 333 | "APM: %d, EPM: %d, FPS: O:%.1f, R:%.1f" 334 | % ( 335 | obs.score.score_details.current_apm, 336 | obs.score.score_details.current_effective_apm, 337 | len(times) / (sum(times) or 1), 338 | len(self._render_times) / (sum(self._render_times) or 1), 339 | ), 340 | align="right", 341 | ) 342 | 343 | def draw_screen(self, surf): 344 | """Draw the screen area.""" 345 | self.draw_base_map(surf) 346 | self.draw_units(surf) 347 | self.draw_overlay(surf) 348 | -------------------------------------------------------------------------------- /smac/env/starcraft2/starcraft2.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | from smac.env.multiagentenv import MultiAgentEnv 6 | from smac.env.starcraft2.maps import get_map_params 7 | 8 | import atexit 9 | from warnings import warn 10 | from operator import attrgetter 11 | from copy import deepcopy 12 | import numpy as np 13 | import enum 14 | import math 15 | from absl import logging 16 | 17 | from pysc2 import maps 18 | from pysc2 import run_configs 19 | from pysc2.lib import protocol 20 | 21 | from s2clientprotocol import common_pb2 as sc_common 22 | from s2clientprotocol import sc2api_pb2 as sc_pb 23 | from s2clientprotocol import raw_pb2 as r_pb 24 | from s2clientprotocol import debug_pb2 as d_pb 25 | 26 | races = { 27 | "R": sc_common.Random, 28 | "P": sc_common.Protoss, 29 | "T": sc_common.Terran, 30 | "Z": sc_common.Zerg, 31 | } 32 | 33 | difficulties = { 34 | "1": sc_pb.VeryEasy, 35 | "2": sc_pb.Easy, 36 | "3": sc_pb.Medium, 37 | "4": sc_pb.MediumHard, 38 | "5": sc_pb.Hard, 39 | "6": sc_pb.Harder, 40 | "7": sc_pb.VeryHard, 41 | "8": sc_pb.CheatVision, 42 | "9": sc_pb.CheatMoney, 43 | "A": sc_pb.CheatInsane, 44 | } 45 | 46 | actions = { 47 | "move": 16, # target: PointOrUnit 48 | "attack": 23, # target: PointOrUnit 49 | "stop": 4, # target: None 50 | "heal": 386, # Unit 51 | } 52 | 53 | 54 | class Direction(enum.IntEnum): 55 | NORTH = 0 56 | SOUTH = 1 57 | EAST = 2 58 | WEST = 3 59 | 60 | 61 | class StarCraft2Env(MultiAgentEnv): 62 | """The StarCraft II environment for decentralised multi-agent 63 | micromanagement scenarios. 64 | """ 65 | 66 | def __init__( 67 | self, 68 | map_name="8m", 69 | step_mul=8, 70 | move_amount=2, 71 | difficulty="7", 72 | game_version=None, 73 | seed=None, 74 | continuing_episode=False, 75 | obs_all_health=True, 76 | obs_own_health=True, 77 | obs_last_action=False, 78 | obs_pathing_grid=False, 79 | obs_terrain_height=False, 80 | obs_instead_of_state=False, 81 | obs_timestep_number=False, 82 | state_last_action=True, 83 | state_timestep_number=False, 84 | reward_sparse=False, 85 | reward_only_positive=True, 86 | reward_death_value=10, 87 | reward_win=200, 88 | reward_defeat=0, 89 | reward_negative_scale=0.5, 90 | reward_scale=True, 91 | reward_scale_rate=20, 92 | replay_dir="", 93 | replay_prefix="", 94 | window_size_x=1920, 95 | window_size_y=1200, 96 | heuristic_ai=False, 97 | heuristic_rest=False, 98 | debug=False, 99 | ): 100 | """ 101 | Create a StarCraftC2Env environment. 102 | 103 | Parameters 104 | ---------- 105 | map_name : str, optional 106 | The name of the SC2 map to play (default is "8m"). The full list 107 | can be found by running bin/map_list. 108 | step_mul : int, optional 109 | How many game steps per agent step (default is 8). None 110 | indicates to use the default map step_mul. 111 | move_amount : float, optional 112 | How far away units are ordered to move per step (default is 2). 113 | difficulty : str, optional 114 | The difficulty of built-in computer AI bot (default is "7"). 115 | game_version : str, optional 116 | StarCraft II game version (default is None). None indicates the 117 | latest version. 118 | seed : int, optional 119 | Random seed used during game initialisation. This allows to 120 | continuing_episode : bool, optional 121 | Whether to consider episodes continuing or finished after time 122 | limit is reached (default is False). 123 | obs_all_health : bool, optional 124 | Agents receive the health of all units (in the sight range) as part 125 | of observations (default is True). 126 | obs_own_health : bool, optional 127 | Agents receive their own health as a part of observations (default 128 | is False). This flag is ignored when obs_all_health == True. 129 | obs_last_action : bool, optional 130 | Agents receive the last actions of all units (in the sight range) 131 | as part of observations (default is False). 132 | obs_pathing_grid : bool, optional 133 | Whether observations include pathing values surrounding the agent 134 | (default is False). 135 | obs_terrain_height : bool, optional 136 | Whether observations include terrain height values surrounding the 137 | agent (default is False). 138 | obs_instead_of_state : bool, optional 139 | Use combination of all agents' observations as the global state 140 | (default is False). 141 | obs_timestep_number : bool, optional 142 | Whether observations include the current timestep of the episode 143 | (default is False). 144 | state_last_action : bool, optional 145 | Include the last actions of all agents as part of the global state 146 | (default is True). 147 | state_timestep_number : bool, optional 148 | Whether the state include the current timestep of the episode 149 | (default is False). 150 | reward_sparse : bool, optional 151 | Receive 1/-1 reward for winning/loosing an episode (default is 152 | False). Whe rest of reward parameters are ignored if True. 153 | reward_only_positive : bool, optional 154 | Reward is always positive (default is True). 155 | reward_death_value : float, optional 156 | The amount of reward received for killing an enemy unit (default 157 | is 10). This is also the negative penalty for having an allied unit 158 | killed if reward_only_positive == False. 159 | reward_win : float, optional 160 | The reward for winning in an episode (default is 200). 161 | reward_defeat : float, optional 162 | The reward for loosing in an episode (default is 0). This value 163 | should be nonpositive. 164 | reward_negative_scale : float, optional 165 | Scaling factor for negative rewards (default is 0.5). This 166 | parameter is ignored when reward_only_positive == True. 167 | reward_scale : bool, optional 168 | Whether or not to scale the reward (default is True). 169 | reward_scale_rate : float, optional 170 | Reward scale rate (default is 20). When reward_scale == True, the 171 | reward received by the agents is divided by (max_reward / 172 | reward_scale_rate), where max_reward is the maximum possible 173 | reward per episode without considering the shield regeneration 174 | of Protoss units. 175 | replay_dir : str, optional 176 | The directory to save replays (default is None). If None, the 177 | replay will be saved in Replays directory where StarCraft II is 178 | installed. 179 | replay_prefix : str, optional 180 | The prefix of the replay to be saved (default is None). If None, 181 | the name of the map will be used. 182 | window_size_x : int, optional 183 | The length of StarCraft II window size (default is 1920). 184 | window_size_y: int, optional 185 | The height of StarCraft II window size (default is 1200). 186 | heuristic_ai: bool, optional 187 | Whether or not to use a non-learning heuristic AI (default False). 188 | heuristic_rest: bool, optional 189 | At any moment, restrict the actions of the heuristic AI to be 190 | chosen from actions available to RL agents (default is False). 191 | Ignored if heuristic_ai == False. 192 | debug: bool, optional 193 | Log messages about observations, state, actions and rewards for 194 | debugging purposes (default is False). 195 | """ 196 | # Map arguments 197 | self.map_name = map_name 198 | map_params = get_map_params(self.map_name) 199 | self.n_agents = map_params["n_agents"] 200 | self.n_enemies = map_params["n_enemies"] 201 | self.episode_limit = map_params["limit"] 202 | self._move_amount = move_amount 203 | self._step_mul = step_mul 204 | self.difficulty = difficulty 205 | 206 | # Observations and state 207 | self.obs_own_health = obs_own_health 208 | self.obs_all_health = obs_all_health 209 | self.obs_instead_of_state = obs_instead_of_state 210 | self.obs_last_action = obs_last_action 211 | self.obs_pathing_grid = obs_pathing_grid 212 | self.obs_terrain_height = obs_terrain_height 213 | self.obs_timestep_number = obs_timestep_number 214 | self.state_last_action = state_last_action 215 | self.state_timestep_number = state_timestep_number 216 | if self.obs_all_health: 217 | self.obs_own_health = True 218 | self.n_obs_pathing = 8 219 | self.n_obs_height = 9 220 | 221 | # Rewards args 222 | self.reward_sparse = reward_sparse 223 | self.reward_only_positive = reward_only_positive 224 | self.reward_negative_scale = reward_negative_scale 225 | self.reward_death_value = reward_death_value 226 | self.reward_win = reward_win 227 | self.reward_defeat = reward_defeat 228 | self.reward_scale = reward_scale 229 | self.reward_scale_rate = reward_scale_rate 230 | 231 | # Other 232 | self.game_version = game_version 233 | self.continuing_episode = continuing_episode 234 | self._seed = seed 235 | self.heuristic_ai = heuristic_ai 236 | self.heuristic_rest = heuristic_rest 237 | self.debug = debug 238 | self.window_size = (window_size_x, window_size_y) 239 | self.replay_dir = replay_dir 240 | self.replay_prefix = replay_prefix 241 | 242 | # Actions 243 | self.n_actions_no_attack = 6 244 | self.n_actions_move = 4 245 | self.n_actions = self.n_actions_no_attack + self.n_enemies 246 | 247 | # Map info 248 | self._agent_race = map_params["a_race"] 249 | self._bot_race = map_params["b_race"] 250 | self.shield_bits_ally = 1 if self._agent_race == "P" else 0 251 | self.shield_bits_enemy = 1 if self._bot_race == "P" else 0 252 | self.unit_type_bits = map_params["unit_type_bits"] 253 | self.map_type = map_params["map_type"] 254 | self._unit_types = None 255 | 256 | self.max_reward = ( 257 | self.n_enemies * self.reward_death_value + self.reward_win 258 | ) 259 | 260 | # create lists containing the names of attributes returned in states 261 | self.ally_state_attr_names = [ 262 | "health", 263 | "energy/cooldown", 264 | "rel_x", 265 | "rel_y", 266 | ] 267 | self.enemy_state_attr_names = ["health", "rel_x", "rel_y"] 268 | 269 | if self.shield_bits_ally > 0: 270 | self.ally_state_attr_names += ["shield"] 271 | if self.shield_bits_enemy > 0: 272 | self.enemy_state_attr_names += ["shield"] 273 | 274 | if self.unit_type_bits > 0: 275 | bit_attr_names = [ 276 | "type_{}".format(bit) for bit in range(self.unit_type_bits) 277 | ] 278 | self.ally_state_attr_names += bit_attr_names 279 | self.enemy_state_attr_names += bit_attr_names 280 | 281 | self.agents = {} 282 | self.enemies = {} 283 | self._episode_count = 0 284 | self._episode_steps = 0 285 | self._total_steps = 0 286 | self._obs = None 287 | self.battles_won = 0 288 | self.battles_game = 0 289 | self.timeouts = 0 290 | self.force_restarts = 0 291 | self.last_stats = None 292 | self.death_tracker_ally = np.zeros(self.n_agents) 293 | self.death_tracker_enemy = np.zeros(self.n_enemies) 294 | self.previous_ally_units = None 295 | self.previous_enemy_units = None 296 | self.last_action = np.zeros((self.n_agents, self.n_actions)) 297 | self._min_unit_type = 0 298 | self.marine_id = self.marauder_id = self.medivac_id = 0 299 | self.hydralisk_id = self.zergling_id = self.baneling_id = 0 300 | self.stalker_id = self.colossus_id = self.zealot_id = 0 301 | self.max_distance_x = 0 302 | self.max_distance_y = 0 303 | self.map_x = 0 304 | self.map_y = 0 305 | self.reward = 0 306 | self.renderer = None 307 | self.terrain_height = None 308 | self.pathing_grid = None 309 | self._run_config = None 310 | self._sc2_proc = None 311 | self._controller = None 312 | 313 | # Try to avoid leaking SC2 processes on shutdown 314 | atexit.register(lambda: self.close()) 315 | 316 | def _launch(self): 317 | """Launch the StarCraft II game.""" 318 | self._run_config = run_configs.get(version=self.game_version) 319 | _map = maps.get(self.map_name) 320 | 321 | # Setting up the interface 322 | interface_options = sc_pb.InterfaceOptions(raw=True, score=False) 323 | self._sc2_proc = self._run_config.start( 324 | window_size=self.window_size, want_rgb=False 325 | ) 326 | self._controller = self._sc2_proc.controller 327 | 328 | # Request to create the game 329 | create = sc_pb.RequestCreateGame( 330 | local_map=sc_pb.LocalMap( 331 | map_path=_map.path, 332 | map_data=self._run_config.map_data(_map.path), 333 | ), 334 | realtime=False, 335 | random_seed=self._seed, 336 | ) 337 | create.player_setup.add(type=sc_pb.Participant) 338 | create.player_setup.add( 339 | type=sc_pb.Computer, 340 | race=races[self._bot_race], 341 | difficulty=difficulties[self.difficulty], 342 | ) 343 | self._controller.create_game(create) 344 | 345 | join = sc_pb.RequestJoinGame( 346 | race=races[self._agent_race], options=interface_options 347 | ) 348 | self._controller.join_game(join) 349 | 350 | game_info = self._controller.game_info() 351 | map_info = game_info.start_raw 352 | map_play_area_min = map_info.playable_area.p0 353 | map_play_area_max = map_info.playable_area.p1 354 | self.max_distance_x = map_play_area_max.x - map_play_area_min.x 355 | self.max_distance_y = map_play_area_max.y - map_play_area_min.y 356 | self.map_x = map_info.map_size.x 357 | self.map_y = map_info.map_size.y 358 | 359 | if map_info.pathing_grid.bits_per_pixel == 1: 360 | vals = np.array(list(map_info.pathing_grid.data)).reshape( 361 | self.map_x, int(self.map_y / 8) 362 | ) 363 | self.pathing_grid = np.transpose( 364 | np.array( 365 | [ 366 | [(b >> i) & 1 for b in row for i in range(7, -1, -1)] 367 | for row in vals 368 | ], 369 | dtype=bool, 370 | ) 371 | ) 372 | else: 373 | self.pathing_grid = np.invert( 374 | np.flip( 375 | np.transpose( 376 | np.array( 377 | list(map_info.pathing_grid.data), dtype=bool 378 | ).reshape(self.map_x, self.map_y) 379 | ), 380 | axis=1, 381 | ) 382 | ) 383 | 384 | self.terrain_height = ( 385 | np.flip( 386 | np.transpose( 387 | np.array(list(map_info.terrain_height.data)).reshape( 388 | self.map_x, self.map_y 389 | ) 390 | ), 391 | 1, 392 | ) 393 | / 255 394 | ) 395 | 396 | def reset(self): 397 | """Reset the environment. Required after each full episode. 398 | Returns initial observations and states. 399 | """ 400 | self._episode_steps = 0 401 | if self._episode_count == 0: 402 | # Launch StarCraft II 403 | self._launch() 404 | else: 405 | self._restart() 406 | 407 | # Information kept for counting the reward 408 | self.death_tracker_ally = np.zeros(self.n_agents) 409 | self.death_tracker_enemy = np.zeros(self.n_enemies) 410 | self.previous_ally_units = None 411 | self.previous_enemy_units = None 412 | self.win_counted = False 413 | self.defeat_counted = False 414 | 415 | self.last_action = np.zeros((self.n_agents, self.n_actions)) 416 | 417 | if self.heuristic_ai: 418 | self.heuristic_targets = [None] * self.n_agents 419 | 420 | try: 421 | self._obs = self._controller.observe() 422 | self.init_units() 423 | except (protocol.ProtocolError, protocol.ConnectionError): 424 | self.full_restart() 425 | 426 | if self.debug: 427 | logging.debug( 428 | "Started Episode {}".format(self._episode_count).center( 429 | 60, "*" 430 | ) 431 | ) 432 | 433 | return self.get_obs(), self.get_state() 434 | 435 | def _restart(self): 436 | """Restart the environment by killing all units on the map. 437 | There is a trigger in the SC2Map file, which restarts the 438 | episode when there are no units left. 439 | """ 440 | try: 441 | self._kill_all_units() 442 | self._controller.step(2) 443 | except (protocol.ProtocolError, protocol.ConnectionError): 444 | self.full_restart() 445 | 446 | def full_restart(self): 447 | """Full restart. Closes the SC2 process and launches a new one.""" 448 | self._sc2_proc.close() 449 | self._launch() 450 | self.force_restarts += 1 451 | 452 | def step(self, actions): 453 | """A single environment step. Returns reward, terminated, info.""" 454 | actions_int = [int(a) for a in actions] 455 | 456 | self.last_action = np.eye(self.n_actions)[np.array(actions_int)] 457 | 458 | # Collect individual actions 459 | sc_actions = [] 460 | if self.debug: 461 | logging.debug("Actions".center(60, "-")) 462 | 463 | for a_id, action in enumerate(actions_int): 464 | if not self.heuristic_ai: 465 | sc_action = self.get_agent_action(a_id, action) 466 | else: 467 | sc_action, action_num = self.get_agent_action_heuristic( 468 | a_id, action 469 | ) 470 | actions[a_id] = action_num 471 | if sc_action: 472 | sc_actions.append(sc_action) 473 | 474 | # Send action request 475 | req_actions = sc_pb.RequestAction(actions=sc_actions) 476 | try: 477 | self._controller.actions(req_actions) 478 | # Make step in SC2, i.e. apply actions 479 | self._controller.step(self._step_mul) 480 | # Observe here so that we know if the episode is over. 481 | self._obs = self._controller.observe() 482 | except (protocol.ProtocolError, protocol.ConnectionError): 483 | self.full_restart() 484 | return 0, True, {} 485 | 486 | self._total_steps += 1 487 | self._episode_steps += 1 488 | 489 | # Update units 490 | game_end_code = self.update_units() 491 | 492 | terminated = False 493 | reward = self.reward_battle() 494 | info = {"battle_won": False} 495 | 496 | # count units that are still alive 497 | dead_allies, dead_enemies = 0, 0 498 | for _al_id, al_unit in self.agents.items(): 499 | if al_unit.health == 0: 500 | dead_allies += 1 501 | for _e_id, e_unit in self.enemies.items(): 502 | if e_unit.health == 0: 503 | dead_enemies += 1 504 | 505 | info["dead_allies"] = dead_allies 506 | info["dead_enemies"] = dead_enemies 507 | 508 | if game_end_code is not None: 509 | # Battle is over 510 | terminated = True 511 | self.battles_game += 1 512 | if game_end_code == 1 and not self.win_counted: 513 | self.battles_won += 1 514 | self.win_counted = True 515 | info["battle_won"] = True 516 | if not self.reward_sparse: 517 | reward += self.reward_win 518 | else: 519 | reward = 1 520 | elif game_end_code == -1 and not self.defeat_counted: 521 | self.defeat_counted = True 522 | if not self.reward_sparse: 523 | reward += self.reward_defeat 524 | else: 525 | reward = -1 526 | 527 | elif self._episode_steps >= self.episode_limit: 528 | # Episode limit reached 529 | terminated = True 530 | if self.continuing_episode: 531 | info["episode_limit"] = True 532 | self.battles_game += 1 533 | self.timeouts += 1 534 | 535 | if self.debug: 536 | logging.debug("Reward = {}".format(reward).center(60, "-")) 537 | 538 | if terminated: 539 | self._episode_count += 1 540 | 541 | if self.reward_scale: 542 | reward /= self.max_reward / self.reward_scale_rate 543 | 544 | self.reward = reward 545 | 546 | return reward, terminated, info 547 | 548 | def get_agent_action(self, a_id, action): 549 | """Construct the action for agent a_id.""" 550 | avail_actions = self.get_avail_agent_actions(a_id) 551 | assert ( 552 | avail_actions[action] == 1 553 | ), "Agent {} cannot perform action {}".format(a_id, action) 554 | 555 | unit = self.get_unit_by_id(a_id) 556 | tag = unit.tag 557 | x = unit.pos.x 558 | y = unit.pos.y 559 | 560 | if action == 0: 561 | # no-op (valid only when dead) 562 | assert unit.health == 0, "No-op only available for dead agents." 563 | if self.debug: 564 | logging.debug("Agent {}: Dead".format(a_id)) 565 | return None 566 | elif action == 1: 567 | # stop 568 | cmd = r_pb.ActionRawUnitCommand( 569 | ability_id=actions["stop"], 570 | unit_tags=[tag], 571 | queue_command=False, 572 | ) 573 | if self.debug: 574 | logging.debug("Agent {}: Stop".format(a_id)) 575 | 576 | elif action == 2: 577 | # move north 578 | cmd = r_pb.ActionRawUnitCommand( 579 | ability_id=actions["move"], 580 | target_world_space_pos=sc_common.Point2D( 581 | x=x, y=y + self._move_amount 582 | ), 583 | unit_tags=[tag], 584 | queue_command=False, 585 | ) 586 | if self.debug: 587 | logging.debug("Agent {}: Move North".format(a_id)) 588 | 589 | elif action == 3: 590 | # move south 591 | cmd = r_pb.ActionRawUnitCommand( 592 | ability_id=actions["move"], 593 | target_world_space_pos=sc_common.Point2D( 594 | x=x, y=y - self._move_amount 595 | ), 596 | unit_tags=[tag], 597 | queue_command=False, 598 | ) 599 | if self.debug: 600 | logging.debug("Agent {}: Move South".format(a_id)) 601 | 602 | elif action == 4: 603 | # move east 604 | cmd = r_pb.ActionRawUnitCommand( 605 | ability_id=actions["move"], 606 | target_world_space_pos=sc_common.Point2D( 607 | x=x + self._move_amount, y=y 608 | ), 609 | unit_tags=[tag], 610 | queue_command=False, 611 | ) 612 | if self.debug: 613 | logging.debug("Agent {}: Move East".format(a_id)) 614 | 615 | elif action == 5: 616 | # move west 617 | cmd = r_pb.ActionRawUnitCommand( 618 | ability_id=actions["move"], 619 | target_world_space_pos=sc_common.Point2D( 620 | x=x - self._move_amount, y=y 621 | ), 622 | unit_tags=[tag], 623 | queue_command=False, 624 | ) 625 | if self.debug: 626 | logging.debug("Agent {}: Move West".format(a_id)) 627 | else: 628 | # attack/heal units that are in range 629 | target_id = action - self.n_actions_no_attack 630 | if self.map_type == "MMM" and unit.unit_type == self.medivac_id: 631 | target_unit = self.agents[target_id] 632 | action_name = "heal" 633 | else: 634 | target_unit = self.enemies[target_id] 635 | action_name = "attack" 636 | 637 | action_id = actions[action_name] 638 | target_tag = target_unit.tag 639 | 640 | cmd = r_pb.ActionRawUnitCommand( 641 | ability_id=action_id, 642 | target_unit_tag=target_tag, 643 | unit_tags=[tag], 644 | queue_command=False, 645 | ) 646 | 647 | if self.debug: 648 | logging.debug( 649 | "Agent {} {}s unit # {}".format( 650 | a_id, action_name, target_id 651 | ) 652 | ) 653 | 654 | sc_action = sc_pb.Action(action_raw=r_pb.ActionRaw(unit_command=cmd)) 655 | return sc_action 656 | 657 | def get_agent_action_heuristic(self, a_id, action): 658 | unit = self.get_unit_by_id(a_id) 659 | tag = unit.tag 660 | 661 | target = self.heuristic_targets[a_id] 662 | if unit.unit_type == self.medivac_id: 663 | if ( 664 | target is None 665 | or self.agents[target].health == 0 666 | or self.agents[target].health == self.agents[target].health_max 667 | ): 668 | min_dist = math.hypot(self.max_distance_x, self.max_distance_y) 669 | min_id = -1 670 | for al_id, al_unit in self.agents.items(): 671 | if al_unit.unit_type == self.medivac_id: 672 | continue 673 | if ( 674 | al_unit.health != 0 675 | and al_unit.health != al_unit.health_max 676 | ): 677 | dist = self.distance( 678 | unit.pos.x, 679 | unit.pos.y, 680 | al_unit.pos.x, 681 | al_unit.pos.y, 682 | ) 683 | if dist < min_dist: 684 | min_dist = dist 685 | min_id = al_id 686 | self.heuristic_targets[a_id] = min_id 687 | if min_id == -1: 688 | self.heuristic_targets[a_id] = None 689 | return None, 0 690 | action_id = actions["heal"] 691 | target_tag = self.agents[self.heuristic_targets[a_id]].tag 692 | else: 693 | if target is None or self.enemies[target].health == 0: 694 | min_dist = math.hypot(self.max_distance_x, self.max_distance_y) 695 | min_id = -1 696 | for e_id, e_unit in self.enemies.items(): 697 | if ( 698 | unit.unit_type == self.marauder_id 699 | and e_unit.unit_type == self.medivac_id 700 | ): 701 | continue 702 | if e_unit.health > 0: 703 | dist = self.distance( 704 | unit.pos.x, unit.pos.y, e_unit.pos.x, e_unit.pos.y 705 | ) 706 | if dist < min_dist: 707 | min_dist = dist 708 | min_id = e_id 709 | self.heuristic_targets[a_id] = min_id 710 | if min_id == -1: 711 | self.heuristic_targets[a_id] = None 712 | return None, 0 713 | action_id = actions["attack"] 714 | target_tag = self.enemies[self.heuristic_targets[a_id]].tag 715 | 716 | action_num = self.heuristic_targets[a_id] + self.n_actions_no_attack 717 | 718 | # Check if the action is available 719 | if ( 720 | self.heuristic_rest 721 | and self.get_avail_agent_actions(a_id)[action_num] == 0 722 | ): 723 | 724 | # Move towards the target rather than attacking/healing 725 | if unit.unit_type == self.medivac_id: 726 | target_unit = self.agents[self.heuristic_targets[a_id]] 727 | else: 728 | target_unit = self.enemies[self.heuristic_targets[a_id]] 729 | 730 | delta_x = target_unit.pos.x - unit.pos.x 731 | delta_y = target_unit.pos.y - unit.pos.y 732 | 733 | if abs(delta_x) > abs(delta_y): # east or west 734 | if delta_x > 0: # east 735 | target_pos = sc_common.Point2D( 736 | x=unit.pos.x + self._move_amount, y=unit.pos.y 737 | ) 738 | action_num = 4 739 | else: # west 740 | target_pos = sc_common.Point2D( 741 | x=unit.pos.x - self._move_amount, y=unit.pos.y 742 | ) 743 | action_num = 5 744 | else: # north or south 745 | if delta_y > 0: # north 746 | target_pos = sc_common.Point2D( 747 | x=unit.pos.x, y=unit.pos.y + self._move_amount 748 | ) 749 | action_num = 2 750 | else: # south 751 | target_pos = sc_common.Point2D( 752 | x=unit.pos.x, y=unit.pos.y - self._move_amount 753 | ) 754 | action_num = 3 755 | 756 | cmd = r_pb.ActionRawUnitCommand( 757 | ability_id=actions["move"], 758 | target_world_space_pos=target_pos, 759 | unit_tags=[tag], 760 | queue_command=False, 761 | ) 762 | else: 763 | # Attack/heal the target 764 | cmd = r_pb.ActionRawUnitCommand( 765 | ability_id=action_id, 766 | target_unit_tag=target_tag, 767 | unit_tags=[tag], 768 | queue_command=False, 769 | ) 770 | 771 | sc_action = sc_pb.Action(action_raw=r_pb.ActionRaw(unit_command=cmd)) 772 | return sc_action, action_num 773 | 774 | def reward_battle(self): 775 | """Reward function when self.reward_spare==False. 776 | Returns accumulative hit/shield point damage dealt to the enemy 777 | + reward_death_value per enemy unit killed, and, in case 778 | self.reward_only_positive == False, - (damage dealt to ally units 779 | + reward_death_value per ally unit killed) * self.reward_negative_scale 780 | """ 781 | if self.reward_sparse: 782 | return 0 783 | 784 | reward = 0 785 | delta_deaths = 0 786 | delta_ally = 0 787 | delta_enemy = 0 788 | 789 | neg_scale = self.reward_negative_scale 790 | 791 | # update deaths 792 | for al_id, al_unit in self.agents.items(): 793 | if not self.death_tracker_ally[al_id]: 794 | # did not die so far 795 | prev_health = ( 796 | self.previous_ally_units[al_id].health 797 | + self.previous_ally_units[al_id].shield 798 | ) 799 | if al_unit.health == 0: 800 | # just died 801 | self.death_tracker_ally[al_id] = 1 802 | if not self.reward_only_positive: 803 | delta_deaths -= self.reward_death_value * neg_scale 804 | delta_ally += prev_health * neg_scale 805 | else: 806 | # still alive 807 | delta_ally += neg_scale * ( 808 | prev_health - al_unit.health - al_unit.shield 809 | ) 810 | 811 | for e_id, e_unit in self.enemies.items(): 812 | if not self.death_tracker_enemy[e_id]: 813 | prev_health = ( 814 | self.previous_enemy_units[e_id].health 815 | + self.previous_enemy_units[e_id].shield 816 | ) 817 | if e_unit.health == 0: 818 | self.death_tracker_enemy[e_id] = 1 819 | delta_deaths += self.reward_death_value 820 | delta_enemy += prev_health 821 | else: 822 | delta_enemy += prev_health - e_unit.health - e_unit.shield 823 | 824 | if self.reward_only_positive: 825 | reward = abs(delta_enemy + delta_deaths) # shield regeneration 826 | else: 827 | reward = delta_enemy + delta_deaths - delta_ally 828 | 829 | return reward 830 | 831 | def get_total_actions(self): 832 | """Returns the total number of actions an agent could ever take.""" 833 | return self.n_actions 834 | 835 | @staticmethod 836 | def distance(x1, y1, x2, y2): 837 | """Distance between two points.""" 838 | return math.hypot(x2 - x1, y2 - y1) 839 | 840 | def unit_shoot_range(self, agent_id): 841 | """Returns the shooting range for an agent.""" 842 | return 6 843 | 844 | def unit_sight_range(self, agent_id): 845 | """Returns the sight range for an agent.""" 846 | return 9 847 | 848 | def unit_max_cooldown(self, unit): 849 | """Returns the maximal cooldown for a unit.""" 850 | switcher = { 851 | self.marine_id: 15, 852 | self.marauder_id: 25, 853 | self.medivac_id: 200, # max energy 854 | self.stalker_id: 35, 855 | self.zealot_id: 22, 856 | self.colossus_id: 24, 857 | self.hydralisk_id: 10, 858 | self.zergling_id: 11, 859 | self.baneling_id: 1, 860 | } 861 | return switcher.get(unit.unit_type, 15) 862 | 863 | def save_replay(self): 864 | """Save a replay.""" 865 | prefix = self.replay_prefix or self.map_name 866 | replay_dir = self.replay_dir or "" 867 | replay_path = self._run_config.save_replay( 868 | self._controller.save_replay(), 869 | replay_dir=replay_dir, 870 | prefix=prefix, 871 | ) 872 | logging.info("Replay saved at: %s" % replay_path) 873 | 874 | def unit_max_shield(self, unit): 875 | """Returns maximal shield for a given unit.""" 876 | if unit.unit_type == 74 or unit.unit_type == self.stalker_id: 877 | return 80 # Protoss's Stalker 878 | if unit.unit_type == 73 or unit.unit_type == self.zealot_id: 879 | return 50 # Protoss's Zaelot 880 | if unit.unit_type == 4 or unit.unit_type == self.colossus_id: 881 | return 150 # Protoss's Colossus 882 | 883 | def can_move(self, unit, direction): 884 | """Whether a unit can move in a given direction.""" 885 | m = self._move_amount / 2 886 | 887 | if direction == Direction.NORTH: 888 | x, y = int(unit.pos.x), int(unit.pos.y + m) 889 | elif direction == Direction.SOUTH: 890 | x, y = int(unit.pos.x), int(unit.pos.y - m) 891 | elif direction == Direction.EAST: 892 | x, y = int(unit.pos.x + m), int(unit.pos.y) 893 | else: 894 | x, y = int(unit.pos.x - m), int(unit.pos.y) 895 | 896 | if self.check_bounds(x, y) and self.pathing_grid[x, y]: 897 | return True 898 | 899 | return False 900 | 901 | def get_surrounding_points(self, unit, include_self=False): 902 | """Returns the surrounding points of the unit in 8 directions.""" 903 | x = int(unit.pos.x) 904 | y = int(unit.pos.y) 905 | 906 | ma = self._move_amount 907 | 908 | points = [ 909 | (x, y + 2 * ma), 910 | (x, y - 2 * ma), 911 | (x + 2 * ma, y), 912 | (x - 2 * ma, y), 913 | (x + ma, y + ma), 914 | (x - ma, y - ma), 915 | (x + ma, y - ma), 916 | (x - ma, y + ma), 917 | ] 918 | 919 | if include_self: 920 | points.append((x, y)) 921 | 922 | return points 923 | 924 | def check_bounds(self, x, y): 925 | """Whether a point is within the map bounds.""" 926 | return 0 <= x < self.map_x and 0 <= y < self.map_y 927 | 928 | def get_surrounding_pathing(self, unit): 929 | """Returns pathing values of the grid surrounding the given unit.""" 930 | points = self.get_surrounding_points(unit, include_self=False) 931 | vals = [ 932 | self.pathing_grid[x, y] if self.check_bounds(x, y) else 1 933 | for x, y in points 934 | ] 935 | return vals 936 | 937 | def get_surrounding_height(self, unit): 938 | """Returns height values of the grid surrounding the given unit.""" 939 | points = self.get_surrounding_points(unit, include_self=True) 940 | vals = [ 941 | self.terrain_height[x, y] if self.check_bounds(x, y) else 1 942 | for x, y in points 943 | ] 944 | return vals 945 | 946 | def get_obs_agent(self, agent_id): 947 | """Returns observation for agent_id. The observation is composed of: 948 | 949 | - agent movement features (where it can move to, height information 950 | and pathing grid) 951 | - enemy features (available_to_attack, health, relative_x, relative_y, 952 | shield, unit_type) 953 | - ally features (visible, distance, relative_x, relative_y, shield, 954 | unit_type) 955 | - agent unit features (health, shield, unit_type) 956 | 957 | All of this information is flattened and concatenated into a list, 958 | in the aforementioned order. To know the sizes of each of the 959 | features inside the final list of features, take a look at the 960 | functions ``get_obs_move_feats_size()``, 961 | ``get_obs_enemy_feats_size()``, ``get_obs_ally_feats_size()`` and 962 | ``get_obs_own_feats_size()``. 963 | 964 | The size of the observation vector may vary, depending on the 965 | environment configuration and type of units present in the map. 966 | For instance, non-Protoss units will not have shields, movement 967 | features may or may not include terrain height and pathing grid, 968 | unit_type is not included if there is only one type of unit in the 969 | map etc.). 970 | 971 | NOTE: Agents should have access only to their local observations 972 | during decentralised execution. 973 | """ 974 | unit = self.get_unit_by_id(agent_id) 975 | 976 | move_feats_dim = self.get_obs_move_feats_size() 977 | enemy_feats_dim = self.get_obs_enemy_feats_size() 978 | ally_feats_dim = self.get_obs_ally_feats_size() 979 | own_feats_dim = self.get_obs_own_feats_size() 980 | 981 | move_feats = np.zeros(move_feats_dim, dtype=np.float32) 982 | enemy_feats = np.zeros(enemy_feats_dim, dtype=np.float32) 983 | ally_feats = np.zeros(ally_feats_dim, dtype=np.float32) 984 | own_feats = np.zeros(own_feats_dim, dtype=np.float32) 985 | 986 | if unit.health > 0: # otherwise dead, return all zeros 987 | x = unit.pos.x 988 | y = unit.pos.y 989 | sight_range = self.unit_sight_range(agent_id) 990 | 991 | # Movement features 992 | avail_actions = self.get_avail_agent_actions(agent_id) 993 | for m in range(self.n_actions_move): 994 | move_feats[m] = avail_actions[m + 2] 995 | 996 | ind = self.n_actions_move 997 | 998 | if self.obs_pathing_grid: 999 | move_feats[ 1000 | ind : ind + self.n_obs_pathing # noqa 1001 | ] = self.get_surrounding_pathing(unit) 1002 | ind += self.n_obs_pathing 1003 | 1004 | if self.obs_terrain_height: 1005 | move_feats[ind:] = self.get_surrounding_height(unit) 1006 | 1007 | # Enemy features 1008 | for e_id, e_unit in self.enemies.items(): 1009 | e_x = e_unit.pos.x 1010 | e_y = e_unit.pos.y 1011 | dist = self.distance(x, y, e_x, e_y) 1012 | 1013 | if ( 1014 | dist < sight_range and e_unit.health > 0 1015 | ): # visible and alive 1016 | # Sight range > shoot range 1017 | enemy_feats[e_id, 0] = avail_actions[ 1018 | self.n_actions_no_attack + e_id 1019 | ] # available 1020 | enemy_feats[e_id, 1] = dist / sight_range # distance 1021 | enemy_feats[e_id, 2] = ( 1022 | e_x - x 1023 | ) / sight_range # relative X 1024 | enemy_feats[e_id, 3] = ( 1025 | e_y - y 1026 | ) / sight_range # relative Y 1027 | 1028 | ind = 4 1029 | if self.obs_all_health: 1030 | enemy_feats[e_id, ind] = ( 1031 | e_unit.health / e_unit.health_max 1032 | ) # health 1033 | ind += 1 1034 | if self.shield_bits_enemy > 0: 1035 | max_shield = self.unit_max_shield(e_unit) 1036 | enemy_feats[e_id, ind] = ( 1037 | e_unit.shield / max_shield 1038 | ) # shield 1039 | ind += 1 1040 | 1041 | if self.unit_type_bits > 0: 1042 | type_id = self.get_unit_type_id(e_unit, False) 1043 | enemy_feats[e_id, ind + type_id] = 1 # unit type 1044 | 1045 | # Ally features 1046 | al_ids = [ 1047 | al_id for al_id in range(self.n_agents) if al_id != agent_id 1048 | ] 1049 | for i, al_id in enumerate(al_ids): 1050 | 1051 | al_unit = self.get_unit_by_id(al_id) 1052 | al_x = al_unit.pos.x 1053 | al_y = al_unit.pos.y 1054 | dist = self.distance(x, y, al_x, al_y) 1055 | 1056 | if ( 1057 | dist < sight_range and al_unit.health > 0 1058 | ): # visible and alive 1059 | ally_feats[i, 0] = 1 # visible 1060 | ally_feats[i, 1] = dist / sight_range # distance 1061 | ally_feats[i, 2] = (al_x - x) / sight_range # relative X 1062 | ally_feats[i, 3] = (al_y - y) / sight_range # relative Y 1063 | 1064 | ind = 4 1065 | if self.obs_all_health: 1066 | ally_feats[i, ind] = ( 1067 | al_unit.health / al_unit.health_max 1068 | ) # health 1069 | ind += 1 1070 | if self.shield_bits_ally > 0: 1071 | max_shield = self.unit_max_shield(al_unit) 1072 | ally_feats[i, ind] = ( 1073 | al_unit.shield / max_shield 1074 | ) # shield 1075 | ind += 1 1076 | 1077 | if self.unit_type_bits > 0: 1078 | type_id = self.get_unit_type_id(al_unit, True) 1079 | ally_feats[i, ind + type_id] = 1 1080 | ind += self.unit_type_bits 1081 | 1082 | if self.obs_last_action: 1083 | ally_feats[i, ind:] = self.last_action[al_id] 1084 | 1085 | # Own features 1086 | ind = 0 1087 | if self.obs_own_health: 1088 | own_feats[ind] = unit.health / unit.health_max 1089 | ind += 1 1090 | if self.shield_bits_ally > 0: 1091 | max_shield = self.unit_max_shield(unit) 1092 | own_feats[ind] = unit.shield / max_shield 1093 | ind += 1 1094 | 1095 | if self.unit_type_bits > 0: 1096 | type_id = self.get_unit_type_id(unit, True) 1097 | own_feats[ind + type_id] = 1 1098 | 1099 | agent_obs = np.concatenate( 1100 | ( 1101 | move_feats.flatten(), 1102 | enemy_feats.flatten(), 1103 | ally_feats.flatten(), 1104 | own_feats.flatten(), 1105 | ) 1106 | ) 1107 | 1108 | if self.obs_timestep_number: 1109 | agent_obs = np.append( 1110 | agent_obs, self._episode_steps / self.episode_limit 1111 | ) 1112 | 1113 | if self.debug: 1114 | logging.debug("Obs Agent: {}".format(agent_id).center(60, "-")) 1115 | logging.debug( 1116 | "Avail. actions {}".format( 1117 | self.get_avail_agent_actions(agent_id) 1118 | ) 1119 | ) 1120 | logging.debug("Move feats {}".format(move_feats)) 1121 | logging.debug("Enemy feats {}".format(enemy_feats)) 1122 | logging.debug("Ally feats {}".format(ally_feats)) 1123 | logging.debug("Own feats {}".format(own_feats)) 1124 | 1125 | return agent_obs 1126 | 1127 | def get_obs(self): 1128 | """Returns all agent observations in a list. 1129 | NOTE: Agents should have access only to their local observations 1130 | during decentralised execution. 1131 | """ 1132 | agents_obs = [self.get_obs_agent(i) for i in range(self.n_agents)] 1133 | return agents_obs 1134 | 1135 | def get_state(self): 1136 | """Returns the global state. 1137 | NOTE: This functon should not be used during decentralised execution. 1138 | """ 1139 | if self.obs_instead_of_state: 1140 | obs_concat = np.concatenate(self.get_obs(), axis=0).astype( 1141 | np.float32 1142 | ) 1143 | return obs_concat 1144 | 1145 | state_dict = self.get_state_dict() 1146 | 1147 | state = np.append( 1148 | state_dict["allies"].flatten(), state_dict["enemies"].flatten() 1149 | ) 1150 | if "last_action" in state_dict: 1151 | state = np.append(state, state_dict["last_action"].flatten()) 1152 | if "timestep" in state_dict: 1153 | state = np.append(state, state_dict["timestep"]) 1154 | 1155 | state = state.astype(dtype=np.float32) 1156 | 1157 | if self.debug: 1158 | logging.debug("STATE".center(60, "-")) 1159 | logging.debug("Ally state {}".format(state_dict["allies"])) 1160 | logging.debug("Enemy state {}".format(state_dict["enemies"])) 1161 | if self.state_last_action: 1162 | logging.debug("Last actions {}".format(self.last_action)) 1163 | 1164 | return state 1165 | 1166 | def get_ally_num_attributes(self): 1167 | return len(self.ally_state_attr_names) 1168 | 1169 | def get_enemy_num_attributes(self): 1170 | return len(self.enemy_state_attr_names) 1171 | 1172 | def get_state_dict(self): 1173 | """Returns the global state as a dictionary. 1174 | 1175 | - allies: numpy array containing agents and their attributes 1176 | - enemies: numpy array containing enemies and their attributes 1177 | - last_action: numpy array of previous actions for each agent 1178 | - timestep: current no. of steps divided by total no. of steps 1179 | 1180 | NOTE: This function should not be used during decentralised execution. 1181 | """ 1182 | 1183 | # number of features equals the number of attribute names 1184 | nf_al = self.get_ally_num_attributes() 1185 | nf_en = self.get_enemy_num_attributes() 1186 | 1187 | ally_state = np.zeros((self.n_agents, nf_al)) 1188 | enemy_state = np.zeros((self.n_enemies, nf_en)) 1189 | 1190 | center_x = self.map_x / 2 1191 | center_y = self.map_y / 2 1192 | 1193 | for al_id, al_unit in self.agents.items(): 1194 | if al_unit.health > 0: 1195 | x = al_unit.pos.x 1196 | y = al_unit.pos.y 1197 | max_cd = self.unit_max_cooldown(al_unit) 1198 | 1199 | ally_state[al_id, 0] = ( 1200 | al_unit.health / al_unit.health_max 1201 | ) # health 1202 | if ( 1203 | self.map_type == "MMM" 1204 | and al_unit.unit_type == self.medivac_id 1205 | ): 1206 | ally_state[al_id, 1] = al_unit.energy / max_cd # energy 1207 | else: 1208 | ally_state[al_id, 1] = ( 1209 | al_unit.weapon_cooldown / max_cd 1210 | ) # cooldown 1211 | ally_state[al_id, 2] = ( 1212 | x - center_x 1213 | ) / self.max_distance_x # relative X 1214 | ally_state[al_id, 3] = ( 1215 | y - center_y 1216 | ) / self.max_distance_y # relative Y 1217 | 1218 | if self.shield_bits_ally > 0: 1219 | max_shield = self.unit_max_shield(al_unit) 1220 | ally_state[al_id, 4] = ( 1221 | al_unit.shield / max_shield 1222 | ) # shield 1223 | 1224 | if self.unit_type_bits > 0: 1225 | type_id = self.get_unit_type_id(al_unit, True) 1226 | ally_state[al_id, type_id - self.unit_type_bits] = 1 1227 | 1228 | for e_id, e_unit in self.enemies.items(): 1229 | if e_unit.health > 0: 1230 | x = e_unit.pos.x 1231 | y = e_unit.pos.y 1232 | 1233 | enemy_state[e_id, 0] = ( 1234 | e_unit.health / e_unit.health_max 1235 | ) # health 1236 | enemy_state[e_id, 1] = ( 1237 | x - center_x 1238 | ) / self.max_distance_x # relative X 1239 | enemy_state[e_id, 2] = ( 1240 | y - center_y 1241 | ) / self.max_distance_y # relative Y 1242 | 1243 | if self.shield_bits_enemy > 0: 1244 | max_shield = self.unit_max_shield(e_unit) 1245 | enemy_state[e_id, 3] = e_unit.shield / max_shield # shield 1246 | 1247 | if self.unit_type_bits > 0: 1248 | type_id = self.get_unit_type_id(e_unit, False) 1249 | enemy_state[e_id, type_id - self.unit_type_bits] = 1 1250 | 1251 | state = {"allies": ally_state, "enemies": enemy_state} 1252 | 1253 | if self.state_last_action: 1254 | state["last_action"] = self.last_action 1255 | if self.state_timestep_number: 1256 | state["timestep"] = self._episode_steps / self.episode_limit 1257 | 1258 | return state 1259 | 1260 | def get_obs_enemy_feats_size(self): 1261 | """Returns the dimensions of the matrix containing enemy features. 1262 | Size is n_enemies x n_features. 1263 | """ 1264 | nf_en = 4 + self.unit_type_bits 1265 | 1266 | if self.obs_all_health: 1267 | nf_en += 1 + self.shield_bits_enemy 1268 | 1269 | return self.n_enemies, nf_en 1270 | 1271 | def get_obs_ally_feats_size(self): 1272 | """Returns the dimensions of the matrix containing ally features. 1273 | Size is n_allies x n_features. 1274 | """ 1275 | nf_al = 4 + self.unit_type_bits 1276 | 1277 | if self.obs_all_health: 1278 | nf_al += 1 + self.shield_bits_ally 1279 | 1280 | if self.obs_last_action: 1281 | nf_al += self.n_actions 1282 | 1283 | return self.n_agents - 1, nf_al 1284 | 1285 | def get_obs_own_feats_size(self): 1286 | """ 1287 | Returns the size of the vector containing the agents' own features. 1288 | """ 1289 | own_feats = self.unit_type_bits 1290 | if self.obs_own_health: 1291 | own_feats += 1 + self.shield_bits_ally 1292 | if self.obs_timestep_number: 1293 | own_feats += 1 1294 | 1295 | return own_feats 1296 | 1297 | def get_obs_move_feats_size(self): 1298 | """Returns the size of the vector containing the agents's movement- 1299 | related features. 1300 | """ 1301 | move_feats = self.n_actions_move 1302 | if self.obs_pathing_grid: 1303 | move_feats += self.n_obs_pathing 1304 | if self.obs_terrain_height: 1305 | move_feats += self.n_obs_height 1306 | 1307 | return move_feats 1308 | 1309 | def get_obs_size(self): 1310 | """Returns the size of the observation.""" 1311 | own_feats = self.get_obs_own_feats_size() 1312 | move_feats = self.get_obs_move_feats_size() 1313 | 1314 | n_enemies, n_enemy_feats = self.get_obs_enemy_feats_size() 1315 | n_allies, n_ally_feats = self.get_obs_ally_feats_size() 1316 | 1317 | enemy_feats = n_enemies * n_enemy_feats 1318 | ally_feats = n_allies * n_ally_feats 1319 | 1320 | return move_feats + enemy_feats + ally_feats + own_feats 1321 | 1322 | def get_state_size(self): 1323 | """Returns the size of the global state.""" 1324 | if self.obs_instead_of_state: 1325 | return self.get_obs_size() * self.n_agents 1326 | 1327 | nf_al = 4 + self.shield_bits_ally + self.unit_type_bits 1328 | nf_en = 3 + self.shield_bits_enemy + self.unit_type_bits 1329 | 1330 | enemy_state = self.n_enemies * nf_en 1331 | ally_state = self.n_agents * nf_al 1332 | 1333 | size = enemy_state + ally_state 1334 | 1335 | if self.state_last_action: 1336 | size += self.n_agents * self.n_actions 1337 | if self.state_timestep_number: 1338 | size += 1 1339 | 1340 | return size 1341 | 1342 | def get_visibility_matrix(self): 1343 | """Returns a boolean numpy array of dimensions 1344 | (n_agents, n_agents + n_enemies) indicating which units 1345 | are visible to each agent. 1346 | """ 1347 | arr = np.zeros( 1348 | (self.n_agents, self.n_agents + self.n_enemies), 1349 | dtype=bool, 1350 | ) 1351 | 1352 | for agent_id in range(self.n_agents): 1353 | current_agent = self.get_unit_by_id(agent_id) 1354 | if current_agent.health > 0: # it agent not dead 1355 | x = current_agent.pos.x 1356 | y = current_agent.pos.y 1357 | sight_range = self.unit_sight_range(agent_id) 1358 | 1359 | # Enemies 1360 | for e_id, e_unit in self.enemies.items(): 1361 | e_x = e_unit.pos.x 1362 | e_y = e_unit.pos.y 1363 | dist = self.distance(x, y, e_x, e_y) 1364 | 1365 | if dist < sight_range and e_unit.health > 0: 1366 | # visible and alive 1367 | arr[agent_id, self.n_agents + e_id] = 1 1368 | 1369 | # The matrix for allies is filled symmetrically 1370 | al_ids = [ 1371 | al_id for al_id in range(self.n_agents) if al_id > agent_id 1372 | ] 1373 | for _, al_id in enumerate(al_ids): 1374 | al_unit = self.get_unit_by_id(al_id) 1375 | al_x = al_unit.pos.x 1376 | al_y = al_unit.pos.y 1377 | dist = self.distance(x, y, al_x, al_y) 1378 | 1379 | if dist < sight_range and al_unit.health > 0: 1380 | # visible and alive 1381 | arr[agent_id, al_id] = arr[al_id, agent_id] = 1 1382 | 1383 | return arr 1384 | 1385 | def get_unit_type_id(self, unit, ally): 1386 | """Returns the ID of unit type in the given scenario.""" 1387 | if ally: # use new SC2 unit types 1388 | type_id = unit.unit_type - self._min_unit_type 1389 | else: # use default SC2 unit types 1390 | if self.map_type == "stalkers_and_zealots": 1391 | # id(Stalker) = 74, id(Zealot) = 73 1392 | # Notice that one-hot zealot unit type in enemy_obs will be [1, 0] but [0, 1] in ally_obs 1393 | # If you want to align the enemy unit type with the ally's, uncomment the following lines 1394 | # if unit.unit_type == 74: 1395 | # type_id = 0 1396 | # else: 1397 | # type_id = 1 1398 | type_id = unit.unit_type - 73 1399 | elif self.map_type == "colossi_stalkers_zealots": 1400 | # id(Stalker) = 74, id(Zealot) = 73, id(Colossus) = 4 1401 | if unit.unit_type == 4: 1402 | type_id = 0 1403 | elif unit.unit_type == 74: 1404 | type_id = 1 1405 | else: 1406 | type_id = 2 1407 | elif self.map_type == "bane": 1408 | # id(Baneling) = 9 1409 | if unit.unit_type == 9: 1410 | type_id = 0 1411 | else: 1412 | type_id = 1 1413 | elif self.map_type == "MMM": 1414 | # id(Marauder) = 51, id(Marine) = 48, id(Medivac) = 54 1415 | if unit.unit_type == 51: 1416 | type_id = 0 1417 | elif unit.unit_type == 48: 1418 | type_id = 1 1419 | else: 1420 | type_id = 2 1421 | 1422 | return type_id 1423 | 1424 | def get_avail_agent_actions(self, agent_id): 1425 | """Returns the available actions for agent_id.""" 1426 | unit = self.get_unit_by_id(agent_id) 1427 | if unit.health > 0: 1428 | # cannot choose no-op when alive 1429 | avail_actions = [0] * self.n_actions 1430 | 1431 | # stop should be allowed 1432 | avail_actions[1] = 1 1433 | 1434 | # see if we can move 1435 | if self.can_move(unit, Direction.NORTH): 1436 | avail_actions[2] = 1 1437 | if self.can_move(unit, Direction.SOUTH): 1438 | avail_actions[3] = 1 1439 | if self.can_move(unit, Direction.EAST): 1440 | avail_actions[4] = 1 1441 | if self.can_move(unit, Direction.WEST): 1442 | avail_actions[5] = 1 1443 | 1444 | # Can attack only alive units that are alive in the shooting range 1445 | shoot_range = self.unit_shoot_range(agent_id) 1446 | 1447 | target_items = self.enemies.items() 1448 | if self.map_type == "MMM" and unit.unit_type == self.medivac_id: 1449 | # Medivacs cannot heal themselves or other flying units 1450 | target_items = [ 1451 | (t_id, t_unit) 1452 | for (t_id, t_unit) in self.agents.items() 1453 | if t_unit.unit_type != self.medivac_id 1454 | ] 1455 | 1456 | for t_id, t_unit in target_items: 1457 | if t_unit.health > 0: 1458 | dist = self.distance( 1459 | unit.pos.x, unit.pos.y, t_unit.pos.x, t_unit.pos.y 1460 | ) 1461 | if dist <= shoot_range: 1462 | avail_actions[t_id + self.n_actions_no_attack] = 1 1463 | 1464 | return avail_actions 1465 | 1466 | else: 1467 | # only no-op allowed 1468 | return [1] + [0] * (self.n_actions - 1) 1469 | 1470 | def get_avail_actions(self): 1471 | """Returns the available actions of all agents in a list.""" 1472 | avail_actions = [] 1473 | for agent_id in range(self.n_agents): 1474 | avail_agent = self.get_avail_agent_actions(agent_id) 1475 | avail_actions.append(avail_agent) 1476 | return avail_actions 1477 | 1478 | def close(self): 1479 | """Close StarCraft II.""" 1480 | if self.renderer is not None: 1481 | self.renderer.close() 1482 | self.renderer = None 1483 | if self._sc2_proc: 1484 | self._sc2_proc.close() 1485 | 1486 | def seed(self): 1487 | """Returns the random seed used by the environment.""" 1488 | return self._seed 1489 | 1490 | def render(self, mode="human"): 1491 | if self.renderer is None: 1492 | from smac.env.starcraft2.render import StarCraft2Renderer 1493 | 1494 | self.renderer = StarCraft2Renderer(self, mode) 1495 | assert ( 1496 | mode == self.renderer.mode 1497 | ), "mode must be consistent across render calls" 1498 | return self.renderer.render(mode) 1499 | 1500 | def _kill_all_units(self): 1501 | """Kill all units on the map.""" 1502 | units_alive = [ 1503 | unit.tag for unit in self.agents.values() if unit.health > 0 1504 | ] + [unit.tag for unit in self.enemies.values() if unit.health > 0] 1505 | debug_command = [ 1506 | d_pb.DebugCommand(kill_unit=d_pb.DebugKillUnit(tag=units_alive)) 1507 | ] 1508 | self._controller.debug(debug_command) 1509 | 1510 | def init_units(self): 1511 | """Initialise the units.""" 1512 | while True: 1513 | # Sometimes not all units have yet been created by SC2 1514 | self.agents = {} 1515 | self.enemies = {} 1516 | 1517 | ally_units = [ 1518 | unit 1519 | for unit in self._obs.observation.raw_data.units 1520 | if unit.owner == 1 1521 | ] 1522 | ally_units_sorted = sorted( 1523 | ally_units, 1524 | key=attrgetter("unit_type", "pos.x", "pos.y"), 1525 | reverse=False, 1526 | ) 1527 | 1528 | for i in range(len(ally_units_sorted)): 1529 | self.agents[i] = ally_units_sorted[i] 1530 | if self.debug: 1531 | logging.debug( 1532 | "Unit {} is {}, x = {}, y = {}".format( 1533 | len(self.agents), 1534 | self.agents[i].unit_type, 1535 | self.agents[i].pos.x, 1536 | self.agents[i].pos.y, 1537 | ) 1538 | ) 1539 | 1540 | for unit in self._obs.observation.raw_data.units: 1541 | if unit.owner == 2: 1542 | self.enemies[len(self.enemies)] = unit 1543 | if self._episode_count == 0: 1544 | self.max_reward += unit.health_max + unit.shield_max 1545 | 1546 | if self._episode_count == 0: 1547 | min_unit_type = min( 1548 | unit.unit_type for unit in self.agents.values() 1549 | ) 1550 | self._init_ally_unit_types(min_unit_type) 1551 | 1552 | all_agents_created = len(self.agents) == self.n_agents 1553 | all_enemies_created = len(self.enemies) == self.n_enemies 1554 | 1555 | self._unit_types = [ 1556 | unit.unit_type for unit in ally_units_sorted 1557 | ] + [ 1558 | unit.unit_type 1559 | for unit in self._obs.observation.raw_data.units 1560 | if unit.owner == 2 1561 | ] 1562 | 1563 | if all_agents_created and all_enemies_created: # all good 1564 | return 1565 | 1566 | try: 1567 | self._controller.step(1) 1568 | self._obs = self._controller.observe() 1569 | except (protocol.ProtocolError, protocol.ConnectionError): 1570 | self.full_restart() 1571 | self.reset() 1572 | 1573 | def get_unit_types(self): 1574 | if self._unit_types is None: 1575 | warn( 1576 | "unit types have not been initialized yet, please call" 1577 | "env.reset() to populate this and call t1286he method again." 1578 | ) 1579 | 1580 | return self._unit_types 1581 | 1582 | def update_units(self): 1583 | """Update units after an environment step. 1584 | This function assumes that self._obs is up-to-date. 1585 | """ 1586 | n_ally_alive = 0 1587 | n_enemy_alive = 0 1588 | 1589 | # Store previous state 1590 | self.previous_ally_units = deepcopy(self.agents) 1591 | self.previous_enemy_units = deepcopy(self.enemies) 1592 | 1593 | for al_id, al_unit in self.agents.items(): 1594 | updated = False 1595 | for unit in self._obs.observation.raw_data.units: 1596 | if al_unit.tag == unit.tag: 1597 | self.agents[al_id] = unit 1598 | updated = True 1599 | n_ally_alive += 1 1600 | break 1601 | 1602 | if not updated: # dead 1603 | al_unit.health = 0 1604 | 1605 | for e_id, e_unit in self.enemies.items(): 1606 | updated = False 1607 | for unit in self._obs.observation.raw_data.units: 1608 | if e_unit.tag == unit.tag: 1609 | self.enemies[e_id] = unit 1610 | updated = True 1611 | n_enemy_alive += 1 1612 | break 1613 | 1614 | if not updated: # dead 1615 | e_unit.health = 0 1616 | 1617 | if ( 1618 | n_ally_alive == 0 1619 | and n_enemy_alive > 0 1620 | or self.only_medivac_left(ally=True) 1621 | ): 1622 | return -1 # lost 1623 | if ( 1624 | n_ally_alive > 0 1625 | and n_enemy_alive == 0 1626 | or self.only_medivac_left(ally=False) 1627 | ): 1628 | return 1 # won 1629 | if n_ally_alive == 0 and n_enemy_alive == 0: 1630 | return 0 1631 | 1632 | return None 1633 | 1634 | def _init_ally_unit_types(self, min_unit_type): 1635 | """Initialise ally unit types. Should be called once from the 1636 | init_units function. 1637 | """ 1638 | self._min_unit_type = min_unit_type 1639 | if self.map_type == "marines": 1640 | self.marine_id = min_unit_type 1641 | elif self.map_type == "stalkers_and_zealots": 1642 | self.stalker_id = min_unit_type 1643 | self.zealot_id = min_unit_type + 1 1644 | elif self.map_type == "colossi_stalkers_zealots": 1645 | self.colossus_id = min_unit_type 1646 | self.stalker_id = min_unit_type + 1 1647 | self.zealot_id = min_unit_type + 2 1648 | elif self.map_type == "MMM": 1649 | self.marauder_id = min_unit_type 1650 | self.marine_id = min_unit_type + 1 1651 | self.medivac_id = min_unit_type + 2 1652 | elif self.map_type == "zealots": 1653 | self.zealot_id = min_unit_type 1654 | elif self.map_type == "hydralisks": 1655 | self.hydralisk_id = min_unit_type 1656 | elif self.map_type == "stalkers": 1657 | self.stalker_id = min_unit_type 1658 | elif self.map_type == "colossus": 1659 | self.colossus_id = min_unit_type 1660 | elif self.map_type == "bane": 1661 | self.baneling_id = min_unit_type 1662 | self.zergling_id = min_unit_type + 1 1663 | 1664 | def only_medivac_left(self, ally): 1665 | """Check if only Medivac units are left.""" 1666 | if self.map_type != "MMM": 1667 | return False 1668 | 1669 | if ally: 1670 | units_alive = [ 1671 | a 1672 | for a in self.agents.values() 1673 | if (a.health > 0 and a.unit_type != self.medivac_id) 1674 | ] 1675 | if len(units_alive) == 0: 1676 | return True 1677 | return False 1678 | else: 1679 | units_alive = [ 1680 | a 1681 | for a in self.enemies.values() 1682 | if (a.health > 0 and a.unit_type != self.medivac_id) 1683 | ] 1684 | if len(units_alive) == 1 and units_alive[0].unit_type == 54: 1685 | return True 1686 | return False 1687 | 1688 | def get_unit_by_id(self, a_id): 1689 | """Get unit by ID.""" 1690 | return self.agents[a_id] 1691 | 1692 | def get_stats(self): 1693 | stats = { 1694 | "battles_won": self.battles_won, 1695 | "battles_game": self.battles_game, 1696 | "battles_draw": self.timeouts, 1697 | "win_rate": self.battles_won / self.battles_game, 1698 | "timeouts": self.timeouts, 1699 | "restarts": self.force_restarts, 1700 | } 1701 | return stats 1702 | 1703 | def get_env_info(self): 1704 | env_info = super().get_env_info() 1705 | env_info["agent_features"] = self.ally_state_attr_names 1706 | env_info["enemy_features"] = self.enemy_state_attr_names 1707 | return env_info 1708 | -------------------------------------------------------------------------------- /smac/examples/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/oxwhirl/smac/d6aab33f76abc3849c50463a8592a84f59a5ef84/smac/examples/__init__.py -------------------------------------------------------------------------------- /smac/examples/pettingzoo/README.rst: -------------------------------------------------------------------------------- 1 | SMAC on PettingZoo 2 | ================== 3 | 4 | This example shows how to run SMAC environments with PettingZoo multi-agent API. 5 | 6 | Instructions 7 | ------------ 8 | 9 | To get started, first install PettingZoo with ``pip install pettingzoo``. 10 | 11 | The SMAC environment for PettingZoo, ``StarCraft2PZEnv``, can be initialized with two different API templates. 12 | * **AEC**: PettingZoo is based in the *Agent Environment Cycle* game model, more information about "AEC" can be read in the following `paper `_. To create a SMAC environment as an "AEC" PettingZoo game model use: :: 13 | 14 | from smac.env.pettingzoo import StarCraft2PZEnv 15 | 16 | env = StarCraft2PZEnv.env() 17 | 18 | * **Parallel**: PettingZoo also supports parallel environments where all agents have simultaneous actions and observations. This type of environment can be created as follows: :: 19 | 20 | from smac.env.pettingzoo import StarCraft2PZEnv 21 | 22 | env = StarCraft2PZEnv.parallel_env() 23 | 24 | `pettingzoo_demo.py` has an example of a SMAC environment being used as a PettingZoo "AEC" environment. With `env.render()` it is possible to output an instance of the environment as a frame in pygame. This is useful for debugging purposes. 25 | 26 | | See https://www.pettingzoo.ml/api for more documentation. 27 | -------------------------------------------------------------------------------- /smac/examples/pettingzoo/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/oxwhirl/smac/d6aab33f76abc3849c50463a8592a84f59a5ef84/smac/examples/pettingzoo/__init__.py -------------------------------------------------------------------------------- /smac/examples/pettingzoo/pettingzoo_demo.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import random 6 | import numpy as np 7 | from smac.env.pettingzoo import StarCraft2PZEnv 8 | 9 | 10 | def main(): 11 | """ 12 | Runs an env object with random actions. 13 | """ 14 | env = StarCraft2PZEnv.env() 15 | episodes = 10 16 | 17 | total_reward = 0 18 | done = False 19 | completed_episodes = 0 20 | 21 | while completed_episodes < episodes: 22 | env.reset() 23 | for agent in env.agent_iter(): 24 | env.render() 25 | 26 | obs, reward, terms, truncs, _ = env.last() 27 | total_reward += reward 28 | if terms or truncs: 29 | action = None 30 | elif isinstance(obs, dict) and "action_mask" in obs: 31 | action = random.choice(np.flatnonzero(obs["action_mask"])) 32 | else: 33 | action = env.action_spaces[agent].sample() 34 | env.step(action) 35 | 36 | completed_episodes += 1 37 | 38 | env.close() 39 | 40 | print("Average total reward", total_reward / episodes) 41 | 42 | 43 | if __name__ == "__main__": 44 | main() 45 | -------------------------------------------------------------------------------- /smac/examples/random_agents.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | from smac.env import StarCraft2Env 6 | import numpy as np 7 | 8 | 9 | def main(): 10 | env = StarCraft2Env(map_name="8m") 11 | env_info = env.get_env_info() 12 | 13 | n_actions = env_info["n_actions"] 14 | n_agents = env_info["n_agents"] 15 | 16 | n_episodes = 10 17 | 18 | for e in range(n_episodes): 19 | env.reset() 20 | terminated = False 21 | episode_reward = 0 22 | 23 | while not terminated: 24 | obs = env.get_obs() 25 | state = env.get_state() 26 | # env.render() # Uncomment for rendering 27 | 28 | actions = [] 29 | for agent_id in range(n_agents): 30 | avail_actions = env.get_avail_agent_actions(agent_id) 31 | avail_actions_ind = np.nonzero(avail_actions)[0] 32 | action = np.random.choice(avail_actions_ind) 33 | actions.append(action) 34 | 35 | reward, terminated, _ = env.step(actions) 36 | episode_reward += reward 37 | 38 | print("Total reward in episode {} = {}".format(e, episode_reward)) 39 | 40 | env.close() 41 | 42 | 43 | if __name__ == "__main__": 44 | main() 45 | -------------------------------------------------------------------------------- /smac/examples/rllib/README.rst: -------------------------------------------------------------------------------- 1 | SMAC on RLlib 2 | ============= 3 | 4 | This example shows how to run SMAC environments with RLlib multi-agent. 5 | 6 | Instructions 7 | ------------ 8 | 9 | To get started, first install RLlib with ``pip install -U ray[rllib]``. You will also need TensorFlow installed. 10 | 11 | In ``run_ppo.py``, each agent will be controlled by an independent PPO policy (the policies share weights). This setup serves as a single-agent baseline for this task. 12 | 13 | In ``run_qmix.py``, the agents are controlled by the multi-agent QMIX policy. This setup is an example of centralized training and decentralized execution. 14 | 15 | See https://ray.readthedocs.io/en/latest/rllib.html for more documentation. 16 | -------------------------------------------------------------------------------- /smac/examples/rllib/__init__.py: -------------------------------------------------------------------------------- 1 | from smac.examples.rllib.env import RLlibStarCraft2Env 2 | from smac.examples.rllib.model import MaskedActionsModel 3 | 4 | __all__ = ["RLlibStarCraft2Env", "MaskedActionsModel"] 5 | -------------------------------------------------------------------------------- /smac/examples/rllib/env.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import random 6 | 7 | import numpy as np 8 | 9 | from gym.spaces import Discrete, Box, Dict 10 | 11 | from ray import rllib 12 | 13 | from smac.env import StarCraft2Env 14 | 15 | 16 | class RLlibStarCraft2Env(rllib.MultiAgentEnv): 17 | """Wraps a smac StarCraft env to be compatible with RLlib multi-agent.""" 18 | 19 | def __init__(self, **smac_args): 20 | """Create a new multi-agent StarCraft env compatible with RLlib. 21 | 22 | Arguments: 23 | smac_args (dict): Arguments to pass to the underlying 24 | smac.env.starcraft.StarCraft2Env instance. 25 | 26 | Examples: 27 | >>> from smac.examples.rllib import RLlibStarCraft2Env 28 | >>> env = RLlibStarCraft2Env(map_name="8m") 29 | >>> print(env.reset()) 30 | """ 31 | 32 | self._env = StarCraft2Env(**smac_args) 33 | self._ready_agents = [] 34 | self.observation_space = Dict( 35 | { 36 | "obs": Box(-1, 1, shape=(self._env.get_obs_size(),)), 37 | "action_mask": Box( 38 | 0, 1, shape=(self._env.get_total_actions(),) 39 | ), 40 | } 41 | ) 42 | self.action_space = Discrete(self._env.get_total_actions()) 43 | 44 | def reset(self): 45 | """Resets the env and returns observations from ready agents. 46 | 47 | Returns: 48 | obs (dict): New observations for each ready agent. 49 | """ 50 | 51 | obs_list, state_list = self._env.reset() 52 | return_obs = {} 53 | for i, obs in enumerate(obs_list): 54 | return_obs[i] = { 55 | "action_mask": np.array(self._env.get_avail_agent_actions(i)), 56 | "obs": obs, 57 | } 58 | 59 | self._ready_agents = list(range(len(obs_list))) 60 | return return_obs 61 | 62 | def step(self, action_dict): 63 | """Returns observations from ready agents. 64 | 65 | The returns are dicts mapping from agent_id strings to values. The 66 | number of agents in the env can vary over time. 67 | 68 | Returns 69 | ------- 70 | obs (dict): New observations for each ready agent. 71 | rewards (dict): Reward values for each ready agent. If the 72 | episode is just started, the value will be None. 73 | dones (dict): Done values for each ready agent. The special key 74 | "__all__" (required) is used to indicate env termination. 75 | infos (dict): Optional info values for each agent id. 76 | """ 77 | 78 | actions = [] 79 | for i in self._ready_agents: 80 | if i not in action_dict: 81 | raise ValueError( 82 | "You must supply an action for agent: {}".format(i) 83 | ) 84 | actions.append(action_dict[i]) 85 | 86 | if len(actions) != len(self._ready_agents): 87 | raise ValueError( 88 | "Unexpected number of actions: {}".format( 89 | action_dict, 90 | ) 91 | ) 92 | 93 | rew, done, info = self._env.step(actions) 94 | obs_list = self._env.get_obs() 95 | return_obs = {} 96 | for i, obs in enumerate(obs_list): 97 | return_obs[i] = { 98 | "action_mask": self._env.get_avail_agent_actions(i), 99 | "obs": obs, 100 | } 101 | rews = {i: rew / len(obs_list) for i in range(len(obs_list))} 102 | dones = {i: done for i in range(len(obs_list))} 103 | dones["__all__"] = done 104 | infos = {i: info for i in range(len(obs_list))} 105 | 106 | self._ready_agents = list(range(len(obs_list))) 107 | return return_obs, rews, dones, infos 108 | 109 | def close(self): 110 | """Close the environment""" 111 | self._env.close() 112 | 113 | def seed(self, seed): 114 | random.seed(seed) 115 | np.random.seed(seed) 116 | -------------------------------------------------------------------------------- /smac/examples/rllib/model.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import tensorflow as tf 6 | 7 | from ray.rllib.models import Model 8 | from ray.rllib.models.tf.misc import normc_initializer 9 | 10 | 11 | class MaskedActionsModel(Model): 12 | """Custom RLlib model that emits -inf logits for invalid actions. 13 | 14 | This is used to handle the variable-length StarCraft action space. 15 | """ 16 | 17 | def _build_layers_v2(self, input_dict, num_outputs, options): 18 | action_mask = input_dict["obs"]["action_mask"] 19 | if num_outputs != action_mask.shape[1].value: 20 | raise ValueError( 21 | "This model assumes num outputs is equal to max avail actions", 22 | num_outputs, 23 | action_mask, 24 | ) 25 | 26 | # Standard fully connected network 27 | last_layer = input_dict["obs"]["obs"] 28 | hiddens = options.get("fcnet_hiddens") 29 | for i, size in enumerate(hiddens): 30 | label = "fc{}".format(i) 31 | last_layer = tf.layers.dense( 32 | last_layer, 33 | size, 34 | kernel_initializer=normc_initializer(1.0), 35 | activation=tf.nn.tanh, 36 | name=label, 37 | ) 38 | action_logits = tf.layers.dense( 39 | last_layer, 40 | num_outputs, 41 | kernel_initializer=normc_initializer(0.01), 42 | activation=None, 43 | name="fc_out", 44 | ) 45 | 46 | # Mask out invalid actions (use tf.float32.min for stability) 47 | inf_mask = tf.maximum(tf.log(action_mask), tf.float32.min) 48 | masked_logits = inf_mask + action_logits 49 | 50 | return masked_logits, last_layer 51 | -------------------------------------------------------------------------------- /smac/examples/rllib/run_ppo.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | """Example of running StarCraft2 with RLlib PPO. 6 | 7 | In this setup, each agent will be controlled by an independent PPO policy. 8 | However the policies share weights. 9 | 10 | Increase the level of parallelism by changing --num-workers. 11 | """ 12 | 13 | import argparse 14 | 15 | import ray 16 | from ray.tune import run_experiments, register_env 17 | from ray.rllib.models import ModelCatalog 18 | 19 | from smac.examples.rllib.env import RLlibStarCraft2Env 20 | from smac.examples.rllib.model import MaskedActionsModel 21 | 22 | 23 | if __name__ == "__main__": 24 | parser = argparse.ArgumentParser() 25 | parser.add_argument("--num-iters", type=int, default=100) 26 | parser.add_argument("--num-workers", type=int, default=2) 27 | parser.add_argument("--map-name", type=str, default="8m") 28 | args = parser.parse_args() 29 | 30 | ray.init() 31 | 32 | register_env("smac", lambda smac_args: RLlibStarCraft2Env(**smac_args)) 33 | ModelCatalog.register_custom_model("mask_model", MaskedActionsModel) 34 | 35 | run_experiments( 36 | { 37 | "ppo_sc2": { 38 | "run": "PPO", 39 | "env": "smac", 40 | "stop": { 41 | "training_iteration": args.num_iters, 42 | }, 43 | "config": { 44 | "num_workers": args.num_workers, 45 | "observation_filter": "NoFilter", # breaks the action mask 46 | "vf_share_layers": True, # no separate value model 47 | "env_config": { 48 | "map_name": args.map_name, 49 | }, 50 | "model": { 51 | "custom_model": "mask_model", 52 | }, 53 | }, 54 | }, 55 | } 56 | ) 57 | -------------------------------------------------------------------------------- /smac/examples/rllib/run_qmix.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | """Example of running StarCraft2 with RLlib QMIX. 6 | 7 | This assumes all agents are homogeneous. The agents are grouped and assigned 8 | to the multi-agent QMIX policy. Note that the default hyperparameters for 9 | RLlib QMIX are different from pymarl's QMIX. 10 | """ 11 | 12 | import argparse 13 | from gym.spaces import Tuple 14 | 15 | import ray 16 | from ray.tune import run_experiments, register_env 17 | 18 | from smac.examples.rllib.env import RLlibStarCraft2Env 19 | 20 | 21 | if __name__ == "__main__": 22 | parser = argparse.ArgumentParser() 23 | parser.add_argument("--num-iters", type=int, default=100) 24 | parser.add_argument("--num-workers", type=int, default=2) 25 | parser.add_argument("--map-name", type=str, default="8m") 26 | args = parser.parse_args() 27 | 28 | def env_creator(smac_args): 29 | env = RLlibStarCraft2Env(**smac_args) 30 | agent_list = list(range(env._env.n_agents)) 31 | grouping = { 32 | "group_1": agent_list, 33 | } 34 | obs_space = Tuple([env.observation_space for i in agent_list]) 35 | act_space = Tuple([env.action_space for i in agent_list]) 36 | return env.with_agent_groups( 37 | grouping, obs_space=obs_space, act_space=act_space 38 | ) 39 | 40 | ray.init() 41 | register_env("sc2_grouped", env_creator) 42 | 43 | run_experiments( 44 | { 45 | "qmix_sc2": { 46 | "run": "QMIX", 47 | "env": "sc2_grouped", 48 | "stop": { 49 | "training_iteration": args.num_iters, 50 | }, 51 | "config": { 52 | "num_workers": args.num_workers, 53 | "env_config": { 54 | "map_name": args.map_name, 55 | }, 56 | }, 57 | }, 58 | } 59 | ) 60 | --------------------------------------------------------------------------------