├── .gitignore ├── .pre-commit-config.yaml ├── LICENSE ├── README.md ├── RUNNING_EXPERIMENTS.md ├── docs ├── imgs │ ├── reflect.png │ └── surrounded.png └── smac.md ├── pyproject.toml ├── setup.py └── smacv2 ├── __init__.py ├── bin ├── __init__.py └── map_list.py ├── env ├── __init__.py ├── multiagentenv.py ├── pettingzoo │ ├── StarCraft2PZEnv.py │ ├── __init__.py │ └── test │ │ ├── __init__.py │ │ ├── all_test.py │ │ └── smac_pettingzoo_test.py └── starcraft2 │ ├── __init__.py │ ├── distributions.py │ ├── maps │ ├── SMAC_Maps │ │ ├── 10gen_empty.SC2Map │ │ ├── 10gen_protoss.SC2Map │ │ ├── 10gen_terran.SC2Map │ │ ├── 10gen_zerg.SC2Map │ │ ├── 10m_vs_11m.SC2Map │ │ ├── 1c3s5z.SC2Map │ │ ├── 25m.SC2Map │ │ ├── 27m_vs_30m.SC2Map │ │ ├── 2c_vs_64zg.SC2Map │ │ ├── 2m_vs_1z.SC2Map │ │ ├── 2s3z.SC2Map │ │ ├── 2s_vs_1sc.SC2Map │ │ ├── 32x32_flat.SC2Map │ │ ├── 32x32_flat_test.SC2Map │ │ ├── 32x32_small.SC2Map │ │ ├── 3m.SC2Map │ │ ├── 3s5z.SC2Map │ │ ├── 3s5z_vs_3s6z.SC2Map │ │ ├── 3s_vs_3z.SC2Map │ │ ├── 3s_vs_4z.SC2Map │ │ ├── 3s_vs_5z.SC2Map │ │ ├── 5m_vs_6m.SC2Map │ │ ├── 6h_vs_8z.SC2Map │ │ ├── 8m.SC2Map │ │ ├── 8m_vs_9m.SC2Map │ │ ├── MMM.SC2Map │ │ ├── MMM2.SC2Map │ │ ├── SMAC_Maps.zip │ │ ├── bane_vs_bane.SC2Map │ │ ├── corridor.SC2Map │ │ └── so_many_baneling.SC2Map │ ├── __init__.py │ └── smac_maps.py │ ├── render.py │ ├── starcraft2.py │ └── wrapper.py └── examples ├── __init__.py ├── configs ├── sc2_gen_protoss.yaml ├── sc2_gen_protoss_epo.yaml ├── sc2_gen_terran.yaml ├── sc2_gen_terran_epo.yaml ├── sc2_gen_zerg.yaml └── sc2_gen_zerg_epo.yaml ├── pettingzoo ├── README.rst ├── __init__.py └── pettingzoo_demo.py ├── random_agents.py ├── results └── smac2_training_results.pkl └── rllib ├── README.rst ├── __init__.py ├── env.py ├── model.py ├── run_ppo.py └── run_qmix.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | .hypothesis/ 50 | .pytest_cache/ 51 | cover/ 52 | 53 | # Translations 54 | *.mo 55 | *.pot 56 | 57 | # Django stuff: 58 | *.log 59 | local_settings.py 60 | db.sqlite3 61 | db.sqlite3-journal 62 | 63 | # Flask stuff: 64 | instance/ 65 | .webassets-cache 66 | 67 | # Scrapy stuff: 68 | .scrapy 69 | 70 | # Sphinx documentation 71 | docs/_build/ 72 | 73 | # PyBuilder 74 | .pybuilder/ 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | # For a library or package, you might want to ignore these files since the code is 86 | # intended to run in multiple environments; otherwise, check them in: 87 | # .python-version 88 | 89 | # pipenv 90 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 91 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 92 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 93 | # install all needed dependencies. 94 | #Pipfile.lock 95 | 96 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 97 | __pypackages__/ 98 | 99 | # Celery stuff 100 | celerybeat-schedule 101 | celerybeat.pid 102 | 103 | # SageMath parsed files 104 | *.sage.py 105 | 106 | # Environments 107 | .env 108 | .venv 109 | # env/ 110 | venv/ 111 | env.bak/ 112 | venv.bak/ 113 | 114 | # Spyder project settings 115 | .spyderproject 116 | .spyproject 117 | 118 | # Rope project settings 119 | .ropeproject 120 | 121 | # mkdocs documentation 122 | /site 123 | 124 | # mypy 125 | .mypy_cache/ 126 | .dmypy.json 127 | dmypy.json 128 | 129 | # Pyre type checker 130 | .pyre/ 131 | 132 | # pytype static type analyzer 133 | .pytype/ 134 | 135 | # Cython debug symbols 136 | cython_debug/ 137 | 138 | .DS_Store 139 | .idea/ 140 | 141 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/ambv/black 3 | rev: stable 4 | hooks: 5 | - id: black 6 | language_version: python3.8 7 | - repo: https://gitlab.com/pycqa/flake8 8 | rev: '3.8.4' 9 | hooks: 10 | - id: flake8 11 | additional_dependencies: [flake8-bugbear] 12 | args: ["--show-source"] 13 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 whirl 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SMACv2 Documentation 2 | 3 | # Introduction 4 | 5 | SMACv2 is an update to [Whirl’s](https://whirl.cs.ox.ac.uk/) [Starcraft Multi-Agent Challenge](https://github.com/oxwhirl/smac), which is a benchmark for research in the field of cooperative multi-agent reinforcement learning. SMAC and SMACv2 both focus on decentralised micromanagement scenarios in [StarCraft II](https://starcraft2.com/en-gb/), rather than the full game. It makes use of Blizzard’s StarCraft II Machine Learning API as well as Deepmind’s PySC2. We hope that you will enjoy using SMACv2! More details about SMAC can be found in the [SMAC README](https://github.com/oxwhirl/smac/blob/master/README.md) as well as the [SMAC paper](https://arxiv.org/abs/1902.04043). **SMAC retains exactly the same API as SMAC so you should not need to change your algorithm code other than adjusting to the new observation and state size**. 6 | 7 | If you encounter difficulties using SMACv2, or have suggestions please raise an issue, or better yet, open a pull request! 8 | 9 | The aim of this README is to answer some basic technical questions and to get people started with SMACv2. For a more scientific account of the work of developing the benchmark, please read [SMACv2 paper](https://arxiv.org/abs/2212.07489). Videos of learned policies are available on [our website](https://sites.google.com/view/smacv2). 10 | 11 | # Differences To SMAC 12 | 13 | SMACv2 makes three major changes to SMACv2: randomising start positions, randomising unit types, and changing the unit sight and attack ranges. These first two changes were motivated by the discovery that many maps in SMAC lack enough randomness to challenge contemporary MARL algorithms. The final change increases diversity among the different agents and brings the sight range in line with the true values in StarCraft. For more details on the motivation behind these changes, please check the accompanying paper, where these are discussed in much more detail! 14 | 15 | ## Capability Config 16 | 17 | All the procedurally generated content in SMACv2 is managed through the **Capability Config.** This describes what units are generated and in what positions. The presence of keys in this config tells SMACv2 that a certain environment component is generated or not. As an example, consider the below config: 18 | 19 | ```yaml 20 | capability_config: 21 | n_units: 5 22 | team_gen: 23 | dist_type: "weighted_teams" 24 | unit_types: 25 | - "marine" 26 | - "marauder" 27 | - "medivac" 28 | weights: 29 | - 0.45 30 | - 0.45 31 | - 0.1 32 | exception_unit_types: 33 | - "medivac" 34 | observe: True 35 | 36 | start_positions: 37 | dist_type: "surrounded_and_reflect" 38 | p: 0.5 39 | n_enemies: 5 40 | map_x: 32 41 | map_y: 32 42 | ``` 43 | 44 | This config is the default config for the SMACv2 Terran scenarios. The `start_positions` key tells SMACv2 to randomly generate start positions. Similarly the `team_gen` key tells SMACv2 to randomly generate teams. The `dist_type` tells SMACv2 **how** to generate some content. For example, team generation has the key `weighted_teams` , where each unit type is spawned with a certain weight. In this case a Stalker is spawned with probability `0.45` for example. Don’t worry too much about the other options for now — they are distribution-specific. 45 | 46 | All the distributions are implemented in the [distributions.py](https://github.com/oxwhirl/smacv2/blob/main/smac/env/starcraft2/distributions.py) file. We encourage users to contribute their own keys and distributions for procedurally generated content! 47 | 48 | ## Random Start Positions 49 | 50 | Random start positions come in two different types. First, there is the `surround` type, where the allied units are spawned in the middle of the map, and surrounded by enemy units. An example is shown below. 51 | 52 |

53 | 54 |

55 | 56 | This challenges the allied units to overcome the enemies approach from multiple angles at once. Secondly, there are the `reflect` scenarios. These randomly select positions for the allied units, and then reflect their positions in the midpoint of the map to get the enemy spawn positions. For example see the image below. 57 | 58 |

59 | 60 |

