├── .gitignore
├── .pre-commit-config.yaml
├── LICENSE
├── README.md
├── docs
├── smac-official.png
└── smac.md
├── pyproject.toml
├── setup.py
└── smac
├── __init__.py
├── bin
├── __init__.py
└── map_list.py
├── env
├── __init__.py
├── multiagentenv.py
├── pettingzoo
│ ├── StarCraft2PZEnv.py
│ ├── __init__.py
│ └── test
│ │ ├── __init__.py
│ │ ├── all_test.py
│ │ └── smac_pettingzoo_test.py
└── starcraft2
│ ├── __init__.py
│ ├── maps
│ ├── SMAC_Maps
│ │ ├── 10m_vs_11m.SC2Map
│ │ ├── 1c3s5z.SC2Map
│ │ ├── 25m.SC2Map
│ │ ├── 27m_vs_30m.SC2Map
│ │ ├── 2c_vs_64zg.SC2Map
│ │ ├── 2m_vs_1z.SC2Map
│ │ ├── 2s3z.SC2Map
│ │ ├── 2s_vs_1sc.SC2Map
│ │ ├── 3m.SC2Map
│ │ ├── 3s5z.SC2Map
│ │ ├── 3s5z_vs_3s6z.SC2Map
│ │ ├── 3s_vs_3z.SC2Map
│ │ ├── 3s_vs_4z.SC2Map
│ │ ├── 3s_vs_5z.SC2Map
│ │ ├── 5m_vs_6m.SC2Map
│ │ ├── 6h_vs_8z.SC2Map
│ │ ├── 8m.SC2Map
│ │ ├── 8m_vs_9m.SC2Map
│ │ ├── MMM.SC2Map
│ │ ├── MMM2.SC2Map
│ │ ├── bane_vs_bane.SC2Map
│ │ ├── corridor.SC2Map
│ │ └── so_many_baneling.SC2Map
│ ├── __init__.py
│ └── smac_maps.py
│ ├── render.py
│ └── starcraft2.py
└── examples
├── __init__.py
├── pettingzoo
├── README.rst
├── __init__.py
└── pettingzoo_demo.py
├── random_agents.py
└── rllib
├── README.rst
├── __init__.py
├── env.py
├── model.py
├── run_ppo.py
└── run_qmix.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 | cover/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | .pybuilder/
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | # For a library or package, you might want to ignore these files since the code is
87 | # intended to run in multiple environments; otherwise, check them in:
88 | # .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
98 | __pypackages__/
99 |
100 | # Celery stuff
101 | celerybeat-schedule
102 | celerybeat.pid
103 |
104 | # SageMath parsed files
105 | *.sage.py
106 |
107 | # Environments
108 | .env
109 | .venv
110 | # env/
111 | venv/
112 | ENV/
113 | env.bak/
114 | venv.bak/
115 |
116 | # Spyder project settings
117 | .spyderproject
118 | .spyproject
119 |
120 | # Rope project settings
121 | .ropeproject
122 |
123 | # mkdocs documentation
124 | /site
125 |
126 | # mypy
127 | .mypy_cache/
128 | .dmypy.json
129 | dmypy.json
130 |
131 | # Pyre type checker
132 | .pyre/
133 |
134 | # pytype static type analyzer
135 | .pytype/
136 |
137 | # Cython debug symbols
138 | cython_debug/
139 |
--------------------------------------------------------------------------------
/.pre-commit-config.yaml:
--------------------------------------------------------------------------------
1 | repos:
2 | - repo: https://github.com/ambv/black
3 | rev: stable
4 | hooks:
5 | - id: black
6 | language_version: python3.8
7 | - repo: https://gitlab.com/pycqa/flake8
8 | rev: '3.8.4'
9 | hooks:
10 | - id: flake8
11 | additional_dependencies: [flake8-bugbear]
12 | args: ["--show-source"]
13 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2019 whirl
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 | > __Note__
6 | > SMACv2 is out! Check it out [here](https://github.com/oxwhirl/smacv2).
7 |
8 | > __Warning__
9 | > **Please pay attention to the version of SC2 used for your experiments.** Performance is **not** always comparable between versions. The results in the [SMAC paper](https://arxiv.org/abs/1902.04043) use `SC2.4.6.2.69232` not `SC2.4.10`.
10 |
11 | # SMAC - StarCraft Multi-Agent Challenge
12 |
13 | [SMAC](https://github.com/oxwhirl/smac) is [WhiRL](http://whirl.cs.ox.ac.uk)'s environment for research in the field of cooperative multi-agent reinforcement learning (MARL) based on [Blizzard](http://blizzard.com)'s [StarCraft II](https://en.wikipedia.org/wiki/StarCraft_II:_Wings_of_Liberty) RTS game. SMAC makes use of Blizzard's [StarCraft II Machine Learning API](https://github.com/Blizzard/s2client-proto) and [DeepMind](https://deepmind.com)'s [PySC2](https://github.com/deepmind/pysc2) to provide a convenient interface for autonomous agents to interact with StarCraft II, getting observations and performing actions. Unlike the [PySC2](https://github.com/deepmind/pysc2), SMAC concentrates on *decentralised micromanagement* scenarios, where each unit of the game is controlled by an individual RL agent.
14 |
15 | Please refer to the accompanying [paper](https://arxiv.org/abs/1902.04043) and [blogpost](https://blog.ucldark.com/2019/02/12/smac.html) for the outline of our motivation for using SMAC as a testbed for MARL research and the initial experimental results.
16 |
17 | ## About
18 |
19 | Together with SMAC we also release [PyMARL](https://github.com/oxwhirl/pymarl) - our [PyTorch](https://github.com/pytorch/pytorch) framework for MARL research, which includes implementations of several state-of-the-art algorithms, such as [QMIX](https://arxiv.org/abs/1803.11485) and [COMA](https://arxiv.org/abs/1705.08926).
20 |
21 | Data from the runs used in the paper is included [here](https://github.com/oxwhirl/smac/releases/download/v1/smac_run_data.json). **These runs are outdated based on recent changes in StarCraft II. If you ran your experiments using the current version of SMAC, you mustn't compare your results with the ones provided here.**
22 |
23 | # Quick Start
24 |
25 | ## Installing SMAC
26 |
27 | You can install SMAC by using the following command:
28 |
29 | ```shell
30 | pip install git+https://github.com/oxwhirl/smac.git
31 | ```
32 |
33 | Alternatively, you can clone the SMAC repository and then install `smac` with its dependencies:
34 |
35 | ```shell
36 | git clone https://github.com/oxwhirl/smac.git
37 | pip install -e smac/
38 | ```
39 |
40 | *NOTE*: If you want to extend SMAC, please install the package as follows:
41 |
42 | ```shell
43 | git clone https://github.com/oxwhirl/smac.git
44 | cd smac
45 | pip install -e ".[dev]"
46 | pre-commit install
47 | ```
48 |
49 | You may also need to upgrade pip: `pip install --upgrade pip` for the install to work.
50 |
51 | ## Installing StarCraft II
52 |
53 | SMAC is based on the full game of StarCraft II (versions >= 3.16.1). To install the game, follow the commands bellow.
54 |
55 | ### Linux
56 |
57 | Please use the Blizzard's [repository](https://github.com/Blizzard/s2client-proto#downloads) to download the Linux version of StarCraft II. By default, the game is expected to be in `~/StarCraftII/` directory. This can be changed by setting the environment variable `SC2PATH`.
58 |
59 | ### MacOS/Windows
60 |
61 | Please install StarCraft II from [Battle.net](https://battle.net). The free [Starter Edition](http://battle.net/sc2/en/legacy-of-the-void/) also works. PySC2 will find the latest binary should you use the default install location. Otherwise, similar to the Linux version, you would need to set the `SC2PATH` environment variable with the correct location of the game.
62 |
63 | ## SMAC maps
64 |
65 | SMAC is composed of many combat scenarios with pre-configured maps. Before SMAC can be used, these maps need to be downloaded into the `Maps` directory of StarCraft II.
66 |
67 | Download the [SMAC Maps](https://github.com/oxwhirl/smac/releases/download/v0.1-beta1/SMAC_Maps.zip) and extract them to your `$SC2PATH/Maps` directory. If you installed SMAC via git, simply copy the `SMAC_Maps` directory from `smac/env/starcraft2/maps/` into `$SC2PATH/Maps` directory.
68 |
69 | ### List the maps
70 |
71 | To see the list of SMAC maps, together with the number of ally and enemy units and episode limit, run:
72 |
73 | ```shell
74 | python -m smac.bin.map_list
75 | ```
76 |
77 | ### Creating new maps
78 |
79 | Users can extend SMAC by adding new maps/scenarios. To this end, one needs to:
80 |
81 | - Design a new map/scenario using StarCraft II Editor:
82 | - Please take a close look at the existing maps to understand the basics that we use (e.g. Triggers, Units, etc),
83 | - We make use of special RL units which never automatically start attacking the enemy. [Here](https://docs.google.com/document/d/1BfAM_AtZWBRhUiOBcMkb_uK4DAZW3CpvO79-vnEOKxA/edit?usp=sharing) is the step-by-step guide on how to create new RL units based on existing SC2 units,
84 | - Add the map information in [smac_maps.py](https://github.com/oxwhirl/smac/blob/master/smac/env/starcraft2/maps/smac_maps.py),
85 | - The newly designed RL units have new ids which need to be handled in [starcraft2.py](https://github.com/oxwhirl/smac/blob/master/smac/env/starcraft2/starcraft2.py). Specifically, for heterogenious maps containing more than one unit types, one needs to manually set the unit ids in the `_init_ally_unit_types()` function.
86 |
87 | ## Testing SMAC
88 |
89 | Please run the following command to make sure that `smac` and its maps are properly installed.
90 |
91 | ```bash
92 | python -m smac.examples.random_agents
93 | ```
94 |
95 | ## Saving and Watching StarCraft II Replays
96 |
97 | ### Saving a replay
98 |
99 | If you’ve using our [PyMARL](https://github.com/oxwhirl/pymarl) framework for multi-agent RL, here’s what needs to be done:
100 | 1. **Saving models**: We run experiments on *Linux* servers with `save_model = True` (also `save_model_interval` is relevant) setting so that we have training checkpoints (parameters of neural networks) saved (click [here](https://github.com/oxwhirl/pymarl#saving-and-loading-learnt-models) for more details).
101 | 2. **Loading models**: Learnt models can be loaded using the `checkpoint_path` parameter. If you run PyMARL on *MacOS* (or *Windows*) while also setting `save_replay=True`, this will save a .SC2Replay file for `test_nepisode` episodes on the test mode (no exploration) in the Replay directory of StarCraft II. (click [here](https://github.com/oxwhirl/pymarl#watching-starcraft-ii-replays) for more details).
102 |
103 | If you want to save replays without using PyMARL, simply call the `save_replay()` function of SMAC's StarCraft2Env in your training/testing code. This will save a replay of all epsidoes since the launch of the StarCraft II client.
104 |
105 | The easiest way to save and later watch a replay on Linux is to use [Wine](https://www.winehq.org/).
106 |
107 | ### Watching a replay
108 |
109 | You can watch the saved replay directly within the StarCraft II client on MacOS/Windows by *clicking on the corresponding Replay file*.
110 |
111 | You can also watch saved replays by running:
112 |
113 | ```shell
114 | python -m pysc2.bin.play --norender --replay
115 | ```
116 |
117 | This works for any replay as long as the map can be found by the game.
118 |
119 | For more information, please refer to [PySC2](https://github.com/deepmind/pysc2) documentation.
120 |
121 | # Documentation
122 |
123 | For the detailed description of the environment, read the [SMAC documentation](docs/smac.md).
124 |
125 | The initial results of our experiments using SMAC can be found in the [accompanying paper](https://arxiv.org/abs/1902.04043).
126 |
127 | # Citing SMAC
128 |
129 | If you use SMAC in your research, please cite the [SMAC paper](https://arxiv.org/abs/1902.04043).
130 |
131 | *M. Samvelyan, T. Rashid, C. Schroeder de Witt, G. Farquhar, N. Nardelli, T.G.J. Rudner, C.-M. Hung, P.H.S. Torr, J. Foerster, S. Whiteson. The StarCraft Multi-Agent Challenge, CoRR abs/1902.04043, 2019.*
132 |
133 | In BibTeX format:
134 |
135 | ```tex
136 | @article{samvelyan19smac,
137 | title = {{The} {StarCraft} {Multi}-{Agent} {Challenge}},
138 | author = {Mikayel Samvelyan and Tabish Rashid and Christian Schroeder de Witt and Gregory Farquhar and Nantas Nardelli and Tim G. J. Rudner and Chia-Man Hung and Philiph H. S. Torr and Jakob Foerster and Shimon Whiteson},
139 | journal = {CoRR},
140 | volume = {abs/1902.04043},
141 | year = {2019},
142 | }
143 | ```
144 |
145 | # Code Examples
146 |
147 | Below is a small code example which illustrates how SMAC can be used. Here, individual agents execute random policies after receiving the observations and global state from the environment.
148 |
149 | If you want to try the state-of-the-art algorithms (such as [QMIX](https://arxiv.org/abs/1803.11485) and [COMA](https://arxiv.org/abs/1705.08926)) on SMAC, make use of [PyMARL](https://github.com/oxwhirl/pymarl) - our framework for MARL research.
150 |
151 | ```python
152 | from smac.env import StarCraft2Env
153 | import numpy as np
154 |
155 |
156 | def main():
157 | env = StarCraft2Env(map_name="8m")
158 | env_info = env.get_env_info()
159 |
160 | n_actions = env_info["n_actions"]
161 | n_agents = env_info["n_agents"]
162 |
163 | n_episodes = 10
164 |
165 | for e in range(n_episodes):
166 | env.reset()
167 | terminated = False
168 | episode_reward = 0
169 |
170 | while not terminated:
171 | obs = env.get_obs()
172 | state = env.get_state()
173 | # env.render() # Uncomment for rendering
174 |
175 | actions = []
176 | for agent_id in range(n_agents):
177 | avail_actions = env.get_avail_agent_actions(agent_id)
178 | avail_actions_ind = np.nonzero(avail_actions)[0]
179 | action = np.random.choice(avail_actions_ind)
180 | actions.append(action)
181 |
182 | reward, terminated, _ = env.step(actions)
183 | episode_reward += reward
184 |
185 | print("Total reward in episode {} = {}".format(e, episode_reward))
186 |
187 | env.close()
188 |
189 | ```
190 |
191 | ## RLlib Examples
192 |
193 | You can also run SMAC environments in [RLlib](https://rllib.io), which includes scalable algorithms such as [PPO](https://ray.readthedocs.io/en/latest/rllib-algorithms.html#proximal-policy-optimization-ppo) and [IMPALA](https://ray.readthedocs.io/en/latest/rllib-algorithms.html#importance-weighted-actor-learner-architecture-impala). Check out the example code [here](https://github.com/oxwhirl/smac/tree/master/smac/examples/rllib).
194 |
195 | ## PettingZoo Example
196 |
197 | Thanks to [Rodrigo de Lazcano](https://github.com/rodrigodelazcano), SMAC now supports [PettingZoo API](https://github.com/PettingZoo-Team/PettingZoo) and PyGame environment rendering. Check out the example code [here](https://github.com/oxwhirl/smac/tree/master/smac/examples/pettingzoo).
198 |
199 |
--------------------------------------------------------------------------------
/docs/smac-official.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/oxwhirl/smac/d6aab33f76abc3849c50463a8592a84f59a5ef84/docs/smac-official.png
--------------------------------------------------------------------------------
/docs/smac.md:
--------------------------------------------------------------------------------
1 | ## Table of Contents
2 |
3 | - [StarCraft II](#starcraft-ii)
4 | - [Micromanagement](#micromanagement)
5 | - [SMAC](#smac)
6 | - [Scenarios](#scenarios)
7 | - [State and Observations](#state-and-observations)
8 | - [Action Space](#action-space)
9 | - [Rewards](#rewards)
10 | - [Environment Settings](#environment-settings)
11 |
12 | ## StarCraft II
13 |
14 | SMAC is based on the popular real-time strategy (RTS) game [StarCraft II](http://us.battle.net/sc2/en/game/guide/whats-sc2) written by [Blizzard](http://blizzard.com/).
15 | In a regular full game of StarCraft II, one or more humans compete against each other or against a built-in game AI to gather resources, construct buildings, and build armies of units to defeat their opponents.
16 |
17 | Akin to most RTSs, StarCraft has two main gameplay components: macromanagement and micromanagement.
18 | - _Macromanagement_ (macro) refers to high-level strategic considerations, such as economy and resource management.
19 | - _Micromanagement_ (micro) refers to fine-grained control of individual units.
20 |
21 | ### Micromanagement
22 |
23 | StarCraft has been used as a research platform for AI, and more recently, RL. Typically, the game is framed as a competitive problem: an agent takes the role of a human player, making macromanagement decisions and performing micromanagement as a puppeteer that issues orders to individual units from a centralised controller.
24 |
25 | In order to build a rich multi-agent testbed, we instead focus solely on micromanagement.
26 | Micro is a vital aspect of StarCraft gameplay with a high skill ceiling, and is practiced in isolation by amateur and professional players.
27 | For SMAC, we leverage the natural multi-agent structure of micromanagement by proposing a modified version of the problem designed specifically for decentralised control.
28 | In particular, we require that each unit be controlled by an independent agent that conditions only on local observations restricted to a limited field of view centred on that unit.
29 | Groups of these agents must be trained to solve challenging combat scenarios, battling an opposing army under the centralised control of the game's built-in scripted AI.
30 |
31 | Proper micro of units during battles will maximise the damage dealt to enemy units while minimising damage received, and requires a range of skills.
32 | For example, one important technique is _focus fire_, i.e., ordering units to jointly attack and kill enemy units one after another. When focusing fire, it is important to avoid _overkill_: inflicting more damage to units than is necessary to kill them.
33 |
34 | Other common micromanagement techniques include: assembling units into formations based on their armour types, making enemy units give chase while maintaining enough distance so that little or no damage is incurred (_kiting_), coordinating the positioning of units to attack from different directions or taking advantage of the terrain to defeat the enemy.
35 |
36 | Learning these rich cooperative behaviours under partial observability is challenging task, which can be used to evaluate the effectiveness of multi-agent reinforcement learning (MARL) algorithms.
37 |
38 | ## SMAC
39 |
40 | SMAC uses the [StarCraft II Learning Environment](https://github.com/deepmind/pysc2) to introduce a cooperative MARL environment.
41 |
42 | ### Scenarios
43 |
44 | SMAC consists of a set of StarCraft II micro scenarios which aim to evaluate how well independent agents are able to learn coordination to solve complex tasks.
45 | These scenarios are carefully designed to necessitate the learning of one or more micromanagement techniques to defeat the enemy.
46 | Each scenario is a confrontation between two armies of units.
47 | The initial position, number, and type of units in each army varies from scenario to scenario, as does the presence or absence of elevated or impassable terrain.
48 |
49 | The first army is controlled by the learned allied agents.
50 | The second army consists of enemy units controlled by the built-in game AI, which uses carefully handcrafted non-learned heuristics.
51 | At the beginning of each episode, the game AI instructs its units to attack the allied agents using its scripted strategies.
52 | An episode ends when all units of either army have died or when a pre-specified time limit is reached (in which case the game is counted as a defeat for the allied agents).
53 | The goal for each scenario is to maximise the win rate of the learned policies, i.e., the expected ratio of games won to games played.
54 | To speed up learning, the enemy AI units are ordered to attack the agents' spawning point in the beginning of each episode.
55 |
56 | Perhaps the simplest scenarios are _symmetric_ battle scenarios.
57 | The most straightforward of these scenarios are _homogeneous_, i.e., each army is composed of only a single unit type (e.g., Marines).
58 | A winning strategy in this setting would be to focus fire, ideally without overkill.
59 | _Heterogeneous_ symmetric scenarios, in which there is more than a single unit type on each side (e.g., Stalkers and Zealots), are more difficult to solve.
60 | Such challenges are particularly interesting when some of the units are extremely effective against others (this is known as _countering_), for example, by dealing bonus damage to a particular armour type.
61 | In such a setting, allied agents must deduce this property of the game and design an intelligent strategy to protect teammates vulnerable to certain enemy attacks.
62 |
63 | SMAC also includes more challenging scenarios, for example, in which the enemy army outnumbers the allied army by one or more units. In such _asymmetric_ scenarios it is essential to consider the health of enemy units in order to effectively target the desired opponent.
64 |
65 | Lastly, SMAC offers a set of interesting _micro-trick_ challenges that require a higher-level of cooperation and a specific micromanagement trick to defeat the enemy.
66 | An example of a challenge scenario is _2m_vs_1z_ (aka Marine Double Team), where two Terran Marines need to defeat an enemy Zealot. In this setting, the Marines must design a strategy which does not allow the Zealot to reach them, otherwise they will die almost immediately.
67 | Another example is _so_many_banelings_ where 7 allied Zealots face 32 enemy Baneling units. Banelings attack by running against a target and exploding when reaching them, causing damage to a certain area around the target. Hence, if a large number of Banelings attack a handful of Zealots located close to each other, the Zealots will be defeated instantly. The optimal strategy, therefore, is to cooperatively spread out around the map far from each other so that the Banelings' damage is distributed as thinly as possible.
68 | The _corridor_ scenario, in which 6 friendly Zealots face 24 enemy Zerglings, requires agents to make effective use of the terrain features. Specifically, agents should collectively wall off the choke point (the narrow region of the map) to block enemy attacks from different directions. Some of the micro-trick challenges are inspired by [StarCraft Master](http://us.battle.net/sc2/en/blog/4544189/new-blizzard-custom-game-starcraft-master-3-1-2012) challenge missions released by Blizzard.
69 |
70 | The complete list of challenges is presented bellow. The difficulty of the game AI is set to _very difficult_ (7). Our experiments, however, suggest that this setting does significantly impact the unit micromanagement of the built-in heuristics.
71 |
72 | | Name | Ally Units | Enemy Units | Type |
73 | | :---: | :---: | :---: | :---:|
74 | | 3m | 3 Marines | 3 Marines | homogeneous & symmetric |
75 | | 8m | 8 Marines | 8 Marines | homogeneous & symmetric |
76 | | 25m | 25 Marines | 25 Marines | homogeneous & symmetric |
77 | | 2s3z | 2 Stalkers & 3 Zealots | 2 Stalkers & 3 Zealots | heterogeneous & symmetric |
78 | | 3s5z | 3 Stalkers & 5 Zealots | 3 Stalkers & 5 Zealots | heterogeneous & symmetric |
79 | | MMM | 1 Medivac, 2 Marauders & 7 Marines | 1 Medivac, 2 Marauders & 7 Marines | heterogeneous & symmetric |
80 | | 5m_vs_6m | 5 Marines | 6 Marines | homogeneous & asymmetric |
81 | | 8m_vs_9m | 8 Marines | 9 Marines | homogeneous & asymmetric |
82 | | 10m_vs_11m | 10 Marines | 11 Marines | homogeneous & asymmetric |
83 | | 27m_vs_30m | 27 Marines | 30 Marines | homogeneous & asymmetric |
84 | | 3s5z_vs_3s6z | 3 Stalkers & 5 Zealots | 3 Stalkers & 6 Zealots | heterogeneous & asymmetric |
85 | | MMM2 | 1 Medivac, 2 Marauders & 7 Marines | 1 Medivac, 3 Marauders & 8 Marines | heterogeneous & asymmetric |
86 | | 2m_vs_1z | 2 Marines | 1 Zealot | micro-trick: alternating fire |
87 | | 2s_vs_1sc| 2 Stalkers | 1 Spine Crawler | micro-trick: alternating fire |
88 | | 3s_vs_3z | 3 Stalkers | 3 Zealots | micro-trick: kiting |
89 | | 3s_vs_4z | 3 Stalkers | 4 Zealots | micro-trick: kiting |
90 | | 3s_vs_5z | 3 Stalkers | 5 Zealots | micro-trick: kiting |
91 | | 6h_vs_8z | 6 Hydralisks | 8 Zealots | micro-trick: focus fire |
92 | | corridor | 6 Zealots | 24 Zerglings | micro-trick: wall off |
93 | | bane_vs_bane | 20 Zerglings & 4 Banelings | 20 Zerglings & 4 Banelings | micro-trick: positioning |
94 | | so_many_banelings| 7 Zealots | 32 Banelings | micro-trick: positioning |
95 | | 2c_vs_64zg| 2 Colossi | 64 Zerglings | micro-trick: positioning |
96 | | 1c3s5z | 1 Colossi & 3 Stalkers & 5 Zealots | 1 Colossi & 3 Stalkers & 5 Zealots | heterogeneous & symmetric |
97 |
98 | ### State and Observations
99 |
100 | At each timestep, agents receive local observations drawn within their field of view. This encompasses information about the map within a circular area around each unit and with a radius equal to the _sight range_. The sight range makes the environment partially observable from the standpoint of each agent. Agents can only observe other agents if they are both alive and located within the sight range. Hence, there is no way for agents to determine whether their teammates are far away or dead.
101 |
102 | The feature vector observed by each agent contains the following attributes for both allied and enemy units within the sight range: _distance_, _relative x_, _relative y_, _health_, _shield_, and _unit\_type_ [1](#myfootnote1). Shields serve as an additional source of protection that needs to be removed before any damage can be done to the health of units.
103 | All Protos units have shields, which can regenerate if no new damage is dealt
104 | (units of the other two races do not have this attribute).
105 | In addition, agents have access to the last actions of allied units that are in the field of view. Lastly, agents can observe the terrain features surrounding them; particularly, the values of eight points at a fixed radius indicating height and walkability.
106 |
107 | The global state, which is only available to agents during centralised training, contains information about all units on the map. Specifically, the state vector includes the coordinates of all agents relative to the centre of the map, together with unit features present in the observations. Additionally, the state stores the _energy_ of Medivacs and _cooldown_ of the rest of allied units, which represents the minimum delay between attacks. Finally, the last actions of all agents are attached to the central state.
108 |
109 | All features, both in the state as well as in the observations of individual agents, are normalised by their maximum values. The sight range is set to 9 for all agents.
110 |
111 | ### Action Space
112 |
113 | The discrete set of actions which agents are allowed to take consists of _move[direction]_ (four directions: north, south, east, or west), _attack[enemy_id]_, _stop_ and _no-op_. Dead agents can only take _no-op_ action while live agents cannot.
114 | As healer units, Medivacs must use _heal[agent\_id]_ actions instead of _attack[enemy\_id]_. The maximum number of actions an agent can take ranges between 7 and 70, depending on the scenario.
115 |
116 | To ensure decentralisation of the task, agents are restricted to use the _attack[enemy\_id]_ action only towards enemies in their _shooting range_.
117 | This additionally constrains the unit's ability to use the built-in _attack-move_ macro-actions on the enemies far away. We set the shooting range equal to 6. Having a larger sight than shooting range forces agents to make use of move commands before starting to fire.
118 |
119 | ### Rewards
120 |
121 | The overall goal is to have the highest win rate for each battle scenario.
122 | We provide a corresponding option for _sparse rewards_, which will cause the environment to return only a reward of +1 for winning and -1 for losing an episode.
123 | However, we also provide a default setting for a shaped reward signal calculated from the hit-point damage dealt and received by agents, some positive (negative) reward after having enemy (allied) units killed and/or a positive (negative) bonus for winning (losing) the battle.
124 | The exact values and scales of this shaped reward can be configured using a range of flags, but we strongly discourage disingenuous engineering of the reward function (e.g. tuning different reward functions for different scenarios).
125 |
126 | ### Environment Settings
127 |
128 | SMAC makes use of the [StarCraft II Learning Environment](https://arxiv.org/abs/1708.04782) (SC2LE) to communicate with the StarCraft II engine. SC2LE provides full control of the game by allowing to send commands and receive observations from the game. However, SMAC is conceptually different from the RL environment of SC2LE. The goal of SC2LE is to learn to play the full game of StarCraft II. This is a competitive task where a centralised RL agent receives RGB pixels as input and performs both macro and micro with the player-level control similar to human players. SMAC, on the other hand, represents a set of cooperative multi-agent micro challenges where each learning agent controls a single military unit.
129 |
130 | SMAC uses the _raw API_ of SC2LE. Raw API observations do not have any graphical component and include information about the units on the map such as health, location coordinates, etc. The raw API also allows sending action commands to individual units using their unit IDs. This setting differs from how humans play the actual game, but is convenient for designing decentralised multi-agent learning tasks.
131 |
132 | Since our micro-scenarios are shorter than actual StarCraft II games, restarting the game after each episode presents a computational bottleneck. To overcome this issue, we make use of the API's debug commands. Specifically, when all units of either army have been killed, we kill all remaining units by sending a debug action. Having no units left launches a trigger programmed with the StarCraft II Editor that re-spawns all units in their original location with full health, thereby restarting the scenario quickly and efficiently.
133 |
134 | Furthermore, to encourage agents to explore interesting micro-strategies themselves, we limit the influence of the StarCraft AI on our agents. Specifically we disable the automatic unit attack against enemies that attack agents or that are located nearby.
135 | To do so, we make use of new units created with the StarCraft II Editor that are exact copies of existing units with two attributes modified: _Combat: Default Acquire Level_ is set to _Passive_ (default _Offensive_) and _Behaviour: Response_ is set to _No Response_ (default _Acquire_). These fields are only modified for allied units; enemy units are unchanged.
136 |
137 | The sight and shooting range values might differ from the built-in _sight_ or _range_ attribute of some StarCraft II units. Our goal is not to master the original full StarCraft game, but rather to benchmark MARL methods for decentralised control.
138 |
139 | 1: _health_, _shield_ and _unit\_type_ of the unit the agent controls is also included in observations
140 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [tool.black]
2 | line-length = 79
3 | include = '\.pyi?$'
4 | exclude = '''
5 | /(
6 | \.git
7 | | \.hg
8 | | \.mypy_cache
9 | | \.tox
10 | | \.venv
11 | | _build
12 | | buck-out
13 | | build
14 | | dist
15 | )/
16 | '''
17 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | from setuptools import setup
6 |
7 | description = """SMAC - StarCraft Multi-Agent Challenge
8 |
9 | SMAC offers a diverse set of decentralised micromanagement challenges based on
10 | StarCraft II game. In these challenges, each of the units is controlled by an
11 | independent, learning agent that has to act based only on local observations,
12 | while the opponent's units are controlled by the built-in StarCraft II AI.
13 |
14 | The accompanying paper which outlines the motivation for using SMAC as well as
15 | results using the state-of-the-art deep multi-agent reinforcement learning
16 | algorithms can be found at https://www.arxiv.link
17 |
18 | Read the README at https://github.com/oxwhirl/smac for more information.
19 | """
20 |
21 | extras_deps = {
22 | "dev": [
23 | "pre-commit>=2.0.1",
24 | "black>=19.10b0",
25 | "flake8>=3.7",
26 | "flake8-bugbear>=20.1",
27 | ],
28 | }
29 |
30 |
31 | setup(
32 | name="SMAC",
33 | version="1.0.0",
34 | description="SMAC - StarCraft Multi-Agent Challenge.",
35 | long_description=description,
36 | author="WhiRL",
37 | author_email="mikayel@samvelyan.com",
38 | license="MIT License",
39 | keywords="StarCraft, Multi-Agent Reinforcement Learning",
40 | url="https://github.com/oxwhirl/smac",
41 | packages=[
42 | "smac",
43 | "smac.env",
44 | "smac.env.starcraft2",
45 | "smac.env.starcraft2.maps",
46 | "smac.env.pettingzoo",
47 | "smac.bin",
48 | "smac.examples",
49 | "smac.examples.rllib",
50 | "smac.examples.pettingzoo",
51 | ],
52 | extras_require=extras_deps,
53 | install_requires=[
54 | "protobuf<3.21",
55 | "pysc2>=3.0.0",
56 | "s2clientprotocol>=4.10.1.75800.0",
57 | "absl-py>=0.1.0",
58 | "numpy>=1.10",
59 | "pygame>=2.0.0",
60 | ],
61 | )
62 |
--------------------------------------------------------------------------------
/smac/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/oxwhirl/smac/d6aab33f76abc3849c50463a8592a84f59a5ef84/smac/__init__.py
--------------------------------------------------------------------------------
/smac/bin/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/oxwhirl/smac/d6aab33f76abc3849c50463a8592a84f59a5ef84/smac/bin/__init__.py
--------------------------------------------------------------------------------
/smac/bin/map_list.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | from smac.env.starcraft2.maps import smac_maps
6 |
7 | from pysc2 import maps as pysc2_maps
8 |
9 |
10 | def main():
11 | smac_map_registry = smac_maps.get_smac_map_registry()
12 | all_maps = pysc2_maps.get_maps()
13 | print("{:<15} {:7} {:7} {:7}".format("Name", "Agents", "Enemies", "Limit"))
14 | for map_name, map_params in smac_map_registry.items():
15 | map_class = all_maps[map_name]
16 | if map_class.path:
17 | print(
18 | "{:<15} {:<7} {:<7} {:<7}".format(
19 | map_name,
20 | map_params["n_agents"],
21 | map_params["n_enemies"],
22 | map_params["limit"],
23 | )
24 | )
25 |
26 |
27 | if __name__ == "__main__":
28 | main()
29 |
--------------------------------------------------------------------------------
/smac/env/__init__.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | from smac.env.multiagentenv import MultiAgentEnv
6 | from smac.env.starcraft2.starcraft2 import StarCraft2Env
7 |
8 | __all__ = ["MultiAgentEnv", "StarCraft2Env"]
9 |
--------------------------------------------------------------------------------
/smac/env/multiagentenv.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 |
6 | class MultiAgentEnv(object):
7 | def step(self, actions):
8 | """Returns reward, terminated, info."""
9 | raise NotImplementedError
10 |
11 | def get_obs(self):
12 | """Returns all agent observations in a list."""
13 | raise NotImplementedError
14 |
15 | def get_obs_agent(self, agent_id):
16 | """Returns observation for agent_id."""
17 | raise NotImplementedError
18 |
19 | def get_obs_size(self):
20 | """Returns the size of the observation."""
21 | raise NotImplementedError
22 |
23 | def get_state(self):
24 | """Returns the global state."""
25 | raise NotImplementedError
26 |
27 | def get_state_size(self):
28 | """Returns the size of the global state."""
29 | raise NotImplementedError
30 |
31 | def get_avail_actions(self):
32 | """Returns the available actions of all agents in a list."""
33 | raise NotImplementedError
34 |
35 | def get_avail_agent_actions(self, agent_id):
36 | """Returns the available actions for agent_id."""
37 | raise NotImplementedError
38 |
39 | def get_total_actions(self):
40 | """Returns the total number of actions an agent could ever take."""
41 | raise NotImplementedError
42 |
43 | def reset(self):
44 | """Returns initial observations and states."""
45 | raise NotImplementedError
46 |
47 | def render(self):
48 | raise NotImplementedError
49 |
50 | def close(self):
51 | raise NotImplementedError
52 |
53 | def seed(self):
54 | raise NotImplementedError
55 |
56 | def save_replay(self):
57 | """Save a replay."""
58 | raise NotImplementedError
59 |
60 | def get_env_info(self):
61 | env_info = {
62 | "state_shape": self.get_state_size(),
63 | "obs_shape": self.get_obs_size(),
64 | "n_actions": self.get_total_actions(),
65 | "n_agents": self.n_agents,
66 | "episode_limit": self.episode_limit,
67 | }
68 | return env_info
69 |
--------------------------------------------------------------------------------
/smac/env/pettingzoo/StarCraft2PZEnv.py:
--------------------------------------------------------------------------------
1 | from smac.env import StarCraft2Env
2 | from smac.env.starcraft2.maps import get_map_params
3 | from gymnasium.utils import EzPickle
4 | from gymnasium.utils import seeding
5 | from gymnasium import spaces
6 | from pettingzoo.utils.env import ParallelEnv
7 | from pettingzoo.utils.conversions import (
8 | parallel_to_aec as from_parallel_wrapper,
9 | )
10 | from pettingzoo.utils import wrappers
11 | import numpy as np
12 |
13 |
14 | def parallel_env(max_cycles=None, **smac_args):
15 | if max_cycles is None:
16 | map_name = smac_args.get("map_name", "8m")
17 | max_cycles = get_map_params(map_name)["limit"]
18 | return _parallel_env(max_cycles, **smac_args)
19 |
20 |
21 | def raw_env(max_cycles=None, **smac_args):
22 | return from_parallel_wrapper(parallel_env(max_cycles, **smac_args))
23 |
24 |
25 | def make_env(raw_env):
26 | def env_fn(**kwargs):
27 | env = raw_env(**kwargs)
28 | # env = wrappers.TerminateIllegalWrapper(env, illegal_reward=-1)
29 | env = wrappers.AssertOutOfBoundsWrapper(env)
30 | env = wrappers.OrderEnforcingWrapper(env)
31 | return env
32 |
33 | return env_fn
34 |
35 |
36 | class smac_parallel_env(ParallelEnv):
37 | def __init__(self, env, max_cycles):
38 | self.max_cycles = max_cycles
39 | self.env = env
40 | self.env.reset()
41 | self.reset_flag = 0
42 | self.agents, self.action_spaces = self._init_agents()
43 | self.possible_agents = self.agents[:]
44 |
45 | observation_size = env.get_obs_size()
46 | self.observation_spaces = {
47 | name: spaces.Dict(
48 | {
49 | "observation": spaces.Box(
50 | low=-1,
51 | high=1,
52 | shape=(observation_size,),
53 | dtype="float32",
54 | ),
55 | "action_mask": spaces.Box(
56 | low=0,
57 | high=1,
58 | shape=(self.action_spaces[name].n,),
59 | dtype=np.int8,
60 | ),
61 | }
62 | )
63 | for name in self.agents
64 | }
65 | state_size = env.get_state_size()
66 | self.state_space = spaces.Box(low=-1, high=1, shape=(state_size,), dtype="float32")
67 | self._reward = 0
68 |
69 | def observation_space(self, agent):
70 | return self.observation_spaces[agent]
71 |
72 | def action_space(self, agent):
73 | return self.action_spaces[agent]
74 |
75 | def _init_agents(self):
76 | last_type = ""
77 | agents = []
78 | action_spaces = {}
79 | self.agents_id = {}
80 | i = 0
81 | for agent_id, agent_info in self.env.agents.items():
82 | unit_action_space = spaces.Discrete(
83 | self.env.get_total_actions() - 1
84 | ) # no-op in dead units is not an action
85 | if agent_info.unit_type == self.env.marine_id:
86 | agent_type = "marine"
87 | elif agent_info.unit_type == self.env.marauder_id:
88 | agent_type = "marauder"
89 | elif agent_info.unit_type == self.env.medivac_id:
90 | agent_type = "medivac"
91 | elif agent_info.unit_type == self.env.hydralisk_id:
92 | agent_type = "hydralisk"
93 | elif agent_info.unit_type == self.env.zergling_id:
94 | agent_type = "zergling"
95 | elif agent_info.unit_type == self.env.baneling_id:
96 | agent_type = "baneling"
97 | elif agent_info.unit_type == self.env.stalker_id:
98 | agent_type = "stalker"
99 | elif agent_info.unit_type == self.env.colossus_id:
100 | agent_type = "colossus"
101 | elif agent_info.unit_type == self.env.zealot_id:
102 | agent_type = "zealot"
103 | else:
104 | raise AssertionError(f"agent type {agent_type} not supported")
105 |
106 | if agent_type == last_type:
107 | i += 1
108 | else:
109 | i = 0
110 |
111 | agents.append(f"{agent_type}_{i}")
112 | self.agents_id[agents[-1]] = agent_id
113 | action_spaces[agents[-1]] = unit_action_space
114 | last_type = agent_type
115 |
116 | return agents, action_spaces
117 |
118 | def seed(self, seed=None):
119 | if seed is None:
120 | self.env._seed = seeding.create_seed(seed, max_bytes=4)
121 | else:
122 | self.env._seed = seed
123 | self.env.full_restart()
124 |
125 | def render(self, mode="human"):
126 | self.env.render(mode)
127 |
128 | def close(self):
129 | self.env.close()
130 |
131 | def reset(self, seed=None, options=None):
132 | self.env._episode_count = 1
133 | self.env.reset()
134 |
135 | self.agents = self.possible_agents[:]
136 | self.frames = 0
137 | self.terminations = {agent: False for agent in self.possible_agents}
138 | self.truncations = {agent: False for agent in self.possible_agents}
139 | return self._observe_all()
140 |
141 | def get_agent_smac_id(self, agent):
142 | return self.agents_id[agent]
143 |
144 | def _all_rewards(self, reward):
145 | all_rewards = [reward] * len(self.agents)
146 | return {
147 | agent: reward for agent, reward in zip(self.agents, all_rewards)
148 | }
149 |
150 | def _observe_all(self):
151 | all_obs = []
152 | for agent in self.agents:
153 | agent_id = self.get_agent_smac_id(agent)
154 | obs = self.env.get_obs_agent(agent_id)
155 | action_mask = self.env.get_avail_agent_actions(agent_id)
156 | action_mask = action_mask[1:]
157 | action_mask = np.array(action_mask).astype(np.int8)
158 | obs = np.asarray(obs, dtype=np.float32)
159 | all_obs.append({"observation": obs, "action_mask": action_mask})
160 | return {agent: obs for agent, obs in zip(self.agents, all_obs)}
161 |
162 | def _all_terms_truncs(self, terminated=False, truncated=False):
163 | terminations = [True] * len(self.agents)
164 |
165 | if not terminated:
166 | for i, agent in enumerate(self.agents):
167 | agent_done = False
168 | agent_id = self.get_agent_smac_id(agent)
169 | agent_info = self.env.get_unit_by_id(agent_id)
170 | if agent_info.health == 0:
171 | agent_done = True
172 | terminations[i] = agent_done
173 |
174 | terminations = {a: bool(t) for a, t in zip(self.agents, terminations)}
175 | truncations = {a: truncated for a in self.agents}
176 |
177 | return terminations, truncations
178 |
179 | def step(self, all_actions):
180 | action_list = [0] * self.env.n_agents
181 | for agent in self.agents:
182 | agent_id = self.get_agent_smac_id(agent)
183 | if agent in all_actions:
184 | if all_actions[agent] is None:
185 | action_list[agent_id] = 0
186 | else:
187 | action_list[agent_id] = all_actions[agent] + 1
188 | self._reward, terminated, smac_info = self.env.step(action_list)
189 | self.frames += 1
190 |
191 | all_infos = {agent: {} for agent in self.agents}
192 | # all_infos.update(smac_info)
193 | all_terms, all_truncs = self._all_terms_truncs(
194 | terminated=terminated, truncated=(self.frames >= self.max_cycles)
195 | )
196 | all_rewards = self._all_rewards(self._reward)
197 | all_observes = self._observe_all()
198 |
199 | self.agents = [agent for agent in self.agents if not all_terms[agent]]
200 | self.agents = [agent for agent in self.agents if not all_truncs[agent]]
201 |
202 | return all_observes, all_rewards, all_terms, all_truncs, all_infos
203 |
204 | def state(self):
205 | return self.env.get_state()
206 |
207 | def __del__(self):
208 | self.env.close()
209 |
210 |
211 | env = make_env(raw_env)
212 |
213 |
214 | class _parallel_env(smac_parallel_env, EzPickle):
215 | metadata = {"render.modes": ["human"], "name": "sc2"}
216 |
217 | def __init__(self, max_cycles, **smac_args):
218 | EzPickle.__init__(self, max_cycles, **smac_args)
219 | env = StarCraft2Env(**smac_args)
220 | super().__init__(env, max_cycles)
221 |
--------------------------------------------------------------------------------
/smac/env/pettingzoo/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/oxwhirl/smac/d6aab33f76abc3849c50463a8592a84f59a5ef84/smac/env/pettingzoo/__init__.py
--------------------------------------------------------------------------------
/smac/env/pettingzoo/test/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/oxwhirl/smac/d6aab33f76abc3849c50463a8592a84f59a5ef84/smac/env/pettingzoo/test/__init__.py
--------------------------------------------------------------------------------
/smac/env/pettingzoo/test/all_test.py:
--------------------------------------------------------------------------------
1 | from smac.env.starcraft2.maps import smac_maps
2 | from pysc2 import maps as pysc2_maps
3 | from smac.env.pettingzoo import StarCraft2PZEnv as sc2
4 | import pytest
5 | from pettingzoo import test
6 | import pickle
7 |
8 | smac_map_registry = smac_maps.get_smac_map_registry()
9 | all_maps = pysc2_maps.get_maps()
10 | map_names = []
11 | for map_name in smac_map_registry.keys():
12 | map_class = all_maps[map_name]
13 | if map_class.path:
14 | map_names.append(map_name)
15 |
16 |
17 | @pytest.mark.parametrize(("map_name"), map_names)
18 | def test_env(map_name):
19 | env = sc2.env(map_name=map_name)
20 | test.api_test(env)
21 | # test.parallel_api_test(sc2_v0.parallel_env()) # does not pass it due to
22 | # illegal actions test.seed_test(sc2.env, 50) # not required, sc2 env only
23 | # allows reseeding at initialization
24 | test.render_test(env)
25 |
26 | recreated_env = pickle.loads(pickle.dumps(env))
27 | test.api_test(recreated_env)
28 |
--------------------------------------------------------------------------------
/smac/env/pettingzoo/test/smac_pettingzoo_test.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import inspect
4 | from pettingzoo import test
5 | from smac.env.pettingzoo import StarCraft2PZEnv as sc2
6 | import pickle
7 |
8 | current_dir = os.path.dirname(
9 | os.path.abspath(inspect.getfile(inspect.currentframe()))
10 | )
11 | parent_dir = os.path.dirname(current_dir)
12 | sys.path.insert(0, parent_dir)
13 |
14 |
15 | if __name__ == "__main__":
16 | env = sc2.env(map_name="corridor")
17 | test.api_test(env)
18 | # test.parallel_api_test(sc2_v0.parallel_env()) # does not pass it due to
19 | # illegal actions test.seed_test(sc2_v0.env, 50) # not required, sc2 env
20 | # only allows reseeding at initialization
21 |
22 | recreated_env = pickle.loads(pickle.dumps(env))
23 | test.api_test(recreated_env)
24 |
--------------------------------------------------------------------------------
/smac/env/starcraft2/__init__.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | from absl import flags
6 |
7 | FLAGS = flags.FLAGS
8 | FLAGS(["main.py"])
9 |
--------------------------------------------------------------------------------
/smac/env/starcraft2/maps/SMAC_Maps/10m_vs_11m.SC2Map:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/oxwhirl/smac/d6aab33f76abc3849c50463a8592a84f59a5ef84/smac/env/starcraft2/maps/SMAC_Maps/10m_vs_11m.SC2Map
--------------------------------------------------------------------------------
/smac/env/starcraft2/maps/SMAC_Maps/1c3s5z.SC2Map:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/oxwhirl/smac/d6aab33f76abc3849c50463a8592a84f59a5ef84/smac/env/starcraft2/maps/SMAC_Maps/1c3s5z.SC2Map
--------------------------------------------------------------------------------
/smac/env/starcraft2/maps/SMAC_Maps/25m.SC2Map:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/oxwhirl/smac/d6aab33f76abc3849c50463a8592a84f59a5ef84/smac/env/starcraft2/maps/SMAC_Maps/25m.SC2Map
--------------------------------------------------------------------------------
/smac/env/starcraft2/maps/SMAC_Maps/27m_vs_30m.SC2Map:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/oxwhirl/smac/d6aab33f76abc3849c50463a8592a84f59a5ef84/smac/env/starcraft2/maps/SMAC_Maps/27m_vs_30m.SC2Map
--------------------------------------------------------------------------------
/smac/env/starcraft2/maps/SMAC_Maps/2c_vs_64zg.SC2Map:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/oxwhirl/smac/d6aab33f76abc3849c50463a8592a84f59a5ef84/smac/env/starcraft2/maps/SMAC_Maps/2c_vs_64zg.SC2Map
--------------------------------------------------------------------------------
/smac/env/starcraft2/maps/SMAC_Maps/2m_vs_1z.SC2Map:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/oxwhirl/smac/d6aab33f76abc3849c50463a8592a84f59a5ef84/smac/env/starcraft2/maps/SMAC_Maps/2m_vs_1z.SC2Map
--------------------------------------------------------------------------------
/smac/env/starcraft2/maps/SMAC_Maps/2s3z.SC2Map:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/oxwhirl/smac/d6aab33f76abc3849c50463a8592a84f59a5ef84/smac/env/starcraft2/maps/SMAC_Maps/2s3z.SC2Map
--------------------------------------------------------------------------------
/smac/env/starcraft2/maps/SMAC_Maps/2s_vs_1sc.SC2Map:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/oxwhirl/smac/d6aab33f76abc3849c50463a8592a84f59a5ef84/smac/env/starcraft2/maps/SMAC_Maps/2s_vs_1sc.SC2Map
--------------------------------------------------------------------------------
/smac/env/starcraft2/maps/SMAC_Maps/3m.SC2Map:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/oxwhirl/smac/d6aab33f76abc3849c50463a8592a84f59a5ef84/smac/env/starcraft2/maps/SMAC_Maps/3m.SC2Map
--------------------------------------------------------------------------------
/smac/env/starcraft2/maps/SMAC_Maps/3s5z.SC2Map:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/oxwhirl/smac/d6aab33f76abc3849c50463a8592a84f59a5ef84/smac/env/starcraft2/maps/SMAC_Maps/3s5z.SC2Map
--------------------------------------------------------------------------------
/smac/env/starcraft2/maps/SMAC_Maps/3s5z_vs_3s6z.SC2Map:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/oxwhirl/smac/d6aab33f76abc3849c50463a8592a84f59a5ef84/smac/env/starcraft2/maps/SMAC_Maps/3s5z_vs_3s6z.SC2Map
--------------------------------------------------------------------------------
/smac/env/starcraft2/maps/SMAC_Maps/3s_vs_3z.SC2Map:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/oxwhirl/smac/d6aab33f76abc3849c50463a8592a84f59a5ef84/smac/env/starcraft2/maps/SMAC_Maps/3s_vs_3z.SC2Map
--------------------------------------------------------------------------------
/smac/env/starcraft2/maps/SMAC_Maps/3s_vs_4z.SC2Map:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/oxwhirl/smac/d6aab33f76abc3849c50463a8592a84f59a5ef84/smac/env/starcraft2/maps/SMAC_Maps/3s_vs_4z.SC2Map
--------------------------------------------------------------------------------
/smac/env/starcraft2/maps/SMAC_Maps/3s_vs_5z.SC2Map:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/oxwhirl/smac/d6aab33f76abc3849c50463a8592a84f59a5ef84/smac/env/starcraft2/maps/SMAC_Maps/3s_vs_5z.SC2Map
--------------------------------------------------------------------------------
/smac/env/starcraft2/maps/SMAC_Maps/5m_vs_6m.SC2Map:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/oxwhirl/smac/d6aab33f76abc3849c50463a8592a84f59a5ef84/smac/env/starcraft2/maps/SMAC_Maps/5m_vs_6m.SC2Map
--------------------------------------------------------------------------------
/smac/env/starcraft2/maps/SMAC_Maps/6h_vs_8z.SC2Map:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/oxwhirl/smac/d6aab33f76abc3849c50463a8592a84f59a5ef84/smac/env/starcraft2/maps/SMAC_Maps/6h_vs_8z.SC2Map
--------------------------------------------------------------------------------
/smac/env/starcraft2/maps/SMAC_Maps/8m.SC2Map:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/oxwhirl/smac/d6aab33f76abc3849c50463a8592a84f59a5ef84/smac/env/starcraft2/maps/SMAC_Maps/8m.SC2Map
--------------------------------------------------------------------------------
/smac/env/starcraft2/maps/SMAC_Maps/8m_vs_9m.SC2Map:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/oxwhirl/smac/d6aab33f76abc3849c50463a8592a84f59a5ef84/smac/env/starcraft2/maps/SMAC_Maps/8m_vs_9m.SC2Map
--------------------------------------------------------------------------------
/smac/env/starcraft2/maps/SMAC_Maps/MMM.SC2Map:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/oxwhirl/smac/d6aab33f76abc3849c50463a8592a84f59a5ef84/smac/env/starcraft2/maps/SMAC_Maps/MMM.SC2Map
--------------------------------------------------------------------------------
/smac/env/starcraft2/maps/SMAC_Maps/MMM2.SC2Map:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/oxwhirl/smac/d6aab33f76abc3849c50463a8592a84f59a5ef84/smac/env/starcraft2/maps/SMAC_Maps/MMM2.SC2Map
--------------------------------------------------------------------------------
/smac/env/starcraft2/maps/SMAC_Maps/bane_vs_bane.SC2Map:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/oxwhirl/smac/d6aab33f76abc3849c50463a8592a84f59a5ef84/smac/env/starcraft2/maps/SMAC_Maps/bane_vs_bane.SC2Map
--------------------------------------------------------------------------------
/smac/env/starcraft2/maps/SMAC_Maps/corridor.SC2Map:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/oxwhirl/smac/d6aab33f76abc3849c50463a8592a84f59a5ef84/smac/env/starcraft2/maps/SMAC_Maps/corridor.SC2Map
--------------------------------------------------------------------------------
/smac/env/starcraft2/maps/SMAC_Maps/so_many_baneling.SC2Map:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/oxwhirl/smac/d6aab33f76abc3849c50463a8592a84f59a5ef84/smac/env/starcraft2/maps/SMAC_Maps/so_many_baneling.SC2Map
--------------------------------------------------------------------------------
/smac/env/starcraft2/maps/__init__.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | from smac.env.starcraft2.maps import smac_maps
6 |
7 |
8 | def get_map_params(map_name):
9 | map_param_registry = smac_maps.get_smac_map_registry()
10 | return map_param_registry[map_name]
11 |
--------------------------------------------------------------------------------
/smac/env/starcraft2/maps/smac_maps.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | from pysc2.maps import lib
6 |
7 |
8 | class SMACMap(lib.Map):
9 | directory = "SMAC_Maps"
10 | download = "https://github.com/oxwhirl/smac#smac-maps"
11 | players = 2
12 | step_mul = 8
13 | game_steps_per_episode = 0
14 |
15 |
16 | map_param_registry = {
17 | "3m": {
18 | "n_agents": 3,
19 | "n_enemies": 3,
20 | "limit": 60,
21 | "a_race": "T",
22 | "b_race": "T",
23 | "unit_type_bits": 0,
24 | "map_type": "marines",
25 | },
26 | "8m": {
27 | "n_agents": 8,
28 | "n_enemies": 8,
29 | "limit": 120,
30 | "a_race": "T",
31 | "b_race": "T",
32 | "unit_type_bits": 0,
33 | "map_type": "marines",
34 | },
35 | "25m": {
36 | "n_agents": 25,
37 | "n_enemies": 25,
38 | "limit": 150,
39 | "a_race": "T",
40 | "b_race": "T",
41 | "unit_type_bits": 0,
42 | "map_type": "marines",
43 | },
44 | "5m_vs_6m": {
45 | "n_agents": 5,
46 | "n_enemies": 6,
47 | "limit": 70,
48 | "a_race": "T",
49 | "b_race": "T",
50 | "unit_type_bits": 0,
51 | "map_type": "marines",
52 | },
53 | "8m_vs_9m": {
54 | "n_agents": 8,
55 | "n_enemies": 9,
56 | "limit": 120,
57 | "a_race": "T",
58 | "b_race": "T",
59 | "unit_type_bits": 0,
60 | "map_type": "marines",
61 | },
62 | "10m_vs_11m": {
63 | "n_agents": 10,
64 | "n_enemies": 11,
65 | "limit": 150,
66 | "a_race": "T",
67 | "b_race": "T",
68 | "unit_type_bits": 0,
69 | "map_type": "marines",
70 | },
71 | "27m_vs_30m": {
72 | "n_agents": 27,
73 | "n_enemies": 30,
74 | "limit": 180,
75 | "a_race": "T",
76 | "b_race": "T",
77 | "unit_type_bits": 0,
78 | "map_type": "marines",
79 | },
80 | "MMM": {
81 | "n_agents": 10,
82 | "n_enemies": 10,
83 | "limit": 150,
84 | "a_race": "T",
85 | "b_race": "T",
86 | "unit_type_bits": 3,
87 | "map_type": "MMM",
88 | },
89 | "MMM2": {
90 | "n_agents": 10,
91 | "n_enemies": 12,
92 | "limit": 180,
93 | "a_race": "T",
94 | "b_race": "T",
95 | "unit_type_bits": 3,
96 | "map_type": "MMM",
97 | },
98 | "2s3z": {
99 | "n_agents": 5,
100 | "n_enemies": 5,
101 | "limit": 120,
102 | "a_race": "P",
103 | "b_race": "P",
104 | "unit_type_bits": 2,
105 | "map_type": "stalkers_and_zealots",
106 | },
107 | "3s5z": {
108 | "n_agents": 8,
109 | "n_enemies": 8,
110 | "limit": 150,
111 | "a_race": "P",
112 | "b_race": "P",
113 | "unit_type_bits": 2,
114 | "map_type": "stalkers_and_zealots",
115 | },
116 | "3s5z_vs_3s6z": {
117 | "n_agents": 8,
118 | "n_enemies": 9,
119 | "limit": 170,
120 | "a_race": "P",
121 | "b_race": "P",
122 | "unit_type_bits": 2,
123 | "map_type": "stalkers_and_zealots",
124 | },
125 | "3s_vs_3z": {
126 | "n_agents": 3,
127 | "n_enemies": 3,
128 | "limit": 150,
129 | "a_race": "P",
130 | "b_race": "P",
131 | "unit_type_bits": 0,
132 | "map_type": "stalkers",
133 | },
134 | "3s_vs_4z": {
135 | "n_agents": 3,
136 | "n_enemies": 4,
137 | "limit": 200,
138 | "a_race": "P",
139 | "b_race": "P",
140 | "unit_type_bits": 0,
141 | "map_type": "stalkers",
142 | },
143 | "3s_vs_5z": {
144 | "n_agents": 3,
145 | "n_enemies": 5,
146 | "limit": 250,
147 | "a_race": "P",
148 | "b_race": "P",
149 | "unit_type_bits": 0,
150 | "map_type": "stalkers",
151 | },
152 | "1c3s5z": {
153 | "n_agents": 9,
154 | "n_enemies": 9,
155 | "limit": 180,
156 | "a_race": "P",
157 | "b_race": "P",
158 | "unit_type_bits": 3,
159 | "map_type": "colossi_stalkers_zealots",
160 | },
161 | "2m_vs_1z": {
162 | "n_agents": 2,
163 | "n_enemies": 1,
164 | "limit": 150,
165 | "a_race": "T",
166 | "b_race": "P",
167 | "unit_type_bits": 0,
168 | "map_type": "marines",
169 | },
170 | "corridor": {
171 | "n_agents": 6,
172 | "n_enemies": 24,
173 | "limit": 400,
174 | "a_race": "P",
175 | "b_race": "Z",
176 | "unit_type_bits": 0,
177 | "map_type": "zealots",
178 | },
179 | "6h_vs_8z": {
180 | "n_agents": 6,
181 | "n_enemies": 8,
182 | "limit": 150,
183 | "a_race": "Z",
184 | "b_race": "P",
185 | "unit_type_bits": 0,
186 | "map_type": "hydralisks",
187 | },
188 | "2s_vs_1sc": {
189 | "n_agents": 2,
190 | "n_enemies": 1,
191 | "limit": 300,
192 | "a_race": "P",
193 | "b_race": "Z",
194 | "unit_type_bits": 0,
195 | "map_type": "stalkers",
196 | },
197 | "so_many_baneling": {
198 | "n_agents": 7,
199 | "n_enemies": 32,
200 | "limit": 100,
201 | "a_race": "P",
202 | "b_race": "Z",
203 | "unit_type_bits": 0,
204 | "map_type": "zealots",
205 | },
206 | "bane_vs_bane": {
207 | "n_agents": 24,
208 | "n_enemies": 24,
209 | "limit": 200,
210 | "a_race": "Z",
211 | "b_race": "Z",
212 | "unit_type_bits": 2,
213 | "map_type": "bane",
214 | },
215 | "2c_vs_64zg": {
216 | "n_agents": 2,
217 | "n_enemies": 64,
218 | "limit": 400,
219 | "a_race": "P",
220 | "b_race": "Z",
221 | "unit_type_bits": 0,
222 | "map_type": "colossus",
223 | },
224 | }
225 |
226 |
227 | def get_smac_map_registry():
228 | return map_param_registry
229 |
230 |
231 | for name in map_param_registry.keys():
232 | globals()[name] = type(name, (SMACMap,), dict(filename=name))
233 |
--------------------------------------------------------------------------------
/smac/env/starcraft2/render.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import re
3 | import subprocess
4 | import platform
5 | from absl import logging
6 | import math
7 | import time
8 | import collections
9 | import os
10 | import pygame
11 | import queue
12 |
13 | from pysc2.lib import colors
14 | from pysc2.lib import point
15 | from pysc2.lib.renderer_human import _Surface
16 | from pysc2.lib import transform
17 | from pysc2.lib import features
18 |
19 |
20 | def clamp(n, smallest, largest):
21 | return max(smallest, min(n, largest))
22 |
23 |
24 | def _get_desktop_size():
25 | """Get the desktop size."""
26 | if platform.system() == "Linux":
27 | try:
28 | xrandr_query = subprocess.check_output(["xrandr", "--query"])
29 | sizes = re.findall(
30 | r"\bconnected primary (\d+)x(\d+)", str(xrandr_query)
31 | )
32 | if sizes[0]:
33 | return point.Point(int(sizes[0][0]), int(sizes[0][1]))
34 | except ValueError:
35 | logging.error("Failed to get the resolution from xrandr.")
36 |
37 | # Most general, but doesn't understand multiple monitors.
38 | display_info = pygame.display.Info()
39 | return point.Point(display_info.current_w, display_info.current_h)
40 |
41 |
42 | class StarCraft2Renderer:
43 | def __init__(self, env, mode):
44 | os.environ["PYGAME_HIDE_SUPPORT_PROMPT"] = "hide"
45 |
46 | self.env = env
47 | self.mode = mode
48 | self.obs = None
49 | self._window_scale = 0.75
50 | self.game_info = game_info = self.env._controller.game_info()
51 | self.static_data = self.env._controller.data()
52 |
53 | self._obs_queue = queue.Queue()
54 | self._game_times = collections.deque(
55 | maxlen=100
56 | ) # Avg FPS over 100 frames. # pytype: disable=wrong-keyword-args
57 | self._render_times = collections.deque(
58 | maxlen=100
59 | ) # pytype: disable=wrong-keyword-args
60 | self._last_time = time.time()
61 | self._last_game_loop = 0
62 | self._name_lengths = {}
63 |
64 | self._map_size = point.Point.build(game_info.start_raw.map_size)
65 | self._playable = point.Rect(
66 | point.Point.build(game_info.start_raw.playable_area.p0),
67 | point.Point.build(game_info.start_raw.playable_area.p1),
68 | )
69 |
70 | window_size_px = point.Point(
71 | self.env.window_size[0], self.env.window_size[1]
72 | )
73 | window_size_px = self._map_size.scale_max_size(
74 | window_size_px * self._window_scale
75 | ).ceil()
76 | self._scale = window_size_px.y // 32
77 |
78 | self.display = pygame.Surface(window_size_px)
79 |
80 | if mode == "human":
81 | self.display = pygame.display.set_mode(window_size_px, 0, 32)
82 | pygame.display.init()
83 |
84 | pygame.display.set_caption("Starcraft Viewer")
85 | pygame.font.init()
86 | self._world_to_world_tl = transform.Linear(
87 | point.Point(1, -1), point.Point(0, self._map_size.y)
88 | )
89 | self._world_tl_to_screen = transform.Linear(scale=window_size_px / 32)
90 | self.screen_transform = transform.Chain(
91 | self._world_to_world_tl, self._world_tl_to_screen
92 | )
93 |
94 | surf_loc = point.Rect(point.origin, window_size_px)
95 | sub_surf = self.display.subsurface(
96 | pygame.Rect(surf_loc.tl, surf_loc.size)
97 | )
98 | self._surf = _Surface(
99 | sub_surf,
100 | None,
101 | surf_loc,
102 | self.screen_transform,
103 | None,
104 | self.draw_screen,
105 | )
106 |
107 | self._font_small = pygame.font.Font(None, int(self._scale * 0.5))
108 | self._font_large = pygame.font.Font(None, self._scale)
109 |
110 | def close(self):
111 | pygame.display.quit()
112 | pygame.quit()
113 |
114 | def _get_units(self):
115 | for u in sorted(
116 | self.obs.observation.raw_data.units,
117 | key=lambda u: (u.pos.z, u.owner != 16, -u.radius, u.tag),
118 | ):
119 | yield u, point.Point.build(u.pos)
120 |
121 | def get_unit_name(self, surf, name, radius):
122 | """Get a length limited unit name for drawing units."""
123 | key = (name, radius)
124 | if key not in self._name_lengths:
125 | max_len = surf.world_to_surf.fwd_dist(radius * 1.6)
126 | for i in range(len(name)):
127 | if self._font_small.size(name[: i + 1])[0] > max_len:
128 | self._name_lengths[key] = name[:i]
129 | break
130 | else:
131 | self._name_lengths[key] = name
132 | return self._name_lengths[key]
133 |
134 | def render(self, mode):
135 | self.obs = self.env._obs
136 | self.score = self.env.reward
137 | self.step = self.env._episode_steps
138 |
139 | now = time.time()
140 | self._game_times.append(
141 | (
142 | now - self._last_time,
143 | max(
144 | 1,
145 | self.obs.observation.game_loop
146 | - self.obs.observation.game_loop,
147 | ),
148 | )
149 | )
150 |
151 | if mode == "human":
152 | pygame.event.pump()
153 |
154 | self._surf.draw(self._surf)
155 |
156 | observation = np.array(pygame.surfarray.pixels3d(self.display))
157 |
158 | if mode == "human":
159 | pygame.display.flip()
160 |
161 | self._last_time = now
162 | self._last_game_loop = self.obs.observation.game_loop
163 | # self._obs_queue.put(self.obs)
164 | return (
165 | np.transpose(observation, axes=(1, 0, 2))
166 | if mode == "rgb_array"
167 | else None
168 | )
169 |
170 | def draw_base_map(self, surf):
171 | """Draw the base map."""
172 | hmap_feature = features.SCREEN_FEATURES.height_map
173 | hmap = self.env.terrain_height * 255
174 | hmap = hmap.astype(np.uint8)
175 | if (
176 | self.env.map_name == "corridor"
177 | or self.env.map_name == "so_many_baneling"
178 | or self.env.map_name == "2s_vs_1sc"
179 | ):
180 | hmap = np.flip(hmap)
181 | else:
182 | hmap = np.rot90(hmap, axes=(1, 0))
183 | if not hmap.any():
184 | hmap = hmap + 100 # pylint: disable=g-no-augmented-assignment
185 | hmap_color = hmap_feature.color(hmap)
186 | out = hmap_color * 0.6
187 |
188 | surf.blit_np_array(out)
189 |
190 | def draw_units(self, surf):
191 | """Draw the units."""
192 | unit_dict = None # Cache the units {tag: unit_proto} for orders.
193 | tau = 2 * math.pi
194 | for u, p in self._get_units():
195 | fraction_damage = clamp(
196 | (u.health_max - u.health) / (u.health_max or 1), 0, 1
197 | )
198 | surf.draw_circle(
199 | colors.PLAYER_ABSOLUTE_PALETTE[u.owner], p, u.radius
200 | )
201 |
202 | if fraction_damage > 0:
203 | surf.draw_circle(
204 | colors.PLAYER_ABSOLUTE_PALETTE[u.owner] // 2,
205 | p,
206 | u.radius * fraction_damage,
207 | )
208 | surf.draw_circle(colors.black, p, u.radius, thickness=1)
209 |
210 | if self.static_data.unit_stats[u.unit_type].movement_speed > 0:
211 | surf.draw_arc(
212 | colors.white,
213 | p,
214 | u.radius,
215 | u.facing - 0.1,
216 | u.facing + 0.1,
217 | thickness=1,
218 | )
219 |
220 | def draw_arc_ratio(
221 | color, world_loc, radius, start, end, thickness=1
222 | ):
223 | surf.draw_arc(
224 | color, world_loc, radius, start * tau, end * tau, thickness
225 | )
226 |
227 | if u.shield and u.shield_max:
228 | draw_arc_ratio(
229 | colors.blue, p, u.radius - 0.05, 0, u.shield / u.shield_max
230 | )
231 |
232 | if u.energy and u.energy_max:
233 | draw_arc_ratio(
234 | colors.purple * 0.9,
235 | p,
236 | u.radius - 0.1,
237 | 0,
238 | u.energy / u.energy_max,
239 | )
240 | elif u.orders and 0 < u.orders[0].progress < 1:
241 | draw_arc_ratio(
242 | colors.cyan, p, u.radius - 0.15, 0, u.orders[0].progress
243 | )
244 | if u.buff_duration_remain and u.buff_duration_max:
245 | draw_arc_ratio(
246 | colors.white,
247 | p,
248 | u.radius - 0.2,
249 | 0,
250 | u.buff_duration_remain / u.buff_duration_max,
251 | )
252 | if u.attack_upgrade_level:
253 | draw_arc_ratio(
254 | self.upgrade_colors[u.attack_upgrade_level],
255 | p,
256 | u.radius - 0.25,
257 | 0.18,
258 | 0.22,
259 | thickness=3,
260 | )
261 | if u.armor_upgrade_level:
262 | draw_arc_ratio(
263 | self.upgrade_colors[u.armor_upgrade_level],
264 | p,
265 | u.radius - 0.25,
266 | 0.23,
267 | 0.27,
268 | thickness=3,
269 | )
270 | if u.shield_upgrade_level:
271 | draw_arc_ratio(
272 | self.upgrade_colors[u.shield_upgrade_level],
273 | p,
274 | u.radius - 0.25,
275 | 0.28,
276 | 0.32,
277 | thickness=3,
278 | )
279 |
280 | def write_small(loc, s):
281 | surf.write_world(self._font_small, colors.white, loc, str(s))
282 |
283 | name = self.get_unit_name(
284 | surf,
285 | self.static_data.units.get(u.unit_type, ""),
286 | u.radius,
287 | )
288 |
289 | if name:
290 | write_small(p, name)
291 |
292 | start_point = p
293 | for o in u.orders:
294 | target_point = None
295 | if o.HasField("target_unit_tag"):
296 | if unit_dict is None:
297 | unit_dict = {
298 | t.tag: t
299 | for t in self.obs.observation.raw_data.units
300 | }
301 | target_unit = unit_dict.get(o.target_unit_tag)
302 | if target_unit:
303 | target_point = point.Point.build(target_unit.pos)
304 | if target_point:
305 | surf.draw_line(colors.cyan, start_point, target_point)
306 | start_point = target_point
307 | else:
308 | break
309 |
310 | def draw_overlay(self, surf):
311 | """Draw the overlay describing resources."""
312 | obs = self.obs.observation
313 | times, steps = zip(*self._game_times)
314 | sec = obs.game_loop // 22.4
315 | surf.write_screen(
316 | self._font_large,
317 | colors.green,
318 | (-0.2, 0.2),
319 | "Score: %s, Step: %s, %.1f/s, Time: %d:%02d"
320 | % (
321 | self.score,
322 | self.step,
323 | sum(steps) / (sum(times) or 1),
324 | sec // 60,
325 | sec % 60,
326 | ),
327 | align="right",
328 | )
329 | surf.write_screen(
330 | self._font_large,
331 | colors.green * 0.8,
332 | (-0.2, 1.2),
333 | "APM: %d, EPM: %d, FPS: O:%.1f, R:%.1f"
334 | % (
335 | obs.score.score_details.current_apm,
336 | obs.score.score_details.current_effective_apm,
337 | len(times) / (sum(times) or 1),
338 | len(self._render_times) / (sum(self._render_times) or 1),
339 | ),
340 | align="right",
341 | )
342 |
343 | def draw_screen(self, surf):
344 | """Draw the screen area."""
345 | self.draw_base_map(surf)
346 | self.draw_units(surf)
347 | self.draw_overlay(surf)
348 |
--------------------------------------------------------------------------------
/smac/env/starcraft2/starcraft2.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | from smac.env.multiagentenv import MultiAgentEnv
6 | from smac.env.starcraft2.maps import get_map_params
7 |
8 | import atexit
9 | from warnings import warn
10 | from operator import attrgetter
11 | from copy import deepcopy
12 | import numpy as np
13 | import enum
14 | import math
15 | from absl import logging
16 |
17 | from pysc2 import maps
18 | from pysc2 import run_configs
19 | from pysc2.lib import protocol
20 |
21 | from s2clientprotocol import common_pb2 as sc_common
22 | from s2clientprotocol import sc2api_pb2 as sc_pb
23 | from s2clientprotocol import raw_pb2 as r_pb
24 | from s2clientprotocol import debug_pb2 as d_pb
25 |
26 | races = {
27 | "R": sc_common.Random,
28 | "P": sc_common.Protoss,
29 | "T": sc_common.Terran,
30 | "Z": sc_common.Zerg,
31 | }
32 |
33 | difficulties = {
34 | "1": sc_pb.VeryEasy,
35 | "2": sc_pb.Easy,
36 | "3": sc_pb.Medium,
37 | "4": sc_pb.MediumHard,
38 | "5": sc_pb.Hard,
39 | "6": sc_pb.Harder,
40 | "7": sc_pb.VeryHard,
41 | "8": sc_pb.CheatVision,
42 | "9": sc_pb.CheatMoney,
43 | "A": sc_pb.CheatInsane,
44 | }
45 |
46 | actions = {
47 | "move": 16, # target: PointOrUnit
48 | "attack": 23, # target: PointOrUnit
49 | "stop": 4, # target: None
50 | "heal": 386, # Unit
51 | }
52 |
53 |
54 | class Direction(enum.IntEnum):
55 | NORTH = 0
56 | SOUTH = 1
57 | EAST = 2
58 | WEST = 3
59 |
60 |
61 | class StarCraft2Env(MultiAgentEnv):
62 | """The StarCraft II environment for decentralised multi-agent
63 | micromanagement scenarios.
64 | """
65 |
66 | def __init__(
67 | self,
68 | map_name="8m",
69 | step_mul=8,
70 | move_amount=2,
71 | difficulty="7",
72 | game_version=None,
73 | seed=None,
74 | continuing_episode=False,
75 | obs_all_health=True,
76 | obs_own_health=True,
77 | obs_last_action=False,
78 | obs_pathing_grid=False,
79 | obs_terrain_height=False,
80 | obs_instead_of_state=False,
81 | obs_timestep_number=False,
82 | state_last_action=True,
83 | state_timestep_number=False,
84 | reward_sparse=False,
85 | reward_only_positive=True,
86 | reward_death_value=10,
87 | reward_win=200,
88 | reward_defeat=0,
89 | reward_negative_scale=0.5,
90 | reward_scale=True,
91 | reward_scale_rate=20,
92 | replay_dir="",
93 | replay_prefix="",
94 | window_size_x=1920,
95 | window_size_y=1200,
96 | heuristic_ai=False,
97 | heuristic_rest=False,
98 | debug=False,
99 | ):
100 | """
101 | Create a StarCraftC2Env environment.
102 |
103 | Parameters
104 | ----------
105 | map_name : str, optional
106 | The name of the SC2 map to play (default is "8m"). The full list
107 | can be found by running bin/map_list.
108 | step_mul : int, optional
109 | How many game steps per agent step (default is 8). None
110 | indicates to use the default map step_mul.
111 | move_amount : float, optional
112 | How far away units are ordered to move per step (default is 2).
113 | difficulty : str, optional
114 | The difficulty of built-in computer AI bot (default is "7").
115 | game_version : str, optional
116 | StarCraft II game version (default is None). None indicates the
117 | latest version.
118 | seed : int, optional
119 | Random seed used during game initialisation. This allows to
120 | continuing_episode : bool, optional
121 | Whether to consider episodes continuing or finished after time
122 | limit is reached (default is False).
123 | obs_all_health : bool, optional
124 | Agents receive the health of all units (in the sight range) as part
125 | of observations (default is True).
126 | obs_own_health : bool, optional
127 | Agents receive their own health as a part of observations (default
128 | is False). This flag is ignored when obs_all_health == True.
129 | obs_last_action : bool, optional
130 | Agents receive the last actions of all units (in the sight range)
131 | as part of observations (default is False).
132 | obs_pathing_grid : bool, optional
133 | Whether observations include pathing values surrounding the agent
134 | (default is False).
135 | obs_terrain_height : bool, optional
136 | Whether observations include terrain height values surrounding the
137 | agent (default is False).
138 | obs_instead_of_state : bool, optional
139 | Use combination of all agents' observations as the global state
140 | (default is False).
141 | obs_timestep_number : bool, optional
142 | Whether observations include the current timestep of the episode
143 | (default is False).
144 | state_last_action : bool, optional
145 | Include the last actions of all agents as part of the global state
146 | (default is True).
147 | state_timestep_number : bool, optional
148 | Whether the state include the current timestep of the episode
149 | (default is False).
150 | reward_sparse : bool, optional
151 | Receive 1/-1 reward for winning/loosing an episode (default is
152 | False). Whe rest of reward parameters are ignored if True.
153 | reward_only_positive : bool, optional
154 | Reward is always positive (default is True).
155 | reward_death_value : float, optional
156 | The amount of reward received for killing an enemy unit (default
157 | is 10). This is also the negative penalty for having an allied unit
158 | killed if reward_only_positive == False.
159 | reward_win : float, optional
160 | The reward for winning in an episode (default is 200).
161 | reward_defeat : float, optional
162 | The reward for loosing in an episode (default is 0). This value
163 | should be nonpositive.
164 | reward_negative_scale : float, optional
165 | Scaling factor for negative rewards (default is 0.5). This
166 | parameter is ignored when reward_only_positive == True.
167 | reward_scale : bool, optional
168 | Whether or not to scale the reward (default is True).
169 | reward_scale_rate : float, optional
170 | Reward scale rate (default is 20). When reward_scale == True, the
171 | reward received by the agents is divided by (max_reward /
172 | reward_scale_rate), where max_reward is the maximum possible
173 | reward per episode without considering the shield regeneration
174 | of Protoss units.
175 | replay_dir : str, optional
176 | The directory to save replays (default is None). If None, the
177 | replay will be saved in Replays directory where StarCraft II is
178 | installed.
179 | replay_prefix : str, optional
180 | The prefix of the replay to be saved (default is None). If None,
181 | the name of the map will be used.
182 | window_size_x : int, optional
183 | The length of StarCraft II window size (default is 1920).
184 | window_size_y: int, optional
185 | The height of StarCraft II window size (default is 1200).
186 | heuristic_ai: bool, optional
187 | Whether or not to use a non-learning heuristic AI (default False).
188 | heuristic_rest: bool, optional
189 | At any moment, restrict the actions of the heuristic AI to be
190 | chosen from actions available to RL agents (default is False).
191 | Ignored if heuristic_ai == False.
192 | debug: bool, optional
193 | Log messages about observations, state, actions and rewards for
194 | debugging purposes (default is False).
195 | """
196 | # Map arguments
197 | self.map_name = map_name
198 | map_params = get_map_params(self.map_name)
199 | self.n_agents = map_params["n_agents"]
200 | self.n_enemies = map_params["n_enemies"]
201 | self.episode_limit = map_params["limit"]
202 | self._move_amount = move_amount
203 | self._step_mul = step_mul
204 | self.difficulty = difficulty
205 |
206 | # Observations and state
207 | self.obs_own_health = obs_own_health
208 | self.obs_all_health = obs_all_health
209 | self.obs_instead_of_state = obs_instead_of_state
210 | self.obs_last_action = obs_last_action
211 | self.obs_pathing_grid = obs_pathing_grid
212 | self.obs_terrain_height = obs_terrain_height
213 | self.obs_timestep_number = obs_timestep_number
214 | self.state_last_action = state_last_action
215 | self.state_timestep_number = state_timestep_number
216 | if self.obs_all_health:
217 | self.obs_own_health = True
218 | self.n_obs_pathing = 8
219 | self.n_obs_height = 9
220 |
221 | # Rewards args
222 | self.reward_sparse = reward_sparse
223 | self.reward_only_positive = reward_only_positive
224 | self.reward_negative_scale = reward_negative_scale
225 | self.reward_death_value = reward_death_value
226 | self.reward_win = reward_win
227 | self.reward_defeat = reward_defeat
228 | self.reward_scale = reward_scale
229 | self.reward_scale_rate = reward_scale_rate
230 |
231 | # Other
232 | self.game_version = game_version
233 | self.continuing_episode = continuing_episode
234 | self._seed = seed
235 | self.heuristic_ai = heuristic_ai
236 | self.heuristic_rest = heuristic_rest
237 | self.debug = debug
238 | self.window_size = (window_size_x, window_size_y)
239 | self.replay_dir = replay_dir
240 | self.replay_prefix = replay_prefix
241 |
242 | # Actions
243 | self.n_actions_no_attack = 6
244 | self.n_actions_move = 4
245 | self.n_actions = self.n_actions_no_attack + self.n_enemies
246 |
247 | # Map info
248 | self._agent_race = map_params["a_race"]
249 | self._bot_race = map_params["b_race"]
250 | self.shield_bits_ally = 1 if self._agent_race == "P" else 0
251 | self.shield_bits_enemy = 1 if self._bot_race == "P" else 0
252 | self.unit_type_bits = map_params["unit_type_bits"]
253 | self.map_type = map_params["map_type"]
254 | self._unit_types = None
255 |
256 | self.max_reward = (
257 | self.n_enemies * self.reward_death_value + self.reward_win
258 | )
259 |
260 | # create lists containing the names of attributes returned in states
261 | self.ally_state_attr_names = [
262 | "health",
263 | "energy/cooldown",
264 | "rel_x",
265 | "rel_y",
266 | ]
267 | self.enemy_state_attr_names = ["health", "rel_x", "rel_y"]
268 |
269 | if self.shield_bits_ally > 0:
270 | self.ally_state_attr_names += ["shield"]
271 | if self.shield_bits_enemy > 0:
272 | self.enemy_state_attr_names += ["shield"]
273 |
274 | if self.unit_type_bits > 0:
275 | bit_attr_names = [
276 | "type_{}".format(bit) for bit in range(self.unit_type_bits)
277 | ]
278 | self.ally_state_attr_names += bit_attr_names
279 | self.enemy_state_attr_names += bit_attr_names
280 |
281 | self.agents = {}
282 | self.enemies = {}
283 | self._episode_count = 0
284 | self._episode_steps = 0
285 | self._total_steps = 0
286 | self._obs = None
287 | self.battles_won = 0
288 | self.battles_game = 0
289 | self.timeouts = 0
290 | self.force_restarts = 0
291 | self.last_stats = None
292 | self.death_tracker_ally = np.zeros(self.n_agents)
293 | self.death_tracker_enemy = np.zeros(self.n_enemies)
294 | self.previous_ally_units = None
295 | self.previous_enemy_units = None
296 | self.last_action = np.zeros((self.n_agents, self.n_actions))
297 | self._min_unit_type = 0
298 | self.marine_id = self.marauder_id = self.medivac_id = 0
299 | self.hydralisk_id = self.zergling_id = self.baneling_id = 0
300 | self.stalker_id = self.colossus_id = self.zealot_id = 0
301 | self.max_distance_x = 0
302 | self.max_distance_y = 0
303 | self.map_x = 0
304 | self.map_y = 0
305 | self.reward = 0
306 | self.renderer = None
307 | self.terrain_height = None
308 | self.pathing_grid = None
309 | self._run_config = None
310 | self._sc2_proc = None
311 | self._controller = None
312 |
313 | # Try to avoid leaking SC2 processes on shutdown
314 | atexit.register(lambda: self.close())
315 |
316 | def _launch(self):
317 | """Launch the StarCraft II game."""
318 | self._run_config = run_configs.get(version=self.game_version)
319 | _map = maps.get(self.map_name)
320 |
321 | # Setting up the interface
322 | interface_options = sc_pb.InterfaceOptions(raw=True, score=False)
323 | self._sc2_proc = self._run_config.start(
324 | window_size=self.window_size, want_rgb=False
325 | )
326 | self._controller = self._sc2_proc.controller
327 |
328 | # Request to create the game
329 | create = sc_pb.RequestCreateGame(
330 | local_map=sc_pb.LocalMap(
331 | map_path=_map.path,
332 | map_data=self._run_config.map_data(_map.path),
333 | ),
334 | realtime=False,
335 | random_seed=self._seed,
336 | )
337 | create.player_setup.add(type=sc_pb.Participant)
338 | create.player_setup.add(
339 | type=sc_pb.Computer,
340 | race=races[self._bot_race],
341 | difficulty=difficulties[self.difficulty],
342 | )
343 | self._controller.create_game(create)
344 |
345 | join = sc_pb.RequestJoinGame(
346 | race=races[self._agent_race], options=interface_options
347 | )
348 | self._controller.join_game(join)
349 |
350 | game_info = self._controller.game_info()
351 | map_info = game_info.start_raw
352 | map_play_area_min = map_info.playable_area.p0
353 | map_play_area_max = map_info.playable_area.p1
354 | self.max_distance_x = map_play_area_max.x - map_play_area_min.x
355 | self.max_distance_y = map_play_area_max.y - map_play_area_min.y
356 | self.map_x = map_info.map_size.x
357 | self.map_y = map_info.map_size.y
358 |
359 | if map_info.pathing_grid.bits_per_pixel == 1:
360 | vals = np.array(list(map_info.pathing_grid.data)).reshape(
361 | self.map_x, int(self.map_y / 8)
362 | )
363 | self.pathing_grid = np.transpose(
364 | np.array(
365 | [
366 | [(b >> i) & 1 for b in row for i in range(7, -1, -1)]
367 | for row in vals
368 | ],
369 | dtype=bool,
370 | )
371 | )
372 | else:
373 | self.pathing_grid = np.invert(
374 | np.flip(
375 | np.transpose(
376 | np.array(
377 | list(map_info.pathing_grid.data), dtype=bool
378 | ).reshape(self.map_x, self.map_y)
379 | ),
380 | axis=1,
381 | )
382 | )
383 |
384 | self.terrain_height = (
385 | np.flip(
386 | np.transpose(
387 | np.array(list(map_info.terrain_height.data)).reshape(
388 | self.map_x, self.map_y
389 | )
390 | ),
391 | 1,
392 | )
393 | / 255
394 | )
395 |
396 | def reset(self):
397 | """Reset the environment. Required after each full episode.
398 | Returns initial observations and states.
399 | """
400 | self._episode_steps = 0
401 | if self._episode_count == 0:
402 | # Launch StarCraft II
403 | self._launch()
404 | else:
405 | self._restart()
406 |
407 | # Information kept for counting the reward
408 | self.death_tracker_ally = np.zeros(self.n_agents)
409 | self.death_tracker_enemy = np.zeros(self.n_enemies)
410 | self.previous_ally_units = None
411 | self.previous_enemy_units = None
412 | self.win_counted = False
413 | self.defeat_counted = False
414 |
415 | self.last_action = np.zeros((self.n_agents, self.n_actions))
416 |
417 | if self.heuristic_ai:
418 | self.heuristic_targets = [None] * self.n_agents
419 |
420 | try:
421 | self._obs = self._controller.observe()
422 | self.init_units()
423 | except (protocol.ProtocolError, protocol.ConnectionError):
424 | self.full_restart()
425 |
426 | if self.debug:
427 | logging.debug(
428 | "Started Episode {}".format(self._episode_count).center(
429 | 60, "*"
430 | )
431 | )
432 |
433 | return self.get_obs(), self.get_state()
434 |
435 | def _restart(self):
436 | """Restart the environment by killing all units on the map.
437 | There is a trigger in the SC2Map file, which restarts the
438 | episode when there are no units left.
439 | """
440 | try:
441 | self._kill_all_units()
442 | self._controller.step(2)
443 | except (protocol.ProtocolError, protocol.ConnectionError):
444 | self.full_restart()
445 |
446 | def full_restart(self):
447 | """Full restart. Closes the SC2 process and launches a new one."""
448 | self._sc2_proc.close()
449 | self._launch()
450 | self.force_restarts += 1
451 |
452 | def step(self, actions):
453 | """A single environment step. Returns reward, terminated, info."""
454 | actions_int = [int(a) for a in actions]
455 |
456 | self.last_action = np.eye(self.n_actions)[np.array(actions_int)]
457 |
458 | # Collect individual actions
459 | sc_actions = []
460 | if self.debug:
461 | logging.debug("Actions".center(60, "-"))
462 |
463 | for a_id, action in enumerate(actions_int):
464 | if not self.heuristic_ai:
465 | sc_action = self.get_agent_action(a_id, action)
466 | else:
467 | sc_action, action_num = self.get_agent_action_heuristic(
468 | a_id, action
469 | )
470 | actions[a_id] = action_num
471 | if sc_action:
472 | sc_actions.append(sc_action)
473 |
474 | # Send action request
475 | req_actions = sc_pb.RequestAction(actions=sc_actions)
476 | try:
477 | self._controller.actions(req_actions)
478 | # Make step in SC2, i.e. apply actions
479 | self._controller.step(self._step_mul)
480 | # Observe here so that we know if the episode is over.
481 | self._obs = self._controller.observe()
482 | except (protocol.ProtocolError, protocol.ConnectionError):
483 | self.full_restart()
484 | return 0, True, {}
485 |
486 | self._total_steps += 1
487 | self._episode_steps += 1
488 |
489 | # Update units
490 | game_end_code = self.update_units()
491 |
492 | terminated = False
493 | reward = self.reward_battle()
494 | info = {"battle_won": False}
495 |
496 | # count units that are still alive
497 | dead_allies, dead_enemies = 0, 0
498 | for _al_id, al_unit in self.agents.items():
499 | if al_unit.health == 0:
500 | dead_allies += 1
501 | for _e_id, e_unit in self.enemies.items():
502 | if e_unit.health == 0:
503 | dead_enemies += 1
504 |
505 | info["dead_allies"] = dead_allies
506 | info["dead_enemies"] = dead_enemies
507 |
508 | if game_end_code is not None:
509 | # Battle is over
510 | terminated = True
511 | self.battles_game += 1
512 | if game_end_code == 1 and not self.win_counted:
513 | self.battles_won += 1
514 | self.win_counted = True
515 | info["battle_won"] = True
516 | if not self.reward_sparse:
517 | reward += self.reward_win
518 | else:
519 | reward = 1
520 | elif game_end_code == -1 and not self.defeat_counted:
521 | self.defeat_counted = True
522 | if not self.reward_sparse:
523 | reward += self.reward_defeat
524 | else:
525 | reward = -1
526 |
527 | elif self._episode_steps >= self.episode_limit:
528 | # Episode limit reached
529 | terminated = True
530 | if self.continuing_episode:
531 | info["episode_limit"] = True
532 | self.battles_game += 1
533 | self.timeouts += 1
534 |
535 | if self.debug:
536 | logging.debug("Reward = {}".format(reward).center(60, "-"))
537 |
538 | if terminated:
539 | self._episode_count += 1
540 |
541 | if self.reward_scale:
542 | reward /= self.max_reward / self.reward_scale_rate
543 |
544 | self.reward = reward
545 |
546 | return reward, terminated, info
547 |
548 | def get_agent_action(self, a_id, action):
549 | """Construct the action for agent a_id."""
550 | avail_actions = self.get_avail_agent_actions(a_id)
551 | assert (
552 | avail_actions[action] == 1
553 | ), "Agent {} cannot perform action {}".format(a_id, action)
554 |
555 | unit = self.get_unit_by_id(a_id)
556 | tag = unit.tag
557 | x = unit.pos.x
558 | y = unit.pos.y
559 |
560 | if action == 0:
561 | # no-op (valid only when dead)
562 | assert unit.health == 0, "No-op only available for dead agents."
563 | if self.debug:
564 | logging.debug("Agent {}: Dead".format(a_id))
565 | return None
566 | elif action == 1:
567 | # stop
568 | cmd = r_pb.ActionRawUnitCommand(
569 | ability_id=actions["stop"],
570 | unit_tags=[tag],
571 | queue_command=False,
572 | )
573 | if self.debug:
574 | logging.debug("Agent {}: Stop".format(a_id))
575 |
576 | elif action == 2:
577 | # move north
578 | cmd = r_pb.ActionRawUnitCommand(
579 | ability_id=actions["move"],
580 | target_world_space_pos=sc_common.Point2D(
581 | x=x, y=y + self._move_amount
582 | ),
583 | unit_tags=[tag],
584 | queue_command=False,
585 | )
586 | if self.debug:
587 | logging.debug("Agent {}: Move North".format(a_id))
588 |
589 | elif action == 3:
590 | # move south
591 | cmd = r_pb.ActionRawUnitCommand(
592 | ability_id=actions["move"],
593 | target_world_space_pos=sc_common.Point2D(
594 | x=x, y=y - self._move_amount
595 | ),
596 | unit_tags=[tag],
597 | queue_command=False,
598 | )
599 | if self.debug:
600 | logging.debug("Agent {}: Move South".format(a_id))
601 |
602 | elif action == 4:
603 | # move east
604 | cmd = r_pb.ActionRawUnitCommand(
605 | ability_id=actions["move"],
606 | target_world_space_pos=sc_common.Point2D(
607 | x=x + self._move_amount, y=y
608 | ),
609 | unit_tags=[tag],
610 | queue_command=False,
611 | )
612 | if self.debug:
613 | logging.debug("Agent {}: Move East".format(a_id))
614 |
615 | elif action == 5:
616 | # move west
617 | cmd = r_pb.ActionRawUnitCommand(
618 | ability_id=actions["move"],
619 | target_world_space_pos=sc_common.Point2D(
620 | x=x - self._move_amount, y=y
621 | ),
622 | unit_tags=[tag],
623 | queue_command=False,
624 | )
625 | if self.debug:
626 | logging.debug("Agent {}: Move West".format(a_id))
627 | else:
628 | # attack/heal units that are in range
629 | target_id = action - self.n_actions_no_attack
630 | if self.map_type == "MMM" and unit.unit_type == self.medivac_id:
631 | target_unit = self.agents[target_id]
632 | action_name = "heal"
633 | else:
634 | target_unit = self.enemies[target_id]
635 | action_name = "attack"
636 |
637 | action_id = actions[action_name]
638 | target_tag = target_unit.tag
639 |
640 | cmd = r_pb.ActionRawUnitCommand(
641 | ability_id=action_id,
642 | target_unit_tag=target_tag,
643 | unit_tags=[tag],
644 | queue_command=False,
645 | )
646 |
647 | if self.debug:
648 | logging.debug(
649 | "Agent {} {}s unit # {}".format(
650 | a_id, action_name, target_id
651 | )
652 | )
653 |
654 | sc_action = sc_pb.Action(action_raw=r_pb.ActionRaw(unit_command=cmd))
655 | return sc_action
656 |
657 | def get_agent_action_heuristic(self, a_id, action):
658 | unit = self.get_unit_by_id(a_id)
659 | tag = unit.tag
660 |
661 | target = self.heuristic_targets[a_id]
662 | if unit.unit_type == self.medivac_id:
663 | if (
664 | target is None
665 | or self.agents[target].health == 0
666 | or self.agents[target].health == self.agents[target].health_max
667 | ):
668 | min_dist = math.hypot(self.max_distance_x, self.max_distance_y)
669 | min_id = -1
670 | for al_id, al_unit in self.agents.items():
671 | if al_unit.unit_type == self.medivac_id:
672 | continue
673 | if (
674 | al_unit.health != 0
675 | and al_unit.health != al_unit.health_max
676 | ):
677 | dist = self.distance(
678 | unit.pos.x,
679 | unit.pos.y,
680 | al_unit.pos.x,
681 | al_unit.pos.y,
682 | )
683 | if dist < min_dist:
684 | min_dist = dist
685 | min_id = al_id
686 | self.heuristic_targets[a_id] = min_id
687 | if min_id == -1:
688 | self.heuristic_targets[a_id] = None
689 | return None, 0
690 | action_id = actions["heal"]
691 | target_tag = self.agents[self.heuristic_targets[a_id]].tag
692 | else:
693 | if target is None or self.enemies[target].health == 0:
694 | min_dist = math.hypot(self.max_distance_x, self.max_distance_y)
695 | min_id = -1
696 | for e_id, e_unit in self.enemies.items():
697 | if (
698 | unit.unit_type == self.marauder_id
699 | and e_unit.unit_type == self.medivac_id
700 | ):
701 | continue
702 | if e_unit.health > 0:
703 | dist = self.distance(
704 | unit.pos.x, unit.pos.y, e_unit.pos.x, e_unit.pos.y
705 | )
706 | if dist < min_dist:
707 | min_dist = dist
708 | min_id = e_id
709 | self.heuristic_targets[a_id] = min_id
710 | if min_id == -1:
711 | self.heuristic_targets[a_id] = None
712 | return None, 0
713 | action_id = actions["attack"]
714 | target_tag = self.enemies[self.heuristic_targets[a_id]].tag
715 |
716 | action_num = self.heuristic_targets[a_id] + self.n_actions_no_attack
717 |
718 | # Check if the action is available
719 | if (
720 | self.heuristic_rest
721 | and self.get_avail_agent_actions(a_id)[action_num] == 0
722 | ):
723 |
724 | # Move towards the target rather than attacking/healing
725 | if unit.unit_type == self.medivac_id:
726 | target_unit = self.agents[self.heuristic_targets[a_id]]
727 | else:
728 | target_unit = self.enemies[self.heuristic_targets[a_id]]
729 |
730 | delta_x = target_unit.pos.x - unit.pos.x
731 | delta_y = target_unit.pos.y - unit.pos.y
732 |
733 | if abs(delta_x) > abs(delta_y): # east or west
734 | if delta_x > 0: # east
735 | target_pos = sc_common.Point2D(
736 | x=unit.pos.x + self._move_amount, y=unit.pos.y
737 | )
738 | action_num = 4
739 | else: # west
740 | target_pos = sc_common.Point2D(
741 | x=unit.pos.x - self._move_amount, y=unit.pos.y
742 | )
743 | action_num = 5
744 | else: # north or south
745 | if delta_y > 0: # north
746 | target_pos = sc_common.Point2D(
747 | x=unit.pos.x, y=unit.pos.y + self._move_amount
748 | )
749 | action_num = 2
750 | else: # south
751 | target_pos = sc_common.Point2D(
752 | x=unit.pos.x, y=unit.pos.y - self._move_amount
753 | )
754 | action_num = 3
755 |
756 | cmd = r_pb.ActionRawUnitCommand(
757 | ability_id=actions["move"],
758 | target_world_space_pos=target_pos,
759 | unit_tags=[tag],
760 | queue_command=False,
761 | )
762 | else:
763 | # Attack/heal the target
764 | cmd = r_pb.ActionRawUnitCommand(
765 | ability_id=action_id,
766 | target_unit_tag=target_tag,
767 | unit_tags=[tag],
768 | queue_command=False,
769 | )
770 |
771 | sc_action = sc_pb.Action(action_raw=r_pb.ActionRaw(unit_command=cmd))
772 | return sc_action, action_num
773 |
774 | def reward_battle(self):
775 | """Reward function when self.reward_spare==False.
776 | Returns accumulative hit/shield point damage dealt to the enemy
777 | + reward_death_value per enemy unit killed, and, in case
778 | self.reward_only_positive == False, - (damage dealt to ally units
779 | + reward_death_value per ally unit killed) * self.reward_negative_scale
780 | """
781 | if self.reward_sparse:
782 | return 0
783 |
784 | reward = 0
785 | delta_deaths = 0
786 | delta_ally = 0
787 | delta_enemy = 0
788 |
789 | neg_scale = self.reward_negative_scale
790 |
791 | # update deaths
792 | for al_id, al_unit in self.agents.items():
793 | if not self.death_tracker_ally[al_id]:
794 | # did not die so far
795 | prev_health = (
796 | self.previous_ally_units[al_id].health
797 | + self.previous_ally_units[al_id].shield
798 | )
799 | if al_unit.health == 0:
800 | # just died
801 | self.death_tracker_ally[al_id] = 1
802 | if not self.reward_only_positive:
803 | delta_deaths -= self.reward_death_value * neg_scale
804 | delta_ally += prev_health * neg_scale
805 | else:
806 | # still alive
807 | delta_ally += neg_scale * (
808 | prev_health - al_unit.health - al_unit.shield
809 | )
810 |
811 | for e_id, e_unit in self.enemies.items():
812 | if not self.death_tracker_enemy[e_id]:
813 | prev_health = (
814 | self.previous_enemy_units[e_id].health
815 | + self.previous_enemy_units[e_id].shield
816 | )
817 | if e_unit.health == 0:
818 | self.death_tracker_enemy[e_id] = 1
819 | delta_deaths += self.reward_death_value
820 | delta_enemy += prev_health
821 | else:
822 | delta_enemy += prev_health - e_unit.health - e_unit.shield
823 |
824 | if self.reward_only_positive:
825 | reward = abs(delta_enemy + delta_deaths) # shield regeneration
826 | else:
827 | reward = delta_enemy + delta_deaths - delta_ally
828 |
829 | return reward
830 |
831 | def get_total_actions(self):
832 | """Returns the total number of actions an agent could ever take."""
833 | return self.n_actions
834 |
835 | @staticmethod
836 | def distance(x1, y1, x2, y2):
837 | """Distance between two points."""
838 | return math.hypot(x2 - x1, y2 - y1)
839 |
840 | def unit_shoot_range(self, agent_id):
841 | """Returns the shooting range for an agent."""
842 | return 6
843 |
844 | def unit_sight_range(self, agent_id):
845 | """Returns the sight range for an agent."""
846 | return 9
847 |
848 | def unit_max_cooldown(self, unit):
849 | """Returns the maximal cooldown for a unit."""
850 | switcher = {
851 | self.marine_id: 15,
852 | self.marauder_id: 25,
853 | self.medivac_id: 200, # max energy
854 | self.stalker_id: 35,
855 | self.zealot_id: 22,
856 | self.colossus_id: 24,
857 | self.hydralisk_id: 10,
858 | self.zergling_id: 11,
859 | self.baneling_id: 1,
860 | }
861 | return switcher.get(unit.unit_type, 15)
862 |
863 | def save_replay(self):
864 | """Save a replay."""
865 | prefix = self.replay_prefix or self.map_name
866 | replay_dir = self.replay_dir or ""
867 | replay_path = self._run_config.save_replay(
868 | self._controller.save_replay(),
869 | replay_dir=replay_dir,
870 | prefix=prefix,
871 | )
872 | logging.info("Replay saved at: %s" % replay_path)
873 |
874 | def unit_max_shield(self, unit):
875 | """Returns maximal shield for a given unit."""
876 | if unit.unit_type == 74 or unit.unit_type == self.stalker_id:
877 | return 80 # Protoss's Stalker
878 | if unit.unit_type == 73 or unit.unit_type == self.zealot_id:
879 | return 50 # Protoss's Zaelot
880 | if unit.unit_type == 4 or unit.unit_type == self.colossus_id:
881 | return 150 # Protoss's Colossus
882 |
883 | def can_move(self, unit, direction):
884 | """Whether a unit can move in a given direction."""
885 | m = self._move_amount / 2
886 |
887 | if direction == Direction.NORTH:
888 | x, y = int(unit.pos.x), int(unit.pos.y + m)
889 | elif direction == Direction.SOUTH:
890 | x, y = int(unit.pos.x), int(unit.pos.y - m)
891 | elif direction == Direction.EAST:
892 | x, y = int(unit.pos.x + m), int(unit.pos.y)
893 | else:
894 | x, y = int(unit.pos.x - m), int(unit.pos.y)
895 |
896 | if self.check_bounds(x, y) and self.pathing_grid[x, y]:
897 | return True
898 |
899 | return False
900 |
901 | def get_surrounding_points(self, unit, include_self=False):
902 | """Returns the surrounding points of the unit in 8 directions."""
903 | x = int(unit.pos.x)
904 | y = int(unit.pos.y)
905 |
906 | ma = self._move_amount
907 |
908 | points = [
909 | (x, y + 2 * ma),
910 | (x, y - 2 * ma),
911 | (x + 2 * ma, y),
912 | (x - 2 * ma, y),
913 | (x + ma, y + ma),
914 | (x - ma, y - ma),
915 | (x + ma, y - ma),
916 | (x - ma, y + ma),
917 | ]
918 |
919 | if include_self:
920 | points.append((x, y))
921 |
922 | return points
923 |
924 | def check_bounds(self, x, y):
925 | """Whether a point is within the map bounds."""
926 | return 0 <= x < self.map_x and 0 <= y < self.map_y
927 |
928 | def get_surrounding_pathing(self, unit):
929 | """Returns pathing values of the grid surrounding the given unit."""
930 | points = self.get_surrounding_points(unit, include_self=False)
931 | vals = [
932 | self.pathing_grid[x, y] if self.check_bounds(x, y) else 1
933 | for x, y in points
934 | ]
935 | return vals
936 |
937 | def get_surrounding_height(self, unit):
938 | """Returns height values of the grid surrounding the given unit."""
939 | points = self.get_surrounding_points(unit, include_self=True)
940 | vals = [
941 | self.terrain_height[x, y] if self.check_bounds(x, y) else 1
942 | for x, y in points
943 | ]
944 | return vals
945 |
946 | def get_obs_agent(self, agent_id):
947 | """Returns observation for agent_id. The observation is composed of:
948 |
949 | - agent movement features (where it can move to, height information
950 | and pathing grid)
951 | - enemy features (available_to_attack, health, relative_x, relative_y,
952 | shield, unit_type)
953 | - ally features (visible, distance, relative_x, relative_y, shield,
954 | unit_type)
955 | - agent unit features (health, shield, unit_type)
956 |
957 | All of this information is flattened and concatenated into a list,
958 | in the aforementioned order. To know the sizes of each of the
959 | features inside the final list of features, take a look at the
960 | functions ``get_obs_move_feats_size()``,
961 | ``get_obs_enemy_feats_size()``, ``get_obs_ally_feats_size()`` and
962 | ``get_obs_own_feats_size()``.
963 |
964 | The size of the observation vector may vary, depending on the
965 | environment configuration and type of units present in the map.
966 | For instance, non-Protoss units will not have shields, movement
967 | features may or may not include terrain height and pathing grid,
968 | unit_type is not included if there is only one type of unit in the
969 | map etc.).
970 |
971 | NOTE: Agents should have access only to their local observations
972 | during decentralised execution.
973 | """
974 | unit = self.get_unit_by_id(agent_id)
975 |
976 | move_feats_dim = self.get_obs_move_feats_size()
977 | enemy_feats_dim = self.get_obs_enemy_feats_size()
978 | ally_feats_dim = self.get_obs_ally_feats_size()
979 | own_feats_dim = self.get_obs_own_feats_size()
980 |
981 | move_feats = np.zeros(move_feats_dim, dtype=np.float32)
982 | enemy_feats = np.zeros(enemy_feats_dim, dtype=np.float32)
983 | ally_feats = np.zeros(ally_feats_dim, dtype=np.float32)
984 | own_feats = np.zeros(own_feats_dim, dtype=np.float32)
985 |
986 | if unit.health > 0: # otherwise dead, return all zeros
987 | x = unit.pos.x
988 | y = unit.pos.y
989 | sight_range = self.unit_sight_range(agent_id)
990 |
991 | # Movement features
992 | avail_actions = self.get_avail_agent_actions(agent_id)
993 | for m in range(self.n_actions_move):
994 | move_feats[m] = avail_actions[m + 2]
995 |
996 | ind = self.n_actions_move
997 |
998 | if self.obs_pathing_grid:
999 | move_feats[
1000 | ind : ind + self.n_obs_pathing # noqa
1001 | ] = self.get_surrounding_pathing(unit)
1002 | ind += self.n_obs_pathing
1003 |
1004 | if self.obs_terrain_height:
1005 | move_feats[ind:] = self.get_surrounding_height(unit)
1006 |
1007 | # Enemy features
1008 | for e_id, e_unit in self.enemies.items():
1009 | e_x = e_unit.pos.x
1010 | e_y = e_unit.pos.y
1011 | dist = self.distance(x, y, e_x, e_y)
1012 |
1013 | if (
1014 | dist < sight_range and e_unit.health > 0
1015 | ): # visible and alive
1016 | # Sight range > shoot range
1017 | enemy_feats[e_id, 0] = avail_actions[
1018 | self.n_actions_no_attack + e_id
1019 | ] # available
1020 | enemy_feats[e_id, 1] = dist / sight_range # distance
1021 | enemy_feats[e_id, 2] = (
1022 | e_x - x
1023 | ) / sight_range # relative X
1024 | enemy_feats[e_id, 3] = (
1025 | e_y - y
1026 | ) / sight_range # relative Y
1027 |
1028 | ind = 4
1029 | if self.obs_all_health:
1030 | enemy_feats[e_id, ind] = (
1031 | e_unit.health / e_unit.health_max
1032 | ) # health
1033 | ind += 1
1034 | if self.shield_bits_enemy > 0:
1035 | max_shield = self.unit_max_shield(e_unit)
1036 | enemy_feats[e_id, ind] = (
1037 | e_unit.shield / max_shield
1038 | ) # shield
1039 | ind += 1
1040 |
1041 | if self.unit_type_bits > 0:
1042 | type_id = self.get_unit_type_id(e_unit, False)
1043 | enemy_feats[e_id, ind + type_id] = 1 # unit type
1044 |
1045 | # Ally features
1046 | al_ids = [
1047 | al_id for al_id in range(self.n_agents) if al_id != agent_id
1048 | ]
1049 | for i, al_id in enumerate(al_ids):
1050 |
1051 | al_unit = self.get_unit_by_id(al_id)
1052 | al_x = al_unit.pos.x
1053 | al_y = al_unit.pos.y
1054 | dist = self.distance(x, y, al_x, al_y)
1055 |
1056 | if (
1057 | dist < sight_range and al_unit.health > 0
1058 | ): # visible and alive
1059 | ally_feats[i, 0] = 1 # visible
1060 | ally_feats[i, 1] = dist / sight_range # distance
1061 | ally_feats[i, 2] = (al_x - x) / sight_range # relative X
1062 | ally_feats[i, 3] = (al_y - y) / sight_range # relative Y
1063 |
1064 | ind = 4
1065 | if self.obs_all_health:
1066 | ally_feats[i, ind] = (
1067 | al_unit.health / al_unit.health_max
1068 | ) # health
1069 | ind += 1
1070 | if self.shield_bits_ally > 0:
1071 | max_shield = self.unit_max_shield(al_unit)
1072 | ally_feats[i, ind] = (
1073 | al_unit.shield / max_shield
1074 | ) # shield
1075 | ind += 1
1076 |
1077 | if self.unit_type_bits > 0:
1078 | type_id = self.get_unit_type_id(al_unit, True)
1079 | ally_feats[i, ind + type_id] = 1
1080 | ind += self.unit_type_bits
1081 |
1082 | if self.obs_last_action:
1083 | ally_feats[i, ind:] = self.last_action[al_id]
1084 |
1085 | # Own features
1086 | ind = 0
1087 | if self.obs_own_health:
1088 | own_feats[ind] = unit.health / unit.health_max
1089 | ind += 1
1090 | if self.shield_bits_ally > 0:
1091 | max_shield = self.unit_max_shield(unit)
1092 | own_feats[ind] = unit.shield / max_shield
1093 | ind += 1
1094 |
1095 | if self.unit_type_bits > 0:
1096 | type_id = self.get_unit_type_id(unit, True)
1097 | own_feats[ind + type_id] = 1
1098 |
1099 | agent_obs = np.concatenate(
1100 | (
1101 | move_feats.flatten(),
1102 | enemy_feats.flatten(),
1103 | ally_feats.flatten(),
1104 | own_feats.flatten(),
1105 | )
1106 | )
1107 |
1108 | if self.obs_timestep_number:
1109 | agent_obs = np.append(
1110 | agent_obs, self._episode_steps / self.episode_limit
1111 | )
1112 |
1113 | if self.debug:
1114 | logging.debug("Obs Agent: {}".format(agent_id).center(60, "-"))
1115 | logging.debug(
1116 | "Avail. actions {}".format(
1117 | self.get_avail_agent_actions(agent_id)
1118 | )
1119 | )
1120 | logging.debug("Move feats {}".format(move_feats))
1121 | logging.debug("Enemy feats {}".format(enemy_feats))
1122 | logging.debug("Ally feats {}".format(ally_feats))
1123 | logging.debug("Own feats {}".format(own_feats))
1124 |
1125 | return agent_obs
1126 |
1127 | def get_obs(self):
1128 | """Returns all agent observations in a list.
1129 | NOTE: Agents should have access only to their local observations
1130 | during decentralised execution.
1131 | """
1132 | agents_obs = [self.get_obs_agent(i) for i in range(self.n_agents)]
1133 | return agents_obs
1134 |
1135 | def get_state(self):
1136 | """Returns the global state.
1137 | NOTE: This functon should not be used during decentralised execution.
1138 | """
1139 | if self.obs_instead_of_state:
1140 | obs_concat = np.concatenate(self.get_obs(), axis=0).astype(
1141 | np.float32
1142 | )
1143 | return obs_concat
1144 |
1145 | state_dict = self.get_state_dict()
1146 |
1147 | state = np.append(
1148 | state_dict["allies"].flatten(), state_dict["enemies"].flatten()
1149 | )
1150 | if "last_action" in state_dict:
1151 | state = np.append(state, state_dict["last_action"].flatten())
1152 | if "timestep" in state_dict:
1153 | state = np.append(state, state_dict["timestep"])
1154 |
1155 | state = state.astype(dtype=np.float32)
1156 |
1157 | if self.debug:
1158 | logging.debug("STATE".center(60, "-"))
1159 | logging.debug("Ally state {}".format(state_dict["allies"]))
1160 | logging.debug("Enemy state {}".format(state_dict["enemies"]))
1161 | if self.state_last_action:
1162 | logging.debug("Last actions {}".format(self.last_action))
1163 |
1164 | return state
1165 |
1166 | def get_ally_num_attributes(self):
1167 | return len(self.ally_state_attr_names)
1168 |
1169 | def get_enemy_num_attributes(self):
1170 | return len(self.enemy_state_attr_names)
1171 |
1172 | def get_state_dict(self):
1173 | """Returns the global state as a dictionary.
1174 |
1175 | - allies: numpy array containing agents and their attributes
1176 | - enemies: numpy array containing enemies and their attributes
1177 | - last_action: numpy array of previous actions for each agent
1178 | - timestep: current no. of steps divided by total no. of steps
1179 |
1180 | NOTE: This function should not be used during decentralised execution.
1181 | """
1182 |
1183 | # number of features equals the number of attribute names
1184 | nf_al = self.get_ally_num_attributes()
1185 | nf_en = self.get_enemy_num_attributes()
1186 |
1187 | ally_state = np.zeros((self.n_agents, nf_al))
1188 | enemy_state = np.zeros((self.n_enemies, nf_en))
1189 |
1190 | center_x = self.map_x / 2
1191 | center_y = self.map_y / 2
1192 |
1193 | for al_id, al_unit in self.agents.items():
1194 | if al_unit.health > 0:
1195 | x = al_unit.pos.x
1196 | y = al_unit.pos.y
1197 | max_cd = self.unit_max_cooldown(al_unit)
1198 |
1199 | ally_state[al_id, 0] = (
1200 | al_unit.health / al_unit.health_max
1201 | ) # health
1202 | if (
1203 | self.map_type == "MMM"
1204 | and al_unit.unit_type == self.medivac_id
1205 | ):
1206 | ally_state[al_id, 1] = al_unit.energy / max_cd # energy
1207 | else:
1208 | ally_state[al_id, 1] = (
1209 | al_unit.weapon_cooldown / max_cd
1210 | ) # cooldown
1211 | ally_state[al_id, 2] = (
1212 | x - center_x
1213 | ) / self.max_distance_x # relative X
1214 | ally_state[al_id, 3] = (
1215 | y - center_y
1216 | ) / self.max_distance_y # relative Y
1217 |
1218 | if self.shield_bits_ally > 0:
1219 | max_shield = self.unit_max_shield(al_unit)
1220 | ally_state[al_id, 4] = (
1221 | al_unit.shield / max_shield
1222 | ) # shield
1223 |
1224 | if self.unit_type_bits > 0:
1225 | type_id = self.get_unit_type_id(al_unit, True)
1226 | ally_state[al_id, type_id - self.unit_type_bits] = 1
1227 |
1228 | for e_id, e_unit in self.enemies.items():
1229 | if e_unit.health > 0:
1230 | x = e_unit.pos.x
1231 | y = e_unit.pos.y
1232 |
1233 | enemy_state[e_id, 0] = (
1234 | e_unit.health / e_unit.health_max
1235 | ) # health
1236 | enemy_state[e_id, 1] = (
1237 | x - center_x
1238 | ) / self.max_distance_x # relative X
1239 | enemy_state[e_id, 2] = (
1240 | y - center_y
1241 | ) / self.max_distance_y # relative Y
1242 |
1243 | if self.shield_bits_enemy > 0:
1244 | max_shield = self.unit_max_shield(e_unit)
1245 | enemy_state[e_id, 3] = e_unit.shield / max_shield # shield
1246 |
1247 | if self.unit_type_bits > 0:
1248 | type_id = self.get_unit_type_id(e_unit, False)
1249 | enemy_state[e_id, type_id - self.unit_type_bits] = 1
1250 |
1251 | state = {"allies": ally_state, "enemies": enemy_state}
1252 |
1253 | if self.state_last_action:
1254 | state["last_action"] = self.last_action
1255 | if self.state_timestep_number:
1256 | state["timestep"] = self._episode_steps / self.episode_limit
1257 |
1258 | return state
1259 |
1260 | def get_obs_enemy_feats_size(self):
1261 | """Returns the dimensions of the matrix containing enemy features.
1262 | Size is n_enemies x n_features.
1263 | """
1264 | nf_en = 4 + self.unit_type_bits
1265 |
1266 | if self.obs_all_health:
1267 | nf_en += 1 + self.shield_bits_enemy
1268 |
1269 | return self.n_enemies, nf_en
1270 |
1271 | def get_obs_ally_feats_size(self):
1272 | """Returns the dimensions of the matrix containing ally features.
1273 | Size is n_allies x n_features.
1274 | """
1275 | nf_al = 4 + self.unit_type_bits
1276 |
1277 | if self.obs_all_health:
1278 | nf_al += 1 + self.shield_bits_ally
1279 |
1280 | if self.obs_last_action:
1281 | nf_al += self.n_actions
1282 |
1283 | return self.n_agents - 1, nf_al
1284 |
1285 | def get_obs_own_feats_size(self):
1286 | """
1287 | Returns the size of the vector containing the agents' own features.
1288 | """
1289 | own_feats = self.unit_type_bits
1290 | if self.obs_own_health:
1291 | own_feats += 1 + self.shield_bits_ally
1292 | if self.obs_timestep_number:
1293 | own_feats += 1
1294 |
1295 | return own_feats
1296 |
1297 | def get_obs_move_feats_size(self):
1298 | """Returns the size of the vector containing the agents's movement-
1299 | related features.
1300 | """
1301 | move_feats = self.n_actions_move
1302 | if self.obs_pathing_grid:
1303 | move_feats += self.n_obs_pathing
1304 | if self.obs_terrain_height:
1305 | move_feats += self.n_obs_height
1306 |
1307 | return move_feats
1308 |
1309 | def get_obs_size(self):
1310 | """Returns the size of the observation."""
1311 | own_feats = self.get_obs_own_feats_size()
1312 | move_feats = self.get_obs_move_feats_size()
1313 |
1314 | n_enemies, n_enemy_feats = self.get_obs_enemy_feats_size()
1315 | n_allies, n_ally_feats = self.get_obs_ally_feats_size()
1316 |
1317 | enemy_feats = n_enemies * n_enemy_feats
1318 | ally_feats = n_allies * n_ally_feats
1319 |
1320 | return move_feats + enemy_feats + ally_feats + own_feats
1321 |
1322 | def get_state_size(self):
1323 | """Returns the size of the global state."""
1324 | if self.obs_instead_of_state:
1325 | return self.get_obs_size() * self.n_agents
1326 |
1327 | nf_al = 4 + self.shield_bits_ally + self.unit_type_bits
1328 | nf_en = 3 + self.shield_bits_enemy + self.unit_type_bits
1329 |
1330 | enemy_state = self.n_enemies * nf_en
1331 | ally_state = self.n_agents * nf_al
1332 |
1333 | size = enemy_state + ally_state
1334 |
1335 | if self.state_last_action:
1336 | size += self.n_agents * self.n_actions
1337 | if self.state_timestep_number:
1338 | size += 1
1339 |
1340 | return size
1341 |
1342 | def get_visibility_matrix(self):
1343 | """Returns a boolean numpy array of dimensions
1344 | (n_agents, n_agents + n_enemies) indicating which units
1345 | are visible to each agent.
1346 | """
1347 | arr = np.zeros(
1348 | (self.n_agents, self.n_agents + self.n_enemies),
1349 | dtype=bool,
1350 | )
1351 |
1352 | for agent_id in range(self.n_agents):
1353 | current_agent = self.get_unit_by_id(agent_id)
1354 | if current_agent.health > 0: # it agent not dead
1355 | x = current_agent.pos.x
1356 | y = current_agent.pos.y
1357 | sight_range = self.unit_sight_range(agent_id)
1358 |
1359 | # Enemies
1360 | for e_id, e_unit in self.enemies.items():
1361 | e_x = e_unit.pos.x
1362 | e_y = e_unit.pos.y
1363 | dist = self.distance(x, y, e_x, e_y)
1364 |
1365 | if dist < sight_range and e_unit.health > 0:
1366 | # visible and alive
1367 | arr[agent_id, self.n_agents + e_id] = 1
1368 |
1369 | # The matrix for allies is filled symmetrically
1370 | al_ids = [
1371 | al_id for al_id in range(self.n_agents) if al_id > agent_id
1372 | ]
1373 | for _, al_id in enumerate(al_ids):
1374 | al_unit = self.get_unit_by_id(al_id)
1375 | al_x = al_unit.pos.x
1376 | al_y = al_unit.pos.y
1377 | dist = self.distance(x, y, al_x, al_y)
1378 |
1379 | if dist < sight_range and al_unit.health > 0:
1380 | # visible and alive
1381 | arr[agent_id, al_id] = arr[al_id, agent_id] = 1
1382 |
1383 | return arr
1384 |
1385 | def get_unit_type_id(self, unit, ally):
1386 | """Returns the ID of unit type in the given scenario."""
1387 | if ally: # use new SC2 unit types
1388 | type_id = unit.unit_type - self._min_unit_type
1389 | else: # use default SC2 unit types
1390 | if self.map_type == "stalkers_and_zealots":
1391 | # id(Stalker) = 74, id(Zealot) = 73
1392 | # Notice that one-hot zealot unit type in enemy_obs will be [1, 0] but [0, 1] in ally_obs
1393 | # If you want to align the enemy unit type with the ally's, uncomment the following lines
1394 | # if unit.unit_type == 74:
1395 | # type_id = 0
1396 | # else:
1397 | # type_id = 1
1398 | type_id = unit.unit_type - 73
1399 | elif self.map_type == "colossi_stalkers_zealots":
1400 | # id(Stalker) = 74, id(Zealot) = 73, id(Colossus) = 4
1401 | if unit.unit_type == 4:
1402 | type_id = 0
1403 | elif unit.unit_type == 74:
1404 | type_id = 1
1405 | else:
1406 | type_id = 2
1407 | elif self.map_type == "bane":
1408 | # id(Baneling) = 9
1409 | if unit.unit_type == 9:
1410 | type_id = 0
1411 | else:
1412 | type_id = 1
1413 | elif self.map_type == "MMM":
1414 | # id(Marauder) = 51, id(Marine) = 48, id(Medivac) = 54
1415 | if unit.unit_type == 51:
1416 | type_id = 0
1417 | elif unit.unit_type == 48:
1418 | type_id = 1
1419 | else:
1420 | type_id = 2
1421 |
1422 | return type_id
1423 |
1424 | def get_avail_agent_actions(self, agent_id):
1425 | """Returns the available actions for agent_id."""
1426 | unit = self.get_unit_by_id(agent_id)
1427 | if unit.health > 0:
1428 | # cannot choose no-op when alive
1429 | avail_actions = [0] * self.n_actions
1430 |
1431 | # stop should be allowed
1432 | avail_actions[1] = 1
1433 |
1434 | # see if we can move
1435 | if self.can_move(unit, Direction.NORTH):
1436 | avail_actions[2] = 1
1437 | if self.can_move(unit, Direction.SOUTH):
1438 | avail_actions[3] = 1
1439 | if self.can_move(unit, Direction.EAST):
1440 | avail_actions[4] = 1
1441 | if self.can_move(unit, Direction.WEST):
1442 | avail_actions[5] = 1
1443 |
1444 | # Can attack only alive units that are alive in the shooting range
1445 | shoot_range = self.unit_shoot_range(agent_id)
1446 |
1447 | target_items = self.enemies.items()
1448 | if self.map_type == "MMM" and unit.unit_type == self.medivac_id:
1449 | # Medivacs cannot heal themselves or other flying units
1450 | target_items = [
1451 | (t_id, t_unit)
1452 | for (t_id, t_unit) in self.agents.items()
1453 | if t_unit.unit_type != self.medivac_id
1454 | ]
1455 |
1456 | for t_id, t_unit in target_items:
1457 | if t_unit.health > 0:
1458 | dist = self.distance(
1459 | unit.pos.x, unit.pos.y, t_unit.pos.x, t_unit.pos.y
1460 | )
1461 | if dist <= shoot_range:
1462 | avail_actions[t_id + self.n_actions_no_attack] = 1
1463 |
1464 | return avail_actions
1465 |
1466 | else:
1467 | # only no-op allowed
1468 | return [1] + [0] * (self.n_actions - 1)
1469 |
1470 | def get_avail_actions(self):
1471 | """Returns the available actions of all agents in a list."""
1472 | avail_actions = []
1473 | for agent_id in range(self.n_agents):
1474 | avail_agent = self.get_avail_agent_actions(agent_id)
1475 | avail_actions.append(avail_agent)
1476 | return avail_actions
1477 |
1478 | def close(self):
1479 | """Close StarCraft II."""
1480 | if self.renderer is not None:
1481 | self.renderer.close()
1482 | self.renderer = None
1483 | if self._sc2_proc:
1484 | self._sc2_proc.close()
1485 |
1486 | def seed(self):
1487 | """Returns the random seed used by the environment."""
1488 | return self._seed
1489 |
1490 | def render(self, mode="human"):
1491 | if self.renderer is None:
1492 | from smac.env.starcraft2.render import StarCraft2Renderer
1493 |
1494 | self.renderer = StarCraft2Renderer(self, mode)
1495 | assert (
1496 | mode == self.renderer.mode
1497 | ), "mode must be consistent across render calls"
1498 | return self.renderer.render(mode)
1499 |
1500 | def _kill_all_units(self):
1501 | """Kill all units on the map."""
1502 | units_alive = [
1503 | unit.tag for unit in self.agents.values() if unit.health > 0
1504 | ] + [unit.tag for unit in self.enemies.values() if unit.health > 0]
1505 | debug_command = [
1506 | d_pb.DebugCommand(kill_unit=d_pb.DebugKillUnit(tag=units_alive))
1507 | ]
1508 | self._controller.debug(debug_command)
1509 |
1510 | def init_units(self):
1511 | """Initialise the units."""
1512 | while True:
1513 | # Sometimes not all units have yet been created by SC2
1514 | self.agents = {}
1515 | self.enemies = {}
1516 |
1517 | ally_units = [
1518 | unit
1519 | for unit in self._obs.observation.raw_data.units
1520 | if unit.owner == 1
1521 | ]
1522 | ally_units_sorted = sorted(
1523 | ally_units,
1524 | key=attrgetter("unit_type", "pos.x", "pos.y"),
1525 | reverse=False,
1526 | )
1527 |
1528 | for i in range(len(ally_units_sorted)):
1529 | self.agents[i] = ally_units_sorted[i]
1530 | if self.debug:
1531 | logging.debug(
1532 | "Unit {} is {}, x = {}, y = {}".format(
1533 | len(self.agents),
1534 | self.agents[i].unit_type,
1535 | self.agents[i].pos.x,
1536 | self.agents[i].pos.y,
1537 | )
1538 | )
1539 |
1540 | for unit in self._obs.observation.raw_data.units:
1541 | if unit.owner == 2:
1542 | self.enemies[len(self.enemies)] = unit
1543 | if self._episode_count == 0:
1544 | self.max_reward += unit.health_max + unit.shield_max
1545 |
1546 | if self._episode_count == 0:
1547 | min_unit_type = min(
1548 | unit.unit_type for unit in self.agents.values()
1549 | )
1550 | self._init_ally_unit_types(min_unit_type)
1551 |
1552 | all_agents_created = len(self.agents) == self.n_agents
1553 | all_enemies_created = len(self.enemies) == self.n_enemies
1554 |
1555 | self._unit_types = [
1556 | unit.unit_type for unit in ally_units_sorted
1557 | ] + [
1558 | unit.unit_type
1559 | for unit in self._obs.observation.raw_data.units
1560 | if unit.owner == 2
1561 | ]
1562 |
1563 | if all_agents_created and all_enemies_created: # all good
1564 | return
1565 |
1566 | try:
1567 | self._controller.step(1)
1568 | self._obs = self._controller.observe()
1569 | except (protocol.ProtocolError, protocol.ConnectionError):
1570 | self.full_restart()
1571 | self.reset()
1572 |
1573 | def get_unit_types(self):
1574 | if self._unit_types is None:
1575 | warn(
1576 | "unit types have not been initialized yet, please call"
1577 | "env.reset() to populate this and call t1286he method again."
1578 | )
1579 |
1580 | return self._unit_types
1581 |
1582 | def update_units(self):
1583 | """Update units after an environment step.
1584 | This function assumes that self._obs is up-to-date.
1585 | """
1586 | n_ally_alive = 0
1587 | n_enemy_alive = 0
1588 |
1589 | # Store previous state
1590 | self.previous_ally_units = deepcopy(self.agents)
1591 | self.previous_enemy_units = deepcopy(self.enemies)
1592 |
1593 | for al_id, al_unit in self.agents.items():
1594 | updated = False
1595 | for unit in self._obs.observation.raw_data.units:
1596 | if al_unit.tag == unit.tag:
1597 | self.agents[al_id] = unit
1598 | updated = True
1599 | n_ally_alive += 1
1600 | break
1601 |
1602 | if not updated: # dead
1603 | al_unit.health = 0
1604 |
1605 | for e_id, e_unit in self.enemies.items():
1606 | updated = False
1607 | for unit in self._obs.observation.raw_data.units:
1608 | if e_unit.tag == unit.tag:
1609 | self.enemies[e_id] = unit
1610 | updated = True
1611 | n_enemy_alive += 1
1612 | break
1613 |
1614 | if not updated: # dead
1615 | e_unit.health = 0
1616 |
1617 | if (
1618 | n_ally_alive == 0
1619 | and n_enemy_alive > 0
1620 | or self.only_medivac_left(ally=True)
1621 | ):
1622 | return -1 # lost
1623 | if (
1624 | n_ally_alive > 0
1625 | and n_enemy_alive == 0
1626 | or self.only_medivac_left(ally=False)
1627 | ):
1628 | return 1 # won
1629 | if n_ally_alive == 0 and n_enemy_alive == 0:
1630 | return 0
1631 |
1632 | return None
1633 |
1634 | def _init_ally_unit_types(self, min_unit_type):
1635 | """Initialise ally unit types. Should be called once from the
1636 | init_units function.
1637 | """
1638 | self._min_unit_type = min_unit_type
1639 | if self.map_type == "marines":
1640 | self.marine_id = min_unit_type
1641 | elif self.map_type == "stalkers_and_zealots":
1642 | self.stalker_id = min_unit_type
1643 | self.zealot_id = min_unit_type + 1
1644 | elif self.map_type == "colossi_stalkers_zealots":
1645 | self.colossus_id = min_unit_type
1646 | self.stalker_id = min_unit_type + 1
1647 | self.zealot_id = min_unit_type + 2
1648 | elif self.map_type == "MMM":
1649 | self.marauder_id = min_unit_type
1650 | self.marine_id = min_unit_type + 1
1651 | self.medivac_id = min_unit_type + 2
1652 | elif self.map_type == "zealots":
1653 | self.zealot_id = min_unit_type
1654 | elif self.map_type == "hydralisks":
1655 | self.hydralisk_id = min_unit_type
1656 | elif self.map_type == "stalkers":
1657 | self.stalker_id = min_unit_type
1658 | elif self.map_type == "colossus":
1659 | self.colossus_id = min_unit_type
1660 | elif self.map_type == "bane":
1661 | self.baneling_id = min_unit_type
1662 | self.zergling_id = min_unit_type + 1
1663 |
1664 | def only_medivac_left(self, ally):
1665 | """Check if only Medivac units are left."""
1666 | if self.map_type != "MMM":
1667 | return False
1668 |
1669 | if ally:
1670 | units_alive = [
1671 | a
1672 | for a in self.agents.values()
1673 | if (a.health > 0 and a.unit_type != self.medivac_id)
1674 | ]
1675 | if len(units_alive) == 0:
1676 | return True
1677 | return False
1678 | else:
1679 | units_alive = [
1680 | a
1681 | for a in self.enemies.values()
1682 | if (a.health > 0 and a.unit_type != self.medivac_id)
1683 | ]
1684 | if len(units_alive) == 1 and units_alive[0].unit_type == 54:
1685 | return True
1686 | return False
1687 |
1688 | def get_unit_by_id(self, a_id):
1689 | """Get unit by ID."""
1690 | return self.agents[a_id]
1691 |
1692 | def get_stats(self):
1693 | stats = {
1694 | "battles_won": self.battles_won,
1695 | "battles_game": self.battles_game,
1696 | "battles_draw": self.timeouts,
1697 | "win_rate": self.battles_won / self.battles_game,
1698 | "timeouts": self.timeouts,
1699 | "restarts": self.force_restarts,
1700 | }
1701 | return stats
1702 |
1703 | def get_env_info(self):
1704 | env_info = super().get_env_info()
1705 | env_info["agent_features"] = self.ally_state_attr_names
1706 | env_info["enemy_features"] = self.enemy_state_attr_names
1707 | return env_info
1708 |
--------------------------------------------------------------------------------
/smac/examples/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/oxwhirl/smac/d6aab33f76abc3849c50463a8592a84f59a5ef84/smac/examples/__init__.py
--------------------------------------------------------------------------------
/smac/examples/pettingzoo/README.rst:
--------------------------------------------------------------------------------
1 | SMAC on PettingZoo
2 | ==================
3 |
4 | This example shows how to run SMAC environments with PettingZoo multi-agent API.
5 |
6 | Instructions
7 | ------------
8 |
9 | To get started, first install PettingZoo with ``pip install pettingzoo``.
10 |
11 | The SMAC environment for PettingZoo, ``StarCraft2PZEnv``, can be initialized with two different API templates.
12 | * **AEC**: PettingZoo is based in the *Agent Environment Cycle* game model, more information about "AEC" can be read in the following `paper `_. To create a SMAC environment as an "AEC" PettingZoo game model use: ::
13 |
14 | from smac.env.pettingzoo import StarCraft2PZEnv
15 |
16 | env = StarCraft2PZEnv.env()
17 |
18 | * **Parallel**: PettingZoo also supports parallel environments where all agents have simultaneous actions and observations. This type of environment can be created as follows: ::
19 |
20 | from smac.env.pettingzoo import StarCraft2PZEnv
21 |
22 | env = StarCraft2PZEnv.parallel_env()
23 |
24 | `pettingzoo_demo.py` has an example of a SMAC environment being used as a PettingZoo "AEC" environment. With `env.render()` it is possible to output an instance of the environment as a frame in pygame. This is useful for debugging purposes.
25 |
26 | | See https://www.pettingzoo.ml/api for more documentation.
27 |
--------------------------------------------------------------------------------
/smac/examples/pettingzoo/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/oxwhirl/smac/d6aab33f76abc3849c50463a8592a84f59a5ef84/smac/examples/pettingzoo/__init__.py
--------------------------------------------------------------------------------
/smac/examples/pettingzoo/pettingzoo_demo.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | import random
6 | import numpy as np
7 | from smac.env.pettingzoo import StarCraft2PZEnv
8 |
9 |
10 | def main():
11 | """
12 | Runs an env object with random actions.
13 | """
14 | env = StarCraft2PZEnv.env()
15 | episodes = 10
16 |
17 | total_reward = 0
18 | done = False
19 | completed_episodes = 0
20 |
21 | while completed_episodes < episodes:
22 | env.reset()
23 | for agent in env.agent_iter():
24 | env.render()
25 |
26 | obs, reward, terms, truncs, _ = env.last()
27 | total_reward += reward
28 | if terms or truncs:
29 | action = None
30 | elif isinstance(obs, dict) and "action_mask" in obs:
31 | action = random.choice(np.flatnonzero(obs["action_mask"]))
32 | else:
33 | action = env.action_spaces[agent].sample()
34 | env.step(action)
35 |
36 | completed_episodes += 1
37 |
38 | env.close()
39 |
40 | print("Average total reward", total_reward / episodes)
41 |
42 |
43 | if __name__ == "__main__":
44 | main()
45 |
--------------------------------------------------------------------------------
/smac/examples/random_agents.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | from smac.env import StarCraft2Env
6 | import numpy as np
7 |
8 |
9 | def main():
10 | env = StarCraft2Env(map_name="8m")
11 | env_info = env.get_env_info()
12 |
13 | n_actions = env_info["n_actions"]
14 | n_agents = env_info["n_agents"]
15 |
16 | n_episodes = 10
17 |
18 | for e in range(n_episodes):
19 | env.reset()
20 | terminated = False
21 | episode_reward = 0
22 |
23 | while not terminated:
24 | obs = env.get_obs()
25 | state = env.get_state()
26 | # env.render() # Uncomment for rendering
27 |
28 | actions = []
29 | for agent_id in range(n_agents):
30 | avail_actions = env.get_avail_agent_actions(agent_id)
31 | avail_actions_ind = np.nonzero(avail_actions)[0]
32 | action = np.random.choice(avail_actions_ind)
33 | actions.append(action)
34 |
35 | reward, terminated, _ = env.step(actions)
36 | episode_reward += reward
37 |
38 | print("Total reward in episode {} = {}".format(e, episode_reward))
39 |
40 | env.close()
41 |
42 |
43 | if __name__ == "__main__":
44 | main()
45 |
--------------------------------------------------------------------------------
/smac/examples/rllib/README.rst:
--------------------------------------------------------------------------------
1 | SMAC on RLlib
2 | =============
3 |
4 | This example shows how to run SMAC environments with RLlib multi-agent.
5 |
6 | Instructions
7 | ------------
8 |
9 | To get started, first install RLlib with ``pip install -U ray[rllib]``. You will also need TensorFlow installed.
10 |
11 | In ``run_ppo.py``, each agent will be controlled by an independent PPO policy (the policies share weights). This setup serves as a single-agent baseline for this task.
12 |
13 | In ``run_qmix.py``, the agents are controlled by the multi-agent QMIX policy. This setup is an example of centralized training and decentralized execution.
14 |
15 | See https://ray.readthedocs.io/en/latest/rllib.html for more documentation.
16 |
--------------------------------------------------------------------------------
/smac/examples/rllib/__init__.py:
--------------------------------------------------------------------------------
1 | from smac.examples.rllib.env import RLlibStarCraft2Env
2 | from smac.examples.rllib.model import MaskedActionsModel
3 |
4 | __all__ = ["RLlibStarCraft2Env", "MaskedActionsModel"]
5 |
--------------------------------------------------------------------------------
/smac/examples/rllib/env.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | import random
6 |
7 | import numpy as np
8 |
9 | from gym.spaces import Discrete, Box, Dict
10 |
11 | from ray import rllib
12 |
13 | from smac.env import StarCraft2Env
14 |
15 |
16 | class RLlibStarCraft2Env(rllib.MultiAgentEnv):
17 | """Wraps a smac StarCraft env to be compatible with RLlib multi-agent."""
18 |
19 | def __init__(self, **smac_args):
20 | """Create a new multi-agent StarCraft env compatible with RLlib.
21 |
22 | Arguments:
23 | smac_args (dict): Arguments to pass to the underlying
24 | smac.env.starcraft.StarCraft2Env instance.
25 |
26 | Examples:
27 | >>> from smac.examples.rllib import RLlibStarCraft2Env
28 | >>> env = RLlibStarCraft2Env(map_name="8m")
29 | >>> print(env.reset())
30 | """
31 |
32 | self._env = StarCraft2Env(**smac_args)
33 | self._ready_agents = []
34 | self.observation_space = Dict(
35 | {
36 | "obs": Box(-1, 1, shape=(self._env.get_obs_size(),)),
37 | "action_mask": Box(
38 | 0, 1, shape=(self._env.get_total_actions(),)
39 | ),
40 | }
41 | )
42 | self.action_space = Discrete(self._env.get_total_actions())
43 |
44 | def reset(self):
45 | """Resets the env and returns observations from ready agents.
46 |
47 | Returns:
48 | obs (dict): New observations for each ready agent.
49 | """
50 |
51 | obs_list, state_list = self._env.reset()
52 | return_obs = {}
53 | for i, obs in enumerate(obs_list):
54 | return_obs[i] = {
55 | "action_mask": np.array(self._env.get_avail_agent_actions(i)),
56 | "obs": obs,
57 | }
58 |
59 | self._ready_agents = list(range(len(obs_list)))
60 | return return_obs
61 |
62 | def step(self, action_dict):
63 | """Returns observations from ready agents.
64 |
65 | The returns are dicts mapping from agent_id strings to values. The
66 | number of agents in the env can vary over time.
67 |
68 | Returns
69 | -------
70 | obs (dict): New observations for each ready agent.
71 | rewards (dict): Reward values for each ready agent. If the
72 | episode is just started, the value will be None.
73 | dones (dict): Done values for each ready agent. The special key
74 | "__all__" (required) is used to indicate env termination.
75 | infos (dict): Optional info values for each agent id.
76 | """
77 |
78 | actions = []
79 | for i in self._ready_agents:
80 | if i not in action_dict:
81 | raise ValueError(
82 | "You must supply an action for agent: {}".format(i)
83 | )
84 | actions.append(action_dict[i])
85 |
86 | if len(actions) != len(self._ready_agents):
87 | raise ValueError(
88 | "Unexpected number of actions: {}".format(
89 | action_dict,
90 | )
91 | )
92 |
93 | rew, done, info = self._env.step(actions)
94 | obs_list = self._env.get_obs()
95 | return_obs = {}
96 | for i, obs in enumerate(obs_list):
97 | return_obs[i] = {
98 | "action_mask": self._env.get_avail_agent_actions(i),
99 | "obs": obs,
100 | }
101 | rews = {i: rew / len(obs_list) for i in range(len(obs_list))}
102 | dones = {i: done for i in range(len(obs_list))}
103 | dones["__all__"] = done
104 | infos = {i: info for i in range(len(obs_list))}
105 |
106 | self._ready_agents = list(range(len(obs_list)))
107 | return return_obs, rews, dones, infos
108 |
109 | def close(self):
110 | """Close the environment"""
111 | self._env.close()
112 |
113 | def seed(self, seed):
114 | random.seed(seed)
115 | np.random.seed(seed)
116 |
--------------------------------------------------------------------------------
/smac/examples/rllib/model.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | import tensorflow as tf
6 |
7 | from ray.rllib.models import Model
8 | from ray.rllib.models.tf.misc import normc_initializer
9 |
10 |
11 | class MaskedActionsModel(Model):
12 | """Custom RLlib model that emits -inf logits for invalid actions.
13 |
14 | This is used to handle the variable-length StarCraft action space.
15 | """
16 |
17 | def _build_layers_v2(self, input_dict, num_outputs, options):
18 | action_mask = input_dict["obs"]["action_mask"]
19 | if num_outputs != action_mask.shape[1].value:
20 | raise ValueError(
21 | "This model assumes num outputs is equal to max avail actions",
22 | num_outputs,
23 | action_mask,
24 | )
25 |
26 | # Standard fully connected network
27 | last_layer = input_dict["obs"]["obs"]
28 | hiddens = options.get("fcnet_hiddens")
29 | for i, size in enumerate(hiddens):
30 | label = "fc{}".format(i)
31 | last_layer = tf.layers.dense(
32 | last_layer,
33 | size,
34 | kernel_initializer=normc_initializer(1.0),
35 | activation=tf.nn.tanh,
36 | name=label,
37 | )
38 | action_logits = tf.layers.dense(
39 | last_layer,
40 | num_outputs,
41 | kernel_initializer=normc_initializer(0.01),
42 | activation=None,
43 | name="fc_out",
44 | )
45 |
46 | # Mask out invalid actions (use tf.float32.min for stability)
47 | inf_mask = tf.maximum(tf.log(action_mask), tf.float32.min)
48 | masked_logits = inf_mask + action_logits
49 |
50 | return masked_logits, last_layer
51 |
--------------------------------------------------------------------------------
/smac/examples/rllib/run_ppo.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | """Example of running StarCraft2 with RLlib PPO.
6 |
7 | In this setup, each agent will be controlled by an independent PPO policy.
8 | However the policies share weights.
9 |
10 | Increase the level of parallelism by changing --num-workers.
11 | """
12 |
13 | import argparse
14 |
15 | import ray
16 | from ray.tune import run_experiments, register_env
17 | from ray.rllib.models import ModelCatalog
18 |
19 | from smac.examples.rllib.env import RLlibStarCraft2Env
20 | from smac.examples.rllib.model import MaskedActionsModel
21 |
22 |
23 | if __name__ == "__main__":
24 | parser = argparse.ArgumentParser()
25 | parser.add_argument("--num-iters", type=int, default=100)
26 | parser.add_argument("--num-workers", type=int, default=2)
27 | parser.add_argument("--map-name", type=str, default="8m")
28 | args = parser.parse_args()
29 |
30 | ray.init()
31 |
32 | register_env("smac", lambda smac_args: RLlibStarCraft2Env(**smac_args))
33 | ModelCatalog.register_custom_model("mask_model", MaskedActionsModel)
34 |
35 | run_experiments(
36 | {
37 | "ppo_sc2": {
38 | "run": "PPO",
39 | "env": "smac",
40 | "stop": {
41 | "training_iteration": args.num_iters,
42 | },
43 | "config": {
44 | "num_workers": args.num_workers,
45 | "observation_filter": "NoFilter", # breaks the action mask
46 | "vf_share_layers": True, # no separate value model
47 | "env_config": {
48 | "map_name": args.map_name,
49 | },
50 | "model": {
51 | "custom_model": "mask_model",
52 | },
53 | },
54 | },
55 | }
56 | )
57 |
--------------------------------------------------------------------------------
/smac/examples/rllib/run_qmix.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | """Example of running StarCraft2 with RLlib QMIX.
6 |
7 | This assumes all agents are homogeneous. The agents are grouped and assigned
8 | to the multi-agent QMIX policy. Note that the default hyperparameters for
9 | RLlib QMIX are different from pymarl's QMIX.
10 | """
11 |
12 | import argparse
13 | from gym.spaces import Tuple
14 |
15 | import ray
16 | from ray.tune import run_experiments, register_env
17 |
18 | from smac.examples.rllib.env import RLlibStarCraft2Env
19 |
20 |
21 | if __name__ == "__main__":
22 | parser = argparse.ArgumentParser()
23 | parser.add_argument("--num-iters", type=int, default=100)
24 | parser.add_argument("--num-workers", type=int, default=2)
25 | parser.add_argument("--map-name", type=str, default="8m")
26 | args = parser.parse_args()
27 |
28 | def env_creator(smac_args):
29 | env = RLlibStarCraft2Env(**smac_args)
30 | agent_list = list(range(env._env.n_agents))
31 | grouping = {
32 | "group_1": agent_list,
33 | }
34 | obs_space = Tuple([env.observation_space for i in agent_list])
35 | act_space = Tuple([env.action_space for i in agent_list])
36 | return env.with_agent_groups(
37 | grouping, obs_space=obs_space, act_space=act_space
38 | )
39 |
40 | ray.init()
41 | register_env("sc2_grouped", env_creator)
42 |
43 | run_experiments(
44 | {
45 | "qmix_sc2": {
46 | "run": "QMIX",
47 | "env": "sc2_grouped",
48 | "stop": {
49 | "training_iteration": args.num_iters,
50 | },
51 | "config": {
52 | "num_workers": args.num_workers,
53 | "env_config": {
54 | "map_name": args.map_name,
55 | },
56 | },
57 | },
58 | }
59 | )
60 |
--------------------------------------------------------------------------------