61 | 62 | 63 | The probability of one type of scenario or the other is controlled with the `p` setting in the capability config. The cones are not visible in the above screenshot because they have not spawned in yet. 64 | 65 | ## Random Unit Types 66 | 67 | Battles in SMACv2 do not always feature units of the same type each time, as they did in SMAC. Instead, units are spawned randomly according to certain pre-fixed probabilities. Units in StarCraft II are split up into different *races.* Units from different races cannot be on the same team. For each of the three races (Protoss, Terran, and Zerg), SMACv2 uses three unit types. 68 | 69 | | Race | Unit | Generation Probability | 70 | | --- | --- | --- | 71 | | Terran | Marine | 0.45 | 72 | | | Marauder | 0.45 | 73 | | | Medivac | 0.1 | 74 | | Protoss | Stalker | 0.45 | 75 | | | Zealot | 0.45 | 76 | | | Colossus | 0.1 | 77 | | Zerg | Zergling | 0.45 | 78 | | | Hydralisk | 0.45 | 79 | | | Baneling | 0.1 | 80 | 81 | Each race has a unit that is generated less often than the others. These are for different reasons. Medivacs are healing-only units and so an abundance of them leads to strange, very long scenarios. Colossi are very powerful units and over-generating them leads to battles being solely determined by colossus use. Banelings are units that explode. If they are too prevalent, the algorithm learns to hide in the corner and hope the enemies all explode! 82 | 83 | These weights are all controllable via the `capability_config` . However, if you do decide to change them we recommend that you do some tests to check that the scenarios you have made are sensible! Weights changes can sometimes have unexpected consequences. 84 | 85 | # Getting Started 86 | 87 | This section will take you through the basic set-up of SMACv2. The set-up process has changed very little from the process for SMAC, so if you are familiar with that, follow the steps as you usually would. Make sure you have the `32x32_flat.SC2Map` map file in your `SMAC_Maps` folder. You can download the `SMAC_Maps` folder [here](https://github.com/oxwhirl/smacv2/releases/tag/maps#:~:text=3-,SMAC_Maps.zip,-503%20KB). 88 | 89 | First, you will need to install StarCraft II. On windows or mac, follow the instructions on the [StarCraft website](https://starcraft2.com/en-gb/). For linux, you can use the bash script [here](https://github.com/benellis3/mappo/blob/main/install_sc2.sh). Then copy 90 | 91 | Then simply install SMAC as a package: 92 | 93 | ```bash 94 | pip install git+https://github.com/oxwhirl/smacv2.git 95 | ``` 96 | 97 | [NOTE]: If you want to extend SMACv2, you must install it like this: 98 | 99 | ```bash 100 | git clone https://github.com/oxwhirl/smacv2.git 101 | cd smacv2 102 | pip install -e ".[dev]" 103 | pre-commit install 104 | ``` 105 | 106 | If you tried these instructions and couldn’t get SMACv2 to work, please let us know by raising an issue. 107 | 108 | We also added configs for the protoss, terran and zerg configs to the [examples folder](https://github.com/oxwhirl/smacv2/tree/main/smacv2/examples/configs). Note that you will have to change the `n_units` and `n_enemies` config to access the different scenarios. 109 | For clarity, the correct settings are in the table below, but the first number in the scenario name is the number of allies (`n_units`) 110 | and the second is the number of enemies (`n_enemies`). 111 | 112 | | Scenario | Config File | `n_units` | `n_enemies` | 113 | |--------------------|----------------------|------------|-------------| 114 | | `protoss_5_vs_5` | sc2_gen_protoss.yaml | 5 | 5 | 115 | | `zerg_5_vs_5` | sc2_gen_zerg.yaml | 5 | 5 | 116 | | `terran_5_vs_5` | sc2_gen_terran.yaml | 5 | 5 | 117 | | `protoss_10_vs_10` | sc2_gen_protoss.yaml | 10 | 10 | 118 | | `zerg_10_vs_10` | sc2_gen_zerg.yaml | 10 | 10 | 119 | | `terran_10_vs_10` | sc2_gen_terran.yaml | 10 | 10 | 120 | | `protoss_20_vs_20` | sc2_gen_protoss.yaml | 20 | 20 | 121 | | `zerg_20_vs_20` | sc2_gen_zerg.yaml | 20 | 20 | 122 | | `terran_20_vs_20` | sc2_gen_terran.yaml | 20 | 20 | 123 | | `protoss_10_vs_11` | sc2_gen_protoss.yaml | 10 | 11 | 124 | | `zerg_10_vs_11` | sc2_gen_zerg.yaml | 10 | 11 | 125 | | `terran_10_vs_11` | sc2_gen_terran.yaml | 10 | 11 | 126 | | `protoss_20_vs_23` | sc2_gen_protoss.yaml | 20 | 23 | 127 | | `zerg_20_vs_23` | sc2_gen_zerg.yaml | 20 | 23 | 128 | | `terran_20_vs_23` | sc2_gen_terran.yaml | 20 | 23 | 129 | 130 | # Training Results 131 | 132 | The smacv2 repo contains the [results](https://github.com/oxwhirl/smacv2/tree/main/smacv2/examples/results) of MAPPO and QMIX baselines that you can compare now. Please 133 | ensure that you are using the correct version of starcraft as otherwise your results will not be 134 | comparable. Using the `install_sc2.sh` in the [mappo](https://github.com/benellis3/mappo/blob/main/install_sc2.sh) repo for example will ensure this. 135 | 136 | # Modifying SMACv2 137 | 138 | SMACv2 procedurally generates some content. We encourage everyone to modify and expand upon the procedurally generated content in SMACv2. 139 | 140 | Procedurally generated content conceptually has two parts: a distribution and an implementation. The implementation part lives in the [starcraft2.py](https://github.com/oxwhirl/smacv2/blob/main/smac/env/starcraft2/starcraft2.py) file and should handle actually generating whatever content is required (e.g. the spawning units at the correct start positions) using the StarCraft APIs given a config passed in at the start of the episode to the `reset` function. 141 | 142 | The second part is the distribution. These live in [distributions.py](https://github.com/oxwhirl/smacv2/blob/main/smac/env/starcraft2/distributions.py) and specify the distribution the content is generated according to. For example start positions might be generated randomly across the whole map. The `distributions.py` file contains a few examples of distributions for the already implemented generated content in SMAC. 143 | 144 | # Code Example 145 | 146 | SMACv2 follows the same API as SMAC and so can be used exactly the same way. As an example, the below code allows individual agents to execute random policies. The config corresponds to the 5 unit Terran map from SMACv2. 147 | 148 | ```python 149 | from __future__ import absolute_import 150 | from __future__ import division 151 | from __future__ import print_function 152 | from os import replace 153 | 154 | from smacv2.env import StarCraft2Env 155 | import numpy as np 156 | from absl import logging 157 | import time 158 | 159 | from smacv2.env.starcraft2.wrapper import StarCraftCapabilityEnvWrapper 160 | 161 | logging.set_verbosity(logging.DEBUG) 162 | 163 | def main(): 164 | 165 | distribution_config = { 166 | "n_units": 5, 167 | "n_enemies": 5, 168 | "team_gen": { 169 | "dist_type": "weighted_teams", 170 | "unit_types": ["marine", "marauder", "medivac"], 171 | "exception_unit_types": ["medivac"], 172 | "weights": [0.45, 0.45, 0.1], 173 | "observe": True, 174 | }, 175 | "start_positions": { 176 | "dist_type": "surrounded_and_reflect", 177 | "p": 0.5, 178 | "n_enemies": 5, 179 | "map_x": 32, 180 | "map_y": 32, 181 | }, 182 | } 183 | env = StarCraftCapabilityEnvWrapper( 184 | capability_config=distribution_config, 185 | map_name="10gen_terran", 186 | debug=True, 187 | conic_fov=False, 188 | obs_own_pos=True, 189 | use_unit_ranges=True, 190 | min_attack_range=2, 191 | ) 192 | 193 | env_info = env.get_env_info() 194 | 195 | n_actions = env_info["n_actions"] 196 | n_agents = env_info["n_agents"] 197 | 198 | n_episodes = 10 199 | 200 | print("Training episodes") 201 | for e in range(n_episodes): 202 | env.reset() 203 | terminated = False 204 | episode_reward = 0 205 | 206 | while not terminated: 207 | obs = env.get_obs() 208 | state = env.get_state() 209 | # env.render() # Uncomment for rendering 210 | 211 | actions = [] 212 | for agent_id in range(n_agents): 213 | avail_actions = env.get_avail_agent_actions(agent_id) 214 | avail_actions_ind = np.nonzero(avail_actions)[0] 215 | action = np.random.choice(avail_actions_ind) 216 | actions.append(action) 217 | 218 | reward, terminated, _ = env.step(actions) 219 | time.sleep(0.15) 220 | episode_reward += reward 221 | print("Total reward in episode {} = {}".format(e, episode_reward)) 222 | 223 | if __name__ == "__main__": 224 | main() 225 | ``` 226 | 227 | # Citation 228 | If you use SMACv2 in your work, please cite: 229 | 230 | ``` 231 | @inproceedings{ellis2023smacv2, 232 | title={{SMAC}v2: An Improved Benchmark for Cooperative Multi-Agent Reinforcement Learning}, 233 | author={Benjamin Ellis and Jonathan Cook and Skander Moalla and Mikayel Samvelyan and Mingfei Sun and Anuj Mahajan and Jakob Nicolaus Foerster and Shimon Whiteson}, 234 | booktitle={Thirty-seventh Conference on Neural Information Processing Systems Datasets and Benchmarks Track}, 235 | year={2023}, 236 | url={https://openreview.net/forum?id=5OjLGiJW3u} 237 | } 238 | ``` 239 | 240 | # FAQ 241 | 242 | ### Why do SMAC maps not work in SMACv2? 243 | 244 | For now, SMAC is not backwards compatible with old SMAC maps, although we will implement this if there is enough demand. 245 | 246 | # Questions/Comments 247 | 248 | If you have any questions or suggestions either raise an issue in this repo or email [Ben Ellis](mailto:benellis@robots.ox.ac.uk) and we will try our 249 | best to answer your query. 250 | -------------------------------------------------------------------------------- /RUNNING_EXPERIMENTS.md: -------------------------------------------------------------------------------- 1 | # Reproducing SMACv2 Experiments 2 | 3 | Logging of the experiments is via [wandb.ai](https://wandb.ai/). Before trying to replicate the experiments here, you will have to create a WandB account. If you have any questions about these instructions please [raise an issue](https://github.com/oxwhirl/smacv2/issues/new/choose). 4 | 5 | To make WandB work, you will need to copy your wandb api key. You can do this by going to your image on the top right > user settings > danger zone > API keys. Copy this and put it into a file. Set the location of this file in an environment variable called `WANDB_API_KEY_FILE`. Use, for example, `export WANDB_API_KEY_FILE=$HOME/.wandb_api_key`. 6 | 7 | # Running SMACv2 Baselines 8 | 9 | ## QMIX 10 | 11 | 0. Make sure you have correctly set the WANDB_API_KEY_FILE environment variable mentioned in the introduction. 12 | 1. Clone the pymarl2 repository: 13 | ```git clone https://github.com/benellis3/pymarl2.git``` 14 | 2. Build the Docker container by running `docker build -t pymarl2:ben_smac -f docker/Dockerfile --build-arg UID=$UID .` from the pymarl2 directory. 15 | 3. Install Starcraft by running `./install_sc2.sh` in the pymarl2 directory. 16 | 4. Navigate to `src/config/default.yaml` and set `project` and `entity` to your desired project name (can be anything) and your wandb username respectively. 17 | 5. Set `td_lambdas` in `run_exp.sh` (line 21) to be `0.4` and `eps_anneal` to `100000`. 18 | 6. Run `./run_exp.sh qmix ` where `` is a word to help you identify the experiments. If you want to run the `10_vs_11` or `20_vs_23` scenarios, you will have to ensure the `./run_docker.sh` command on line 46 has `n_units` and `n_enemies` set correctly. For example for the `20_vs_23` scenario you would set `n_units=20` and `n_enemies=23`. 19 | 20 | ## MAPPO 21 | 22 | 0. Make sure you have correctly set the WANDB_API_KEY_FILE environment variable mentioned in the introduction. 23 | 1. Clone the MAPPO repository: 24 | ```git clone https://github.com/benellis3/mappo.git``` 25 | 2. Build the docker container by running `build.sh` in the `docker directory` 26 | 3. Install Starcraft by running `./install_sc2.sh` in the mappo directory. 27 | 4. Navigate to `src/config/default.yaml` and set `project` and `entity` to your desired project name (can be anything) and your wandb username respectively. 28 | 5. Set `lr` in `run.sh` to `0.0005` and `clip_range` to `0.1`. If you want to run the closed-loop baseline, change `maps` to only contain maps *without* `open_loop` in their name. For the open-loop baseline, do the opposite, i.e. keep all the map names with `open_loop` in them and delete the rest. 29 | 6. Run `./run.sh clipping_rnn_central_V ` where `` is a word to help you identify the experiments. If you want to run the `10_vs_11` or `20_vs_23` scenarios, you will have to set `offset` on line 21 of the script. This controls how many more enemies there are than allies. 30 | 31 | # Running EPO Baselines 32 | 33 | ## QMIX 34 | 35 | 1. Complete steps 0-4 of Running SMACv2 Baselines (QMIX), making sure to use `run_exp_epo.sh` where `run_exp.sh` is mentioned. 36 | 2. Run `./run_exp_epo.sh qmix ` where `` is a word to help you identify the experiments. 37 | 38 | ## MAPPO 39 | 40 | 1. Complete steps 0-4 of Running SMACv2 Baselines (MAPPO), making sure to use `run_exp_epo.sh`. where `run_exp.sh` is mentioned. 41 | 2. Run `./run_exp_epo.sh mappo ` where `` is a word to help you identify the experiments. 42 | 43 | # Running Open-Loop SMAC baselines 44 | 45 | ## MAPPO 46 | 47 | 1. Follow steps 0 and 1 of Running SMACv2 Baselines (MAPPO). 48 | 2. Checkout the `stochastic-experiment` branch: 49 | ```git checkout stochastic-experiment``` 50 | 3. Install Starcraft by running `./install_sc2.sh` in the mappo directory. 51 | 4. Build the docker container by running `./build.sh` in the `docker` directory. 52 | 5. Navigate to `src/config/default.yaml` and set `project` and `entity` to your desired project name (can be anything) and your wandb username respectively. 53 | 6. Run `./run_exp.sh clipping_rnn_central_V ` where `` is a word to help you identify experiments. 54 | 55 | ## QMIX 56 | 57 | 1. Follow steps 0 and 1 of Running SMACv2 Baselines (QMIX). 58 | 2. Checkout the `stochastic_test` branch: 59 | ```git checkout stochastic_test``` 60 | 3. Install Starcraft by running `./install_sc2.sh` in the pymarl2 directory. 61 | 4. Build the Docker container by running `./build.sh` in the 62 | `docker` directory. 63 | 6. Navigate to `src/config/default.yaml` and set `project` and `entity` to your desired project name (can be anything) and your wandb username respectively. 64 | 7. Run `./run.sh qmix ` where `` is a word to help you identify experiments. 65 | 66 | # Running Q-Regression Experiments 67 | 68 | See the README in the [pymarl2 repo](https://github.com/benellis3/pymarl2/tree/smacv2-feature-inferrability) 69 | -------------------------------------------------------------------------------- /docs/imgs/reflect.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/oxwhirl/smacv2/577ab5a2cff2391f8df582da5731ea9cd6adf3c6/docs/imgs/reflect.png -------------------------------------------------------------------------------- /docs/imgs/surrounded.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/oxwhirl/smacv2/577ab5a2cff2391f8df582da5731ea9cd6adf3c6/docs/imgs/surrounded.png -------------------------------------------------------------------------------- /docs/smac.md: -------------------------------------------------------------------------------- 1 | ## Table of Contents 2 | 3 | - [StarCraft II](#starcraft-ii) 4 | - [Micromanagement](#micromanagement) 5 | - [SMAC](#smac) 6 | - [Scenarios](#scenarios) 7 | - [State and Observations](#state-and-observations) 8 | - [Action Space](#action-space) 9 | - [Rewards](#rewards) 10 | - [Environment Settings](#environment-settings) 11 | 12 | ## StarCraft II 13 | 14 | SMAC is based on the popular real-time strategy (RTS) game [StarCraft II](http://us.battle.net/sc2/en/game/guide/whats-sc2) written by [Blizzard](http://blizzard.com/). 15 | In a regular full game of StarCraft II, one or more humans compete against each other or against a built-in game AI to gather resources, construct buildings, and build armies of units to defeat their opponents. 16 | 17 | Akin to most RTSs, StarCraft has two main gameplay components: macromanagement and micromanagement. 18 | - _Macromanagement_ (macro) refers to high-level strategic considerations, such as economy and resource management. 19 | - _Micromanagement_ (micro) refers to fine-grained control of individual units. 20 | 21 | ### Micromanagement 22 | 23 | StarCraft has been used as a research platform for AI, and more recently, RL. Typically, the game is framed as a competitive problem: an agent takes the role of a human player, making macromanagement decisions and performing micromanagement as a puppeteer that issues orders to individual units from a centralised controller. 24 | 25 | In order to build a rich multi-agent testbed, we instead focus solely on micromanagement. 26 | Micro is a vital aspect of StarCraft gameplay with a high skill ceiling, and is practiced in isolation by amateur and professional players. 27 | For SMAC, we leverage the natural multi-agent structure of micromanagement by proposing a modified version of the problem designed specifically for decentralised control. 28 | In particular, we require that each unit be controlled by an independent agent that conditions only on local observations restricted to a limited field of view centred on that unit. 29 | Groups of these agents must be trained to solve challenging combat scenarios, battling an opposing army under the centralised control of the game's built-in scripted AI. 30 | 31 | Proper micro of units during battles will maximise the damage dealt to enemy units while minimising damage received, and requires a range of skills. 32 | For example, one important technique is _focus fire_, i.e., ordering units to jointly attack and kill enemy units one after another. When focusing fire, it is important to avoid _overkill_: inflicting more damage to units than is necessary to kill them. 33 | 34 | Other common micromanagement techniques include: assembling units into formations based on their armour types, making enemy units give chase while maintaining enough distance so that little or no damage is incurred (_kiting_), coordinating the positioning of units to attack from different directions or taking advantage of the terrain to defeat the enemy. 35 | 36 | Learning these rich cooperative behaviours under partial observability is challenging task, which can be used to evaluate the effectiveness of multi-agent reinforcement learning (MARL) algorithms. 37 | 38 | ## SMAC 39 | 40 | SMAC uses the [StarCraft II Learning Environment](https://github.com/deepmind/pysc2) to introduce a cooperative MARL environment. 41 | 42 | ### Scenarios 43 | 44 | SMAC consists of a set of StarCraft II micro scenarios which aim to evaluate how well independent agents are able to learn coordination to solve complex tasks. 45 | These scenarios are carefully designed to necessitate the learning of one or more micromanagement techniques to defeat the enemy. 46 | Each scenario is a confrontation between two armies of units. 47 | The initial position, number, and type of units in each army varies from scenario to scenario, as does the presence or absence of elevated or impassable terrain. 48 | 49 | The first army is controlled by the learned allied agents. 50 | The second army consists of enemy units controlled by the built-in game AI, which uses carefully handcrafted non-learned heuristics. 51 | At the beginning of each episode, the game AI instructs its units to attack the allied agents using its scripted strategies. 52 | An episode ends when all units of either army have died or when a pre-specified time limit is reached (in which case the game is counted as a defeat for the allied agents). 53 | The goal for each scenario is to maximise the win rate of the learned policies, i.e., the expected ratio of games won to games played. 54 | To speed up learning, the enemy AI units are ordered to attack the agents' spawning point in the beginning of each episode. 55 | 56 | Perhaps the simplest scenarios are _symmetric_ battle scenarios. 57 | The most straightforward of these scenarios are _homogeneous_, i.e., each army is composed of only a single unit type (e.g., Marines). 58 | A winning strategy in this setting would be to focus fire, ideally without overkill. 59 | _Heterogeneous_ symmetric scenarios, in which there is more than a single unit type on each side (e.g., Stalkers and Zealots), are more difficult to solve. 60 | Such challenges are particularly interesting when some of the units are extremely effective against others (this is known as _countering_), for example, by dealing bonus damage to a particular armour type. 61 | In such a setting, allied agents must deduce this property of the game and design an intelligent strategy to protect teammates vulnerable to certain enemy attacks. 62 | 63 | SMAC also includes more challenging scenarios, for example, in which the enemy army outnumbers the allied army by one or more units. In such _asymmetric_ scenarios it is essential to consider the health of enemy units in order to effectively target the desired opponent. 64 | 65 | Lastly, SMAC offers a set of interesting _micro-trick_ challenges that require a higher-level of cooperation and a specific micromanagement trick to defeat the enemy. 66 | An example of a challenge scenario is _2m_vs_1z_ (aka Marine Double Team), where two Terran Marines need to defeat an enemy Zealot. In this setting, the Marines must design a strategy which does not allow the Zealot to reach them, otherwise they will die almost immediately. 67 | Another example is _so_many_banelings_ where 7 allied Zealots face 32 enemy Baneling units. Banelings attack by running against a target and exploding when reaching them, causing damage to a certain area around the target. Hence, if a large number of Banelings attack a handful of Zealots located close to each other, the Zealots will be defeated instantly. The optimal strategy, therefore, is to cooperatively spread out around the map far from each other so that the Banelings' damage is distributed as thinly as possible. 68 | The _corridor_ scenario, in which 6 friendly Zealots face 24 enemy Zerglings, requires agents to make effective use of the terrain features. Specifically, agents should collectively wall off the choke point (the narrow region of the map) to block enemy attacks from different directions. Some of the micro-trick challenges are inspired by [StarCraft Master](http://us.battle.net/sc2/en/blog/4544189/new-blizzard-custom-game-starcraft-master-3-1-2012) challenge missions released by Blizzard. 69 | 70 | The complete list of challenges is presented bellow. The difficulty of the game AI is set to _very difficult_ (7). Our experiments, however, suggest that this setting does significantly impact the unit micromanagement of the built-in heuristics. 71 | 72 | | Name | Ally Units | Enemy Units | Type | 73 | | :---: | :---: | :---: | :---:| 74 | | 3m | 3 Marines | 3 Marines | homogeneous & symmetric | 75 | | 8m | 8 Marines | 8 Marines | homogeneous & symmetric | 76 | | 25m | 25 Marines | 25 Marines | homogeneous & symmetric | 77 | | 2s3z | 2 Stalkers & 3 Zealots | 2 Stalkers & 3 Zealots | heterogeneous & symmetric | 78 | | 3s5z | 3 Stalkers & 5 Zealots | 3 Stalkers & 5 Zealots | heterogeneous & symmetric | 79 | | MMM | 1 Medivac, 2 Marauders & 7 Marines | 1 Medivac, 2 Marauders & 7 Marines | heterogeneous & symmetric | 80 | | 5m_vs_6m | 5 Marines | 6 Marines | homogeneous & asymmetric | 81 | | 8m_vs_9m | 8 Marines | 9 Marines | homogeneous & asymmetric | 82 | | 10m_vs_11m | 10 Marines | 11 Marines | homogeneous & asymmetric | 83 | | 27m_vs_30m | 27 Marines | 30 Marines | homogeneous & asymmetric | 84 | | 3s5z_vs_3s6z | 3 Stalkers & 5 Zealots | 3 Stalkers & 6 Zealots | heterogeneous & asymmetric | 85 | | MMM2 | 1 Medivac, 2 Marauders & 7 Marines | 1 Medivac, 3 Marauders & 8 Marines | heterogeneous & asymmetric | 86 | | 2m_vs_1z | 2 Marines | 1 Zealot | micro-trick: alternating fire | 87 | | 2s_vs_1sc| 2 Stalkers | 1 Spine Crawler | micro-trick: alternating fire | 88 | | 3s_vs_3z | 3 Stalkers | 3 Zealots | micro-trick: kiting | 89 | | 3s_vs_4z | 3 Stalkers | 4 Zealots | micro-trick: kiting | 90 | | 3s_vs_5z | 3 Stalkers | 5 Zealots | micro-trick: kiting | 91 | | 6h_vs_8z | 6 Hydralisks | 8 Zealots | micro-trick: focus fire | 92 | | corridor | 6 Zealots | 24 Zerglings | micro-trick: wall off | 93 | | bane_vs_bane | 20 Zerglings & 4 Banelings | 20 Zerglings & 4 Banelings | micro-trick: positioning | 94 | | so_many_banelings| 7 Zealots | 32 Banelings | micro-trick: positioning | 95 | | 2c_vs_64zg| 2 Colossi | 64 Zerglings | micro-trick: positioning | 96 | | 1c3s5z | 1 Colossi & 3 Stalkers & 5 Zealots | 1 Colossi & 3 Stalkers & 5 Zealots | heterogeneous & symmetric | 97 | 98 | ### State and Observations 99 | 100 | At each timestep, agents receive local observations drawn within their field of view. This encompasses information about the map within a circular area around each unit and with a radius equal to the _sight range_. The sight range makes the environment partially observable from the standpoint of each agent. Agents can only observe other agents if they are both alive and located within the sight range. Hence, there is no way for agents to determine whether their teammates are far away or dead. 101 | 102 | The feature vector observed by each agent contains the following attributes for both allied and enemy units within the sight range: _distance_, _relative x_, _relative y_, _health_, _shield_, and _unit\_type_ [1](#myfootnote1). Shields serve as an additional source of protection that needs to be removed before any damage can be done to the health of units. 103 | All Protos units have shields, which can regenerate if no new damage is dealt 104 | (units of the other two races do not have this attribute). 105 | In addition, agents have access to the last actions of allied units that are in the field of view. Lastly, agents can observe the terrain features surrounding them; particularly, the values of eight points at a fixed radius indicating height and walkability. 106 | 107 | The global state, which is only available to agents during centralised training, contains information about all units on the map. Specifically, the state vector includes the coordinates of all agents relative to the centre of the map, together with unit features present in the observations. Additionally, the state stores the _energy_ of Medivacs and _cooldown_ of the rest of allied units, which represents the minimum delay between attacks. Finally, the last actions of all agents are attached to the central state. 108 | 109 | All features, both in the state as well as in the observations of individual agents, are normalised by their maximum values. The sight range is set to 9 for all agents. 110 | 111 | ### Action Space 112 | 113 | The discrete set of actions which agents are allowed to take consists of _move[direction]_ (four directions: north, south, east, or west), _attack[enemy_id]_, _stop_ and _no-op_. Dead agents can only take _no-op_ action while live agents cannot. 114 | As healer units, Medivacs must use _heal[agent\_id]_ actions instead of _attack[enemy\_id]_. The maximum number of actions an agent can take ranges between 7 and 70, depending on the scenario. 115 | 116 | To ensure decentralisation of the task, agents are restricted to use the _attack[enemy\_id]_ action only towards enemies in their _shooting range_. 117 | This additionally constrains the unit's ability to use the built-in _attack-move_ macro-actions on the enemies far away. We set the shooting range equal to 6. Having a larger sight than shooting range forces agents to make use of move commands before starting to fire. 118 | 119 | ### Rewards 120 | 121 | The overall goal is to have the highest win rate for each battle scenario. 122 | We provide a corresponding option for _sparse rewards_, which will cause the environment to return only a reward of +1 for winning and -1 for losing an episode. 123 | However, we also provide a default setting for a shaped reward signal calculated from the hit-point damage dealt and received by agents, some positive (negative) reward after having enemy (allied) units killed and/or a positive (negative) bonus for winning (losing) the battle. 124 | The exact values and scales of this shaped reward can be configured using a range of flags, but we strongly discourage disingenuous engineering of the reward function (e.g. tuning different reward functions for different scenarios). 125 | 126 | ### Environment Settings 127 | 128 | SMAC makes use of the [StarCraft II Learning Environment](https://arxiv.org/abs/1708.04782) (SC2LE) to communicate with the StarCraft II engine. SC2LE provides full control of the game by allowing to send commands and receive observations from the game. However, SMAC is conceptually different from the RL environment of SC2LE. The goal of SC2LE is to learn to play the full game of StarCraft II. This is a competitive task where a centralised RL agent receives RGB pixels as input and performs both macro and micro with the player-level control similar to human players. SMAC, on the other hand, represents a set of cooperative multi-agent micro challenges where each learning agent controls a single military unit. 129 | 130 | SMAC uses the _raw API_ of SC2LE. Raw API observations do not have any graphical component and include information about the units on the map such as health, location coordinates, etc. The raw API also allows sending action commands to individual units using their unit IDs. This setting differs from how humans play the actual game, but is convenient for designing decentralised multi-agent learning tasks. 131 | 132 | Since our micro-scenarios are shorter than actual StarCraft II games, restarting the game after each episode presents a computational bottleneck. To overcome this issue, we make use of the API's debug commands. Specifically, when all units of either army have been killed, we kill all remaining units by sending a debug action. Having no units left launches a trigger programmed with the StarCraft II Editor that re-spawns all units in their original location with full health, thereby restarting the scenario quickly and efficiently. 133 | 134 | Furthermore, to encourage agents to explore interesting micro-strategies themselves, we limit the influence of the StarCraft AI on our agents. Specifically we disable the automatic unit attack against enemies that attack agents or that are located nearby. 135 | To do so, we make use of new units created with the StarCraft II Editor that are exact copies of existing units with two attributes modified: _Combat: Default Acquire Level_ is set to _Passive_ (default _Offensive_) and _Behaviour: Response_ is set to _No Response_ (default _Acquire_). These fields are only modified for allied units; enemy units are unchanged. 136 | 137 | The sight and shooting range values might differ from the built-in _sight_ or _range_ attribute of some StarCraft II units. Our goal is not to master the original full StarCraft game, but rather to benchmark MARL methods for decentralised control. 138 | 139 | 1: _health_, _shield_ and _unit\_type_ of the unit the agent controls is also included in observations 140 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.black] 2 | line-length = 79 3 | include = '\.pyi?$' 4 | exclude = ''' 5 | /( 6 | \.git 7 | | \.hg 8 | | \.mypy_cache 9 | | \.tox 10 | | \.venv 11 | | _build 12 | | buck-out 13 | | build 14 | | dist 15 | )/ 16 | ''' 17 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | from setuptools import setup 6 | 7 | description = """SMACv2 - StarCraft Multi-Agent Challenge 8 | 9 | SMACv2 is an update to Whirl’s Starcraft Multi-Agent Challenge, 10 | which is a benchmark for research in the field of cooperative 11 | multi-agent reinforcement learning. SMAC and SMACv2 both focus 12 | on decentralised micromanagement scenarios in StarCraft II, 13 | rather than the full game. 14 | 15 | The accompanying paper which outlines the motivation for using SMAC as well as 16 | results using the state-of-the-art deep multi-agent reinforcement learning 17 | algorithms can be found at https://www.arxiv.link 18 | 19 | Read the README at https://github.com/oxwhirl/smacv2 for more information. 20 | """ 21 | 22 | extras_deps = { 23 | "dev": [ 24 | "pre-commit>=2.0.1", 25 | "black>=19.10b0", 26 | "flake8>=3.7", 27 | "flake8-bugbear>=20.1", 28 | ], 29 | } 30 | 31 | 32 | setup( 33 | name="SMACv2", 34 | version="1.0.0", 35 | description="SMACv2 - StarCraft Multi-Agent Challenge.", 36 | long_description=description, 37 | author="WhiRL", 38 | author_email="benellis@robots.ox.ac.uk", 39 | license="MIT License", 40 | keywords="StarCraft, Multi-Agent Reinforcement Learning", 41 | url="https://github.com/oxwhirl/smacv2", 42 | packages=[ 43 | "smacv2", 44 | "smacv2.env", 45 | "smacv2.env.starcraft2", 46 | "smacv2.env.starcraft2.maps", 47 | "smacv2.env.pettingzoo", 48 | "smacv2.bin", 49 | "smacv2.examples", 50 | "smacv2.examples.rllib", 51 | "smacv2.examples.pettingzoo", 52 | ], 53 | extras_require=extras_deps, 54 | install_requires=[ 55 | "pysc2>=3.0.0", 56 | "protobuf<3.21", 57 | "s2clientprotocol>=4.10.1.75800.0", 58 | "absl-py>=0.1.0", 59 | "numpy>=1.10", 60 | "pygame>=2.0.0", 61 | ], 62 | ) 63 | -------------------------------------------------------------------------------- /smacv2/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/oxwhirl/smacv2/577ab5a2cff2391f8df582da5731ea9cd6adf3c6/smacv2/__init__.py -------------------------------------------------------------------------------- /smacv2/bin/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/oxwhirl/smacv2/577ab5a2cff2391f8df582da5731ea9cd6adf3c6/smacv2/bin/__init__.py -------------------------------------------------------------------------------- /smacv2/bin/map_list.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | from smacv2.env.starcraft2.maps import smac_maps 6 | 7 | from pysc2 import maps as pysc2_maps 8 | 9 | 10 | def main(): 11 | smac_map_registry = smac_maps.get_smac_map_registry() 12 | all_maps = pysc2_maps.get_maps() 13 | print("{:<15} {:7} {:7} {:7}".format("Name", "Agents", "Enemies", "Limit")) 14 | for map_name, map_params in smac_map_registry.items(): 15 | map_class = all_maps[map_name] 16 | if map_class.path: 17 | print( 18 | "{:<15} {:<7} {:<7} {:<7}".format( 19 | map_name, 20 | map_params["n_agents"], 21 | map_params["n_enemies"], 22 | map_params["limit"], 23 | ) 24 | ) 25 | 26 | 27 | if __name__ == "__main__": 28 | main() 29 | -------------------------------------------------------------------------------- /smacv2/env/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | from smacv2.env.multiagentenv import MultiAgentEnv 6 | from smacv2.env.starcraft2.starcraft2 import StarCraft2Env 7 | from smacv2.env.starcraft2.wrapper import StarCraftCapabilityEnvWrapper 8 | 9 | __all__ = ["MultiAgentEnv", "StarCraft2Env", "StarCraftCapabilityEnvWrapper"] 10 | -------------------------------------------------------------------------------- /smacv2/env/multiagentenv.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | 6 | class MultiAgentEnv(object): 7 | def step(self, actions): 8 | """Returns reward, terminated, info.""" 9 | raise NotImplementedError 10 | 11 | def get_obs(self): 12 | """Returns all agent observations in a list.""" 13 | raise NotImplementedError 14 | 15 | def get_obs_agent(self, agent_id): 16 | """Returns observation for agent_id.""" 17 | raise NotImplementedError 18 | 19 | def get_capabilities(self): 20 | """Returns the capabilities of all agents in a list.""" 21 | raise NotImplementedError 22 | 23 | def get_capabilities_agent(self, agent_id): 24 | """Returns the capabilities of a single agent.""" 25 | raise NotImplementedError 26 | 27 | def get_obs_size(self): 28 | """Returns the size of the observation.""" 29 | raise NotImplementedError 30 | 31 | def get_state(self): 32 | """Returns the global state.""" 33 | raise NotImplementedError 34 | 35 | def get_state_size(self): 36 | """Returns the size of the global state.""" 37 | raise NotImplementedError 38 | 39 | def get_cap_size(self): 40 | """Returns the size of the own capabilities of the agent.""" 41 | raise NotImplementedError 42 | 43 | def get_avail_actions(self): 44 | """Returns the available actions of all agents in a list.""" 45 | raise NotImplementedError 46 | 47 | def get_avail_agent_actions(self, agent_id): 48 | """Returns the available actions for agent_id.""" 49 | raise NotImplementedError 50 | 51 | def get_total_actions(self): 52 | """Returns the total number of actions an agent could ever take.""" 53 | raise NotImplementedError 54 | 55 | def reset(self): 56 | """Returns initial observations and states.""" 57 | raise NotImplementedError 58 | 59 | def render(self): 60 | raise NotImplementedError 61 | 62 | def close(self): 63 | raise NotImplementedError 64 | 65 | def seed(self): 66 | raise NotImplementedError 67 | 68 | def save_replay(self): 69 | """Save a replay.""" 70 | raise NotImplementedError 71 | 72 | def get_env_info(self): 73 | env_info = { 74 | "state_shape": self.get_state_size(), 75 | "obs_shape": self.get_obs_size(), 76 | "cap_shape": self.get_cap_size(), 77 | "n_actions": self.get_total_actions(), 78 | "n_agents": self.n_agents, 79 | "episode_limit": self.episode_limit, 80 | } 81 | return env_info 82 | -------------------------------------------------------------------------------- /smacv2/env/pettingzoo/StarCraft2PZEnv.py: -------------------------------------------------------------------------------- 1 | from smacv2.env import StarCraft2Env 2 | from gym.utils import EzPickle 3 | from gym.utils import seeding 4 | from gym import spaces 5 | from pettingzoo.utils.env import ParallelEnv 6 | from pettingzoo.utils.conversions import from_parallel_wrapper 7 | from pettingzoo.utils import wrappers 8 | import numpy as np 9 | 10 | max_cycles_default = 1000 11 | 12 | 13 | def parallel_env(max_cycles=max_cycles_default, **smac_args): 14 | return _parallel_env(max_cycles, **smac_args) 15 | 16 | 17 | def raw_env(max_cycles=max_cycles_default, **smac_args): 18 | return from_parallel_wrapper(parallel_env(max_cycles, **smac_args)) 19 | 20 | 21 | def make_env(raw_env): 22 | def env_fn(**kwargs): 23 | env = raw_env(**kwargs) 24 | # env = wrappers.TerminateIllegalWrapper(env, illegal_reward=-1) 25 | env = wrappers.AssertOutOfBoundsWrapper(env) 26 | env = wrappers.OrderEnforcingWrapper(env) 27 | return env 28 | 29 | return env_fn 30 | 31 | 32 | class smac_parallel_env(ParallelEnv): 33 | def __init__(self, env, max_cycles): 34 | self.max_cycles = max_cycles 35 | self.env = env 36 | self.env.reset() 37 | self.reset_flag = 0 38 | self.agents, self.action_spaces = self._init_agents() 39 | self.possible_agents = self.agents[:] 40 | 41 | observation_size = env.get_obs_size() 42 | self.observation_spaces = { 43 | name: spaces.Dict( 44 | { 45 | "observation": spaces.Box( 46 | low=-1, 47 | high=1, 48 | shape=(observation_size,), 49 | dtype="float32", 50 | ), 51 | "action_mask": spaces.Box( 52 | low=0, 53 | high=1, 54 | shape=(self.action_spaces[name].n,), 55 | dtype=np.int8, 56 | ), 57 | } 58 | ) 59 | for name in self.agents 60 | } 61 | self._reward = 0 62 | 63 | def _init_agents(self): 64 | last_type = "" 65 | agents = [] 66 | action_spaces = {} 67 | self.agents_id = {} 68 | i = 0 69 | for agent_id, agent_info in self.env.agents.items(): 70 | unit_action_space = spaces.Discrete( 71 | self.env.get_total_actions() - 1 72 | ) # no-op in dead units is not an action 73 | if agent_info.unit_type == self.env.marine_id: 74 | agent_type = "marine" 75 | elif agent_info.unit_type == self.env.marauder_id: 76 | agent_type = "marauder" 77 | elif agent_info.unit_type == self.env.medivac_id: 78 | agent_type = "medivac" 79 | elif agent_info.unit_type == self.env.hydralisk_id: 80 | agent_type = "hydralisk" 81 | elif agent_info.unit_type == self.env.zergling_id: 82 | agent_type = "zergling" 83 | elif agent_info.unit_type == self.env.baneling_id: 84 | agent_type = "baneling" 85 | elif agent_info.unit_type == self.env.stalker_id: 86 | agent_type = "stalker" 87 | elif agent_info.unit_type == self.env.colossus_id: 88 | agent_type = "colossus" 89 | elif agent_info.unit_type == self.env.zealot_id: 90 | agent_type = "zealot" 91 | else: 92 | raise AssertionError(f"agent type {agent_type} not supported") 93 | 94 | if agent_type == last_type: 95 | i += 1 96 | else: 97 | i = 0 98 | 99 | agents.append(f"{agent_type}_{i}") 100 | self.agents_id[agents[-1]] = agent_id 101 | action_spaces[agents[-1]] = unit_action_space 102 | last_type = agent_type 103 | 104 | return agents, action_spaces 105 | 106 | def seed(self, seed=None): 107 | if seed is None: 108 | self.env._seed = seeding.create_seed(seed, max_bytes=4) 109 | else: 110 | self.env._seed = seed 111 | self.env.full_restart() 112 | 113 | def render(self, mode="human"): 114 | self.env.render(mode) 115 | 116 | def close(self): 117 | self.env.close() 118 | 119 | def reset(self): 120 | self.env._episode_count = 1 121 | self.env.reset() 122 | 123 | self.agents = self.possible_agents[:] 124 | self.frames = 0 125 | self.all_dones = {agent: False for agent in self.possible_agents} 126 | return self._observe_all() 127 | 128 | def get_agent_smac_id(self, agent): 129 | return self.agents_id[agent] 130 | 131 | def _all_rewards(self, reward): 132 | all_rewards = [reward] * len(self.agents) 133 | return { 134 | agent: reward for agent, reward in zip(self.agents, all_rewards) 135 | } 136 | 137 | def _observe_all(self): 138 | all_obs = [] 139 | for agent in self.agents: 140 | agent_id = self.get_agent_smac_id(agent) 141 | obs = self.env.get_obs_agent(agent_id) 142 | action_mask = self.env.get_avail_agent_actions(agent_id) 143 | action_mask = action_mask[1:] 144 | action_mask = np.array(action_mask).astype(np.int8) 145 | obs = np.asarray(obs, dtype=np.float32) 146 | all_obs.append({"observation": obs, "action_mask": action_mask}) 147 | return {agent: obs for agent, obs in zip(self.agents, all_obs)} 148 | 149 | def _all_dones(self, step_done=False): 150 | dones = [True] * len(self.agents) 151 | if not step_done: 152 | for i, agent in enumerate(self.agents): 153 | agent_done = False 154 | agent_id = self.get_agent_smac_id(agent) 155 | agent_info = self.env.get_unit_by_id(agent_id) 156 | if agent_info.health == 0: 157 | agent_done = True 158 | dones[i] = agent_done 159 | return {agent: bool(done) for agent, done in zip(self.agents, dones)} 160 | 161 | def step(self, all_actions): 162 | action_list = [0] * self.env.n_agents 163 | for agent in self.agents: 164 | agent_id = self.get_agent_smac_id(agent) 165 | if agent in all_actions: 166 | if all_actions[agent] is None: 167 | action_list[agent_id] = 0 168 | else: 169 | action_list[agent_id] = all_actions[agent] + 1 170 | self._reward, terminated, smac_info = self.env.step(action_list) 171 | self.frames += 1 172 | done = terminated or self.frames >= self.max_cycles 173 | 174 | all_infos = {agent: {} for agent in self.agents} 175 | # all_infos.update(smac_info) 176 | all_dones = self._all_dones(done) 177 | all_rewards = self._all_rewards(self._reward) 178 | all_observes = self._observe_all() 179 | 180 | self.agents = [agent for agent in self.agents if not all_dones[agent]] 181 | 182 | return all_observes, all_rewards, all_dones, all_infos 183 | 184 | def __del__(self): 185 | self.env.close() 186 | 187 | 188 | env = make_env(raw_env) 189 | 190 | 191 | class _parallel_env(smac_parallel_env, EzPickle): 192 | metadata = {"render.modes": ["human"], "name": "sc2"} 193 | 194 | def __init__(self, max_cycles, **smac_args): 195 | EzPickle.__init__(self, max_cycles, **smac_args) 196 | env = StarCraft2Env(**smac_args) 197 | super().__init__(env, max_cycles) 198 | -------------------------------------------------------------------------------- /smacv2/env/pettingzoo/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/oxwhirl/smacv2/577ab5a2cff2391f8df582da5731ea9cd6adf3c6/smacv2/env/pettingzoo/__init__.py -------------------------------------------------------------------------------- /smacv2/env/pettingzoo/test/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/oxwhirl/smacv2/577ab5a2cff2391f8df582da5731ea9cd6adf3c6/smacv2/env/pettingzoo/test/__init__.py -------------------------------------------------------------------------------- /smacv2/env/pettingzoo/test/all_test.py: -------------------------------------------------------------------------------- 1 | from smacv2.env.starcraft2.maps import smac_maps 2 | from pysc2 import maps as pysc2_maps 3 | from smacv2.env.pettingzoo import StarCraft2PZEnv as sc2 4 | import pytest 5 | from pettingzoo import test 6 | import pickle 7 | 8 | smac_map_registry = smac_maps.get_smac_map_registry() 9 | all_maps = pysc2_maps.get_maps() 10 | map_names = [] 11 | for map_name in smac_map_registry.keys(): 12 | map_class = all_maps[map_name] 13 | if map_class.path: 14 | map_names.append(map_name) 15 | 16 | 17 | @pytest.mark.parametrize(("map_name"), map_names) 18 | def test_env(map_name): 19 | env = sc2.env(map_name=map_name) 20 | test.api_test(env) 21 | # test.parallel_api_test(sc2_v0.parallel_env()) # does not pass it due to 22 | # illegal actions test.seed_test(sc2.env, 50) # not required, sc2 env only 23 | # allows reseeding at initialization 24 | test.render_test(env) 25 | 26 | recreated_env = pickle.loads(pickle.dumps(env)) 27 | test.api_test(recreated_env) 28 | -------------------------------------------------------------------------------- /smacv2/env/pettingzoo/test/smac_pettingzoo_test.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import inspect 4 | from pettingzoo import test 5 | from smacv2.env.pettingzoo import StarCraft2PZEnv as sc2 6 | import pickle 7 | 8 | current_dir = os.path.dirname( 9 | os.path.abspath(inspect.getfile(inspect.currentframe())) 10 | ) 11 | parent_dir = os.path.dirname(current_dir) 12 | sys.path.insert(0, parent_dir) 13 | 14 | 15 | if __name__ == "__main__": 16 | env = sc2.env(map_name="corridor") 17 | test.api_test(env) 18 | # test.parallel_api_test(sc2_v0.parallel_env()) # does not pass it due to 19 | # illegal actions test.seed_test(sc2_v0.env, 50) # not required, sc2 env 20 | # only allows reseeding at initialization 21 | 22 | recreated_env = pickle.loads(pickle.dumps(env)) 23 | test.api_test(recreated_env) 24 | -------------------------------------------------------------------------------- /smacv2/env/starcraft2/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | from absl import flags 6 | 7 | FLAGS = flags.FLAGS 8 | FLAGS(["main.py"]) 9 | -------------------------------------------------------------------------------- /smacv2/env/starcraft2/distributions.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod, abstractproperty 2 | from copy import deepcopy 3 | from typing import Any, Dict 4 | from itertools import combinations_with_replacement 5 | from random import choice, shuffle 6 | from math import inf 7 | from numpy.random import default_rng 8 | import numpy as np 9 | 10 | 11 | class Distribution(ABC): 12 | @abstractmethod 13 | def generate(self) -> Dict[str, Any]: 14 | pass 15 | 16 | @property 17 | @abstractproperty 18 | def n_tasks(self) -> int: 19 | pass 20 | 21 | 22 | DISTRIBUTION_MAP = {} 23 | 24 | 25 | def get_distribution(key): 26 | return DISTRIBUTION_MAP[key] 27 | 28 | 29 | def register_distribution(key, cls): 30 | DISTRIBUTION_MAP[key] = cls 31 | 32 | 33 | class FixedDistribution(Distribution): 34 | """A generic disribution that draws from a fixed list. 35 | May operate in test mode, where items are drawn sequentially, 36 | or train mode where items are drawn randomly. Example uses of this 37 | are for team generation or per-agent accuracy generation in SMAC by 38 | drawing from separate fixed lists at test and train time. 39 | """ 40 | 41 | def __init__(self, config): 42 | """ 43 | Args: 44 | config (dict): Must contain `env_key`, `test_mode` and `items` 45 | entries. `env_key` is the key to pass to the environment so that it 46 | recognises what to do with the list. `test_mode` controls the sampling 47 | behaviour (sequential if true, uniform at random if false), `items` 48 | is the list of items (team configurations/accuracies etc.) to sample from. 49 | """ 50 | self.config = config 51 | self.env_key = config["env_key"] 52 | self.test_mode = config["test_mode"] 53 | self.teams = config["items"] 54 | self.index = 0 55 | 56 | def generate(self) -> Dict[str, Dict[str, Any]]: 57 | """Returns: 58 | Dict: Returns a dict of the form 59 | {self.env_key: {"item": , "id": }} 60 | """ 61 | if self.test_mode: 62 | team = self.teams[self.index] 63 | team_id = self.index 64 | self.index = (self.index + 1) % len(self.teams) 65 | shuffle(team) 66 | return {self.env_key: {"item": team, "id": team_id}} 67 | else: 68 | team = choice(self.teams) 69 | team_id = self.teams.index(team) 70 | shuffle(team) 71 | return {self.env_key: {"item": team, "id": team_id}} 72 | 73 | @property 74 | def n_tasks(self): 75 | return len(self.teams) 76 | 77 | 78 | register_distribution("fixed", FixedDistribution) 79 | 80 | 81 | class AllTeamsDistribution(Distribution): 82 | def __init__(self, config): 83 | self.config = config 84 | self.units = config["unit_types"] 85 | self.n_units = config["n_units"] 86 | self.exceptions = config.get("exception_unit_types", []) 87 | self.env_key = config["env_key"] 88 | self.combinations = list( 89 | combinations_with_replacement(self.units, self.n_units) 90 | ) 91 | 92 | def generate(self) -> Dict[str, Dict[str, Any]]: 93 | team = [] 94 | while not team or all(member in self.exceptions for member in team): 95 | team = list(choice(self.combinations)) 96 | team_id = self.combinations.index(tuple(team)) 97 | shuffle(team) 98 | return { 99 | self.env_key: { 100 | "ally_team": team, 101 | "enemy_team": team, 102 | "id": team_id, 103 | } 104 | } 105 | 106 | @property 107 | def n_tasks(self): 108 | # TODO adjust so that this can handle exceptions 109 | assert not self.exceptions 110 | return len(self.combinations) 111 | 112 | 113 | register_distribution("all_teams", AllTeamsDistribution) 114 | 115 | 116 | class WeightedTeamsDistribution(Distribution): 117 | def __init__(self, config): 118 | self.config = config 119 | self.units = np.array(config["unit_types"]) 120 | self.n_units = config["n_units"] 121 | self.n_enemies = config["n_enemies"] 122 | # assert ( 123 | # self.n_enemies >= self.n_units 124 | # ), "Only handle larger number of enemies than allies" 125 | self.weights = np.array(config["weights"]) 126 | # unit types that cannot make up the whole team 127 | self.exceptions = config.get("exception_unit_types", set()) 128 | self.rng = default_rng() 129 | self.env_key = config["env_key"] 130 | 131 | def _gen_team(self, n_units: int, use_exceptions: bool): 132 | team = [] 133 | while not team or ( 134 | all(member in self.exceptions for member in team) 135 | and use_exceptions 136 | ): 137 | team = list( 138 | self.rng.choice(self.units, size=(n_units,), p=self.weights) 139 | ) 140 | shuffle(team) 141 | return team 142 | 143 | def generate(self) -> Dict[str, Dict[str, Any]]: 144 | team = self._gen_team(self.n_units, use_exceptions=True) 145 | enemy_team = team.copy() 146 | if self.n_enemies > self.n_units: 147 | extra_enemies = self._gen_team( 148 | self.n_enemies - self.n_units, use_exceptions=True 149 | ) 150 | enemy_team.extend(extra_enemies) 151 | elif self.n_enemies < self.n_units: 152 | enemy_team = enemy_team[:self.n_enemies] 153 | 154 | return { 155 | self.env_key: { 156 | "ally_team": team, 157 | "enemy_team": enemy_team, 158 | "id": 0, 159 | } 160 | } 161 | 162 | @property 163 | def n_tasks(self): 164 | return inf 165 | 166 | 167 | register_distribution("weighted_teams", WeightedTeamsDistribution) 168 | 169 | 170 | class PerAgentUniformDistribution(Distribution): 171 | """A generic distribution for generating some information per-agent drawn 172 | from a uniform distribution in a specified range. 173 | """ 174 | 175 | def __init__(self, config): 176 | self.config = config 177 | self.lower_bound = config["lower_bound"] 178 | self.upper_bound = config["upper_bound"] 179 | self.env_key = config["env_key"] 180 | self.n_units = config["n_units"] 181 | self.rng = default_rng() 182 | 183 | def generate(self) -> Dict[str, Dict[str, Any]]: 184 | probs = self.rng.uniform( 185 | low=self.lower_bound, 186 | high=self.upper_bound, 187 | size=(self.n_units, len(self.lower_bound)), 188 | ) 189 | return {self.env_key: {"item": probs, "id": 0}} 190 | 191 | @property 192 | def n_tasks(self): 193 | return inf 194 | 195 | 196 | register_distribution("per_agent_uniform", PerAgentUniformDistribution) 197 | 198 | 199 | class MaskDistribution(Distribution): 200 | def __init__(self, config: Dict[str, Any]): 201 | self.config = config 202 | self.mask_probability = config["mask_probability"] 203 | self.n_units = config["n_units"] 204 | self.n_enemies = config["n_enemies"] 205 | self.rng = default_rng() 206 | 207 | def generate(self) -> Dict[str, Dict[str, Any]]: 208 | mask = self.rng.choice( 209 | [0, 1], 210 | size=(self.n_units, self.n_enemies), 211 | p=[ 212 | self.mask_probability, 213 | 1.0 - self.mask_probability, 214 | ], 215 | ) 216 | return {"enemy_mask": {"item": mask, "id": 0}} 217 | 218 | @property 219 | def n_tasks(self): 220 | return inf 221 | 222 | 223 | register_distribution("mask", MaskDistribution) 224 | 225 | 226 | class ReflectPositionDistribution(Distribution): 227 | """Distribution that will generate enemy and ally 228 | positions. Generates ally positions uniformly at 229 | random and then reflects these in a vertical line 230 | half-way across the map to get the enemy positions. 231 | Only works when the number of agents and enemies is the same. 232 | """ 233 | 234 | def __init__(self, config): 235 | self.config = config 236 | self.n_units = config["n_units"] 237 | self.n_enemies = config["n_enemies"] 238 | # assert ( 239 | # self.n_enemies >= self.n_units 240 | # ), "Number of enemies must be >= number of units" 241 | self.map_x = config["map_x"] 242 | self.map_y = config["map_y"] 243 | config_copy = deepcopy(config) 244 | config_copy["env_key"] = "ally_start_positions" 245 | config_copy["lower_bound"] = (0, 0) 246 | # subtract one from the x coordinate because SC2 goes wrong 247 | # when you spawn ally and enemy units on top of one another 248 | # -1 gives a sensible 'buffer zone' of size 2 249 | config_copy["upper_bound"] = (self.map_x / 2 - 1, self.map_y) 250 | self.pos_generator = PerAgentUniformDistribution(config_copy) 251 | if self.n_enemies > self.n_units: 252 | enemy_config_copy = deepcopy(config) 253 | enemy_config_copy["env_key"] = "enemy_start_positions" 254 | enemy_config_copy["lower_bound"] = (self.map_x / 2, 0) 255 | enemy_config_copy["upper_bound"] = (self.map_x, self.map_y) 256 | enemy_config_copy["n_units"] = self.n_enemies - self.n_units 257 | self.enemy_pos_generator = PerAgentUniformDistribution( 258 | enemy_config_copy 259 | ) 260 | 261 | def generate(self) -> Dict[str, Dict[str, Any]]: 262 | ally_positions_dict = self.pos_generator.generate() 263 | ally_positions = ally_positions_dict["ally_start_positions"]["item"] 264 | enemy_positions = np.zeros((self.n_enemies, 2)) 265 | if self.n_enemies >= self.n_units: 266 | enemy_positions[: self.n_units, 0] = self.map_x - ally_positions[:, 0] 267 | enemy_positions[: self.n_units, 1] = ally_positions[:, 1] 268 | if self.n_enemies > self.n_units: 269 | gen_enemy_positions = self.enemy_pos_generator.generate() 270 | gen_enemy_positions = gen_enemy_positions["enemy_start_positions"][ 271 | "item" 272 | ] 273 | enemy_positions[self.n_units:, :] = gen_enemy_positions 274 | else: 275 | enemy_positions[:, 0] = self.map_x - ally_positions[: self.n_enemies, 0] 276 | enemy_positions[:, 1] = ally_positions[: self.n_enemies, 1] 277 | return { 278 | "ally_start_positions": {"item": ally_positions, "id": 0}, 279 | "enemy_start_positions": {"item": enemy_positions, "id": 0}, 280 | } 281 | 282 | @property 283 | def n_tasks(self) -> int: 284 | return inf 285 | 286 | 287 | register_distribution("reflect_position", ReflectPositionDistribution) 288 | 289 | 290 | class SurroundedPositionDistribution(Distribution): 291 | """Distribution that generates ally positions in a 292 | circle at the centre of the map, and then has enemies 293 | randomly distributed in the four diagonal directions at a 294 | random distance. 295 | """ 296 | 297 | def __init__(self, config): 298 | self.config = config 299 | self.n_units = config["n_units"] 300 | self.n_enemies = config["n_enemies"] 301 | self.map_x = config["map_x"] 302 | self.map_y = config["map_y"] 303 | self.rng = default_rng() 304 | 305 | def generate(self) -> Dict[str, Dict[str, Any]]: 306 | # need multiple centre points because SC2 does not cope with 307 | # spawning ally and enemy units on top of one another in some 308 | # cases 309 | offset = 2 310 | centre_point = np.array([self.map_x / 2, self.map_y / 2]) 311 | diagonal_to_centre_point = { 312 | 0: np.array([self.map_x / 2 - offset, self.map_y / 2 - offset]), 313 | 1: np.array([self.map_x / 2 - offset, self.map_y / 2 + offset]), 314 | 2: np.array([self.map_x / 2 + offset, self.map_y / 2 + offset]), 315 | 3: np.array([self.map_x / 2 + offset, self.map_y / 2 - offset]), 316 | } 317 | ally_position = np.tile(centre_point, (self.n_units, 1)) 318 | enemy_position = np.zeros((self.n_enemies, 2)) 319 | # decide on the number of groups (between 1 and 4) 320 | n_groups = self.rng.integers(1, 5) 321 | # generate the number of enemies in each group 322 | group_membership = self.rng.multinomial( 323 | self.n_enemies, np.ones(n_groups) / n_groups 324 | ) 325 | # decide on the distance along the diagonal for each group 326 | group_position = self.rng.uniform(size=(n_groups,)) 327 | group_diagonals = self.rng.choice( 328 | np.array(range(4)), size=(n_groups,), replace=False 329 | ) 330 | 331 | diagonal_to_point_map = { 332 | 0: np.array([0, 0]), 333 | 1: np.array([0, self.map_y]), 334 | 2: np.array([self.map_x, self.map_y]), 335 | 3: np.array([self.map_x, 0]), 336 | } 337 | unit_index = 0 338 | for i in range(n_groups): 339 | t = group_position[i] 340 | enemy_position[ 341 | unit_index : unit_index + group_membership[i], : 342 | ] = diagonal_to_centre_point[ 343 | group_diagonals[i] 344 | ] * t + diagonal_to_point_map[ 345 | group_diagonals[i] 346 | ] * ( 347 | 1 - t 348 | ) 349 | unit_index += group_membership[i] 350 | 351 | return { 352 | "ally_start_positions": {"item": ally_position, "id": 0}, 353 | "enemy_start_positions": {"item": enemy_position, "id": 0}, 354 | } 355 | 356 | @property 357 | def n_tasks(self): 358 | return inf 359 | 360 | 361 | register_distribution("surrounded", SurroundedPositionDistribution) 362 | 363 | # If this becomes common, then should work on a more satisfying way 364 | # of doing this 365 | class SurroundedAndReflectPositionDistribution(Distribution): 366 | def __init__(self, config): 367 | self.p_threshold = config["p"] 368 | self.surrounded_distribution = SurroundedPositionDistribution(config) 369 | self.reflect_distribution = ReflectPositionDistribution(config) 370 | self.rng = default_rng() 371 | 372 | def generate(self) -> Dict[str, Dict[str, Any]]: 373 | p = self.rng.random() 374 | if p > self.p_threshold: 375 | return self.surrounded_distribution.generate() 376 | else: 377 | return self.reflect_distribution.generate() 378 | 379 | @property 380 | def n_tasks(self): 381 | return inf 382 | 383 | 384 | register_distribution( 385 | "surrounded_and_reflect", SurroundedAndReflectPositionDistribution 386 | ) 387 | -------------------------------------------------------------------------------- /smacv2/env/starcraft2/maps/SMAC_Maps/10gen_empty.SC2Map: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/oxwhirl/smacv2/577ab5a2cff2391f8df582da5731ea9cd6adf3c6/smacv2/env/starcraft2/maps/SMAC_Maps/10gen_empty.SC2Map -------------------------------------------------------------------------------- /smacv2/env/starcraft2/maps/SMAC_Maps/10gen_protoss.SC2Map: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/oxwhirl/smacv2/577ab5a2cff2391f8df582da5731ea9cd6adf3c6/smacv2/env/starcraft2/maps/SMAC_Maps/10gen_protoss.SC2Map -------------------------------------------------------------------------------- /smacv2/env/starcraft2/maps/SMAC_Maps/10gen_terran.SC2Map: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/oxwhirl/smacv2/577ab5a2cff2391f8df582da5731ea9cd6adf3c6/smacv2/env/starcraft2/maps/SMAC_Maps/10gen_terran.SC2Map -------------------------------------------------------------------------------- /smacv2/env/starcraft2/maps/SMAC_Maps/10gen_zerg.SC2Map: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/oxwhirl/smacv2/577ab5a2cff2391f8df582da5731ea9cd6adf3c6/smacv2/env/starcraft2/maps/SMAC_Maps/10gen_zerg.SC2Map -------------------------------------------------------------------------------- /smacv2/env/starcraft2/maps/SMAC_Maps/10m_vs_11m.SC2Map: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/oxwhirl/smacv2/577ab5a2cff2391f8df582da5731ea9cd6adf3c6/smacv2/env/starcraft2/maps/SMAC_Maps/10m_vs_11m.SC2Map -------------------------------------------------------------------------------- /smacv2/env/starcraft2/maps/SMAC_Maps/1c3s5z.SC2Map: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/oxwhirl/smacv2/577ab5a2cff2391f8df582da5731ea9cd6adf3c6/smacv2/env/starcraft2/maps/SMAC_Maps/1c3s5z.SC2Map -------------------------------------------------------------------------------- /smacv2/env/starcraft2/maps/SMAC_Maps/25m.SC2Map: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/oxwhirl/smacv2/577ab5a2cff2391f8df582da5731ea9cd6adf3c6/smacv2/env/starcraft2/maps/SMAC_Maps/25m.SC2Map -------------------------------------------------------------------------------- /smacv2/env/starcraft2/maps/SMAC_Maps/27m_vs_30m.SC2Map: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/oxwhirl/smacv2/577ab5a2cff2391f8df582da5731ea9cd6adf3c6/smacv2/env/starcraft2/maps/SMAC_Maps/27m_vs_30m.SC2Map -------------------------------------------------------------------------------- /smacv2/env/starcraft2/maps/SMAC_Maps/2c_vs_64zg.SC2Map: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/oxwhirl/smacv2/577ab5a2cff2391f8df582da5731ea9cd6adf3c6/smacv2/env/starcraft2/maps/SMAC_Maps/2c_vs_64zg.SC2Map -------------------------------------------------------------------------------- /smacv2/env/starcraft2/maps/SMAC_Maps/2m_vs_1z.SC2Map: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/oxwhirl/smacv2/577ab5a2cff2391f8df582da5731ea9cd6adf3c6/smacv2/env/starcraft2/maps/SMAC_Maps/2m_vs_1z.SC2Map -------------------------------------------------------------------------------- /smacv2/env/starcraft2/maps/SMAC_Maps/2s3z.SC2Map: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/oxwhirl/smacv2/577ab5a2cff2391f8df582da5731ea9cd6adf3c6/smacv2/env/starcraft2/maps/SMAC_Maps/2s3z.SC2Map -------------------------------------------------------------------------------- /smacv2/env/starcraft2/maps/SMAC_Maps/2s_vs_1sc.SC2Map: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/oxwhirl/smacv2/577ab5a2cff2391f8df582da5731ea9cd6adf3c6/smacv2/env/starcraft2/maps/SMAC_Maps/2s_vs_1sc.SC2Map -------------------------------------------------------------------------------- /smacv2/env/starcraft2/maps/SMAC_Maps/32x32_flat.SC2Map: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/oxwhirl/smacv2/577ab5a2cff2391f8df582da5731ea9cd6adf3c6/smacv2/env/starcraft2/maps/SMAC_Maps/32x32_flat.SC2Map -------------------------------------------------------------------------------- /smacv2/env/starcraft2/maps/SMAC_Maps/32x32_flat_test.SC2Map: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/oxwhirl/smacv2/577ab5a2cff2391f8df582da5731ea9cd6adf3c6/smacv2/env/starcraft2/maps/SMAC_Maps/32x32_flat_test.SC2Map -------------------------------------------------------------------------------- /smacv2/env/starcraft2/maps/SMAC_Maps/32x32_small.SC2Map: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/oxwhirl/smacv2/577ab5a2cff2391f8df582da5731ea9cd6adf3c6/smacv2/env/starcraft2/maps/SMAC_Maps/32x32_small.SC2Map -------------------------------------------------------------------------------- /smacv2/env/starcraft2/maps/SMAC_Maps/3m.SC2Map: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/oxwhirl/smacv2/577ab5a2cff2391f8df582da5731ea9cd6adf3c6/smacv2/env/starcraft2/maps/SMAC_Maps/3m.SC2Map -------------------------------------------------------------------------------- /smacv2/env/starcraft2/maps/SMAC_Maps/3s5z.SC2Map: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/oxwhirl/smacv2/577ab5a2cff2391f8df582da5731ea9cd6adf3c6/smacv2/env/starcraft2/maps/SMAC_Maps/3s5z.SC2Map -------------------------------------------------------------------------------- /smacv2/env/starcraft2/maps/SMAC_Maps/3s5z_vs_3s6z.SC2Map: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/oxwhirl/smacv2/577ab5a2cff2391f8df582da5731ea9cd6adf3c6/smacv2/env/starcraft2/maps/SMAC_Maps/3s5z_vs_3s6z.SC2Map -------------------------------------------------------------------------------- /smacv2/env/starcraft2/maps/SMAC_Maps/3s_vs_3z.SC2Map: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/oxwhirl/smacv2/577ab5a2cff2391f8df582da5731ea9cd6adf3c6/smacv2/env/starcraft2/maps/SMAC_Maps/3s_vs_3z.SC2Map -------------------------------------------------------------------------------- /smacv2/env/starcraft2/maps/SMAC_Maps/3s_vs_4z.SC2Map: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/oxwhirl/smacv2/577ab5a2cff2391f8df582da5731ea9cd6adf3c6/smacv2/env/starcraft2/maps/SMAC_Maps/3s_vs_4z.SC2Map -------------------------------------------------------------------------------- /smacv2/env/starcraft2/maps/SMAC_Maps/3s_vs_5z.SC2Map: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/oxwhirl/smacv2/577ab5a2cff2391f8df582da5731ea9cd6adf3c6/smacv2/env/starcraft2/maps/SMAC_Maps/3s_vs_5z.SC2Map -------------------------------------------------------------------------------- /smacv2/env/starcraft2/maps/SMAC_Maps/5m_vs_6m.SC2Map: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/oxwhirl/smacv2/577ab5a2cff2391f8df582da5731ea9cd6adf3c6/smacv2/env/starcraft2/maps/SMAC_Maps/5m_vs_6m.SC2Map -------------------------------------------------------------------------------- /smacv2/env/starcraft2/maps/SMAC_Maps/6h_vs_8z.SC2Map: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/oxwhirl/smacv2/577ab5a2cff2391f8df582da5731ea9cd6adf3c6/smacv2/env/starcraft2/maps/SMAC_Maps/6h_vs_8z.SC2Map -------------------------------------------------------------------------------- /smacv2/env/starcraft2/maps/SMAC_Maps/8m.SC2Map: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/oxwhirl/smacv2/577ab5a2cff2391f8df582da5731ea9cd6adf3c6/smacv2/env/starcraft2/maps/SMAC_Maps/8m.SC2Map -------------------------------------------------------------------------------- /smacv2/env/starcraft2/maps/SMAC_Maps/8m_vs_9m.SC2Map: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/oxwhirl/smacv2/577ab5a2cff2391f8df582da5731ea9cd6adf3c6/smacv2/env/starcraft2/maps/SMAC_Maps/8m_vs_9m.SC2Map -------------------------------------------------------------------------------- /smacv2/env/starcraft2/maps/SMAC_Maps/MMM.SC2Map: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/oxwhirl/smacv2/577ab5a2cff2391f8df582da5731ea9cd6adf3c6/smacv2/env/starcraft2/maps/SMAC_Maps/MMM.SC2Map -------------------------------------------------------------------------------- /smacv2/env/starcraft2/maps/SMAC_Maps/MMM2.SC2Map: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/oxwhirl/smacv2/577ab5a2cff2391f8df582da5731ea9cd6adf3c6/smacv2/env/starcraft2/maps/SMAC_Maps/MMM2.SC2Map -------------------------------------------------------------------------------- /smacv2/env/starcraft2/maps/SMAC_Maps/SMAC_Maps.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/oxwhirl/smacv2/577ab5a2cff2391f8df582da5731ea9cd6adf3c6/smacv2/env/starcraft2/maps/SMAC_Maps/SMAC_Maps.zip -------------------------------------------------------------------------------- /smacv2/env/starcraft2/maps/SMAC_Maps/bane_vs_bane.SC2Map: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/oxwhirl/smacv2/577ab5a2cff2391f8df582da5731ea9cd6adf3c6/smacv2/env/starcraft2/maps/SMAC_Maps/bane_vs_bane.SC2Map -------------------------------------------------------------------------------- /smacv2/env/starcraft2/maps/SMAC_Maps/corridor.SC2Map: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/oxwhirl/smacv2/577ab5a2cff2391f8df582da5731ea9cd6adf3c6/smacv2/env/starcraft2/maps/SMAC_Maps/corridor.SC2Map -------------------------------------------------------------------------------- /smacv2/env/starcraft2/maps/SMAC_Maps/so_many_baneling.SC2Map: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/oxwhirl/smacv2/577ab5a2cff2391f8df582da5731ea9cd6adf3c6/smacv2/env/starcraft2/maps/SMAC_Maps/so_many_baneling.SC2Map -------------------------------------------------------------------------------- /smacv2/env/starcraft2/maps/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | from smacv2.env.starcraft2.maps import smac_maps 6 | 7 | 8 | def get_map_params(map_name): 9 | map_param_registry = smac_maps.get_smac_map_registry() 10 | return map_param_registry[map_name] 11 | -------------------------------------------------------------------------------- /smacv2/env/starcraft2/maps/smac_maps.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | from pysc2.maps import lib 6 | 7 | 8 | class SMACMap(lib.Map): 9 | directory = "SMAC_Maps" 10 | download = "https://github.com/oxwhirl/smac#smac-maps" 11 | players = 2 12 | step_mul = 8 13 | game_steps_per_episode = 0 14 | 15 | 16 | map_param_registry = { 17 | "10gen_terran": { 18 | "n_agents": 10, 19 | "n_enemies": 10, 20 | "limit": 200, 21 | "a_race": "T", 22 | "b_race": "T", 23 | "unit_type_bits": 3, 24 | "map_type": "terran_gen", 25 | "map_name": "32x32_flat", 26 | }, 27 | "10gen_zerg": { 28 | "n_agents": 10, 29 | "n_enemies": 10, 30 | "limit": 200, 31 | "a_race": "Z", 32 | "b_race": "Z", 33 | "unit_type_bits": 3, 34 | "map_type": "zerg_gen", 35 | "map_name": "32x32_flat", 36 | }, 37 | "10gen_protoss": { 38 | "n_agents": 10, 39 | "n_enemies": 10, 40 | "limit": 200, 41 | "a_race": "P", 42 | "b_race": "P", 43 | "unit_type_bits": 3, 44 | "map_type": "protoss_gen", 45 | "map_name": "32x32_flat", 46 | }, 47 | "3m": { 48 | "n_agents": 3, 49 | "n_enemies": 3, 50 | "limit": 60, 51 | "a_race": "T", 52 | "b_race": "T", 53 | "unit_type_bits": 0, 54 | "map_type": "marines", 55 | "map_name": "3m", 56 | }, 57 | "8m": { 58 | "n_agents": 8, 59 | "n_enemies": 8, 60 | "limit": 120, 61 | "a_race": "T", 62 | "b_race": "T", 63 | "unit_type_bits": 0, 64 | "map_type": "marines", 65 | "map_name": "8m", 66 | }, 67 | "25m": { 68 | "n_agents": 25, 69 | "n_enemies": 25, 70 | "limit": 150, 71 | "a_race": "T", 72 | "b_race": "T", 73 | "unit_type_bits": 0, 74 | "map_type": "marines", 75 | "map_name": "25m", 76 | }, 77 | "5m_vs_6m": { 78 | "n_agents": 5, 79 | "n_enemies": 6, 80 | "limit": 70, 81 | "a_race": "T", 82 | "b_race": "T", 83 | "unit_type_bits": 0, 84 | "map_type": "marines", 85 | "map_name": "5m_vs_6m", 86 | }, 87 | "8m_vs_9m": { 88 | "n_agents": 8, 89 | "n_enemies": 9, 90 | "limit": 120, 91 | "a_race": "T", 92 | "b_race": "T", 93 | "unit_type_bits": 0, 94 | "map_type": "marines", 95 | "map_name": "8m_vs_9m", 96 | }, 97 | "10m_vs_11m": { 98 | "n_agents": 10, 99 | "n_enemies": 11, 100 | "limit": 150, 101 | "a_race": "T", 102 | "b_race": "T", 103 | "unit_type_bits": 0, 104 | "map_type": "marines", 105 | "map_name": "10m_vs_11m", 106 | }, 107 | "27m_vs_30m": { 108 | "n_agents": 27, 109 | "n_enemies": 30, 110 | "limit": 180, 111 | "a_race": "T", 112 | "b_race": "T", 113 | "unit_type_bits": 0, 114 | "map_type": "marines", 115 | "map_name": "27m_vs_30m", 116 | }, 117 | "MMM": { 118 | "n_agents": 10, 119 | "n_enemies": 10, 120 | "limit": 150, 121 | "a_race": "T", 122 | "b_race": "T", 123 | "unit_type_bits": 3, 124 | "map_type": "MMM", 125 | "map_name": "MMM", 126 | }, 127 | "MMM2": { 128 | "n_agents": 10, 129 | "n_enemies": 12, 130 | "limit": 180, 131 | "a_race": "T", 132 | "b_race": "T", 133 | "unit_type_bits": 3, 134 | "map_type": "MMM", 135 | "map_name": "MMM2", 136 | }, 137 | "2s3z": { 138 | "n_agents": 5, 139 | "n_enemies": 5, 140 | "limit": 120, 141 | "a_race": "P", 142 | "b_race": "P", 143 | "unit_type_bits": 2, 144 | "map_type": "stalkers_and_zealots", 145 | "map_name": "2s3z", 146 | }, 147 | "3s5z": { 148 | "n_agents": 8, 149 | "n_enemies": 8, 150 | "limit": 150, 151 | "a_race": "P", 152 | "b_race": "P", 153 | "unit_type_bits": 2, 154 | "map_type": "stalkers_and_zealots", 155 | "map_name": "3s5z", 156 | }, 157 | "3s5z_vs_3s6z": { 158 | "n_agents": 8, 159 | "n_enemies": 9, 160 | "limit": 170, 161 | "a_race": "P", 162 | "b_race": "P", 163 | "unit_type_bits": 2, 164 | "map_type": "stalkers_and_zealots", 165 | "map_name": "3s5z_vs_3s6z", 166 | }, 167 | "3s_vs_3z": { 168 | "n_agents": 3, 169 | "n_enemies": 3, 170 | "limit": 150, 171 | "a_race": "P", 172 | "b_race": "P", 173 | "unit_type_bits": 0, 174 | "map_type": "stalkers", 175 | "map_name": "3s_vs_3z", 176 | }, 177 | "3s_vs_4z": { 178 | "n_agents": 3, 179 | "n_enemies": 4, 180 | "limit": 200, 181 | "a_race": "P", 182 | "b_race": "P", 183 | "unit_type_bits": 0, 184 | "map_type": "stalkers", 185 | "map_name": "3s_vs_4z", 186 | }, 187 | "3s_vs_5z": { 188 | "n_agents": 3, 189 | "n_enemies": 5, 190 | "limit": 250, 191 | "a_race": "P", 192 | "b_race": "P", 193 | "unit_type_bits": 0, 194 | "map_type": "stalkers", 195 | "map_name": "3s_vs_5z", 196 | }, 197 | "1c3s5z": { 198 | "n_agents": 9, 199 | "n_enemies": 9, 200 | "limit": 180, 201 | "a_race": "P", 202 | "b_race": "P", 203 | "unit_type_bits": 3, 204 | "map_type": "colossi_stalkers_zealots", 205 | "map_name": "1c3s5z", 206 | }, 207 | "2m_vs_1z": { 208 | "n_agents": 2, 209 | "n_enemies": 1, 210 | "limit": 150, 211 | "a_race": "T", 212 | "b_race": "P", 213 | "unit_type_bits": 0, 214 | "map_type": "marines", 215 | "map_name": "2m_vs_1z", 216 | }, 217 | "corridor": { 218 | "n_agents": 6, 219 | "n_enemies": 24, 220 | "limit": 400, 221 | "a_race": "P", 222 | "b_race": "Z", 223 | "unit_type_bits": 0, 224 | "map_type": "zealots", 225 | "map_name": "corridor", 226 | }, 227 | "6h_vs_8z": { 228 | "n_agents": 6, 229 | "n_enemies": 8, 230 | "limit": 150, 231 | "a_race": "Z", 232 | "b_race": "P", 233 | "unit_type_bits": 0, 234 | "map_type": "hydralisks", 235 | "map_name": "6h_vs_8z", 236 | }, 237 | "2s_vs_1sc": { 238 | "n_agents": 2, 239 | "n_enemies": 1, 240 | "limit": 300, 241 | "a_race": "P", 242 | "b_race": "Z", 243 | "unit_type_bits": 0, 244 | "map_type": "stalkers", 245 | "map_name": "2s_vs_1sc", 246 | }, 247 | "so_many_baneling": { 248 | "n_agents": 7, 249 | "n_enemies": 32, 250 | "limit": 100, 251 | "a_race": "P", 252 | "b_race": "Z", 253 | "unit_type_bits": 0, 254 | "map_type": "zealots", 255 | "map_name": "so_many_baneling", 256 | }, 257 | "bane_vs_bane": { 258 | "n_agents": 24, 259 | "n_enemies": 24, 260 | "limit": 200, 261 | "a_race": "Z", 262 | "b_race": "Z", 263 | "unit_type_bits": 2, 264 | "map_type": "bane", 265 | "map_name": "bane_vs_bane", 266 | }, 267 | "2c_vs_64zg": { 268 | "n_agents": 2, 269 | "n_enemies": 64, 270 | "limit": 400, 271 | "a_race": "P", 272 | "b_race": "Z", 273 | "unit_type_bits": 0, 274 | "map_type": "colossus", 275 | "map_name": "2c_vs_64zg", 276 | }, 277 | } 278 | 279 | 280 | def get_smac_map_registry(): 281 | return map_param_registry 282 | 283 | 284 | for name, map_params in map_param_registry.items(): 285 | globals()[name] = type( 286 | name, (SMACMap,), dict(filename=map_params["map_name"]) 287 | ) 288 | -------------------------------------------------------------------------------- /smacv2/env/starcraft2/render.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import re 3 | import subprocess 4 | import platform 5 | from absl import logging 6 | import math 7 | import time 8 | import collections 9 | import os 10 | import pygame 11 | import queue 12 | 13 | from pysc2.lib import colors 14 | from pysc2.lib import point 15 | from pysc2.lib.renderer_human import _Surface 16 | from pysc2.lib import transform 17 | from pysc2.lib import features 18 | 19 | 20 | def clamp(n, smallest, largest): 21 | return max(smallest, min(n, largest)) 22 | 23 | 24 | def _get_desktop_size(): 25 | """Get the desktop size.""" 26 | if platform.system() == "Linux": 27 | try: 28 | xrandr_query = subprocess.check_output(["xrandr", "--query"]) 29 | sizes = re.findall( 30 | r"\bconnected primary (\d+)x(\d+)", str(xrandr_query) 31 | ) 32 | if sizes[0]: 33 | return point.Point(int(sizes[0][0]), int(sizes[0][1])) 34 | except ValueError: 35 | logging.error("Failed to get the resolution from xrandr.") 36 | 37 | # Most general, but doesn't understand multiple monitors. 38 | display_info = pygame.display.Info() 39 | return point.Point(display_info.current_w, display_info.current_h) 40 | 41 | 42 | class StarCraft2Renderer: 43 | def __init__(self, env, mode): 44 | os.environ["PYGAME_HIDE_SUPPORT_PROMPT"] = "hide" 45 | 46 | self.env = env 47 | self.mode = mode 48 | self.obs = None 49 | self._window_scale = 0.75 50 | self.game_info = game_info = self.env._controller.game_info() 51 | self.static_data = self.env._controller.data() 52 | 53 | self._obs_queue = queue.Queue() 54 | self._game_times = collections.deque( 55 | maxlen=100 56 | ) # Avg FPS over 100 frames. # pytype: disable=wrong-keyword-args 57 | self._render_times = collections.deque( 58 | maxlen=100 59 | ) # pytype: disable=wrong-keyword-args 60 | self._last_time = time.time() 61 | self._last_game_loop = 0 62 | self._name_lengths = {} 63 | 64 | self._map_size = point.Point.build(game_info.start_raw.map_size) 65 | self._playable = point.Rect( 66 | point.Point.build(game_info.start_raw.playable_area.p0), 67 | point.Point.build(game_info.start_raw.playable_area.p1), 68 | ) 69 | 70 | window_size_px = point.Point( 71 | self.env.window_size[0], self.env.window_size[1] 72 | ) 73 | window_size_px = self._map_size.scale_max_size( 74 | window_size_px * self._window_scale 75 | ).ceil() 76 | self._scale = window_size_px.y // 32 77 | 78 | self.display = pygame.Surface(window_size_px) 79 | 80 | if mode == "human": 81 | self.display = pygame.display.set_mode(window_size_px, 0, 32) 82 | pygame.display.init() 83 | 84 | pygame.display.set_caption("Starcraft Viewer") 85 | pygame.font.init() 86 | self._world_to_world_tl = transform.Linear( 87 | point.Point(1, -1), point.Point(0, self._map_size.y) 88 | ) 89 | self._world_tl_to_screen = transform.Linear(scale=window_size_px / 32) 90 | self.screen_transform = transform.Chain( 91 | self._world_to_world_tl, self._world_tl_to_screen 92 | ) 93 | 94 | surf_loc = point.Rect(point.origin, window_size_px) 95 | sub_surf = self.display.subsurface( 96 | pygame.Rect(surf_loc.tl, surf_loc.size) 97 | ) 98 | self._surf = _Surface( 99 | sub_surf, 100 | None, 101 | surf_loc, 102 | self.screen_transform, 103 | None, 104 | self.draw_screen, 105 | ) 106 | 107 | self._font_small = pygame.font.Font(None, int(self._scale * 0.5)) 108 | self._font_large = pygame.font.Font(None, self._scale) 109 | 110 | def close(self): 111 | pygame.display.quit() 112 | pygame.quit() 113 | 114 | def _get_units(self): 115 | for u in sorted( 116 | self.obs.observation.raw_data.units, 117 | key=lambda u: (u.pos.z, u.owner != 16, -u.radius, u.tag), 118 | ): 119 | yield u, point.Point.build(u.pos) 120 | 121 | def get_unit_name(self, surf, name, radius): 122 | """Get a length limited unit name for drawing units.""" 123 | key = (name, radius) 124 | if key not in self._name_lengths: 125 | max_len = surf.world_to_surf.fwd_dist(radius * 1.6) 126 | for i in range(len(name)): 127 | if self._font_small.size(name[: i + 1])[0] > max_len: 128 | self._name_lengths[key] = name[:i] 129 | break 130 | else: 131 | self._name_lengths[key] = name 132 | return self._name_lengths[key] 133 | 134 | def render(self, mode): 135 | self.obs = self.env._obs 136 | self.score = self.env.reward 137 | self.step = self.env._episode_steps 138 | 139 | now = time.time() 140 | self._game_times.append( 141 | ( 142 | now - self._last_time, 143 | max( 144 | 1, 145 | self.obs.observation.game_loop 146 | - self.obs.observation.game_loop, 147 | ), 148 | ) 149 | ) 150 | 151 | if mode == "human": 152 | pygame.event.pump() 153 | 154 | self._surf.draw(self._surf) 155 | 156 | observation = np.array(pygame.surfarray.pixels3d(self.display)) 157 | 158 | if mode == "human": 159 | pygame.display.flip() 160 | 161 | self._last_time = now 162 | self._last_game_loop = self.obs.observation.game_loop 163 | # self._obs_queue.put(self.obs) 164 | return ( 165 | np.transpose(observation, axes=(1, 0, 2)) 166 | if mode == "rgb_array" 167 | else None 168 | ) 169 | 170 | def draw_base_map(self, surf): 171 | """Draw the base map.""" 172 | hmap_feature = features.SCREEN_FEATURES.height_map 173 | hmap = self.env.terrain_height * 255 174 | hmap = hmap.astype(np.uint8) 175 | if ( 176 | self.env.map_name == "corridor" 177 | or self.env.map_name == "so_many_baneling" 178 | or self.env.map_name == "2s_vs_1sc" 179 | ): 180 | hmap = np.flip(hmap) 181 | else: 182 | hmap = np.rot90(hmap, axes=(1, 0)) 183 | if not hmap.any(): 184 | hmap = hmap + 100 # pylint: disable=g-no-augmented-assignment 185 | hmap_color = hmap_feature.color(hmap) 186 | out = hmap_color * 0.6 187 | 188 | surf.blit_np_array(out) 189 | 190 | def draw_units(self, surf): 191 | """Draw the units.""" 192 | unit_dict = None # Cache the units {tag: unit_proto} for orders. 193 | tau = 2 * math.pi 194 | for u, p in self._get_units(): 195 | fraction_damage = clamp( 196 | (u.health_max - u.health) / (u.health_max or 1), 0, 1 197 | ) 198 | surf.draw_circle( 199 | colors.PLAYER_ABSOLUTE_PALETTE[u.owner], p, u.radius 200 | ) 201 | 202 | if fraction_damage > 0: 203 | surf.draw_circle( 204 | colors.PLAYER_ABSOLUTE_PALETTE[u.owner] // 2, 205 | p, 206 | u.radius * fraction_damage, 207 | ) 208 | surf.draw_circle(colors.black, p, u.radius, thickness=1) 209 | 210 | if self.static_data.unit_stats[u.unit_type].movement_speed > 0: 211 | surf.draw_arc( 212 | colors.white, 213 | p, 214 | u.radius, 215 | u.facing - 0.1, 216 | u.facing + 0.1, 217 | thickness=1, 218 | ) 219 | 220 | def draw_arc_ratio( 221 | color, world_loc, radius, start, end, thickness=1 222 | ): 223 | surf.draw_arc( 224 | color, world_loc, radius, start * tau, end * tau, thickness 225 | ) 226 | 227 | if u.shield and u.shield_max: 228 | draw_arc_ratio( 229 | colors.blue, p, u.radius - 0.05, 0, u.shield / u.shield_max 230 | ) 231 | 232 | if u.energy and u.energy_max: 233 | draw_arc_ratio( 234 | colors.purple * 0.9, 235 | p, 236 | u.radius - 0.1, 237 | 0, 238 | u.energy / u.energy_max, 239 | ) 240 | elif u.orders and 0 < u.orders[0].progress < 1: 241 | draw_arc_ratio( 242 | colors.cyan, p, u.radius - 0.15, 0, u.orders[0].progress 243 | ) 244 | if u.buff_duration_remain and u.buff_duration_max: 245 | draw_arc_ratio( 246 | colors.white, 247 | p, 248 | u.radius - 0.2, 249 | 0, 250 | u.buff_duration_remain / u.buff_duration_max, 251 | ) 252 | if u.attack_upgrade_level: 253 | draw_arc_ratio( 254 | self.upgrade_colors[u.attack_upgrade_level], 255 | p, 256 | u.radius - 0.25, 257 | 0.18, 258 | 0.22, 259 | thickness=3, 260 | ) 261 | if u.armor_upgrade_level: 262 | draw_arc_ratio( 263 | self.upgrade_colors[u.armor_upgrade_level], 264 | p, 265 | u.radius - 0.25, 266 | 0.23, 267 | 0.27, 268 | thickness=3, 269 | ) 270 | if u.shield_upgrade_level: 271 | draw_arc_ratio( 272 | self.upgrade_colors[u.shield_upgrade_level], 273 | p, 274 | u.radius - 0.25, 275 | 0.28, 276 | 0.32, 277 | thickness=3, 278 | ) 279 | 280 | def write_small(loc, s): 281 | surf.write_world(self._font_small, colors.white, loc, str(s)) 282 | 283 | name = self.get_unit_name( 284 | surf, 285 | self.static_data.units.get(u.unit_type, ""), 286 | u.radius, 287 | ) 288 | 289 | if name: 290 | write_small(p, name) 291 | 292 | start_point = p 293 | for o in u.orders: 294 | target_point = None 295 | if o.HasField("target_unit_tag"): 296 | if unit_dict is None: 297 | unit_dict = { 298 | t.tag: t 299 | for t in self.obs.observation.raw_data.units 300 | } 301 | target_unit = unit_dict.get(o.target_unit_tag) 302 | if target_unit: 303 | target_point = point.Point.build(target_unit.pos) 304 | if target_point: 305 | surf.draw_line(colors.cyan, start_point, target_point) 306 | start_point = target_point 307 | else: 308 | break 309 | 310 | def draw_overlay(self, surf): 311 | """Draw the overlay describing resources.""" 312 | obs = self.obs.observation 313 | times, steps = zip(*self._game_times) 314 | sec = obs.game_loop // 22.4 315 | surf.write_screen( 316 | self._font_large, 317 | colors.green, 318 | (-0.2, 0.2), 319 | "Score: %s, Step: %s, %.1f/s, Time: %d:%02d" 320 | % ( 321 | self.score, 322 | self.step, 323 | sum(steps) / (sum(times) or 1), 324 | sec // 60, 325 | sec % 60, 326 | ), 327 | align="right", 328 | ) 329 | surf.write_screen( 330 | self._font_large, 331 | colors.green * 0.8, 332 | (-0.2, 1.2), 333 | "APM: %d, EPM: %d, FPS: O:%.1f, R:%.1f" 334 | % ( 335 | obs.score.score_details.current_apm, 336 | obs.score.score_details.current_effective_apm, 337 | len(times) / (sum(times) or 1), 338 | len(self._render_times) / (sum(self._render_times) or 1), 339 | ), 340 | align="right", 341 | ) 342 | 343 | def draw_screen(self, surf): 344 | """Draw the screen area.""" 345 | self.draw_base_map(surf) 346 | self.draw_units(surf) 347 | self.draw_overlay(surf) 348 | -------------------------------------------------------------------------------- /smacv2/env/starcraft2/wrapper.py: -------------------------------------------------------------------------------- 1 | from smacv2.env.starcraft2.distributions import get_distribution 2 | from smacv2.env.starcraft2.starcraft2 import StarCraft2Env, CannotResetException 3 | from smacv2.env import MultiAgentEnv 4 | 5 | 6 | class StarCraftCapabilityEnvWrapper(MultiAgentEnv): 7 | def __init__(self, **kwargs): 8 | self.distribution_config = kwargs["capability_config"] 9 | self.env_key_to_distribution_map = {} 10 | self._parse_distribution_config() 11 | self.env = StarCraft2Env(**kwargs) 12 | assert ( 13 | self.distribution_config.keys() 14 | == kwargs["capability_config"].keys() 15 | ), "Must give distribution config and capability config the same keys" 16 | 17 | def _parse_distribution_config(self): 18 | for env_key, config in self.distribution_config.items(): 19 | if env_key == "n_units" or env_key == "n_enemies": 20 | continue 21 | config["env_key"] = env_key 22 | # add n_units key 23 | config["n_units"] = self.distribution_config["n_units"] 24 | config["n_enemies"] = self.distribution_config["n_enemies"] 25 | distribution = get_distribution(config["dist_type"])(config) 26 | self.env_key_to_distribution_map[env_key] = distribution 27 | 28 | def reset(self): 29 | try: 30 | reset_config = {} 31 | for distribution in self.env_key_to_distribution_map.values(): 32 | reset_config = {**reset_config, **distribution.generate()} 33 | 34 | return self.env.reset(reset_config) 35 | except CannotResetException as cre: 36 | # just retry 37 | self.reset() 38 | 39 | def __getattr__(self, name): 40 | if hasattr(self.env, name): 41 | return getattr(self.env, name) 42 | else: 43 | raise AttributeError 44 | 45 | def get_obs(self): 46 | return self.env.get_obs() 47 | 48 | def get_obs_feature_names(self): 49 | return self.env.get_obs_feature_names() 50 | 51 | def get_state(self): 52 | return self.env.get_state() 53 | 54 | def get_state_feature_names(self): 55 | return self.env.get_state_feature_names() 56 | 57 | def get_avail_actions(self): 58 | return self.env.get_avail_actions() 59 | 60 | def get_env_info(self): 61 | return self.env.get_env_info() 62 | 63 | def get_obs_size(self): 64 | return self.env.get_obs_size() 65 | 66 | def get_state_size(self): 67 | return self.env.get_state_size() 68 | 69 | def get_total_actions(self): 70 | return self.env.get_total_actions() 71 | 72 | def get_capabilities(self): 73 | return self.env.get_capabilities() 74 | 75 | def get_obs_agent(self, agent_id): 76 | return self.env.get_obs_agent(agent_id) 77 | 78 | def get_avail_agent_actions(self, agent_id): 79 | return self.env.get_avail_agent_actions(agent_id) 80 | 81 | def render(self, mode="human"): 82 | return self.env.render(mode=mode) 83 | 84 | def step(self, actions): 85 | return self.env.step(actions) 86 | 87 | def get_stats(self): 88 | return self.env.get_stats() 89 | 90 | def full_restart(self): 91 | return self.env.full_restart() 92 | 93 | def save_replay(self): 94 | self.env.save_replay() 95 | 96 | def close(self): 97 | return self.env.close() 98 | -------------------------------------------------------------------------------- /smacv2/examples/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/oxwhirl/smacv2/577ab5a2cff2391f8df582da5731ea9cd6adf3c6/smacv2/examples/__init__.py -------------------------------------------------------------------------------- /smacv2/examples/configs/sc2_gen_protoss.yaml: -------------------------------------------------------------------------------- 1 | env: sc2wrapped 2 | 3 | env_args: 4 | continuing_episode: False 5 | difficulty: "7" 6 | game_version: null 7 | map_name: "10gen_protoss" 8 | move_amount: 2 9 | obs_all_health: True 10 | obs_instead_of_state: False 11 | obs_last_action: False 12 | obs_own_health: True 13 | obs_pathing_grid: False 14 | obs_terrain_height: False 15 | obs_timestep_number: False 16 | reward_death_value: 10 17 | reward_defeat: 0 18 | reward_negative_scale: 0.5 19 | reward_only_positive: True 20 | reward_scale: True 21 | reward_scale_rate: 20 22 | reward_sparse: False 23 | reward_win: 200 24 | replay_dir: "" 25 | replay_prefix: "" 26 | conic_fov: False 27 | use_unit_ranges: True 28 | min_attack_range: 2 29 | obs_own_pos: True 30 | num_fov_actions: 12 31 | capability_config: 32 | n_units: 5 33 | n_enemies: 5 34 | team_gen: 35 | dist_type: "weighted_teams" 36 | unit_types: 37 | - "stalker" 38 | - "zealot" 39 | - "colossus" 40 | weights: 41 | - 0.45 42 | - 0.45 43 | - 0.1 44 | observe: True 45 | start_positions: 46 | dist_type: "surrounded_and_reflect" 47 | p: 0.5 48 | map_x: 32 49 | map_y: 32 50 | 51 | # enemy_mask: 52 | # dist_type: "mask" 53 | # mask_probability: 0.5 54 | # n_enemies: 5 55 | state_last_action: True 56 | state_timestep_number: False 57 | step_mul: 8 58 | heuristic_ai: False 59 | # heuristic_rest: False 60 | debug: False 61 | prob_obs_enemy: 1.0 62 | action_mask: True 63 | 64 | test_nepisode: 32 65 | test_interval: 10000 66 | log_interval: 2000 67 | runner_log_interval: 2000 68 | learner_log_interval: 2000 69 | t_max: 10050000 70 | -------------------------------------------------------------------------------- /smacv2/examples/configs/sc2_gen_protoss_epo.yaml: -------------------------------------------------------------------------------- 1 | env: sc2wrapped 2 | 3 | env_args: 4 | continuing_episode: False 5 | difficulty: "7" 6 | game_version: null 7 | map_name: "10gen_protoss" 8 | move_amount: 2 9 | obs_all_health: True 10 | obs_instead_of_state: False 11 | obs_last_action: False 12 | obs_own_health: True 13 | obs_pathing_grid: False 14 | obs_terrain_height: False 15 | obs_timestep_number: False 16 | reward_death_value: 10 17 | reward_defeat: 0 18 | reward_negative_scale: 0.5 19 | reward_only_positive: True 20 | reward_scale: True 21 | reward_scale_rate: 20 22 | reward_sparse: False 23 | reward_win: 200 24 | replay_dir: "" 25 | replay_prefix: "" 26 | conic_fov: False 27 | use_unit_ranges: True 28 | min_attack_range: 2 29 | obs_own_pos: True 30 | num_fov_actions: 12 31 | capability_config: 32 | n_units: 5 33 | n_enemies: 5 34 | team_gen: 35 | dist_type: "weighted_teams" 36 | unit_types: 37 | - "stalker" 38 | - "zealot" 39 | - "colossus" 40 | weights: 41 | - 0.45 42 | - 0.45 43 | - 0.1 44 | observe: True 45 | start_positions: 46 | dist_type: "surrounded_and_reflect" 47 | p: 0.5 48 | map_x: 32 49 | map_y: 32 50 | 51 | # enemy_mask: 52 | # dist_type: "mask" 53 | # mask_probability: 0.5 54 | # n_enemies: 5 55 | state_last_action: True 56 | state_timestep_number: False 57 | step_mul: 8 58 | heuristic_ai: False 59 | # heuristic_rest: False 60 | debug: False 61 | # Most severe partial obs setting: 62 | prob_obs_enemy: 0.0 63 | action_mask: False 64 | 65 | test_nepisode: 32 66 | test_interval: 10000 67 | log_interval: 2000 68 | runner_log_interval: 2000 69 | learner_log_interval: 2000 70 | t_max: 10050000 -------------------------------------------------------------------------------- /smacv2/examples/configs/sc2_gen_terran.yaml: -------------------------------------------------------------------------------- 1 | env: sc2wrapped 2 | 3 | env_args: 4 | continuing_episode: False 5 | difficulty: "7" 6 | game_version: null 7 | map_name: "10gen_terran" 8 | move_amount: 2 9 | obs_all_health: True 10 | obs_instead_of_state: False 11 | obs_last_action: False 12 | obs_own_health: True 13 | obs_pathing_grid: False 14 | obs_terrain_height: False 15 | obs_timestep_number: False 16 | reward_death_value: 10 17 | reward_defeat: 0 18 | reward_negative_scale: 0.5 19 | reward_only_positive: True 20 | reward_scale: True 21 | reward_scale_rate: 20 22 | reward_sparse: False 23 | reward_win: 200 24 | replay_dir: "" 25 | replay_prefix: "" 26 | conic_fov: False 27 | obs_own_pos: True 28 | use_unit_ranges: True 29 | min_attack_range: 2 30 | num_fov_actions: 12 31 | capability_config: 32 | n_units: 5 33 | n_enemies: 5 34 | team_gen: 35 | dist_type: "weighted_teams" 36 | unit_types: 37 | - "marine" 38 | - "marauder" 39 | - "medivac" 40 | weights: 41 | - 0.45 42 | - 0.45 43 | - 0.1 44 | exception_unit_types: 45 | - "medivac" 46 | observe: True 47 | 48 | start_positions: 49 | dist_type: "surrounded_and_reflect" 50 | p: 0.5 51 | map_x: 32 52 | map_y: 32 53 | # enemy_mask: 54 | # dist_type: "mask" 55 | # mask_probability: 0.5 56 | # n_enemies: 5 57 | state_last_action: True 58 | state_timestep_number: False 59 | step_mul: 8 60 | heuristic_ai: False 61 | # heuristic_rest: False 62 | debug: False 63 | prob_obs_enemy: 1.0 64 | action_mask: True 65 | 66 | test_nepisode: 32 67 | test_interval: 10000 68 | log_interval: 2000 69 | runner_log_interval: 2000 70 | learner_log_interval: 2000 71 | t_max: 10050000 72 | -------------------------------------------------------------------------------- /smacv2/examples/configs/sc2_gen_terran_epo.yaml: -------------------------------------------------------------------------------- 1 | env: sc2wrapped 2 | 3 | env_args: 4 | continuing_episode: False 5 | difficulty: "7" 6 | game_version: null 7 | map_name: "10gen_terran" 8 | move_amount: 2 9 | obs_all_health: True 10 | obs_instead_of_state: False 11 | obs_last_action: False 12 | obs_own_health: True 13 | obs_pathing_grid: False 14 | obs_terrain_height: False 15 | obs_timestep_number: False 16 | reward_death_value: 10 17 | reward_defeat: 0 18 | reward_negative_scale: 0.5 19 | reward_only_positive: True 20 | reward_scale: True 21 | reward_scale_rate: 20 22 | reward_sparse: False 23 | reward_win: 200 24 | replay_dir: "" 25 | replay_prefix: "" 26 | conic_fov: False 27 | obs_own_pos: True 28 | use_unit_ranges: True 29 | min_attack_range: 2 30 | num_fov_actions: 12 31 | capability_config: 32 | n_units: 5 33 | n_enemies: 5 34 | team_gen: 35 | dist_type: "weighted_teams" 36 | unit_types: 37 | - "marine" 38 | - "marauder" 39 | - "medivac" 40 | weights: 41 | - 0.45 42 | - 0.45 43 | - 0.1 44 | exception_unit_types: 45 | - "medivac" 46 | observe: True 47 | 48 | start_positions: 49 | dist_type: "surrounded_and_reflect" 50 | p: 0.5 51 | map_x: 32 52 | map_y: 32 53 | # enemy_mask: 54 | # dist_type: "mask" 55 | # mask_probability: 0.5 56 | # n_enemies: 5 57 | state_last_action: True 58 | state_timestep_number: False 59 | step_mul: 8 60 | heuristic_ai: False 61 | # heuristic_rest: False 62 | debug: False 63 | # Most severe partial obs setting: 64 | prob_obs_enemy: 0.0 65 | action_mask: False 66 | 67 | test_nepisode: 32 68 | test_interval: 10000 69 | log_interval: 2000 70 | runner_log_interval: 2000 71 | learner_log_interval: 2000 72 | t_max: 10050000 73 | -------------------------------------------------------------------------------- /smacv2/examples/configs/sc2_gen_zerg.yaml: -------------------------------------------------------------------------------- 1 | env: sc2wrapped 2 | 3 | env_args: 4 | continuing_episode: False 5 | difficulty: "7" 6 | game_version: null 7 | map_name: "10gen_zerg" 8 | move_amount: 2 9 | obs_all_health: True 10 | obs_instead_of_state: False 11 | obs_last_action: False 12 | obs_own_health: True 13 | obs_pathing_grid: False 14 | obs_terrain_height: False 15 | obs_timestep_number: False 16 | reward_death_value: 10 17 | reward_defeat: 0 18 | reward_negative_scale: 0.5 19 | reward_only_positive: True 20 | reward_scale: True 21 | reward_scale_rate: 20 22 | reward_sparse: False 23 | reward_win: 200 24 | replay_dir: "" 25 | replay_prefix: "" 26 | conic_fov: False 27 | use_unit_ranges: True 28 | min_attack_range: 2 29 | num_fov_actions: 12 30 | obs_own_pos: True 31 | capability_config: 32 | n_units: 5 33 | n_enemies: 5 34 | team_gen: 35 | dist_type: "weighted_teams" 36 | unit_types: 37 | - "zergling" 38 | - "baneling" 39 | - "hydralisk" 40 | weights: 41 | - 0.45 42 | - 0.1 43 | - 0.45 44 | exception_unit_types: 45 | - "baneling" 46 | observe: True 47 | 48 | start_positions: 49 | dist_type: "surrounded_and_reflect" 50 | p: 0.5 51 | map_x: 32 52 | map_y: 32 53 | # enemy_mask: 54 | # dist_type: "mask" 55 | # mask_probability: 0.5 56 | # n_enemies: 5 57 | state_last_action: True 58 | state_timestep_number: False 59 | step_mul: 8 60 | heuristic_ai: False 61 | # heuristic_rest: False 62 | debug: False 63 | prob_obs_enemy: 1.0 64 | action_mask: True 65 | 66 | test_nepisode: 32 67 | test_interval: 10000 68 | log_interval: 2000 69 | runner_log_interval: 2000 70 | learner_log_interval: 2000 71 | t_max: 10050000 72 | -------------------------------------------------------------------------------- /smacv2/examples/configs/sc2_gen_zerg_epo.yaml: -------------------------------------------------------------------------------- 1 | env: sc2wrapped 2 | 3 | env_args: 4 | continuing_episode: False 5 | difficulty: "7" 6 | game_version: null 7 | map_name: "10gen_zerg" 8 | move_amount: 2 9 | obs_all_health: True 10 | obs_instead_of_state: False 11 | obs_last_action: False 12 | obs_own_health: True 13 | obs_pathing_grid: False 14 | obs_terrain_height: False 15 | obs_timestep_number: False 16 | reward_death_value: 10 17 | reward_defeat: 0 18 | reward_negative_scale: 0.5 19 | reward_only_positive: True 20 | reward_scale: True 21 | reward_scale_rate: 20 22 | reward_sparse: False 23 | reward_win: 200 24 | replay_dir: "" 25 | replay_prefix: "" 26 | conic_fov: False 27 | use_unit_ranges: True 28 | min_attack_range: 2 29 | num_fov_actions: 12 30 | obs_own_pos: True 31 | capability_config: 32 | n_units: 5 33 | n_enemies: 5 34 | team_gen: 35 | dist_type: "weighted_teams" 36 | unit_types: 37 | - "zergling" 38 | - "baneling" 39 | - "hydralisk" 40 | weights: 41 | - 0.45 42 | - 0.1 43 | - 0.45 44 | exception_unit_types: 45 | - "baneling" 46 | observe: True 47 | 48 | start_positions: 49 | dist_type: "surrounded_and_reflect" 50 | p: 0.5 51 | map_x: 32 52 | map_y: 32 53 | # enemy_mask: 54 | # dist_type: "mask" 55 | # mask_probability: 0.5 56 | # n_enemies: 5 57 | state_last_action: True 58 | state_timestep_number: False 59 | step_mul: 8 60 | heuristic_ai: False 61 | # heuristic_rest: False 62 | debug: False 63 | # most severe partial obs setting: 64 | prob_obs_enemy: 0.0 65 | action_mask: False 66 | 67 | test_nepisode: 32 68 | test_interval: 10000 69 | log_interval: 2000 70 | runner_log_interval: 2000 71 | learner_log_interval: 2000 72 | t_max: 10050000 -------------------------------------------------------------------------------- /smacv2/examples/pettingzoo/README.rst: -------------------------------------------------------------------------------- 1 | SMAC on PettingZoo 2 | ================== 3 | 4 | This example shows how to run SMAC environments with PettingZoo multi-agent API. 5 | 6 | Instructions 7 | ------------ 8 | 9 | To get started, first install PettingZoo with ``pip install pettingzoo``. 10 | 11 | The SMAC environment for PettingZoo, ``StarCraft2PZEnv``, can be initialized with two different API templates. 12 | * **AEC**: PettingZoo is based in the *Agent Environment Cycle* game model, more information about "AEC" can be read in the following `paper `_. To create a SMAC environment as an "AEC" PettingZoo game model use: :: 13 | 14 | from smac.env.pettingzoo import StarCraft2PZEnv 15 | 16 | env = StarCraft2PZEnv.env() 17 | 18 | * **Parallel**: PettingZoo also supports parallel environments where all agents have simultaneous actions and observations. This type of environment can be created as follows: :: 19 | 20 | from smac.env.pettingzoo import StarCraft2PZEnv 21 | 22 | env = StarCraft2PZEnv.parallel_env() 23 | 24 | `pettingzoo_demo.py` has an example of a SMAC environment being used as a PettingZoo "AEC" environment. With `env.render()` it is possible to output an instance of the environment as a frame in pygame. This is useful for debugging purposes. 25 | 26 | | See https://www.pettingzoo.ml/api for more documentation. 27 | -------------------------------------------------------------------------------- /smacv2/examples/pettingzoo/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/oxwhirl/smacv2/577ab5a2cff2391f8df582da5731ea9cd6adf3c6/smacv2/examples/pettingzoo/__init__.py -------------------------------------------------------------------------------- /smacv2/examples/pettingzoo/pettingzoo_demo.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import random 6 | import numpy as np 7 | from smacv2.env.pettingzoo import StarCraft2PZEnv 8 | 9 | 10 | def main(): 11 | """ 12 | Runs an env object with random actions. 13 | """ 14 | env = StarCraft2PZEnv.env() 15 | episodes = 10 16 | 17 | total_reward = 0 18 | done = False 19 | completed_episodes = 0 20 | 21 | while completed_episodes < episodes: 22 | env.reset() 23 | for agent in env.agent_iter(): 24 | env.render() 25 | 26 | obs, reward, done, _ = env.last() 27 | total_reward += reward 28 | if done: 29 | action = None 30 | elif isinstance(obs, dict) and "action_mask" in obs: 31 | action = random.choice(np.flatnonzero(obs["action_mask"])) 32 | else: 33 | action = env.action_spaces[agent].sample() 34 | env.step(action) 35 | 36 | completed_episodes += 1 37 | 38 | env.close() 39 | 40 | print("Average total reward", total_reward / episodes) 41 | 42 | 43 | if __name__ == "__main__": 44 | main() 45 | -------------------------------------------------------------------------------- /smacv2/examples/random_agents.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import, division, print_function 2 | 3 | import time 4 | from os import replace 5 | 6 | import numpy as np 7 | from absl import logging 8 | from smacv2.env import StarCraft2Env 9 | from smacv2.env.starcraft2.wrapper import StarCraftCapabilityEnvWrapper 10 | 11 | logging.set_verbosity(logging.DEBUG) 12 | 13 | 14 | def main(): 15 | 16 | distribution_config = { 17 | "n_units": 10, 18 | "n_enemies": 11, 19 | "team_gen": { 20 | "dist_type": "weighted_teams", 21 | "unit_types": ["marine", "marauder", "medivac"], 22 | "weights": [0.45, 0.45, 0.1], 23 | "observe": True, 24 | "exception_unit_types": ["medivac"], 25 | }, 26 | 27 | "start_positions": { 28 | "dist_type": "surrounded_and_reflect", 29 | "p": 0.5, 30 | "map_x": 32, 31 | "map_y": 32, 32 | } 33 | 34 | } 35 | env = StarCraftCapabilityEnvWrapper( 36 | capability_config=distribution_config, 37 | map_name="10gen_terran", 38 | debug=False, 39 | conic_fov=False, 40 | use_unit_ranges=True, 41 | min_attack_range=2, 42 | obs_own_pos=True, 43 | fully_observable=False, 44 | ) 45 | 46 | env_info = env.get_env_info() 47 | 48 | n_actions = env_info["n_actions"] 49 | n_agents = env_info["n_agents"] 50 | cap_size = env_info["cap_shape"] 51 | 52 | n_episodes = 10 53 | print("Training episodes") 54 | env.reset() 55 | for e in range(n_episodes): 56 | env.reset() 57 | terminated = False 58 | episode_reward = 0 59 | state_features = env.get_state_feature_names() 60 | obs_features = env.get_obs_feature_names() 61 | 62 | while not terminated: 63 | obs = env.get_obs() 64 | print(f"Obs size: {obs[0].shape}") 65 | state = env.get_state() 66 | cap = env.get_capabilities() 67 | # env.render() # Uncomment for rendering 68 | 69 | actions = [] 70 | for agent_id in range(n_agents): 71 | avail_actions = env.get_avail_agent_actions(agent_id) 72 | avail_actions_ind = np.nonzero(avail_actions)[0] 73 | action = np.random.choice(avail_actions_ind) 74 | actions.append(action) 75 | 76 | reward, terminated, _ = env.step(actions) 77 | time.sleep(0.15) 78 | episode_reward += reward 79 | 80 | # print("Total reward in episode {} = {}".format(e, episode_reward)) 81 | assert len(state) == len(state_features) 82 | assert len(obs[0]) == len(obs_features) 83 | 84 | 85 | if __name__ == "__main__": 86 | main() 87 | -------------------------------------------------------------------------------- /smacv2/examples/results/smac2_training_results.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/oxwhirl/smacv2/577ab5a2cff2391f8df582da5731ea9cd6adf3c6/smacv2/examples/results/smac2_training_results.pkl -------------------------------------------------------------------------------- /smacv2/examples/rllib/README.rst: -------------------------------------------------------------------------------- 1 | SMAC on RLlib 2 | ============= 3 | 4 | This example shows how to run SMAC environments with RLlib multi-agent. 5 | 6 | Instructions 7 | ------------ 8 | 9 | To get started, first install RLlib with ``pip install -U ray[rllib]``. You will also need TensorFlow installed. 10 | 11 | In ``run_ppo.py``, each agent will be controlled by an independent PPO policy (the policies share weights). This setup serves as a single-agent baseline for this task. 12 | 13 | In ``run_qmix.py``, the agents are controlled by the multi-agent QMIX policy. This setup is an example of centralized training and decentralized execution. 14 | 15 | See https://ray.readthedocs.io/en/latest/rllib.html for more documentation. 16 | -------------------------------------------------------------------------------- /smacv2/examples/rllib/__init__.py: -------------------------------------------------------------------------------- 1 | from smacv2.examples.rllib.env import RLlibStarCraft2Env 2 | from smacv2.examples.rllib.model import MaskedActionsModel 3 | 4 | __all__ = ["RLlibStarCraft2Env", "MaskedActionsModel"] 5 | -------------------------------------------------------------------------------- /smacv2/examples/rllib/env.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import random 6 | 7 | import numpy as np 8 | 9 | from gym.spaces import Discrete, Box, Dict 10 | 11 | from ray import rllib 12 | 13 | from smacv2.env import StarCraft2Env 14 | 15 | 16 | class RLlibStarCraft2Env(rllib.MultiAgentEnv): 17 | """Wraps a smac StarCraft env to be compatible with RLlib multi-agent.""" 18 | 19 | def __init__(self, **smac_args): 20 | """Create a new multi-agent StarCraft env compatible with RLlib. 21 | 22 | Arguments: 23 | smac_args (dict): Arguments to pass to the underlying 24 | smac.env.starcraft.StarCraft2Env instance. 25 | 26 | Examples: 27 | >>> from smac.examples.rllib import RLlibStarCraft2Env 28 | >>> env = RLlibStarCraft2Env(map_name="8m") 29 | >>> print(env.reset()) 30 | """ 31 | 32 | self._env = StarCraft2Env(**smac_args) 33 | self._ready_agents = [] 34 | self.observation_space = Dict( 35 | { 36 | "obs": Box(-1, 1, shape=(self._env.get_obs_size(),)), 37 | "action_mask": Box( 38 | 0, 1, shape=(self._env.get_total_actions(),) 39 | ), 40 | } 41 | ) 42 | self.action_space = Discrete(self._env.get_total_actions()) 43 | 44 | def reset(self): 45 | """Resets the env and returns observations from ready agents. 46 | 47 | Returns: 48 | obs (dict): New observations for each ready agent. 49 | """ 50 | 51 | obs_list, state_list = self._env.reset() 52 | return_obs = {} 53 | for i, obs in enumerate(obs_list): 54 | return_obs[i] = { 55 | "action_mask": np.array(self._env.get_avail_agent_actions(i)), 56 | "obs": obs, 57 | } 58 | 59 | self._ready_agents = list(range(len(obs_list))) 60 | return return_obs 61 | 62 | def step(self, action_dict): 63 | """Returns observations from ready agents. 64 | 65 | The returns are dicts mapping from agent_id strings to values. The 66 | number of agents in the env can vary over time. 67 | 68 | Returns 69 | ------- 70 | obs (dict): New observations for each ready agent. 71 | rewards (dict): Reward values for each ready agent. If the 72 | episode is just started, the value will be None. 73 | dones (dict): Done values for each ready agent. The special key 74 | "__all__" (required) is used to indicate env termination. 75 | infos (dict): Optional info values for each agent id. 76 | """ 77 | 78 | actions = [] 79 | for i in self._ready_agents: 80 | if i not in action_dict: 81 | raise ValueError( 82 | "You must supply an action for agent: {}".format(i) 83 | ) 84 | actions.append(action_dict[i]) 85 | 86 | if len(actions) != len(self._ready_agents): 87 | raise ValueError( 88 | "Unexpected number of actions: {}".format( 89 | action_dict, 90 | ) 91 | ) 92 | 93 | rew, done, info = self._env.step(actions) 94 | obs_list = self._env.get_obs() 95 | return_obs = {} 96 | for i, obs in enumerate(obs_list): 97 | return_obs[i] = { 98 | "action_mask": self._env.get_avail_agent_actions(i), 99 | "obs": obs, 100 | } 101 | rews = {i: rew / len(obs_list) for i in range(len(obs_list))} 102 | dones = {i: done for i in range(len(obs_list))} 103 | dones["__all__"] = done 104 | infos = {i: info for i in range(len(obs_list))} 105 | 106 | self._ready_agents = list(range(len(obs_list))) 107 | return return_obs, rews, dones, infos 108 | 109 | def close(self): 110 | """Close the environment""" 111 | self._env.close() 112 | 113 | def seed(self, seed): 114 | random.seed(seed) 115 | np.random.seed(seed) 116 | -------------------------------------------------------------------------------- /smacv2/examples/rllib/model.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import tensorflow as tf 6 | 7 | from ray.rllib.models import Model 8 | from ray.rllib.models.tf.misc import normc_initializer 9 | 10 | 11 | class MaskedActionsModel(Model): 12 | """Custom RLlib model that emits -inf logits for invalid actions. 13 | 14 | This is used to handle the variable-length StarCraft action space. 15 | """ 16 | 17 | def _build_layers_v2(self, input_dict, num_outputs, options): 18 | action_mask = input_dict["obs"]["action_mask"] 19 | if num_outputs != action_mask.shape[1].value: 20 | raise ValueError( 21 | "This model assumes num outputs is equal to max avail actions", 22 | num_outputs, 23 | action_mask, 24 | ) 25 | 26 | # Standard fully connected network 27 | last_layer = input_dict["obs"]["obs"] 28 | hiddens = options.get("fcnet_hiddens") 29 | for i, size in enumerate(hiddens): 30 | label = "fc{}".format(i) 31 | last_layer = tf.layers.dense( 32 | last_layer, 33 | size, 34 | kernel_initializer=normc_initializer(1.0), 35 | activation=tf.nn.tanh, 36 | name=label, 37 | ) 38 | action_logits = tf.layers.dense( 39 | last_layer, 40 | num_outputs, 41 | kernel_initializer=normc_initializer(0.01), 42 | activation=None, 43 | name="fc_out", 44 | ) 45 | 46 | # Mask out invalid actions (use tf.float32.min for stability) 47 | inf_mask = tf.maximum(tf.log(action_mask), tf.float32.min) 48 | masked_logits = inf_mask + action_logits 49 | 50 | return masked_logits, last_layer 51 | -------------------------------------------------------------------------------- /smacv2/examples/rllib/run_ppo.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | """Example of running StarCraft2 with RLlib PPO. 6 | 7 | In this setup, each agent will be controlled by an independent PPO policy. 8 | However the policies share weights. 9 | 10 | Increase the level of parallelism by changing --num-workers. 11 | """ 12 | 13 | import argparse 14 | 15 | import ray 16 | from ray.tune import run_experiments, register_env 17 | from ray.rllib.models import ModelCatalog 18 | 19 | from smacv2.examples.rllib.env import RLlibStarCraft2Env 20 | from smacv2.examples.rllib.model import MaskedActionsModel 21 | 22 | 23 | if __name__ == "__main__": 24 | parser = argparse.ArgumentParser() 25 | parser.add_argument("--num-iters", type=int, default=100) 26 | parser.add_argument("--num-workers", type=int, default=2) 27 | parser.add_argument("--map-name", type=str, default="8m") 28 | args = parser.parse_args() 29 | 30 | ray.init() 31 | 32 | register_env("smac", lambda smac_args: RLlibStarCraft2Env(**smac_args)) 33 | ModelCatalog.register_custom_model("mask_model", MaskedActionsModel) 34 | 35 | run_experiments( 36 | { 37 | "ppo_sc2": { 38 | "run": "PPO", 39 | "env": "smac", 40 | "stop": { 41 | "training_iteration": args.num_iters, 42 | }, 43 | "config": { 44 | "num_workers": args.num_workers, 45 | "observation_filter": "NoFilter", # breaks the action mask 46 | "vf_share_layers": True, # no separate value model 47 | "env_config": { 48 | "map_name": args.map_name, 49 | }, 50 | "model": { 51 | "custom_model": "mask_model", 52 | }, 53 | }, 54 | }, 55 | } 56 | ) 57 | -------------------------------------------------------------------------------- /smacv2/examples/rllib/run_qmix.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | """Example of running StarCraft2 with RLlib QMIX. 6 | 7 | This assumes all agents are homogeneous. The agents are grouped and assigned 8 | to the multi-agent QMIX policy. Note that the default hyperparameters for 9 | RLlib QMIX are different from pymarl's QMIX. 10 | """ 11 | 12 | import argparse 13 | from gym.spaces import Tuple 14 | 15 | import ray 16 | from ray.tune import run_experiments, register_env 17 | 18 | from smacv2.examples.rllib.env import RLlibStarCraft2Env 19 | 20 | 21 | if __name__ == "__main__": 22 | parser = argparse.ArgumentParser() 23 | parser.add_argument("--num-iters", type=int, default=100) 24 | parser.add_argument("--num-workers", type=int, default=2) 25 | parser.add_argument("--map-name", type=str, default="8m") 26 | args = parser.parse_args() 27 | 28 | def env_creator(smac_args): 29 | env = RLlibStarCraft2Env(**smac_args) 30 | agent_list = list(range(env._env.n_agents)) 31 | grouping = { 32 | "group_1": agent_list, 33 | } 34 | obs_space = Tuple([env.observation_space for i in agent_list]) 35 | act_space = Tuple([env.action_space for i in agent_list]) 36 | return env.with_agent_groups( 37 | grouping, obs_space=obs_space, act_space=act_space 38 | ) 39 | 40 | ray.init() 41 | register_env("sc2_grouped", env_creator) 42 | 43 | run_experiments( 44 | { 45 | "qmix_sc2": { 46 | "run": "QMIX", 47 | "env": "sc2_grouped", 48 | "stop": { 49 | "training_iteration": args.num_iters, 50 | }, 51 | "config": { 52 | "num_workers": args.num_workers, 53 | "env_config": { 54 | "map_name": args.map_name, 55 | }, 56 | }, 57 | }, 58 | } 59 | ) 60 | --------------------------------------------------------------------------------