├── .gitignore ├── pysc2 ├── maps │ ├── mini_games │ │ ├── BuildMarines.SC2Map │ │ ├── DefeatRoaches.SC2Map │ │ ├── MoveToBeacon.SC2Map │ │ ├── CollectMineralShards.SC2Map │ │ ├── CollectMineralsAndGas.SC2Map │ │ ├── FindAndDefeatZerglings.SC2Map │ │ └── DefeatZerglingsAndBanelings.SC2Map │ ├── melee.py │ ├── __init__.py │ ├── mini_games.py │ ├── lib.py │ └── ladder.py ├── agents │ ├── __init__.py │ ├── random_agent.py │ ├── base_agent.py │ └── scripted_agent.py ├── bin │ ├── __init__.py │ ├── run_tests.py │ ├── map_list.py │ ├── gen_versions.py │ ├── valid_actions.py │ ├── gen_units.py │ ├── benchmark_observe.py │ ├── replay_info.py │ ├── mem_leak_check.py │ └── agent.py ├── env │ ├── __init__.py │ ├── available_actions_printer.py │ ├── base_env_wrapper.py │ ├── run_loop.py │ ├── mock_sc2_env_comparison_test.py │ ├── environment.py │ ├── mock_sc2_env_test.py │ └── host_remote_agent.py ├── lib │ ├── __init__.py │ ├── gfile.py │ ├── video_writer.py │ ├── portspicker_test.py │ ├── metrics.py │ ├── point_flag.py │ ├── portspicker.py │ ├── run_parallel_test.py │ ├── static_data.py │ ├── run_parallel.py │ ├── transform.py │ ├── stopwatch_test.py │ ├── protocol.py │ ├── sc_process.py │ ├── units.py │ └── point_test.py ├── tests │ ├── __init__.py │ ├── utils.py │ ├── ping_test.py │ ├── protocol_error_test.py │ ├── random_agent_test.py │ ├── step_mul_override_test.py │ ├── multi_player_env_test.py │ ├── observer_test.py │ ├── maps_test.py │ ├── host_remote_agent_test.py │ ├── debug_test.py │ ├── render_test.py │ ├── easy_scripted_test.py │ ├── versions_test.py │ ├── multi_player_test.py │ └── obs_spec_test.py ├── __init__.py └── run_configs │ ├── __init__.py │ ├── platforms.py │ └── lib.py ├── CONTRIBUTING.md ├── docs ├── maps.md └── mini_games.md └── setup.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | *_pb2.py 3 | -------------------------------------------------------------------------------- /pysc2/maps/mini_games/BuildMarines.SC2Map: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/llSourcell/pysc2/HEAD/pysc2/maps/mini_games/BuildMarines.SC2Map -------------------------------------------------------------------------------- /pysc2/maps/mini_games/DefeatRoaches.SC2Map: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/llSourcell/pysc2/HEAD/pysc2/maps/mini_games/DefeatRoaches.SC2Map -------------------------------------------------------------------------------- /pysc2/maps/mini_games/MoveToBeacon.SC2Map: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/llSourcell/pysc2/HEAD/pysc2/maps/mini_games/MoveToBeacon.SC2Map -------------------------------------------------------------------------------- /pysc2/maps/mini_games/CollectMineralShards.SC2Map: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/llSourcell/pysc2/HEAD/pysc2/maps/mini_games/CollectMineralShards.SC2Map -------------------------------------------------------------------------------- /pysc2/maps/mini_games/CollectMineralsAndGas.SC2Map: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/llSourcell/pysc2/HEAD/pysc2/maps/mini_games/CollectMineralsAndGas.SC2Map -------------------------------------------------------------------------------- /pysc2/maps/mini_games/FindAndDefeatZerglings.SC2Map: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/llSourcell/pysc2/HEAD/pysc2/maps/mini_games/FindAndDefeatZerglings.SC2Map -------------------------------------------------------------------------------- /pysc2/maps/mini_games/DefeatZerglingsAndBanelings.SC2Map: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/llSourcell/pysc2/HEAD/pysc2/maps/mini_games/DefeatZerglingsAndBanelings.SC2Map -------------------------------------------------------------------------------- /pysc2/agents/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS-IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /pysc2/bin/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS-IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /pysc2/env/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS-IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /pysc2/lib/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS-IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /pysc2/tests/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS-IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /pysc2/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS-IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """PySC2 module: https://github.com/deepmind/pysc2 .""" 15 | 16 | import os 17 | 18 | 19 | def load_tests(loader, standard_tests, unused_pattern): 20 | """Our tests end in `_test.py`, so need to override the test discovery.""" 21 | this_dir = os.path.dirname(__file__) 22 | package_tests = loader.discover(start_dir=this_dir, pattern="*_test.py") 23 | standard_tests.addTests(package_tests) 24 | return standard_tests 25 | -------------------------------------------------------------------------------- /pysc2/bin/run_tests.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # Copyright 2017 Google Inc. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS-IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Find and run the tests. 16 | 17 | Run as: python -m pysc2.bin.run_tests 18 | """ 19 | 20 | from __future__ import absolute_import 21 | from __future__ import division 22 | from __future__ import print_function 23 | 24 | from absl.testing import absltest 25 | 26 | import pysc2 27 | import pysc2.run_configs.platforms # So that the version flags work. 28 | 29 | 30 | if __name__ == '__main__': 31 | absltest.main(module=pysc2) 32 | -------------------------------------------------------------------------------- /pysc2/bin/map_list.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # Copyright 2017 Google Inc. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS-IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Print the list of defined maps.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | from absl import app 22 | 23 | from pysc2 import maps 24 | 25 | 26 | def main(unused_argv): 27 | for _, map_class in sorted(maps.get_maps().items()): 28 | mp = map_class() 29 | if mp.path: 30 | print(mp, "\n") 31 | 32 | 33 | if __name__ == "__main__": 34 | app.run(main) 35 | -------------------------------------------------------------------------------- /pysc2/lib/gfile.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS-IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """This is replaces google's gfile used for network storage. 15 | 16 | A more complete public version of gfile: 17 | https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/platform/gfile.py 18 | """ 19 | 20 | from __future__ import absolute_import 21 | from __future__ import division 22 | from __future__ import print_function 23 | 24 | import os 25 | 26 | # pylint: disable=invalid-name 27 | Exists = os.path.exists 28 | IsDirectory = os.path.isdir 29 | ListDir = os.listdir 30 | MakeDirs = os.makedirs 31 | Open = open 32 | -------------------------------------------------------------------------------- /pysc2/agents/random_agent.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS-IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """A random agent for starcraft.""" 15 | 16 | from __future__ import absolute_import 17 | from __future__ import division 18 | from __future__ import print_function 19 | 20 | import numpy 21 | 22 | from pysc2.agents import base_agent 23 | from pysc2.lib import actions 24 | 25 | 26 | class RandomAgent(base_agent.BaseAgent): 27 | """A random agent for starcraft.""" 28 | 29 | def step(self, obs): 30 | super(RandomAgent, self).step(obs) 31 | function_id = numpy.random.choice(obs.observation.available_actions) 32 | args = [[numpy.random.randint(0, size) for size in arg.sizes] 33 | for arg in self.action_spec.functions[function_id].args] 34 | return actions.FunctionCall(function_id, args) 35 | -------------------------------------------------------------------------------- /pysc2/lib/video_writer.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS-IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language gwritererning permissions and 13 | # limitations under the License. 14 | """Write a video based on a numpy array.""" 15 | 16 | from __future__ import absolute_import 17 | from __future__ import division 18 | from __future__ import print_function 19 | 20 | import skvideo.io 21 | 22 | 23 | class VideoWriter(skvideo.io.FFmpegWriter): 24 | """Write a video based on a numpy array. 25 | 26 | Subclass/wrap FFmpegWriter to make it easy to switch to a different library. 27 | """ 28 | 29 | def __init__(self, filename, frame_rate): 30 | super(VideoWriter, self).__init__( 31 | filename, outputdict={"-r": str(frame_rate)}) 32 | 33 | def add(self, frame): 34 | """Add a frame to the video based on a numpy array.""" 35 | self.writeFrame(frame) 36 | 37 | def __del__(self): 38 | self.close() 39 | 40 | -------------------------------------------------------------------------------- /pysc2/tests/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS-IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Unit test tools.""" 15 | 16 | from __future__ import absolute_import 17 | from __future__ import division 18 | from __future__ import print_function 19 | 20 | from absl import logging 21 | 22 | from absl.testing import absltest 23 | from pysc2.lib import stopwatch 24 | 25 | 26 | class TestCase(absltest.TestCase): 27 | """A test base class that enables stopwatch profiling.""" 28 | 29 | def setUp(self): 30 | super(TestCase, self).setUp() 31 | stopwatch.sw.clear() 32 | self._sw_enabled = stopwatch.sw.enabled 33 | stopwatch.sw.enabled = True 34 | 35 | def tearDown(self): 36 | super(TestCase, self).tearDown() 37 | s = str(stopwatch.sw) 38 | if s: 39 | logging.info("Stop watch profile:\n%s", s) 40 | stopwatch.sw.enabled = self._sw_enabled 41 | -------------------------------------------------------------------------------- /pysc2/maps/melee.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS-IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Define the melee map configs.""" 15 | 16 | from __future__ import absolute_import 17 | from __future__ import division 18 | from __future__ import print_function 19 | 20 | from pysc2.maps import lib 21 | 22 | 23 | class Melee(lib.Map): 24 | directory = "Melee" 25 | download = "https://github.com/Blizzard/s2client-proto#map-packs" 26 | players = 2 27 | game_steps_per_episode = 16 * 60 * 30 # 30 minute limit. 28 | 29 | 30 | melee_maps = [ 31 | # "Empty128", # Not really playable, but may be useful in the future. 32 | "Flat32", 33 | "Flat48", 34 | "Flat64", 35 | "Flat96", 36 | "Flat128", 37 | "Simple64", 38 | "Simple96", 39 | "Simple128", 40 | ] 41 | 42 | for name in melee_maps: 43 | globals()[name] = type(name, (Melee,), dict(filename=name)) 44 | -------------------------------------------------------------------------------- /pysc2/maps/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS-IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Register/import the maps, and offer a way to create one by name. 15 | 16 | Users of maps should import this module: 17 | from pysc2 import maps 18 | and create the maps by name: 19 | maps.get("MapName") 20 | 21 | If you want to create your own map, then import the map lib and subclass Map. 22 | Your subclass will be implicitly registered as a map that can be constructed by 23 | name, as long as it is imported somewhere. 24 | """ 25 | 26 | from __future__ import absolute_import 27 | from __future__ import division 28 | from __future__ import print_function 29 | 30 | from pysc2.maps import ladder 31 | from pysc2.maps import lib 32 | from pysc2.maps import melee 33 | from pysc2.maps import mini_games 34 | 35 | 36 | # Use `get` to create a map by name. 37 | get = lib.get 38 | get_maps = lib.get_maps 39 | -------------------------------------------------------------------------------- /pysc2/tests/ping_test.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # Copyright 2017 Google Inc. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS-IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Benchmark the ping rate of SC2.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | from absl.testing import absltest 22 | from future.builtins import range # pylint: disable=redefined-builtin 23 | 24 | from pysc2 import run_configs 25 | from pysc2.lib import stopwatch 26 | from pysc2.tests import utils 27 | 28 | 29 | class TestPing(utils.TestCase): 30 | 31 | def test_ping(self): 32 | count = 100 33 | 34 | with run_configs.get().start(want_rgb=False) as controller: 35 | with stopwatch.sw("first"): 36 | controller.ping() 37 | 38 | for _ in range(count): 39 | controller.ping() 40 | 41 | self.assertEqual(stopwatch.sw["ping"].num, count) 42 | 43 | 44 | if __name__ == "__main__": 45 | absltest.main() 46 | -------------------------------------------------------------------------------- /pysc2/maps/mini_games.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS-IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Define the mini game map configs. These are maps made by Deepmind.""" 15 | 16 | from __future__ import absolute_import 17 | from __future__ import division 18 | from __future__ import print_function 19 | 20 | from pysc2.maps import lib 21 | 22 | 23 | class MiniGame(lib.Map): 24 | directory = "mini_games" 25 | download = "https://github.com/deepmind/pysc2#get-the-maps" 26 | players = 1 27 | score_index = 0 28 | game_steps_per_episode = 0 29 | step_mul = 8 30 | 31 | 32 | mini_games = [ 33 | "BuildMarines", # 900s 34 | "CollectMineralsAndGas", # 420s 35 | "CollectMineralShards", # 120s 36 | "DefeatRoaches", # 120s 37 | "DefeatZerglingsAndBanelings", # 120s 38 | "FindAndDefeatZerglings", # 180s 39 | "MoveToBeacon", # 120s 40 | ] 41 | 42 | 43 | for name in mini_games: 44 | globals()[name] = type(name, (MiniGame,), dict(filename=name)) 45 | -------------------------------------------------------------------------------- /pysc2/bin/gen_versions.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # Copyright 2017 Google Inc. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS-IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Generate the list of versions for run_configs.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | from absl import app 22 | import requests 23 | 24 | # raw version of: 25 | # https://github.com/Blizzard/s2client-proto/blob/master/buildinfo/versions.json 26 | VERSIONS_FILE = "https://raw.githubusercontent.com/Blizzard/s2client-proto/master/buildinfo/versions.json" 27 | 28 | 29 | def main(argv): 30 | del argv # Unused. 31 | 32 | versions = requests.get(VERSIONS_FILE).json() 33 | 34 | for v in versions: 35 | version_str = v["label"] 36 | if version_str.count(".") == 1: 37 | version_str += ".0" 38 | print(' Version("%s", %i, "%s", None),' % ( 39 | version_str, v["base-version"], v["data-hash"])) 40 | 41 | 42 | if __name__ == "__main__": 43 | app.run(main) 44 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | Want to contribute? Great! First, read this page (including the small print at 2 | the end). 3 | 4 | ### Before you contribute 5 | 6 | Before we can use your code, you must sign the 7 | [Google Individual Contributor License Agreement] 8 | (https://cla.developers.google.com/about/google-individual) 9 | (CLA), which you can do online. The CLA is necessary mainly because you own the 10 | copyright to your changes, even after your contribution becomes part of our 11 | codebase, so we need your permission to use and distribute your code. We also 12 | need to be sure of various other things—for instance that you'll tell us if you 13 | know that your code infringes on other people's patents. You don't have to sign 14 | the CLA until after you've submitted your code for review and a member has 15 | approved it, but you must do it before we can put your code into our codebase. 16 | Before you start working on a larger contribution, you should get in touch with 17 | us first through the issue tracker with your idea so that we can help out and 18 | possibly guide you. Coordinating up front makes it much easier to avoid 19 | frustration later on. 20 | 21 | ### Code reviews 22 | 23 | All submissions, including submissions by project members, require review. We 24 | use Github pull requests for this purpose. 25 | 26 | ### The small print 27 | 28 | Contributions made by corporations are covered by a different agreement than 29 | the one above, the 30 | [Software Grant and Corporate Contributor License Agreement] 31 | (https://cla.developers.google.com/about/google-corporate). 32 | -------------------------------------------------------------------------------- /pysc2/agents/base_agent.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS-IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """A base agent to write custom scripted agents.""" 15 | 16 | from __future__ import absolute_import 17 | from __future__ import division 18 | from __future__ import print_function 19 | 20 | from pysc2.lib import actions 21 | 22 | 23 | class BaseAgent(object): 24 | """A base agent to write custom scripted agents. 25 | 26 | It can also act as a passive agent that does nothing but no-ops. 27 | """ 28 | 29 | def __init__(self): 30 | self.reward = 0 31 | self.episodes = 0 32 | self.steps = 0 33 | self.obs_spec = None 34 | self.action_spec = None 35 | 36 | def setup(self, obs_spec, action_spec): 37 | self.obs_spec = obs_spec 38 | self.action_spec = action_spec 39 | 40 | def reset(self): 41 | self.episodes += 1 42 | 43 | def step(self, obs): 44 | self.steps += 1 45 | self.reward += obs.reward 46 | return actions.FunctionCall(actions.FUNCTIONS.no_op.id, []) 47 | -------------------------------------------------------------------------------- /pysc2/lib/portspicker_test.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # Copyright 2018 Google Inc. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS-IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Tests for portspicker.py.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | from absl.testing import absltest 22 | from absl.testing import parameterized 23 | from pysc2.lib import portspicker 24 | 25 | 26 | class PortsTest(parameterized.TestCase): 27 | 28 | @parameterized.parameters(range(10)) 29 | def testNonContiguousReservation(self, num_ports): 30 | reserved = portspicker.pick_unused_ports(num_ports) 31 | self.assertEqual(len(reserved), num_ports) 32 | portspicker.return_ports(reserved) 33 | 34 | @parameterized.parameters(range(10)) 35 | def testContiguousReservation(self, num_ports): 36 | reserved = portspicker.pick_contiguous_unused_ports(num_ports) 37 | self.assertEqual(len(reserved), num_ports) 38 | portspicker.return_ports(reserved) 39 | 40 | 41 | if __name__ == "__main__": 42 | absltest.main() 43 | -------------------------------------------------------------------------------- /pysc2/env/available_actions_printer.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS-IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """An env wrapper to print the available actions.""" 15 | 16 | from __future__ import absolute_import 17 | from __future__ import division 18 | from __future__ import print_function 19 | 20 | from pysc2.env import base_env_wrapper 21 | 22 | 23 | class AvailableActionsPrinter(base_env_wrapper.BaseEnvWrapper): 24 | """An env wrapper to print the available actions.""" 25 | 26 | def __init__(self, env): 27 | super(AvailableActionsPrinter, self).__init__(env) 28 | self._seen = set() 29 | self._action_spec = self.action_spec()[0] 30 | 31 | def step(self, *args, **kwargs): 32 | all_obs = super(AvailableActionsPrinter, self).step(*args, **kwargs) 33 | for obs in all_obs: 34 | for avail in obs.observation["available_actions"]: 35 | if avail not in self._seen: 36 | self._seen.add(avail) 37 | self._print(self._action_spec.functions[avail].str(True)) 38 | return all_obs 39 | 40 | def _print(self, s): 41 | print(s) 42 | -------------------------------------------------------------------------------- /pysc2/tests/protocol_error_test.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # Copyright 2017 Google Inc. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS-IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Verify that we blow up if SC2 thinks we did something wrong.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | from absl.testing import absltest 22 | from pysc2 import run_configs 23 | from pysc2.lib import protocol 24 | from pysc2.lib import remote_controller 25 | from pysc2.tests import utils 26 | 27 | from s2clientprotocol import sc2api_pb2 as sc_pb 28 | 29 | 30 | class TestProtocolError(utils.TestCase): 31 | """Verify that we blow up if SC2 thinks we did something wrong.""" 32 | 33 | def test_error(self): 34 | with run_configs.get().start(want_rgb=False) as controller: 35 | with self.assertRaises(remote_controller.RequestError): 36 | controller.create_game(sc_pb.RequestCreateGame()) # Missing map, etc. 37 | 38 | with self.assertRaises(protocol.ProtocolError): 39 | controller.join_game(sc_pb.RequestJoinGame()) # No game to join. 40 | 41 | 42 | if __name__ == "__main__": 43 | absltest.main() 44 | -------------------------------------------------------------------------------- /pysc2/env/base_env_wrapper.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS-IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """A base env wrapper so we don't need to override everything every time.""" 15 | 16 | from __future__ import absolute_import 17 | from __future__ import division 18 | from __future__ import print_function 19 | 20 | from pysc2.env import environment 21 | 22 | 23 | class BaseEnvWrapper(environment.Base): 24 | """A base env wrapper so we don't need to override everything every time.""" 25 | 26 | def __init__(self, env): 27 | self._env = env 28 | 29 | def close(self, *args, **kwargs): 30 | return self._env.close(*args, **kwargs) 31 | 32 | def action_spec(self, *args, **kwargs): 33 | return self._env.action_spec(*args, **kwargs) 34 | 35 | def observation_spec(self, *args, **kwargs): 36 | return self._env.observation_spec(*args, **kwargs) 37 | 38 | def reset(self, *args, **kwargs): 39 | return self._env.reset(*args, **kwargs) 40 | 41 | def step(self, *args, **kwargs): 42 | return self._env.step(*args, **kwargs) 43 | 44 | def save_replay(self, *args, **kwargs): 45 | return self._env.save_replay(*args, **kwargs) 46 | 47 | @property 48 | def state(self): 49 | return self._env.state 50 | -------------------------------------------------------------------------------- /pysc2/run_configs/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS-IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Configs for various ways to run starcraft.""" 15 | 16 | from __future__ import absolute_import 17 | from __future__ import division 18 | from __future__ import print_function 19 | 20 | from absl import flags 21 | 22 | from pysc2.lib import sc_process 23 | from pysc2.run_configs import platforms 24 | from pysc2.run_configs import lib 25 | 26 | flags.DEFINE_string("sc2_run_config", None, 27 | "Which run_config to use to spawn the binary.") 28 | FLAGS = flags.FLAGS 29 | 30 | 31 | def get(): 32 | """Get the config chosen by the flags.""" 33 | configs = {c.name(): c 34 | for c in lib.RunConfig.all_subclasses() if c.priority()} 35 | 36 | if not configs: 37 | raise sc_process.SC2LaunchError("No valid run_configs found.") 38 | 39 | if FLAGS.sc2_run_config is None: # Find the highest priority as default. 40 | return max(configs.values(), key=lambda c: c.priority())() 41 | 42 | try: 43 | return configs[FLAGS.sc2_run_config]() 44 | except KeyError: 45 | raise sc_process.SC2LaunchError( 46 | "Invalid run_config. Valid configs are: %s" % ( 47 | ", ".join(sorted(configs.keys())))) 48 | -------------------------------------------------------------------------------- /pysc2/lib/metrics.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS-IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Interface for tracking the number and/or latency of episodes and steps.""" 15 | 16 | from __future__ import absolute_import 17 | from __future__ import division 18 | from __future__ import print_function 19 | 20 | 21 | class _EventTimer(object): 22 | """Example event timer to measure step and observation times.""" 23 | 24 | def __enter__(self): 25 | pass 26 | 27 | def __exit__(self, unused_exception_type, unused_exc_value, unused_traceback): 28 | pass 29 | 30 | 31 | class Metrics(object): 32 | """Interface for tracking the number and/or latency of episodes and steps.""" 33 | 34 | def __init__(self, map_name): 35 | pass 36 | 37 | def increment_instance(self): 38 | pass 39 | 40 | def increment_episode(self): 41 | pass 42 | 43 | def measure_step_time(self, num_steps=1): 44 | """Return a context manager to measure the time to perform N game steps.""" 45 | del num_steps 46 | return _EventTimer() 47 | 48 | def measure_observation_time(self): 49 | """Return a context manager to measure the time to get an observation.""" 50 | return _EventTimer() 51 | 52 | def close(self): 53 | pass 54 | 55 | def __del__(self): 56 | self.close() 57 | -------------------------------------------------------------------------------- /docs/maps.md: -------------------------------------------------------------------------------- 1 | # StarCraft II Maps 2 | 3 | ## Map config 4 | 5 | SC2Map files are what is used by the SC2 game, but they can be used differently, 6 | and those differences are defined in our map configs. The config gives 7 | information like how long the episodes last, how many players it can play, and 8 | how to score it. 9 | 10 | To create your own map config, just subclass the base Map class and override 11 | some of the settings. The most important is to define the directory and filename 12 | for the SC2Map. Any Map subclass will be automatically picked up as long as it's 13 | imported somewhere. 14 | 15 | ## DeepMind Mini-Games 16 | 17 | The [mini-games](mini_games.md) are designed to be single-player, fixed length 18 | and exercise different aspects of the game. They expose a score/reward which 19 | lets the agent know how well it is doing. The score should differentiate poor 20 | agents (eg random) from good agents. 21 | 22 | ## Ladder 23 | 24 | [Ladder maps](http://wiki.teamliquid.net/starcraft2/Maps/Ladder_Maps/Legacy_of_the_Void) 25 | are the maps played by human players on Battle.net. There are just a handful 26 | active at a time. Every few months a new season starts bringing a new set of 27 | maps. 28 | 29 | Some of the maps have suffixes LE or TE. LE means Ladder Edition. These are 30 | community maps that were edited by Blizzard for bugs and made ready for the 31 | ladder pool. TE means Tournament Edition. These maps were used in tournaments. 32 | 33 | They are all multiplayer maps with fairly long time limits. 34 | 35 | ## Melee 36 | 37 | These are maps made specifically for machine learning. They resemble 38 | ladder maps in format, but may be smaller sizes and aren't necessarily balanced 39 | for high level play. 40 | 41 | The **Flat** maps have no special features on the terrain, encouraging easy 42 | attacking. The number specifies the map size. 43 | 44 | The **Simple** maps are more normal with expansions, ramps, and lanes of attack, 45 | but are smaller than normal ladder maps. The number specifies the map size. 46 | -------------------------------------------------------------------------------- /pysc2/lib/point_flag.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS-IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Define a flag type for points.""" 15 | 16 | from __future__ import absolute_import 17 | from __future__ import division 18 | from __future__ import print_function 19 | 20 | from absl import flags 21 | import six 22 | 23 | from pysc2.lib import point 24 | 25 | 26 | class PointParser(flags.ArgumentParser): 27 | """Parse a flag into a point.""" 28 | 29 | def parse(self, argument): 30 | if not argument or argument == "0": 31 | return None 32 | 33 | if isinstance(argument, int): 34 | args = [argument] 35 | elif isinstance(argument, (list, tuple)): 36 | args = argument 37 | elif isinstance(argument, six.string_types): 38 | args = argument.split(",") 39 | else: 40 | raise ValueError( 41 | "Invalid point: '%r'. Valid: '' or ','." % argument) 42 | 43 | args = [int(v) for v in args] 44 | 45 | if len(args) == 1: 46 | args *= 2 47 | if len(args) == 2: 48 | return point.Point(args[0], args[1]) 49 | raise ValueError( 50 | "Invalid point: '%s'. Valid: '' or ','." % argument) 51 | 52 | def flag_type(self): 53 | return "pysc2 point" 54 | 55 | 56 | def DEFINE_point(name, default, help): # pylint: disable=invalid-name,redefined-builtin 57 | """Registers a flag whose value parses as a point.""" 58 | flags.DEFINE(PointParser(), name, default, help) 59 | -------------------------------------------------------------------------------- /pysc2/env/run_loop.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS-IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """A run loop for agent/environment interaction.""" 15 | 16 | from __future__ import absolute_import 17 | from __future__ import division 18 | from __future__ import print_function 19 | 20 | import time 21 | 22 | 23 | def run_loop(agents, env, max_frames=0, max_episodes=0): 24 | """A run loop to have agents and an environment interact.""" 25 | total_frames = 0 26 | total_episodes = 0 27 | start_time = time.time() 28 | 29 | observation_spec = env.observation_spec() 30 | action_spec = env.action_spec() 31 | for agent, obs_spec, act_spec in zip(agents, observation_spec, action_spec): 32 | agent.setup(obs_spec, act_spec) 33 | 34 | try: 35 | while not max_episodes or total_episodes < max_episodes: 36 | total_episodes += 1 37 | timesteps = env.reset() 38 | for a in agents: 39 | a.reset() 40 | while True: 41 | total_frames += 1 42 | actions = [agent.step(timestep) 43 | for agent, timestep in zip(agents, timesteps)] 44 | if max_frames and total_frames >= max_frames: 45 | return 46 | if timesteps[0].last(): 47 | break 48 | timesteps = env.step(actions) 49 | except KeyboardInterrupt: 50 | pass 51 | finally: 52 | elapsed_time = time.time() - start_time 53 | print("Took %.3f seconds for %s steps: %.3f fps" % ( 54 | elapsed_time, total_frames, total_frames / elapsed_time)) 55 | -------------------------------------------------------------------------------- /pysc2/bin/valid_actions.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # Copyright 2017 Google Inc. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS-IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Print the valid actions.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | from absl import app 22 | from absl import flags 23 | 24 | from pysc2.lib import actions 25 | from pysc2.lib import features 26 | from pysc2.lib import point_flag 27 | 28 | FLAGS = flags.FLAGS 29 | point_flag.DEFINE_point("screen_size", "84", "Resolution for screen actions.") 30 | point_flag.DEFINE_point("minimap_size", "64", "Resolution for minimap actions.") 31 | flags.DEFINE_bool("hide_specific", False, "Hide the specific actions") 32 | 33 | 34 | def main(unused_argv): 35 | """Print the valid actions.""" 36 | feats = features.Features( 37 | # Actually irrelevant whether it's feature or rgb size. 38 | features.AgentInterfaceFormat( 39 | feature_dimensions=features.Dimensions( 40 | screen=FLAGS.screen_size, 41 | minimap=FLAGS.minimap_size))) 42 | action_spec = feats.action_spec() 43 | flattened = 0 44 | count = 0 45 | for func in action_spec.functions: 46 | if FLAGS.hide_specific and actions.FUNCTIONS[func.id].general_id != 0: 47 | continue 48 | count += 1 49 | act_flat = 1 50 | for arg in func.args: 51 | for size in arg.sizes: 52 | act_flat *= size 53 | flattened += act_flat 54 | print(func.str(True)) 55 | print("Total base actions:", count) 56 | print("Total possible actions (flattened):", flattened) 57 | 58 | 59 | if __name__ == "__main__": 60 | app.run(main) 61 | -------------------------------------------------------------------------------- /pysc2/tests/random_agent_test.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # Copyright 2017 Google Inc. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS-IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Run a random agent for a few steps.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | from absl.testing import absltest 22 | from absl.testing import parameterized 23 | 24 | from pysc2.agents import random_agent 25 | from pysc2.env import run_loop 26 | from pysc2.env import sc2_env 27 | from pysc2.tests import utils 28 | 29 | 30 | class TestRandomAgent(parameterized.TestCase, utils.TestCase): 31 | 32 | @parameterized.named_parameters( 33 | ("features", sc2_env.AgentInterfaceFormat( 34 | feature_dimensions=sc2_env.Dimensions(screen=84, minimap=64))), 35 | ("rgb", sc2_env.AgentInterfaceFormat( 36 | rgb_dimensions=sc2_env.Dimensions(screen=128, minimap=64))), 37 | ("all", sc2_env.AgentInterfaceFormat( 38 | feature_dimensions=sc2_env.Dimensions(screen=84, minimap=64), 39 | rgb_dimensions=sc2_env.Dimensions(screen=128, minimap=64), 40 | action_space=sc2_env.ActionSpace.FEATURES, 41 | use_unit_counts=True, 42 | use_feature_units=True)), 43 | ) 44 | def test_random_agent(self, agent_interface_format): 45 | steps = 250 46 | step_mul = 8 47 | with sc2_env.SC2Env( 48 | map_name="Simple64", 49 | agent_interface_format=agent_interface_format, 50 | step_mul=step_mul, 51 | game_steps_per_episode=steps * step_mul//2) as env: 52 | agent = random_agent.RandomAgent() 53 | run_loop.run_loop([agent], env, steps) 54 | 55 | self.assertEqual(agent.steps, steps) 56 | 57 | 58 | if __name__ == "__main__": 59 | absltest.main() 60 | -------------------------------------------------------------------------------- /pysc2/tests/step_mul_override_test.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # Copyright 2018 Google Inc. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS-IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Test that stepping without observing works correctly for multiple players.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | from absl.testing import absltest 22 | 23 | from pysc2.env import sc2_env 24 | from pysc2.lib import actions 25 | from pysc2.tests import utils 26 | 27 | 28 | AGENT_INTERFACE_FORMAT = sc2_env.AgentInterfaceFormat( 29 | feature_dimensions=sc2_env.Dimensions(screen=32, minimap=32) 30 | ) 31 | 32 | 33 | class StepMulOverrideTest(utils.TestCase): 34 | 35 | def test_returns_game_loop_zero_on_first_step_despite_override(self): 36 | with sc2_env.SC2Env( 37 | map_name="DefeatRoaches", 38 | players=[sc2_env.Agent(sc2_env.Race.random)], 39 | step_mul=1, 40 | agent_interface_format=AGENT_INTERFACE_FORMAT) as env: 41 | timestep = env.step( 42 | actions=[actions.FUNCTIONS.no_op()], 43 | step_mul=1234) 44 | 45 | self.assertEqual( 46 | timestep[0].observation.game_loop[0], 47 | 0) 48 | 49 | def test_respects_override(self): 50 | with sc2_env.SC2Env( 51 | map_name="DefeatRoaches", 52 | players=[sc2_env.Agent(sc2_env.Race.random)], 53 | step_mul=1, 54 | agent_interface_format=AGENT_INTERFACE_FORMAT) as env: 55 | 56 | expected_game_loop = 0 57 | for delta in range(10): 58 | timestep = env.step( 59 | actions=[actions.FUNCTIONS.no_op()], 60 | step_mul=delta) 61 | 62 | expected_game_loop += delta 63 | self.assertEqual( 64 | timestep[0].observation.game_loop[0], 65 | expected_game_loop) 66 | 67 | 68 | if __name__ == "__main__": 69 | absltest.main() 70 | -------------------------------------------------------------------------------- /pysc2/lib/portspicker.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS-IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """portpicker for multiple ports.""" 15 | 16 | from __future__ import absolute_import 17 | from __future__ import division 18 | from __future__ import print_function 19 | 20 | import time 21 | import portpicker 22 | 23 | 24 | def pick_unused_ports(num_ports, retry_interval_secs=3, retry_attempts=5): 25 | """Reserves and returns a list of `num_ports` unused ports.""" 26 | ports = set() 27 | for _ in range(retry_attempts): 28 | ports.update( 29 | portpicker.pick_unused_port() for _ in range(num_ports - len(ports))) 30 | ports.discard(None) # portpicker returns None on error. 31 | if len(ports) == num_ports: 32 | return list(ports) 33 | # Duplicate ports can be returned, especially when insufficient ports are 34 | # free. Wait for more ports to be freed and retry. 35 | time.sleep(retry_interval_secs) 36 | 37 | # Could not obtain enough ports. Release what we do have. 38 | return_ports(ports) 39 | 40 | raise RuntimeError("Unable to obtain %d unused ports." % num_ports) 41 | 42 | 43 | def pick_contiguous_unused_ports( 44 | num_ports, 45 | retry_interval_secs=3, 46 | retry_attempts=5): 47 | """Reserves and returns a list of `num_ports` contiguous unused ports.""" 48 | for _ in range(retry_attempts): 49 | start_port = portpicker.pick_unused_port() 50 | if start_port is not None: 51 | ports = [start_port + p for p in range(num_ports)] 52 | if all(portpicker.is_port_free(p) for p in ports): 53 | return ports 54 | else: 55 | return_ports(ports) 56 | 57 | time.sleep(retry_interval_secs) 58 | 59 | raise RuntimeError("Unable to obtain %d contiguous unused ports." % num_ports) 60 | 61 | 62 | def return_ports(ports): 63 | """Returns previously reserved ports so that may be reused.""" 64 | for port in ports: 65 | portpicker.return_port(port) 66 | -------------------------------------------------------------------------------- /pysc2/tests/multi_player_env_test.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # Copyright 2017 Google Inc. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS-IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Test that the multiplayer environment works.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | from absl.testing import absltest 22 | from absl.testing import parameterized 23 | from future.builtins import range # pylint: disable=redefined-builtin 24 | 25 | from pysc2.agents import random_agent 26 | from pysc2.env import run_loop 27 | from pysc2.env import sc2_env 28 | from pysc2.tests import utils 29 | 30 | 31 | class TestMultiplayerEnv(parameterized.TestCase, utils.TestCase): 32 | 33 | @parameterized.named_parameters( 34 | ("features", sc2_env.AgentInterfaceFormat( 35 | feature_dimensions=sc2_env.Dimensions(screen=84, minimap=64))), 36 | ("rgb", sc2_env.AgentInterfaceFormat( 37 | rgb_dimensions=sc2_env.Dimensions(screen=84, minimap=64))), 38 | ("features_and_rgb", [ 39 | sc2_env.AgentInterfaceFormat( 40 | feature_dimensions=sc2_env.Dimensions(screen=84, minimap=64)), 41 | sc2_env.AgentInterfaceFormat( 42 | rgb_dimensions=sc2_env.Dimensions(screen=128, minimap=32)) 43 | ]), 44 | ) 45 | def test_multi_player_env(self, agent_interface_format): 46 | steps = 100 47 | step_mul = 16 48 | players = 2 49 | with sc2_env.SC2Env( 50 | map_name="Simple64", 51 | players=[sc2_env.Agent(sc2_env.Race.random, "random"), 52 | sc2_env.Agent(sc2_env.Race.random, "random")], 53 | step_mul=step_mul, 54 | game_steps_per_episode=steps * step_mul // 2, 55 | agent_interface_format=agent_interface_format) as env: 56 | agents = [random_agent.RandomAgent() for _ in range(players)] 57 | run_loop.run_loop(agents, env, steps) 58 | 59 | 60 | if __name__ == "__main__": 61 | absltest.main() 62 | -------------------------------------------------------------------------------- /pysc2/env/mock_sc2_env_comparison_test.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # Copyright 2017 Google Inc. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS-IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Tests that mock environment has same shape outputs as true environment.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | from absl.testing import absltest 22 | 23 | from pysc2.env import mock_sc2_env 24 | from pysc2.env import sc2_env 25 | 26 | 27 | class TestCompareEnvironments(absltest.TestCase): 28 | 29 | @classmethod 30 | def setUpClass(cls): 31 | players = [ 32 | sc2_env.Agent(race=sc2_env.Race.terran), 33 | sc2_env.Agent(race=sc2_env.Race.protoss), 34 | ] 35 | kwargs = { 36 | 'map_name': 'Flat64', 37 | 'players': players, 38 | 'agent_interface_format': [ 39 | sc2_env.AgentInterfaceFormat( 40 | feature_dimensions=sc2_env.Dimensions( 41 | screen=(32, 64), 42 | minimap=(8, 16) 43 | ), 44 | rgb_dimensions=sc2_env.Dimensions( 45 | screen=(31, 63), 46 | minimap=(7, 15) 47 | ), 48 | action_space=sc2_env.ActionSpace.FEATURES 49 | ), 50 | sc2_env.AgentInterfaceFormat( 51 | rgb_dimensions=sc2_env.Dimensions(screen=64, minimap=32) 52 | ) 53 | ] 54 | } 55 | cls._env = sc2_env.SC2Env(**kwargs) 56 | cls._mock_env = mock_sc2_env.SC2TestEnv(**kwargs) 57 | 58 | @classmethod 59 | def tearDownClass(cls): 60 | cls._env.close() 61 | cls._mock_env.close() 62 | 63 | def test_observation_spec(self): 64 | self.assertEqual(self._env.observation_spec(), 65 | self._mock_env.observation_spec()) 66 | 67 | def test_action_spec(self): 68 | self.assertEqual(self._env.action_spec(), self._mock_env.action_spec()) 69 | 70 | 71 | if __name__ == '__main__': 72 | absltest.main() 73 | -------------------------------------------------------------------------------- /pysc2/tests/observer_test.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # Copyright 2017 Google Inc. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS-IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Test that two built in bots can be watched by an observer.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | from absl.testing import absltest 22 | from future.builtins import range # pylint: disable=redefined-builtin 23 | 24 | from pysc2 import maps 25 | from pysc2 import run_configs 26 | from pysc2.tests import utils 27 | 28 | from s2clientprotocol import common_pb2 as sc_common 29 | from s2clientprotocol import sc2api_pb2 as sc_pb 30 | 31 | 32 | class TestObserver(utils.TestCase): 33 | 34 | def test_observer(self): 35 | run_config = run_configs.get() 36 | map_inst = maps.get("Simple64") 37 | 38 | with run_config.start(want_rgb=False) as controller: 39 | create = sc_pb.RequestCreateGame(local_map=sc_pb.LocalMap( 40 | map_path=map_inst.path, map_data=map_inst.data(run_config))) 41 | create.player_setup.add( 42 | type=sc_pb.Computer, race=sc_common.Random, difficulty=sc_pb.VeryEasy) 43 | create.player_setup.add( 44 | type=sc_pb.Computer, race=sc_common.Random, difficulty=sc_pb.VeryHard) 45 | create.player_setup.add(type=sc_pb.Observer) 46 | controller.create_game(create) 47 | 48 | join = sc_pb.RequestJoinGame( 49 | options=sc_pb.InterfaceOptions(), # cheap observations 50 | observed_player_id=0) 51 | controller.join_game(join) 52 | 53 | outcome = False 54 | for _ in range(60 * 60): # 60 minutes should be plenty. 55 | controller.step(16) 56 | obs = controller.observe() 57 | if obs.player_result: 58 | print("Outcome after %s steps (%0.1f game minutes):" % ( 59 | obs.observation.game_loop, obs.observation.game_loop / (16 * 60))) 60 | for r in obs.player_result: 61 | print("Player %s: %s" % (r.player_id, sc_pb.Result.Name(r.result))) 62 | outcome = True 63 | break 64 | 65 | self.assertTrue(outcome) 66 | 67 | 68 | if __name__ == "__main__": 69 | absltest.main() 70 | -------------------------------------------------------------------------------- /pysc2/bin/gen_units.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # Copyright 2017 Google Inc. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS-IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Generate the unit definitions for units.py.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import collections 22 | 23 | from absl import app 24 | from pysc2 import maps 25 | from pysc2 import run_configs 26 | from pysc2.lib import static_data 27 | 28 | from s2clientprotocol import common_pb2 as sc_common 29 | from s2clientprotocol import sc2api_pb2 as sc_pb 30 | 31 | 32 | def get_data(): 33 | """Get the game's static data from an actual game.""" 34 | run_config = run_configs.get() 35 | 36 | with run_config.start(want_rgb=False) as controller: 37 | m = maps.get("Sequencer") # Arbitrary ladder map. 38 | create = sc_pb.RequestCreateGame(local_map=sc_pb.LocalMap( 39 | map_path=m.path, map_data=m.data(run_config))) 40 | create.player_setup.add(type=sc_pb.Participant) 41 | create.player_setup.add(type=sc_pb.Computer, race=sc_common.Random, 42 | difficulty=sc_pb.VeryEasy) 43 | join = sc_pb.RequestJoinGame(race=sc_common.Random, 44 | options=sc_pb.InterfaceOptions(raw=True)) 45 | 46 | controller.create_game(create) 47 | controller.join_game(join) 48 | return controller.data_raw() 49 | 50 | 51 | def generate_py_units(data): 52 | """Generate the list of units in units.py.""" 53 | units = collections.defaultdict(list) 54 | for unit in sorted(data.units, key=lambda a: a.name): 55 | if unit.unit_id in static_data.UNIT_TYPES: 56 | units[unit.race].append(unit) 57 | 58 | def print_race(name, race): 59 | print("class %s(enum.IntEnum):" % name) 60 | print(' """%s units."""' % name) 61 | for unit in units[race]: 62 | print(" %s = %s" % (unit.name, unit.unit_id)) 63 | print("\n") 64 | 65 | print_race("Neutral", sc_common.NoRace) 66 | print_race("Protoss", sc_common.Protoss) 67 | print_race("Terran", sc_common.Terran) 68 | print_race("Zerg", sc_common.Zerg) 69 | 70 | 71 | def main(unused_argv): 72 | data = get_data() 73 | print("-" * 60) 74 | 75 | generate_py_units(data) 76 | 77 | 78 | if __name__ == "__main__": 79 | app.run(main) 80 | -------------------------------------------------------------------------------- /pysc2/lib/run_parallel_test.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # Copyright 2017 Google Inc. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS-IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Tests for lib.run_parallel.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import threading 22 | 23 | from absl.testing import absltest 24 | from pysc2.lib import run_parallel 25 | 26 | 27 | class Barrier(object): 28 | 29 | def __init__(self, n): 30 | self.n = n 31 | self.count = 0 32 | self.cond = threading.Condition() 33 | 34 | def wait(self): 35 | self.cond.acquire() 36 | me = self.count 37 | self.count += 1 38 | if self.count < self.n: 39 | self.cond.wait() 40 | else: 41 | self.count = 0 42 | self.cond.notify_all() 43 | self.cond.release() 44 | return me 45 | 46 | def clear(self): 47 | self.cond.acquire() 48 | self.cond.notify_all() 49 | self.cond.release() 50 | 51 | 52 | def bad(): 53 | raise ValueError() 54 | 55 | 56 | class RunParallelTest(absltest.TestCase): 57 | 58 | def test_returns_expected_values(self): 59 | pool = run_parallel.RunParallel() 60 | out = pool.run([int]) 61 | self.assertListEqual(out, [0]) 62 | out = pool.run([lambda: 1, lambda: 2, lambda: "asdf", lambda: {1: 2}]) 63 | self.assertListEqual(out, [1, 2, "asdf", {1: 2}]) 64 | 65 | def test_run_in_parallel(self): 66 | b = Barrier(3) 67 | pool = run_parallel.RunParallel() 68 | out = pool.run([b.wait, b.wait, b.wait]) 69 | self.assertItemsEqual(out, [0, 1, 2]) 70 | 71 | def test_avoids_deadlock(self): 72 | b = Barrier(2) 73 | pool = run_parallel.RunParallel(timeout=2) 74 | with self.assertRaises(ValueError): 75 | pool.run([int, b.wait, bad]) 76 | # Release the thread waiting on the barrier so the process can exit cleanly. 77 | b.clear() 78 | 79 | def test_exception(self): 80 | pool = run_parallel.RunParallel() 81 | out = pool.run([lambda: 1, ValueError]) 82 | self.assertEqual(out[0], 1) 83 | self.assertIsInstance(out[1], ValueError) 84 | with self.assertRaises(ValueError): 85 | pool.run([bad]) 86 | with self.assertRaises(ValueError): 87 | pool.run([int, bad]) 88 | 89 | def test_partial(self): 90 | pool = run_parallel.RunParallel() 91 | out = pool.run((max, 0, i - 2) for i in range(5)) 92 | self.assertListEqual(out, [0, 0, 0, 1, 2]) 93 | 94 | 95 | if __name__ == "__main__": 96 | absltest.main() 97 | -------------------------------------------------------------------------------- /pysc2/lib/static_data.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS-IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Expose static data in a more useful form than the raw protos.""" 15 | 16 | from __future__ import absolute_import 17 | from __future__ import division 18 | from __future__ import print_function 19 | 20 | import six 21 | 22 | 23 | class StaticData(object): 24 | """Expose static data in a more useful form than the raw protos.""" 25 | 26 | def __init__(self, data): 27 | """Takes data from RequestData.""" 28 | self._units = {u.unit_id: u.name for u in data.units} 29 | self._unit_stats = {u.unit_id: u for u in data.units} 30 | self._abilities = {a.ability_id: a for a in data.abilities} 31 | self._general_abilities = {a.remaps_to_ability_id 32 | for a in data.abilities 33 | if a.remaps_to_ability_id} 34 | 35 | for a in six.itervalues(self._abilities): 36 | a.hotkey = a.hotkey.lower() 37 | 38 | @property 39 | def abilities(self): 40 | return self._abilities 41 | 42 | @property 43 | def units(self): 44 | return self._units 45 | 46 | @property 47 | def unit_stats(self): 48 | return self._unit_stats 49 | 50 | @property 51 | def general_abilities(self): 52 | return self._general_abilities 53 | 54 | 55 | # List of known unit types. It is taken from: 56 | # https://github.com/Blizzard/s2client-api/blob/master/include/sc2api/sc2_typeenums.h 57 | UNIT_TYPES = [ 58 | 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 59 | 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 60 | 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 61 | 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 62 | 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 63 | 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 64 | 115, 116, 117, 118, 119, 120, 125, 126, 127, 128, 129, 130, 131, 132, 133, 65 | 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 149, 66 | 150, 151, 268, 289, 311, 322, 324, 330, 335, 341, 342, 343, 344, 365, 371, 67 | 373, 376, 377, 472, 473, 474, 483, 484, 485, 486, 487, 488, 489, 490, 493, 68 | 494, 495, 496, 498, 499, 500, 501, 502, 503, 504, 517, 518, 559, 560, 561, 69 | 562, 563, 564, 588, 589, 590, 591, 608, 630, 638, 639, 640, 641, 643, 661, 70 | 663, 664, 665, 666, 687, 688, 689, 690, 691, 692, 693, 694, 732, 733, 734, 71 | 796, 797, 801, 824, 830, 880, 881, 884, 885, 886, 887, 892, 893, 894, 1904, 72 | 1908, 1910, 1911, 1912, 1913, 73 | ] 74 | -------------------------------------------------------------------------------- /pysc2/lib/run_parallel.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS-IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """A thread pool for running a set of functions synchronously in parallel. 15 | 16 | This is mainly intended for use where the functions have a barrier and none will 17 | return until all have been called. 18 | """ 19 | 20 | from __future__ import absolute_import 21 | from __future__ import division 22 | from __future__ import print_function 23 | 24 | import functools 25 | 26 | from concurrent import futures 27 | 28 | 29 | class RunParallel(object): 30 | """Run all funcs in parallel.""" 31 | 32 | def __init__(self, timeout=None): 33 | self._timeout = timeout 34 | self._executor = None 35 | self._workers = 0 36 | 37 | def run(self, funcs): 38 | """Run a set of functions in parallel, returning their results. 39 | 40 | Make sure any function you pass exits with a reasonable timeout. If it 41 | doesn't return within the timeout or the result is ignored due an exception 42 | in a separate thread it will continue to stick around until it finishes, 43 | including blocking process exit. 44 | 45 | Args: 46 | funcs: An iterable of functions or iterable of args to functools.partial. 47 | 48 | Returns: 49 | A list of return values with the values matching the order in funcs. 50 | 51 | Raises: 52 | Propagates the first exception encountered in one of the functions. 53 | """ 54 | funcs = [f if callable(f) else functools.partial(*f) for f in funcs] 55 | if len(funcs) == 1: # Ignore threads if it's not needed. 56 | return [funcs[0]()] 57 | if len(funcs) > self._workers: # Lazy init and grow as needed. 58 | self.shutdown() 59 | self._workers = len(funcs) 60 | self._executor = futures.ThreadPoolExecutor(self._workers) 61 | futs = [self._executor.submit(f) for f in funcs] 62 | done, not_done = futures.wait(futs, self._timeout, futures.FIRST_EXCEPTION) 63 | # Make sure to propagate any exceptions. 64 | for f in done: 65 | if not f.cancelled() and f.exception() is not None: 66 | if not_done: 67 | # If there are some calls that haven't finished, cancel and recreate 68 | # the thread pool. Otherwise we may have a thread running forever 69 | # blocking parallel calls. 70 | for nd in not_done: 71 | nd.cancel() 72 | self.shutdown(False) # Don't wait, they may be deadlocked. 73 | raise f.exception() 74 | # Either done or timed out, so don't wait again. 75 | return [f.result(timeout=0) for f in futs] 76 | 77 | def shutdown(self, wait=True): 78 | if self._executor: 79 | self._executor.shutdown(wait) 80 | self._executor = None 81 | self._workers = 0 82 | 83 | def __del__(self): 84 | self.shutdown() 85 | -------------------------------------------------------------------------------- /pysc2/tests/maps_test.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # Copyright 2017 Google Inc. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS-IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Test that some of the maps work.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | from absl import logging 22 | import os 23 | import random 24 | 25 | from absl.testing import absltest 26 | from absl.testing import parameterized 27 | from future.builtins import range # pylint: disable=redefined-builtin 28 | 29 | from pysc2 import maps 30 | from pysc2 import run_configs 31 | from pysc2.tests import utils 32 | 33 | from s2clientprotocol import common_pb2 as sc_common 34 | from s2clientprotocol import sc2api_pb2 as sc_pb 35 | 36 | 37 | def get_maps(count=None): 38 | """Test only a few random maps to minimize time.""" 39 | all_maps = maps.get_maps() 40 | count = count or len(all_maps) 41 | return sorted(random.sample(all_maps.keys(), min(count, len(all_maps)))) 42 | 43 | 44 | class MapsTest(parameterized.TestCase, utils.TestCase): 45 | 46 | @parameterized.parameters(get_maps()) 47 | def test_list_all_maps(self, map_name): 48 | """Make sure all maps can be read.""" 49 | run_config = run_configs.get() 50 | map_inst = maps.get(map_name) 51 | logging.info("map: %s", map_inst.name) 52 | self.assertTrue(map_inst.data(run_config), msg="Failed on %s" % map_inst) 53 | 54 | @parameterized.parameters(get_maps(5)) 55 | def test_load_random_map(self, map_name): 56 | """Test loading a few random maps.""" 57 | m = maps.get(map_name) 58 | run_config = run_configs.get() 59 | 60 | with run_config.start(want_rgb=False) as controller: 61 | logging.info("Loading map: %s", m.name) 62 | create = sc_pb.RequestCreateGame(local_map=sc_pb.LocalMap( 63 | map_path=m.path, map_data=m.data(run_config))) 64 | create.player_setup.add(type=sc_pb.Participant) 65 | create.player_setup.add(type=sc_pb.Computer, race=sc_common.Random, 66 | difficulty=sc_pb.VeryEasy) 67 | join = sc_pb.RequestJoinGame(race=sc_common.Random, 68 | options=sc_pb.InterfaceOptions(raw=True)) 69 | 70 | controller.create_game(create) 71 | controller.join_game(join) 72 | 73 | # Verify it has the right mods and isn't running into licensing issues. 74 | info = controller.game_info() 75 | logging.info("Mods for %s: %s", m.name, info.mod_names) 76 | self.assertIn("Mods/Void.SC2Mod", info.mod_names) 77 | self.assertIn("Mods/VoidMulti.SC2Mod", info.mod_names) 78 | 79 | # Verify it can be played without making actions. 80 | for _ in range(3): 81 | controller.step() 82 | controller.observe() 83 | 84 | 85 | if __name__ == "__main__": 86 | absltest.main() 87 | -------------------------------------------------------------------------------- /pysc2/tests/host_remote_agent_test.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # Copyright 2018 Google Inc. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS-IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Test host_remote_agent.py.""" 16 | 17 | from absl.testing import absltest 18 | 19 | from pysc2.env import host_remote_agent 20 | from pysc2.lib import remote_controller 21 | from pysc2.lib import run_parallel 22 | from pysc2.tests import utils 23 | 24 | from s2clientprotocol import common_pb2 as sc_common 25 | from s2clientprotocol import sc2api_pb2 as sc_pb 26 | 27 | 28 | NUM_MATCHES = 2 29 | STEPS = 100 30 | 31 | 32 | class TestHostRemoteAgent(utils.TestCase): 33 | 34 | def testVsBot(self): 35 | bot_first = True 36 | for _ in range(NUM_MATCHES): 37 | with host_remote_agent.VsBot() as game: 38 | game.create_game( 39 | map_name="Simple64", 40 | bot_difficulty=sc_pb.VeryHard, 41 | bot_first=bot_first) 42 | controller = remote_controller.RemoteController( 43 | host=game.host, 44 | port=game.host_port) 45 | 46 | join = sc_pb.RequestJoinGame(options=sc_pb.InterfaceOptions(raw=True)) 47 | join.race = sc_common.Random 48 | controller.join_game(join) 49 | for _ in range(STEPS): 50 | controller.step() 51 | response_observation = controller.observe() 52 | if response_observation.player_result: 53 | break 54 | 55 | controller.leave() 56 | controller.close() 57 | bot_first = not bot_first 58 | 59 | def testVsAgent(self): 60 | parallel = run_parallel.RunParallel() 61 | for _ in range(NUM_MATCHES): 62 | with host_remote_agent.VsAgent() as game: 63 | game.create_game("Simple64") 64 | controllers = [ 65 | remote_controller.RemoteController( 66 | host=host, 67 | port=host_port) 68 | for host, host_port in zip(game.hosts, game.host_ports)] 69 | 70 | join = sc_pb.RequestJoinGame(options=sc_pb.InterfaceOptions(raw=True)) 71 | join.race = sc_common.Random 72 | join.shared_port = 0 73 | join.server_ports.game_port = game.lan_ports[0] 74 | join.server_ports.base_port = game.lan_ports[1] 75 | join.client_ports.add( 76 | game_port=game.lan_ports[2], 77 | base_port=game.lan_ports[3]) 78 | 79 | parallel.run((c.join_game, join) for c in controllers) 80 | for _ in range(STEPS): 81 | parallel.run(c.step for c in controllers) 82 | response_observations = [c.observe() for c in controllers] 83 | 84 | if response_observations[0].player_result: 85 | break 86 | 87 | parallel.run(c.leave for c in controllers) 88 | parallel.run(c.close for c in controllers) 89 | 90 | 91 | if __name__ == "__main__": 92 | absltest.main() 93 | -------------------------------------------------------------------------------- /pysc2/tests/debug_test.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # Copyright 2018 Google Inc. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS-IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Test that the debug commands work.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | from absl.testing import absltest 22 | 23 | from pysc2 import maps 24 | from pysc2 import run_configs 25 | from pysc2.lib import units 26 | 27 | from s2clientprotocol import common_pb2 as sc_common 28 | from s2clientprotocol import debug_pb2 as sc_debug 29 | from s2clientprotocol import sc2api_pb2 as sc_pb 30 | 31 | 32 | class DebugTest(absltest.TestCase): 33 | 34 | def test_multi_player(self): 35 | run_config = run_configs.get() 36 | map_inst = maps.get("Simple64") 37 | 38 | with run_config.start(want_rgb=False) as controller: 39 | 40 | create = sc_pb.RequestCreateGame( 41 | local_map=sc_pb.LocalMap( 42 | map_path=map_inst.path, map_data=map_inst.data(run_config))) 43 | create.player_setup.add(type=sc_pb.Participant) 44 | create.player_setup.add( 45 | type=sc_pb.Computer, 46 | race=sc_common.Terran, 47 | difficulty=sc_pb.VeryEasy) 48 | join = sc_pb.RequestJoinGame(race=sc_common.Terran, 49 | options=sc_pb.InterfaceOptions(raw=True)) 50 | 51 | controller.create_game(create) 52 | controller.join_game(join) 53 | 54 | info = controller.game_info() 55 | map_size = info.start_raw.map_size 56 | 57 | controller.step(2) 58 | 59 | obs = controller.observe() 60 | 61 | def get_marines(obs): 62 | return {u.tag: u for u in obs.observation.raw_data.units 63 | if u.unit_type == units.Terran.Marine} 64 | 65 | self.assertEmpty(get_marines(obs)) 66 | 67 | controller.debug(sc_debug.DebugCommand( 68 | create_unit=sc_debug.DebugCreateUnit( 69 | unit_type=units.Terran.Marine, 70 | owner=1, 71 | pos=sc_common.Point2D(x=map_size.x // 2, y=map_size.y // 2), 72 | quantity=5))) 73 | 74 | controller.step(2) 75 | 76 | obs = controller.observe() 77 | 78 | marines = get_marines(obs) 79 | self.assertEqual(5, len(marines)) 80 | 81 | tags = sorted(marines.keys()) 82 | 83 | controller.debug([ 84 | sc_debug.DebugCommand(kill_unit=sc_debug.DebugKillUnit( 85 | tag=[tags[0]])), 86 | sc_debug.DebugCommand(unit_value=sc_debug.DebugSetUnitValue( 87 | unit_value=sc_debug.DebugSetUnitValue.Life, value=5, 88 | unit_tag=tags[1])), 89 | ]) 90 | 91 | controller.step(2) 92 | 93 | obs = controller.observe() 94 | 95 | marines = get_marines(obs) 96 | self.assertEqual(4, len(marines)) 97 | self.assertNotIn(tags[0], marines) 98 | self.assertEqual(marines[tags[1]].health, 5) 99 | 100 | 101 | if __name__ == "__main__": 102 | absltest.main() 103 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS-IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Module setuptools script.""" 15 | 16 | from __future__ import absolute_import 17 | from __future__ import division 18 | from __future__ import print_function 19 | 20 | from setuptools import setup 21 | 22 | description = """PySC2 - StarCraft II Learning Environment 23 | 24 | PySC2 is DeepMind's Python component of the StarCraft II Learning Environment 25 | (SC2LE). It exposes Blizzard Entertainment's StarCraft II Machine Learning API 26 | as a Python RL Environment. This is a collaboration between DeepMind and 27 | Blizzard to develop StarCraft II into a rich environment for RL research. PySC2 28 | provides an interface for RL agents to interact with StarCraft 2, getting 29 | observations and sending actions. 30 | 31 | We have published an accompanying blogpost and paper 32 | https://deepmind.com/blog/deepmind-and-blizzard-open-starcraft-ii-ai-research-environment/ 33 | which outlines our motivation for using StarCraft II for DeepRL research, and 34 | some initial research results using the environment. 35 | 36 | Read the README at https://github.com/deepmind/pysc2 for more information. 37 | """ 38 | 39 | setup( 40 | name='PySC2', 41 | version='2.0.2', 42 | description='Starcraft II environment and library for training agents.', 43 | long_description=description, 44 | author='DeepMind', 45 | author_email='pysc2@deepmind.com', 46 | license='Apache License, Version 2.0', 47 | keywords='StarCraft AI', 48 | url='https://github.com/deepmind/pysc2', 49 | packages=[ 50 | 'pysc2', 51 | 'pysc2.agents', 52 | 'pysc2.bin', 53 | 'pysc2.env', 54 | 'pysc2.lib', 55 | 'pysc2.maps', 56 | 'pysc2.run_configs', 57 | 'pysc2.tests', 58 | ], 59 | install_requires=[ 60 | 'absl-py>=0.1.0', 61 | 'enum34', 62 | 'future', 63 | 'futures; python_version == "2.7"', 64 | 'mock', 65 | 'mpyq', 66 | 'numpy>=1.10', 67 | 'portpicker>=1.2.0', 68 | 'protobuf>=2.6', 69 | 'pygame', 70 | 'requests', 71 | 's2clientprotocol>=4.6.0.67926.0', 72 | 'six', 73 | 'sk-video', 74 | 'websocket-client', 75 | 'whichcraft', 76 | ], 77 | entry_points={ 78 | 'console_scripts': [ 79 | 'pysc2_agent = pysc2.bin.agent:entry_point', 80 | 'pysc2_play = pysc2.bin.play:entry_point', 81 | 'pysc2_replay_info = pysc2.bin.replay_info:entry_point', 82 | ], 83 | }, 84 | classifiers=[ 85 | 'Development Status :: 4 - Beta', 86 | 'Environment :: Console', 87 | 'Intended Audience :: Science/Research', 88 | 'License :: OSI Approved :: Apache Software License', 89 | 'Operating System :: POSIX :: Linux', 90 | 'Operating System :: Microsoft :: Windows', 91 | 'Operating System :: MacOS :: MacOS X', 92 | 'Programming Language :: Python :: 2.7', 93 | 'Programming Language :: Python :: 3.4', 94 | 'Topic :: Scientific/Engineering :: Artificial Intelligence', 95 | ], 96 | ) 97 | -------------------------------------------------------------------------------- /pysc2/tests/render_test.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # Copyright 2017 Google Inc. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS-IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Verify that the game renders rgb pixels.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | from absl.testing import absltest 22 | from future.builtins import range # pylint: disable=redefined-builtin 23 | 24 | import numpy as np 25 | 26 | from pysc2 import maps 27 | from pysc2 import run_configs 28 | from pysc2.lib import features 29 | from pysc2.tests import utils 30 | 31 | from s2clientprotocol import common_pb2 as sc_common 32 | from s2clientprotocol import sc2api_pb2 as sc_pb 33 | 34 | 35 | class TestRender(utils.TestCase): 36 | 37 | def test_render(self): 38 | interface = sc_pb.InterfaceOptions() 39 | interface.raw = True 40 | interface.score = True 41 | interface.feature_layer.width = 24 42 | interface.feature_layer.resolution.x = 84 43 | interface.feature_layer.resolution.y = 84 44 | interface.feature_layer.minimap_resolution.x = 64 45 | interface.feature_layer.minimap_resolution.y = 64 46 | interface.render.resolution.x = 256 47 | interface.render.resolution.y = 256 48 | interface.render.minimap_resolution.x = 128 49 | interface.render.minimap_resolution.y = 128 50 | 51 | run_config = run_configs.get() 52 | with run_config.start() as controller: 53 | map_inst = maps.get("Simple64") 54 | create = sc_pb.RequestCreateGame( 55 | realtime=False, disable_fog=False, 56 | local_map=sc_pb.LocalMap(map_path=map_inst.path, 57 | map_data=map_inst.data(run_config))) 58 | create.player_setup.add(type=sc_pb.Participant) 59 | create.player_setup.add( 60 | type=sc_pb.Computer, race=sc_common.Random, difficulty=sc_pb.VeryEasy) 61 | join = sc_pb.RequestJoinGame(race=sc_common.Random, options=interface) 62 | controller.create_game(create) 63 | controller.join_game(join) 64 | 65 | game_info = controller.game_info() 66 | 67 | # Can fail if rendering is disabled. 68 | self.assertEqual(interface, game_info.options) 69 | 70 | for _ in range(50): 71 | controller.step(8) 72 | observation = controller.observe() 73 | 74 | obs = observation.observation 75 | rgb_screen = features.Feature.unpack_rgb_image(obs.render_data.map) 76 | rgb_minimap = features.Feature.unpack_rgb_image(obs.render_data.minimap) 77 | fl_screen = np.stack(f.unpack(obs) for f in features.SCREEN_FEATURES) 78 | fl_minimap = np.stack(f.unpack(obs) for f in features.MINIMAP_FEATURES) 79 | 80 | # Right shapes. 81 | self.assertEqual(rgb_screen.shape, (256, 256, 3)) 82 | self.assertEqual(rgb_minimap.shape, (128, 128, 3)) 83 | self.assertEqual(fl_screen.shape, 84 | (len(features.SCREEN_FEATURES), 84, 84)) 85 | self.assertEqual(fl_minimap.shape, 86 | (len(features.MINIMAP_FEATURES), 64, 64)) 87 | 88 | # Not all black. 89 | self.assertTrue(rgb_screen.any()) 90 | self.assertTrue(rgb_minimap.any()) 91 | self.assertTrue(fl_screen.any()) 92 | self.assertTrue(fl_minimap.any()) 93 | 94 | if observation.player_result: 95 | break 96 | 97 | if __name__ == "__main__": 98 | absltest.main() 99 | -------------------------------------------------------------------------------- /pysc2/tests/easy_scripted_test.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # Copyright 2017 Google Inc. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS-IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Solve the nm_easy map using a fixed policy by reading the feature layers.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | from absl.testing import absltest 22 | 23 | from pysc2.agents import scripted_agent 24 | from pysc2.env import run_loop 25 | from pysc2.env import sc2_env 26 | from pysc2.tests import utils 27 | 28 | 29 | class TestEasy(utils.TestCase): 30 | steps = 200 31 | step_mul = 16 32 | 33 | def test_move_to_beacon(self): 34 | with sc2_env.SC2Env( 35 | map_name="MoveToBeacon", 36 | agent_interface_format=sc2_env.AgentInterfaceFormat( 37 | feature_dimensions=sc2_env.Dimensions( 38 | screen=84, 39 | minimap=64)), 40 | step_mul=self.step_mul, 41 | game_steps_per_episode=self.steps * self.step_mul) as env: 42 | agent = scripted_agent.MoveToBeacon() 43 | run_loop.run_loop([agent], env, self.steps) 44 | 45 | # Get some points 46 | self.assertLessEqual(agent.episodes, agent.reward) 47 | self.assertEqual(agent.steps, self.steps) 48 | 49 | def test_collect_mineral_shards(self): 50 | with sc2_env.SC2Env( 51 | map_name="CollectMineralShards", 52 | agent_interface_format=sc2_env.AgentInterfaceFormat( 53 | feature_dimensions=sc2_env.Dimensions( 54 | screen=84, 55 | minimap=64)), 56 | step_mul=self.step_mul, 57 | game_steps_per_episode=self.steps * self.step_mul) as env: 58 | agent = scripted_agent.CollectMineralShards() 59 | run_loop.run_loop([agent], env, self.steps) 60 | 61 | # Get some points 62 | self.assertLessEqual(agent.episodes, agent.reward) 63 | self.assertEqual(agent.steps, self.steps) 64 | 65 | def test_collect_mineral_shards_feature_units(self): 66 | with sc2_env.SC2Env( 67 | map_name="CollectMineralShards", 68 | agent_interface_format=sc2_env.AgentInterfaceFormat( 69 | feature_dimensions=sc2_env.Dimensions( 70 | screen=84, 71 | minimap=64), 72 | use_feature_units=True), 73 | step_mul=self.step_mul, 74 | game_steps_per_episode=self.steps * self.step_mul) as env: 75 | agent = scripted_agent.CollectMineralShardsFeatureUnits() 76 | run_loop.run_loop([agent], env, self.steps) 77 | 78 | # Get some points 79 | self.assertLessEqual(agent.episodes, agent.reward) 80 | self.assertEqual(agent.steps, self.steps) 81 | 82 | def test_defeat_roaches(self): 83 | with sc2_env.SC2Env( 84 | map_name="DefeatRoaches", 85 | agent_interface_format=sc2_env.AgentInterfaceFormat( 86 | feature_dimensions=sc2_env.Dimensions( 87 | screen=84, 88 | minimap=64)), 89 | step_mul=self.step_mul, 90 | game_steps_per_episode=self.steps * self.step_mul) as env: 91 | agent = scripted_agent.DefeatRoaches() 92 | run_loop.run_loop([agent], env, self.steps) 93 | 94 | # Get some points 95 | self.assertLessEqual(agent.episodes, agent.reward) 96 | self.assertEqual(agent.steps, self.steps) 97 | 98 | 99 | if __name__ == "__main__": 100 | absltest.main() 101 | -------------------------------------------------------------------------------- /pysc2/lib/transform.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS-IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Transform coordinates for rendering in various ways. 15 | 16 | It's best to name these as `a_to_b` for example `screen_to_world`. The 17 | `fwd` methods take a point or distance in coordinate system `a` and 18 | convert it to a point or distance in coordinate system `b`. The `back` methods 19 | do the reverse going from `b` to `a`. 20 | 21 | These can then be chained as b_to_c.fwd(a_to_b.fwd(pt)) which will take 22 | something in `a` and return something in `c`. It's better to use the Chain 23 | transform to create `a_to_c`. 24 | """ 25 | 26 | from __future__ import absolute_import 27 | from __future__ import division 28 | from __future__ import print_function 29 | 30 | import numbers 31 | 32 | from pysc2.lib import point 33 | 34 | 35 | class Transform(object): 36 | """Base class for coordinate transforms.""" 37 | 38 | def fwd_dist(self, dist): 39 | raise NotImplementedError() 40 | 41 | def fwd_pt(self, pt): 42 | raise NotImplementedError() 43 | 44 | def back_dist(self, dist): 45 | raise NotImplementedError() 46 | 47 | def back_pt(self, pt): 48 | raise NotImplementedError() 49 | 50 | 51 | class Linear(Transform): 52 | """A linear transform with a scale and offset.""" 53 | 54 | def __init__(self, scale=None, offset=None): 55 | if scale is None: 56 | self.scale = point.Point(1, 1) 57 | elif isinstance(scale, numbers.Number): 58 | self.scale = point.Point(scale, scale) 59 | else: 60 | self.scale = scale 61 | assert self.scale.x != 0 and self.scale.y != 0 62 | self.offset = offset or point.Point(0, 0) 63 | 64 | def fwd_dist(self, dist): 65 | return dist * self.scale.x 66 | 67 | def fwd_pt(self, pt): 68 | return pt * self.scale + self.offset 69 | 70 | def back_dist(self, dist): 71 | return dist / self.scale.x 72 | 73 | def back_pt(self, pt): 74 | return (pt - self.offset) / self.scale 75 | 76 | def __str__(self): 77 | return "Linear(scale=%s, offset=%s)" % (self.scale, self.offset) 78 | 79 | 80 | class Chain(Transform): 81 | """Chain a set of transforms: Chain(a_to_b, b_to_c) => a_to_c.""" 82 | 83 | def __init__(self, *args): 84 | self.transforms = args 85 | 86 | def fwd_dist(self, dist): 87 | for transform in self.transforms: 88 | dist = transform.fwd_dist(dist) 89 | return dist 90 | 91 | def fwd_pt(self, pt): 92 | for transform in self.transforms: 93 | pt = transform.fwd_pt(pt) 94 | return pt 95 | 96 | def back_dist(self, dist): 97 | for transform in reversed(self.transforms): 98 | dist = transform.back_dist(dist) 99 | return dist 100 | 101 | def back_pt(self, pt): 102 | for transform in reversed(self.transforms): 103 | pt = transform.back_pt(pt) 104 | return pt 105 | 106 | def __str__(self): 107 | return "Chain(%s)" % (self.transforms,) 108 | 109 | 110 | class PixelToCoord(Transform): 111 | """Take a point within a pixel and use the tl, or tl to pixel center.""" 112 | 113 | def fwd_dist(self, dist): 114 | return dist 115 | 116 | def fwd_pt(self, pt): 117 | return pt.floor() 118 | 119 | def back_dist(self, dist): 120 | return dist 121 | 122 | def back_pt(self, pt): 123 | return pt.floor() + 0.5 124 | 125 | def __str__(self): 126 | return "PixelToCoord()" 127 | 128 | -------------------------------------------------------------------------------- /pysc2/bin/benchmark_observe.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # Copyright 2018 Google Inc. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS-IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Benchmark observation times.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import time 22 | 23 | from absl import app 24 | from absl import flags 25 | from future.builtins import range # pylint: disable=redefined-builtin 26 | 27 | from pysc2 import maps 28 | from pysc2 import run_configs 29 | 30 | from s2clientprotocol import common_pb2 as sc_common 31 | from s2clientprotocol import sc2api_pb2 as sc_pb 32 | 33 | 34 | flags.DEFINE_integer("count", 500, "How many observations to run.") 35 | flags.DEFINE_integer("step_mul", 8, "How many game steps per observation.") 36 | FLAGS = flags.FLAGS 37 | 38 | 39 | def interface_options(score=False, raw=False, features=None, rgb=None): 40 | """Get an InterfaceOptions for the config.""" 41 | interface = sc_pb.InterfaceOptions() 42 | interface.score = score 43 | interface.raw = raw 44 | if features: 45 | interface.feature_layer.width = 24 46 | interface.feature_layer.resolution.x = features 47 | interface.feature_layer.resolution.y = features 48 | interface.feature_layer.minimap_resolution.x = features 49 | interface.feature_layer.minimap_resolution.y = features 50 | if rgb: 51 | interface.render.resolution.x = rgb 52 | interface.render.resolution.y = rgb 53 | interface.render.minimap_resolution.x = rgb 54 | interface.render.minimap_resolution.y = rgb 55 | return interface 56 | 57 | 58 | def main(unused_argv): 59 | configs = [ 60 | ("raw", interface_options(raw=True)), 61 | ("raw-feat-48", interface_options(raw=True, features=48)), 62 | ("feat-32", interface_options(features=32)), 63 | ("feat-48", interface_options(features=48)), 64 | ("feat-72", interface_options(features=72)), 65 | ("feat-96", interface_options(features=96)), 66 | ("feat-128", interface_options(features=128)), 67 | ("rgb-64", interface_options(rgb=64)), 68 | ("rgb-128", interface_options(rgb=128)), 69 | ] 70 | 71 | results = [] 72 | try: 73 | for config, interface in configs: 74 | timeline = [] 75 | 76 | run_config = run_configs.get() 77 | with run_config.start( 78 | want_rgb=interface.HasField("render")) as controller: 79 | map_inst = maps.get("Catalyst") 80 | create = sc_pb.RequestCreateGame( 81 | realtime=False, disable_fog=False, random_seed=1, 82 | local_map=sc_pb.LocalMap(map_path=map_inst.path, 83 | map_data=map_inst.data(run_config))) 84 | create.player_setup.add(type=sc_pb.Participant) 85 | create.player_setup.add(type=sc_pb.Computer, race=sc_common.Terran, 86 | difficulty=sc_pb.VeryEasy) 87 | join = sc_pb.RequestJoinGame(race=sc_common.Protoss, options=interface) 88 | controller.create_game(create) 89 | controller.join_game(join) 90 | 91 | for _ in range(FLAGS.count): 92 | controller.step(FLAGS.step_mul) 93 | start = time.time() 94 | obs = controller.observe() 95 | timeline.append(time.time() - start) 96 | if obs.player_result: 97 | break 98 | 99 | results.append((config, timeline)) 100 | except KeyboardInterrupt: 101 | pass 102 | 103 | names, values = zip(*results) 104 | 105 | print("\n\nTimeline:\n") 106 | print(",".join(names)) 107 | for times in zip(*values): 108 | print(",".join("%0.2f" % (t * 1000) for t in times)) 109 | 110 | 111 | if __name__ == "__main__": 112 | app.run(main) 113 | -------------------------------------------------------------------------------- /pysc2/bin/replay_info.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # Copyright 2017 Google Inc. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS-IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Query one or more replays for information.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import os 22 | 23 | from absl import app 24 | from future.builtins import str # pylint: disable=redefined-builtin 25 | 26 | from pysc2 import run_configs 27 | from pysc2.lib import remote_controller 28 | 29 | from pysc2.lib import gfile 30 | from s2clientprotocol import common_pb2 as sc_common 31 | from s2clientprotocol import sc2api_pb2 as sc_pb 32 | 33 | 34 | def _replay_index(replay_dir): 35 | """Output information for a directory of replays.""" 36 | run_config = run_configs.get() 37 | replay_dir = run_config.abs_replay_path(replay_dir) 38 | print("Checking: ", replay_dir) 39 | 40 | with run_config.start(want_rgb=False) as controller: 41 | print("-" * 60) 42 | print(",".join(( 43 | "filename", 44 | "build", 45 | "map_name", 46 | "game_duration_loops", 47 | "players", 48 | "P1-outcome", 49 | "P1-race", 50 | "P1-apm", 51 | "P2-race", 52 | "P2-apm", 53 | ))) 54 | 55 | try: 56 | bad_replays = [] 57 | for file_path in run_config.replay_paths(replay_dir): 58 | file_name = os.path.basename(file_path) 59 | try: 60 | info = controller.replay_info(run_config.replay_data(file_path)) 61 | except remote_controller.RequestError as e: 62 | bad_replays.append("%s: %s" % (file_name, e)) 63 | continue 64 | if info.HasField("error"): 65 | print("failed:", file_name, info.error, info.error_details) 66 | bad_replays.append(file_name) 67 | else: 68 | out = [ 69 | file_name, 70 | info.base_build, 71 | info.map_name, 72 | info.game_duration_loops, 73 | len(info.player_info), 74 | sc_pb.Result.Name(info.player_info[0].player_result.result), 75 | sc_common.Race.Name(info.player_info[0].player_info.race_actual), 76 | info.player_info[0].player_apm, 77 | ] 78 | if len(info.player_info) >= 2: 79 | out += [ 80 | sc_common.Race.Name( 81 | info.player_info[1].player_info.race_actual), 82 | info.player_info[1].player_apm, 83 | ] 84 | print(u",".join(str(s) for s in out)) 85 | except KeyboardInterrupt: 86 | pass 87 | finally: 88 | if bad_replays: 89 | print("\n") 90 | print("Replays with errors:") 91 | print("\n".join(bad_replays)) 92 | 93 | 94 | def _replay_info(replay_path): 95 | """Query a replay for information.""" 96 | if not replay_path.lower().endswith("sc2replay"): 97 | print("Must be a replay.") 98 | return 99 | 100 | run_config = run_configs.get() 101 | with run_config.start(want_rgb=False) as controller: 102 | info = controller.replay_info(run_config.replay_data(replay_path)) 103 | print("-" * 60) 104 | print(info) 105 | 106 | 107 | def main(argv): 108 | if not argv: 109 | raise ValueError("No replay directory or path specified.") 110 | if len(argv) > 2: 111 | raise ValueError("Too many arguments provided.") 112 | path = argv[1] 113 | 114 | try: 115 | if gfile.IsDirectory(path): 116 | return _replay_index(path) 117 | else: 118 | return _replay_info(path) 119 | except KeyboardInterrupt: 120 | pass 121 | 122 | 123 | def entry_point(): # Needed so the setup.py scripts work. 124 | app.run(main) 125 | 126 | 127 | if __name__ == "__main__": 128 | app.run(main) 129 | -------------------------------------------------------------------------------- /pysc2/tests/versions_test.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # Copyright 2017 Google Inc. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS-IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Test that every version in run_configs actually runs.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | from absl import logging 22 | 23 | from absl.testing import absltest 24 | from pysc2 import maps 25 | from pysc2 import run_configs 26 | 27 | from s2clientprotocol import common_pb2 as sc_common 28 | from s2clientprotocol import sc2api_pb2 as sc_pb 29 | 30 | 31 | def major_version(v): 32 | return ".".join(v.split(".")[:2]) 33 | 34 | 35 | def log_center(s, *args): 36 | logging.info(((" " + s + " ") % args).center(80, "-")) 37 | 38 | 39 | class TestVersions(absltest.TestCase): 40 | 41 | def test_version_numbers(self): 42 | run_config = run_configs.get() 43 | failures = [] 44 | for game_version, version in sorted(run_config.get_versions().items()): 45 | try: 46 | self.assertEqual(game_version, version.game_version) 47 | log_center("starting version check: %s", game_version) 48 | with run_config.start(version=game_version, 49 | want_rgb=False) as controller: 50 | ping = controller.ping() 51 | logging.info("expected: %s", version) 52 | logging.info("actual: %s", ", ".join(str(ping).strip().split("\n"))) 53 | self.assertEqual(version.build_version, ping.base_build) 54 | if version.game_version != "latest": 55 | self.assertEqual(major_version(ping.game_version), 56 | major_version(version.game_version)) 57 | self.assertEqual(version.data_version.lower(), 58 | ping.data_version.lower()) 59 | log_center("success: %s", game_version) 60 | except: # pylint: disable=bare-except 61 | log_center("failure: %s", game_version) 62 | logging.exception("Failed") 63 | failures.append(game_version) 64 | self.assertEmpty(failures) 65 | 66 | def test_versions_create_game(self): 67 | run_config = run_configs.get() 68 | failures = [] 69 | for game_version in sorted(run_config.get_versions().keys()): 70 | try: 71 | log_center("starting create game: %s", game_version) 72 | with run_config.start(version=game_version, 73 | want_rgb=False) as controller: 74 | interface = sc_pb.InterfaceOptions() 75 | interface.raw = True 76 | interface.score = True 77 | interface.feature_layer.width = 24 78 | interface.feature_layer.resolution.x = 84 79 | interface.feature_layer.resolution.y = 84 80 | interface.feature_layer.minimap_resolution.x = 64 81 | interface.feature_layer.minimap_resolution.y = 64 82 | 83 | map_inst = maps.get("Simple64") 84 | create = sc_pb.RequestCreateGame(local_map=sc_pb.LocalMap( 85 | map_path=map_inst.path, map_data=map_inst.data(run_config))) 86 | create.player_setup.add(type=sc_pb.Participant) 87 | create.player_setup.add(type=sc_pb.Computer, race=sc_common.Terran, 88 | difficulty=sc_pb.VeryEasy) 89 | join = sc_pb.RequestJoinGame(race=sc_common.Terran, options=interface) 90 | 91 | controller.create_game(create) 92 | controller.join_game(join) 93 | 94 | for _ in range(5): 95 | controller.step(16) 96 | controller.observe() 97 | 98 | log_center("success: %s", game_version) 99 | except: # pylint: disable=bare-except 100 | logging.exception("Failed") 101 | log_center("failure: %s", game_version) 102 | failures.append(game_version) 103 | self.assertEmpty(failures) 104 | 105 | 106 | if __name__ == "__main__": 107 | absltest.main() 108 | -------------------------------------------------------------------------------- /pysc2/bin/mem_leak_check.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # Copyright 2018 Google Inc. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS-IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Test for memory leaks.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | # pylint: disable=g-import-not-at-top 22 | 23 | import collections 24 | import sys 25 | import time 26 | 27 | from absl import app 28 | from absl import flags 29 | from future.builtins import range # pylint: disable=redefined-builtin 30 | 31 | try: 32 | import psutil 33 | except ImportError: 34 | sys.exit( 35 | "`psutil` library required to track memory. This can be installed with:\n" 36 | "$ pip install psutil\n" 37 | "and needs the python-dev headers installed, for example:\n" 38 | "$ apt install python-dev") 39 | 40 | from pysc2 import maps 41 | from pysc2 import run_configs 42 | from pysc2.lib import protocol 43 | 44 | from s2clientprotocol import common_pb2 as sc_common 45 | from s2clientprotocol import sc2api_pb2 as sc_pb 46 | 47 | # pylint: enable=g-import-not-at-top 48 | 49 | 50 | flags.DEFINE_integer("mem_limit", 2000, "Max memory usage in Mb.") 51 | flags.DEFINE_integer("episodes", 200, "Max number of episodes.") 52 | FLAGS = flags.FLAGS 53 | 54 | 55 | class MemoryException(Exception): 56 | pass 57 | 58 | 59 | class Timestep(collections.namedtuple( 60 | "Timestep", ["episode", "time", "cpu", "memory", "name"])): 61 | 62 | def __str__(self): 63 | return "[%3d: %7.3f] cpu: %5.1f s, mem: %4d Mb; %s" % self 64 | 65 | 66 | def main(unused_argv): 67 | interface = sc_pb.InterfaceOptions() 68 | interface.raw = True 69 | interface.score = True 70 | interface.feature_layer.width = 24 71 | interface.feature_layer.resolution.x = 84 72 | interface.feature_layer.resolution.y = 84 73 | interface.feature_layer.minimap_resolution.x = 64 74 | interface.feature_layer.minimap_resolution.y = 64 75 | 76 | timeline = [] 77 | 78 | start = time.time() 79 | run_config = run_configs.get() 80 | proc = run_config.start(want_rgb=interface.HasField("render")) 81 | process = psutil.Process(proc.pid) 82 | episode = 0 83 | 84 | def add(s): 85 | cpu = process.cpu_times().user 86 | mem = process.memory_info().rss / 2 ** 20 # In Mb 87 | step = Timestep(episode, time.time() - start, cpu, mem, s) 88 | print(step) 89 | timeline.append(step) 90 | if mem > FLAGS.mem_limit: 91 | raise MemoryException("%s Mb mem limit exceeded" % FLAGS.mem_limit) 92 | 93 | try: 94 | add("Started process") 95 | 96 | controller = proc.controller 97 | map_inst = maps.get("Simple64") 98 | create = sc_pb.RequestCreateGame( 99 | realtime=False, disable_fog=False, random_seed=1, 100 | local_map=sc_pb.LocalMap(map_path=map_inst.path, 101 | map_data=map_inst.data(run_config))) 102 | create.player_setup.add(type=sc_pb.Participant) 103 | create.player_setup.add(type=sc_pb.Computer, race=sc_common.Protoss, 104 | difficulty=sc_pb.CheatInsane) 105 | join = sc_pb.RequestJoinGame(race=sc_common.Protoss, options=interface) 106 | controller.create_game(create) 107 | 108 | add("Created game") 109 | 110 | controller.join_game(join) 111 | 112 | episode += 1 113 | add("Joined game") 114 | 115 | for _ in range(FLAGS.episodes): 116 | for i in range(2000): 117 | controller.step(16) 118 | obs = controller.observe() 119 | if obs.player_result: 120 | add("Lost on step %s" % i) 121 | break 122 | if i > 0 and i % 100 == 0: 123 | add("Step %s" % i) 124 | controller.restart() 125 | episode += 1 126 | add("Restarted") 127 | add("Done") 128 | except KeyboardInterrupt: 129 | pass 130 | except (MemoryException, protocol.ConnectionError) as e: 131 | print(e) 132 | finally: 133 | proc.close() 134 | 135 | print("Timeline:") 136 | for t in timeline: 137 | print(t) 138 | 139 | 140 | if __name__ == "__main__": 141 | app.run(main) 142 | -------------------------------------------------------------------------------- /pysc2/lib/stopwatch_test.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # Copyright 2017 Google Inc. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS-IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Tests for stopwatch.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import os 22 | 23 | from absl.testing import absltest 24 | from future.builtins import range # pylint: disable=redefined-builtin 25 | 26 | import mock 27 | from pysc2.lib import stopwatch 28 | 29 | 30 | def ham_dist(str1, str2): 31 | """Hamming distance. Count the number of differences between str1 and str2.""" 32 | assert len(str1) == len(str2) 33 | return sum(c1 != c2 for c1, c2 in zip(str1, str2)) 34 | 35 | 36 | class StatTest(absltest.TestCase): 37 | 38 | def testRange(self): 39 | stat = stopwatch.Stat() 40 | stat.add(1) 41 | stat.add(5) 42 | stat.add(3) 43 | self.assertEqual(stat.num, 3) 44 | self.assertEqual(stat.sum, 9) 45 | self.assertEqual(stat.min, 1) 46 | self.assertEqual(stat.max, 5) 47 | self.assertEqual(stat.avg, 3) 48 | 49 | def testParse(self): 50 | stat = stopwatch.Stat() 51 | stat.add(1) 52 | stat.add(3) 53 | out = str(stat) 54 | self.assertEqual(out, "sum: 4.0000, avg: 2.0000, dev: 1.0000, " 55 | "min: 1.0000, max: 3.0000, num: 2") 56 | # Allow a few small rounding errors 57 | self.assertLess(ham_dist(out, str(stopwatch.Stat.parse(out))), 5) 58 | 59 | 60 | class StopwatchTest(absltest.TestCase): 61 | 62 | @mock.patch("time.time") 63 | def testStopwatch(self, mock_time): 64 | mock_time.return_value = 0 65 | sw = stopwatch.StopWatch() 66 | with sw("one"): 67 | mock_time.return_value += 0.002 68 | with sw("one"): 69 | mock_time.return_value += 0.004 70 | with sw("two"): 71 | with sw("three"): 72 | mock_time.return_value += 0.006 73 | 74 | @sw.decorate 75 | def four(): 76 | mock_time.return_value += 0.004 77 | four() 78 | 79 | @sw.decorate("five") 80 | def foo(): 81 | mock_time.return_value += 0.005 82 | foo() 83 | 84 | out = str(sw) 85 | 86 | # The names should be in sorted order. 87 | names = [l.split(None)[0] for l in out.splitlines()[1:]] 88 | self.assertEqual(names, ["five", "four", "one", "two", "two.three"]) 89 | 90 | one_line = out.splitlines()[3].split(None) 91 | self.assertLess(one_line[5], one_line[6]) # min < max 92 | self.assertEqual(one_line[7], "2") # num 93 | # Can't test the rest since they'll be flaky. 94 | 95 | # Allow a few small rounding errors for the round trip. 96 | round_trip = str(stopwatch.StopWatch.parse(out)) 97 | self.assertLess(ham_dist(out, round_trip), 15, 98 | "%s != %s" % (out, round_trip)) 99 | 100 | def testDivideZero(self): 101 | sw = stopwatch.StopWatch() 102 | with sw("zero"): 103 | pass 104 | 105 | # Just make sure this doesn't have a divide by 0 for when the total is 0. 106 | self.assertIn("zero", str(sw)) 107 | 108 | @mock.patch.dict(os.environ, {"SC2_NO_STOPWATCH": "1"}) 109 | def testDecoratorDisabled(self): 110 | sw = stopwatch.StopWatch() 111 | self.assertEqual(round, sw.decorate(round)) 112 | self.assertEqual(round, sw.decorate("name")(round)) 113 | 114 | @mock.patch.dict(os.environ, {"SC2_NO_STOPWATCH": ""}) 115 | def testDecoratorEnabled(self): 116 | sw = stopwatch.StopWatch() 117 | self.assertNotEqual(round, sw.decorate(round)) 118 | self.assertNotEqual(round, sw.decorate("name")(round)) 119 | 120 | def testSpeed(self): 121 | count = 1000 122 | 123 | def run(): 124 | for _ in range(count): 125 | with sw("name"): 126 | pass 127 | 128 | sw = stopwatch.StopWatch() 129 | for _ in range(10): 130 | sw.enabled = True 131 | sw.trace = False 132 | with sw("enabled"): 133 | run() 134 | 135 | sw.enabled = True 136 | sw.trace = True 137 | with sw("trace"): 138 | run() 139 | 140 | sw.enabled = True # To catch "disabled". 141 | with sw("disabled"): 142 | sw.enabled = False 143 | run() 144 | 145 | # No asserts. Succeed but print the timings. 146 | print(sw) 147 | 148 | 149 | if __name__ == "__main__": 150 | absltest.main() 151 | -------------------------------------------------------------------------------- /pysc2/maps/lib.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS-IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """The library and base Map for defining full maps. 15 | 16 | To define your own map just import this library and subclass Map. It will be 17 | automatically registered for creation by `get`. 18 | 19 | class NewMap(lib.Map): 20 | prefix = "map_dir" 21 | filename = "map_name" 22 | players = 3 23 | 24 | You can build a hierarchy of classes to make your definitions less verbose. 25 | 26 | To use a map, either import the map module and instantiate the map directly, or 27 | import the maps lib and use `get`. Using `get` from this lib will work, but only 28 | if you've imported the map module somewhere. 29 | """ 30 | 31 | from __future__ import absolute_import 32 | from __future__ import division 33 | from __future__ import print_function 34 | 35 | from absl import logging 36 | import os 37 | 38 | 39 | class DuplicateMapException(Exception): 40 | pass 41 | 42 | 43 | class NoMapException(Exception): 44 | pass 45 | 46 | 47 | class Map(object): 48 | """Base map object to configure a map. To define a map just subclass this. 49 | 50 | Properties: 51 | directory: Directory for the map 52 | filename: Actual filename. You can skip the ".SC2Map" file ending. 53 | download: Where to download the map. 54 | game_steps_per_episode: Game steps per episode, independent of the step_mul. 55 | 0 (default) means no limit. 56 | step_mul: How many game steps per agent step? 57 | score_index: Which score to give for this map. -1 means the win/loss 58 | reward. >=0 is the index into score_cumulative. 59 | score_multiplier: A score multiplier to allow make small scores good. 60 | players: Max number of players for this map. 61 | """ 62 | directory = "" 63 | filename = None 64 | download = None 65 | game_steps_per_episode = 0 66 | step_mul = 8 67 | score_index = -1 68 | score_multiplier = 1 69 | players = None 70 | 71 | @property 72 | def path(self): 73 | """The full path to the map file: directory, filename and file ending.""" 74 | if self.filename: 75 | map_path = os.path.join(self.directory, self.filename) 76 | if not map_path.endswith(".SC2Map"): 77 | map_path += ".SC2Map" 78 | return map_path 79 | 80 | def data(self, run_config): 81 | """Return the map data.""" 82 | try: 83 | return run_config.map_data(self.path) 84 | except (IOError, OSError) as e: # Catch both for python 2/3 compatibility. 85 | if self.download and hasattr(e, "filename"): 86 | logging.error("Error reading map '%s' from: %s", self.name, e.filename) 87 | logging.error("Download the map from: %s", self.download) 88 | raise 89 | 90 | @property 91 | def name(self): 92 | return self.__class__.__name__ 93 | 94 | def __str__(self): 95 | return "\n".join([ 96 | self.name, 97 | " %s" % self.path, 98 | " players: %s, score_index: %s, score_multiplier: %s" % ( 99 | self.players, self.score_index, self.score_multiplier), 100 | " step_mul: %s, game_steps_per_episode: %s" % ( 101 | self.step_mul, self.game_steps_per_episode), 102 | ]) 103 | 104 | @classmethod 105 | def all_subclasses(cls): 106 | """An iterator over all subclasses of `cls`.""" 107 | for s in cls.__subclasses__(): 108 | yield s 109 | for c in s.all_subclasses(): 110 | yield c 111 | 112 | 113 | def get_maps(): 114 | """Get the full dict of maps {map_name: map_class}.""" 115 | maps = {} 116 | for mp in Map.all_subclasses(): 117 | if mp.filename: 118 | map_name = mp.__name__ 119 | if map_name in maps: 120 | raise DuplicateMapException("Duplicate map found: " + map_name) 121 | maps[map_name] = mp 122 | return maps 123 | 124 | 125 | def get(map_name): 126 | """Get an instance of a map by name. Errors if the map doesn't exist.""" 127 | if isinstance(map_name, Map): 128 | return map_name 129 | 130 | # Get the list of maps. This isn't at module scope to avoid problems of maps 131 | # being defined after this module is imported. 132 | maps = get_maps() 133 | map_class = maps.get(map_name) 134 | if map_class: 135 | return map_class() 136 | raise NoMapException("Map doesn't exist: %s" % map_name) 137 | -------------------------------------------------------------------------------- /pysc2/tests/multi_player_test.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # Copyright 2017 Google Inc. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS-IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Test that multiplayer works independently of the SC2Env.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | from absl import logging 22 | import os 23 | 24 | from absl.testing import absltest 25 | from future.builtins import range # pylint: disable=redefined-builtin 26 | 27 | from pysc2 import maps 28 | from pysc2 import run_configs 29 | from pysc2.lib import point 30 | from pysc2.lib import portspicker 31 | from pysc2.lib import run_parallel 32 | from pysc2.tests import utils 33 | 34 | from s2clientprotocol import common_pb2 as sc_common 35 | from s2clientprotocol import sc2api_pb2 as sc_pb 36 | 37 | 38 | def print_stage(stage): 39 | logging.info((" %s " % stage).center(80, "-")) 40 | 41 | 42 | class TestMultiplayer(utils.TestCase): 43 | 44 | def test_multi_player(self): 45 | players = 2 46 | run_config = run_configs.get() 47 | parallel = run_parallel.RunParallel() 48 | map_inst = maps.get("Simple64") 49 | 50 | screen_size_px = point.Point(64, 64) 51 | minimap_size_px = point.Point(32, 32) 52 | interface = sc_pb.InterfaceOptions() 53 | screen_size_px.assign_to(interface.feature_layer.resolution) 54 | minimap_size_px.assign_to(interface.feature_layer.minimap_resolution) 55 | 56 | # Reserve a whole bunch of ports for the weird multiplayer implementation. 57 | ports = portspicker.pick_unused_ports(players * 2) 58 | logging.info("Valid Ports: %s", ports) 59 | 60 | # Actually launch the game processes. 61 | print_stage("start") 62 | sc2_procs = [run_config.start(extra_ports=ports, want_rgb=False) 63 | for _ in range(players)] 64 | controllers = [p.controller for p in sc2_procs] 65 | 66 | try: 67 | # Save the maps so they can access it. 68 | map_path = os.path.basename(map_inst.path) 69 | print_stage("save_map") 70 | parallel.run((c.save_map, map_path, map_inst.data(run_config)) 71 | for c in controllers) 72 | 73 | # Create the create request. 74 | create = sc_pb.RequestCreateGame( 75 | local_map=sc_pb.LocalMap(map_path=map_path)) 76 | for _ in range(players): 77 | create.player_setup.add(type=sc_pb.Participant) 78 | 79 | # Create the join request. 80 | join = sc_pb.RequestJoinGame(race=sc_common.Random, options=interface) 81 | join.shared_port = 0 # unused 82 | join.server_ports.game_port = ports.pop(0) 83 | join.server_ports.base_port = ports.pop(0) 84 | for _ in range(players - 1): 85 | join.client_ports.add(game_port=ports.pop(0), base_port=ports.pop(0)) 86 | 87 | # Play a few short games. 88 | for _ in range(2): # 2 episodes 89 | # Create and Join 90 | print_stage("create") 91 | controllers[0].create_game(create) 92 | print_stage("join") 93 | parallel.run((c.join_game, join) for c in controllers) 94 | 95 | print_stage("run") 96 | for game_loop in range(1, 10): # steps per episode 97 | # Step the game 98 | parallel.run(c.step for c in controllers) 99 | 100 | # Observe 101 | obs = parallel.run(c.observe for c in controllers) 102 | for p_id, o in enumerate(obs): 103 | self.assertEqual(o.observation.game_loop, game_loop) 104 | self.assertEqual(o.observation.player_common.player_id, p_id + 1) 105 | 106 | # Act 107 | actions = [sc_pb.Action() for _ in range(players)] 108 | for action in actions: 109 | pt = (point.Point.unit_rand() * minimap_size_px).floor() 110 | pt.assign_to(action.action_feature_layer.camera_move.center_minimap) 111 | parallel.run((c.act, a) for c, a in zip(controllers, actions)) 112 | 113 | # Done this game. 114 | print_stage("leave") 115 | parallel.run(c.leave for c in controllers) 116 | finally: 117 | print_stage("quit") 118 | # Done, shut down. Don't depend on parallel since it might be broken. 119 | for c in controllers: 120 | c.quit() 121 | for p in sc2_procs: 122 | p.close() 123 | portspicker.return_ports(ports) 124 | 125 | 126 | if __name__ == "__main__": 127 | absltest.main() 128 | -------------------------------------------------------------------------------- /pysc2/tests/obs_spec_test.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # Copyright 2017 Google Inc. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS-IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Verify that the observations match the observation spec.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | from absl.testing import absltest 22 | from future.builtins import range # pylint: disable=redefined-builtin 23 | import six 24 | 25 | from pysc2.agents import random_agent 26 | from pysc2.env import sc2_env 27 | from pysc2.tests import utils 28 | 29 | 30 | class TestObservationSpec(utils.TestCase): 31 | 32 | def test_observation_matches_obs_spec(self): 33 | with sc2_env.SC2Env( 34 | map_name="Simple64", 35 | agent_interface_format=sc2_env.AgentInterfaceFormat( 36 | feature_dimensions=sc2_env.Dimensions( 37 | screen=(84, 87), 38 | minimap=(64, 67)))) as env: 39 | 40 | multiplayer_obs_spec = env.observation_spec() 41 | self.assertIsInstance(multiplayer_obs_spec, tuple) 42 | self.assertLen(multiplayer_obs_spec, 1) 43 | obs_spec = multiplayer_obs_spec[0] 44 | 45 | multiplayer_action_spec = env.action_spec() 46 | self.assertIsInstance(multiplayer_action_spec, tuple) 47 | self.assertLen(multiplayer_action_spec, 1) 48 | action_spec = multiplayer_action_spec[0] 49 | 50 | agent = random_agent.RandomAgent() 51 | agent.setup(obs_spec, action_spec) 52 | 53 | multiplayer_obs = env.reset() 54 | agent.reset() 55 | for _ in range(100): 56 | self.assertIsInstance(multiplayer_obs, tuple) 57 | self.assertLen(multiplayer_obs, 1) 58 | raw_obs = multiplayer_obs[0] 59 | obs = raw_obs.observation 60 | self.check_observation_matches_spec(obs, obs_spec) 61 | 62 | act = agent.step(raw_obs) 63 | multiplayer_act = (act,) 64 | multiplayer_obs = env.step(multiplayer_act) 65 | 66 | def test_heterogeneous_observations(self): 67 | with sc2_env.SC2Env( 68 | map_name="Simple64", 69 | players=[ 70 | sc2_env.Agent(sc2_env.Race.random), 71 | sc2_env.Agent(sc2_env.Race.random) 72 | ], 73 | agent_interface_format=[ 74 | sc2_env.AgentInterfaceFormat( 75 | feature_dimensions=sc2_env.Dimensions( 76 | screen=(84, 87), 77 | minimap=(64, 67) 78 | ) 79 | ), 80 | sc2_env.AgentInterfaceFormat( 81 | rgb_dimensions=sc2_env.Dimensions( 82 | screen=128, 83 | minimap=64 84 | ) 85 | ) 86 | ]) as env: 87 | 88 | obs_specs = env.observation_spec() 89 | self.assertIsInstance(obs_specs, tuple) 90 | self.assertLen(obs_specs, 2) 91 | 92 | actions_specs = env.action_spec() 93 | self.assertIsInstance(actions_specs, tuple) 94 | self.assertLen(actions_specs, 2) 95 | 96 | agents = [] 97 | for obs_spec, action_spec in zip(obs_specs, actions_specs): 98 | agent = random_agent.RandomAgent() 99 | agent.setup(obs_spec, action_spec) 100 | agent.reset() 101 | agents.append(agent) 102 | 103 | time_steps = env.reset() 104 | for _ in range(100): 105 | self.assertIsInstance(time_steps, tuple) 106 | self.assertLen(time_steps, 2) 107 | 108 | actions = [] 109 | for i, agent in enumerate(agents): 110 | time_step = time_steps[i] 111 | obs = time_step.observation 112 | self.check_observation_matches_spec(obs, obs_specs[i]) 113 | actions.append(agent.step(time_step)) 114 | 115 | time_steps = env.step(actions) 116 | 117 | def check_observation_matches_spec(self, obs, obs_spec): 118 | self.assertItemsEqual(obs_spec.keys(), obs.keys()) 119 | for k, o in six.iteritems(obs): 120 | descr = "%s: spec: %s != obs: %s" % (k, obs_spec[k], o.shape) 121 | 122 | if o.shape == (0,): # Empty tensor can't have a shape. 123 | self.assertIn(0, obs_spec[k], descr) 124 | else: 125 | self.assertEqual(len(obs_spec[k]), len(o.shape), descr) 126 | for a, b in zip(obs_spec[k], o.shape): 127 | if a != 0: 128 | self.assertEqual(a, b, descr) 129 | 130 | 131 | if __name__ == "__main__": 132 | absltest.main() 133 | -------------------------------------------------------------------------------- /pysc2/maps/ladder.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS-IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Define the ladder map configs. 15 | 16 | Refer to the map descriptions here: 17 | http://wiki.teamliquid.net/starcraft2/Maps/Ladder_Maps/Legacy_of_the_Void 18 | """ 19 | 20 | from __future__ import absolute_import 21 | from __future__ import division 22 | from __future__ import print_function 23 | 24 | from pysc2.maps import lib 25 | 26 | 27 | class Ladder(lib.Map): 28 | players = 2 29 | game_steps_per_episode = 16 * 60 * 30 # 30 minute limit. 30 | download = "https://github.com/Blizzard/s2client-proto#map-packs" 31 | 32 | 33 | ladder_seasons = [ 34 | "Ladder2017Season1", 35 | "Ladder2017Season2", 36 | "Ladder2017Season3", 37 | "Ladder2017Season4", 38 | "Ladder2018Season1", 39 | "Ladder2018Season2", 40 | "Ladder2018Season3", 41 | "Ladder2018Season4", 42 | "Ladder2019Season1", 43 | ] 44 | 45 | for name in ladder_seasons: 46 | globals()[name] = type(name, (Ladder,), dict(directory=name)) 47 | 48 | 49 | # pylint: disable=bad-whitespace, undefined-variable 50 | # pytype: disable=name-error 51 | ladder_maps = [ 52 | ("16Bit", Ladder2018Season2, "(2)16-BitLE", 2), 53 | ("Abiogenesis", Ladder2018Season1, "AbiogenesisLE", 2), 54 | ("AbyssalReef", Ladder2017Season4, "AbyssalReefLE", 2), 55 | ("AcidPlant", Ladder2018Season3, "AcidPlantLE", 2), 56 | ("Acolyte", Ladder2017Season3, "AcolyteLE", 2), 57 | ("AscensiontoAiur", Ladder2017Season4, "AscensiontoAiurLE", 2), 58 | ("Automaton", Ladder2019Season1, "AutomatonLE", 2), 59 | ("Backwater", Ladder2018Season1, "BackwaterLE", 2), 60 | ("BattleontheBoardwalk", Ladder2017Season4, "BattleontheBoardwalkLE", 2), 61 | ("BelShirVestige", Ladder2017Season1, "BelShirVestigeLE", 2), 62 | ("Blackpink", Ladder2018Season1, "BlackpinkLE", 2), 63 | ("BloodBoil", Ladder2017Season2, "BloodBoilLE", 2), 64 | ("Blueshift", Ladder2018Season4, "BlueshiftLE", 2), 65 | ("CactusValley", Ladder2017Season1, "CactusValleyLE", 4), 66 | ("Catalyst", Ladder2018Season2, "(2)CatalystLE", 2), 67 | ("CeruleanFall", Ladder2018Season4, "CeruleanFallLE", 2), 68 | ("CyberForest", Ladder2019Season1, "CyberForestLE", 2), 69 | ("DarknessSanctuary", Ladder2018Season2, "(4)DarknessSanctuaryLE", 4), 70 | ("DefendersLanding", Ladder2017Season2, "DefendersLandingLE", 2), 71 | ("Dreamcatcher", Ladder2018Season3, "DreamcatcherLE", 2), 72 | ("Eastwatch", Ladder2018Season1, "EastwatchLE", 2), 73 | ("Fracture", Ladder2018Season3, "FractureLE", 2), 74 | ("Frost", Ladder2017Season3, "FrostLE", 2), 75 | ("Honorgrounds", Ladder2017Season1, "HonorgroundsLE", 4), 76 | ("Interloper", Ladder2017Season3, "InterloperLE", 2), 77 | ("KairosJunction", Ladder2019Season1, "KairosJunctionLE", 2), 78 | ("KingsCove", Ladder2019Season1, "KingsCoveLE", 2), 79 | ("LostandFound", Ladder2018Season3, "LostandFoundLE", 2), 80 | ("MechDepot", Ladder2017Season3, "MechDepotLE", 2), 81 | ("NewRepugnancy", Ladder2019Season1, "NewRepugnancyLE", 2), 82 | ("NewkirkPrecinct", Ladder2017Season1, "NewkirkPrecinctTE", 2), 83 | ("Odyssey", Ladder2017Season4, "OdysseyLE", 2), 84 | ("PaladinoTerminal", Ladder2017Season1, "PaladinoTerminalLE", 2), 85 | ("ParaSite", Ladder2018Season4, "ParaSiteLE", 2), 86 | ("PortAleksander", Ladder2019Season1, "PortAleksanderLE", 2), 87 | ("ProximaStation", Ladder2017Season2, "ProximaStationLE", 2), 88 | ("Redshift", Ladder2018Season2, "(2)RedshiftLE", 2), 89 | ("Sequencer", Ladder2017Season2, "SequencerLE", 2), 90 | ("Stasis", Ladder2018Season4, "StasisLE", 2), 91 | ("YearZero", Ladder2019Season1, "YearZeroLE", 2), 92 | 93 | # Disabled due to failing on 4.1.2 on Linux (Websocket Timeout). 94 | # ("NeonVioletSquare", Ladder2018Season1, "NeonVioletSquareLE", 2), 95 | ] 96 | # pylint: enable=bad-whitespace, undefined-variable 97 | # pytype: enable=name-error 98 | 99 | # Create the classes dynamically, putting them into the module scope. They all 100 | # inherit from a parent and set the players based on the map filename. 101 | for name, parent, map_file, players in ladder_maps: 102 | globals()[name] = type(name, (parent,), dict(filename=map_file, 103 | players=players)) 104 | 105 | -------------------------------------------------------------------------------- /pysc2/env/environment.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS-IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Python RL Environment API.""" 15 | 16 | from __future__ import absolute_import 17 | from __future__ import division 18 | from __future__ import print_function 19 | 20 | import abc 21 | import collections 22 | 23 | import enum 24 | import six 25 | 26 | 27 | class TimeStep(collections.namedtuple( 28 | 'TimeStep', ['step_type', 'reward', 'discount', 'observation'])): 29 | """Returned with every call to `step` and `reset` on an environment. 30 | 31 | A `TimeStep` contains the data emitted by an environment at each step of 32 | interaction. A `TimeStep` holds a `step_type`, an `observation`, and an 33 | associated `reward` and `discount`. 34 | 35 | The first `TimeStep` in a sequence will have `StepType.FIRST`. The final 36 | `TimeStep` will have `StepType.LAST`. All other `TimeStep`s in a sequence will 37 | have `StepType.MID. 38 | 39 | Attributes: 40 | step_type: A `StepType` enum value. 41 | reward: A scalar, or 0 if `step_type` is `StepType.FIRST`, i.e. at the 42 | start of a sequence. 43 | discount: A discount value in the range `[0, 1]`, or 0 if `step_type` 44 | is `StepType.FIRST`, i.e. at the start of a sequence. 45 | observation: A NumPy array, or a dict, list or tuple of arrays. 46 | """ 47 | __slots__ = () 48 | 49 | def first(self): 50 | return self.step_type is StepType.FIRST 51 | 52 | def mid(self): 53 | return self.step_type is StepType.MID 54 | 55 | def last(self): 56 | return self.step_type is StepType.LAST 57 | 58 | 59 | class StepType(enum.IntEnum): 60 | """Defines the status of a `TimeStep` within a sequence.""" 61 | # Denotes the first `TimeStep` in a sequence. 62 | FIRST = 0 63 | # Denotes any `TimeStep` in a sequence that is not FIRST or LAST. 64 | MID = 1 65 | # Denotes the last `TimeStep` in a sequence. 66 | LAST = 2 67 | 68 | 69 | @six.add_metaclass(abc.ABCMeta) 70 | class Base(object): # pytype: disable=ignored-abstractmethod 71 | """Abstract base class for Python RL environments.""" 72 | 73 | @abc.abstractmethod 74 | def reset(self): 75 | """Starts a new sequence and returns the first `TimeStep` of this sequence. 76 | 77 | Returns: 78 | A `TimeStep` namedtuple containing: 79 | step_type: A `StepType` of `FIRST`. 80 | reward: Zero. 81 | discount: Zero. 82 | observation: A NumPy array, or a dict, list or tuple of arrays 83 | corresponding to `observation_spec()`. 84 | """ 85 | 86 | @abc.abstractmethod 87 | def step(self, action): 88 | """Updates the environment according to the action and returns a `TimeStep`. 89 | 90 | If the environment returned a `TimeStep` with `StepType.LAST` at the 91 | previous step, this call to `step` will start a new sequence and `action` 92 | will be ignored. 93 | 94 | This method will also start a new sequence if called after the environment 95 | has been constructed and `restart` has not been called. Again, in this case 96 | `action` will be ignored. 97 | 98 | Args: 99 | action: A NumPy array, or a dict, list or tuple of arrays corresponding to 100 | `action_spec()`. 101 | 102 | Returns: 103 | A `TimeStep` namedtuple containing: 104 | step_type: A `StepType` value. 105 | reward: Reward at this timestep. 106 | discount: A discount in the range [0, 1]. 107 | observation: A NumPy array, or a dict, list or tuple of arrays 108 | corresponding to `observation_spec()`. 109 | """ 110 | 111 | @abc.abstractmethod 112 | def observation_spec(self): 113 | """Defines the observations provided by the environment. 114 | 115 | Returns: 116 | A tuple of specs (one per agent), where each spec is a dict of shape 117 | tuples. 118 | """ 119 | 120 | @abc.abstractmethod 121 | def action_spec(self): 122 | """Defines the actions that should be provided to `step`. 123 | 124 | Returns: 125 | A tuple of specs (one per agent), where each spec is something that 126 | defines the shape of the actions. 127 | """ 128 | 129 | def close(self): 130 | """Frees any resources used by the environment. 131 | 132 | Implement this method for an environment backed by an external process. 133 | 134 | This method be used directly 135 | 136 | ```python 137 | env = Env(...) 138 | # Use env. 139 | env.close() 140 | ``` 141 | 142 | or via a context manager 143 | 144 | ```python 145 | with Env(...) as env: 146 | # Use env. 147 | ``` 148 | """ 149 | pass 150 | 151 | def __enter__(self): 152 | """Allows the environment to be used in a with-statement context.""" 153 | return self 154 | 155 | def __exit__(self, unused_exception_type, unused_exc_value, unused_traceback): 156 | """Allows the environment to be used in a with-statement context.""" 157 | self.close() 158 | 159 | def __del__(self): 160 | self.close() 161 | 162 | -------------------------------------------------------------------------------- /pysc2/agents/scripted_agent.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS-IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Scripted agents.""" 15 | 16 | from __future__ import absolute_import 17 | from __future__ import division 18 | from __future__ import print_function 19 | 20 | import numpy 21 | 22 | from pysc2.agents import base_agent 23 | from pysc2.lib import actions 24 | from pysc2.lib import features 25 | 26 | _PLAYER_SELF = features.PlayerRelative.SELF 27 | _PLAYER_NEUTRAL = features.PlayerRelative.NEUTRAL # beacon/minerals 28 | _PLAYER_ENEMY = features.PlayerRelative.ENEMY 29 | 30 | FUNCTIONS = actions.FUNCTIONS 31 | 32 | 33 | def _xy_locs(mask): 34 | """Mask should be a set of bools from comparison with a feature layer.""" 35 | y, x = mask.nonzero() 36 | return list(zip(x, y)) 37 | 38 | 39 | class MoveToBeacon(base_agent.BaseAgent): 40 | """An agent specifically for solving the MoveToBeacon map.""" 41 | 42 | def step(self, obs): 43 | super(MoveToBeacon, self).step(obs) 44 | if FUNCTIONS.Move_screen.id in obs.observation.available_actions: 45 | player_relative = obs.observation.feature_screen.player_relative 46 | beacon = _xy_locs(player_relative == _PLAYER_NEUTRAL) 47 | if not beacon: 48 | return FUNCTIONS.no_op() 49 | beacon_center = numpy.mean(beacon, axis=0).round() 50 | return FUNCTIONS.Move_screen("now", beacon_center) 51 | else: 52 | return FUNCTIONS.select_army("select") 53 | 54 | 55 | class CollectMineralShards(base_agent.BaseAgent): 56 | """An agent specifically for solving the CollectMineralShards map.""" 57 | 58 | def step(self, obs): 59 | super(CollectMineralShards, self).step(obs) 60 | if FUNCTIONS.Move_screen.id in obs.observation.available_actions: 61 | player_relative = obs.observation.feature_screen.player_relative 62 | minerals = _xy_locs(player_relative == _PLAYER_NEUTRAL) 63 | if not minerals: 64 | return FUNCTIONS.no_op() 65 | marines = _xy_locs(player_relative == _PLAYER_SELF) 66 | marine_xy = numpy.mean(marines, axis=0).round() # Average location. 67 | distances = numpy.linalg.norm(numpy.array(minerals) - marine_xy, axis=1) 68 | closest_mineral_xy = minerals[numpy.argmin(distances)] 69 | return FUNCTIONS.Move_screen("now", closest_mineral_xy) 70 | else: 71 | return FUNCTIONS.select_army("select") 72 | 73 | 74 | class CollectMineralShardsFeatureUnits(base_agent.BaseAgent): 75 | """An agent for solving the CollectMineralShards map with feature units. 76 | 77 | Controls the two marines independently: 78 | - select marine 79 | - move to nearest mineral shard that wasn't the previous target 80 | - swap marine and repeat 81 | """ 82 | 83 | def setup(self, obs_spec, action_spec): 84 | super(CollectMineralShardsFeatureUnits, self).setup(obs_spec, action_spec) 85 | if "feature_units" not in obs_spec: 86 | raise Exception("This agent requires the feature_units observation.") 87 | 88 | def reset(self): 89 | super(CollectMineralShardsFeatureUnits, self).reset() 90 | self._marine_selected = False 91 | self._previous_mineral_xy = [-1, -1] 92 | 93 | def step(self, obs): 94 | super(CollectMineralShardsFeatureUnits, self).step(obs) 95 | marines = [unit for unit in obs.observation.feature_units 96 | if unit.alliance == _PLAYER_SELF] 97 | if not marines: 98 | return FUNCTIONS.no_op() 99 | marine_unit = next((m for m in marines 100 | if m.is_selected == self._marine_selected), marines[0]) 101 | marine_xy = [marine_unit.x, marine_unit.y] 102 | 103 | if not marine_unit.is_selected: 104 | # Nothing selected or the wrong marine is selected. 105 | self._marine_selected = True 106 | return FUNCTIONS.select_point("select", marine_xy) 107 | 108 | if FUNCTIONS.Move_screen.id in obs.observation.available_actions: 109 | # Find and move to the nearest mineral. 110 | minerals = [[unit.x, unit.y] for unit in obs.observation.feature_units 111 | if unit.alliance == _PLAYER_NEUTRAL] 112 | 113 | if self._previous_mineral_xy in minerals: 114 | # Don't go for the same mineral shard as other marine. 115 | minerals.remove(self._previous_mineral_xy) 116 | 117 | if minerals: 118 | # Find the closest. 119 | distances = numpy.linalg.norm( 120 | numpy.array(minerals) - numpy.array(marine_xy), axis=1) 121 | closest_mineral_xy = minerals[numpy.argmin(distances)] 122 | 123 | # Swap to the other marine. 124 | self._marine_selected = False 125 | self._previous_mineral_xy = closest_mineral_xy 126 | return FUNCTIONS.Move_screen("now", closest_mineral_xy) 127 | 128 | return FUNCTIONS.no_op() 129 | 130 | 131 | class DefeatRoaches(base_agent.BaseAgent): 132 | """An agent specifically for solving the DefeatRoaches map.""" 133 | 134 | def step(self, obs): 135 | super(DefeatRoaches, self).step(obs) 136 | if FUNCTIONS.Attack_screen.id in obs.observation.available_actions: 137 | player_relative = obs.observation.feature_screen.player_relative 138 | roaches = _xy_locs(player_relative == _PLAYER_ENEMY) 139 | if not roaches: 140 | return FUNCTIONS.no_op() 141 | 142 | # Find the roach with max y coord. 143 | target = roaches[numpy.argmax(numpy.array(roaches)[:, 1])] 144 | return FUNCTIONS.Attack_screen("now", target) 145 | 146 | if FUNCTIONS.select_army.id in obs.observation.available_actions: 147 | return FUNCTIONS.select_army("select") 148 | 149 | return FUNCTIONS.no_op() 150 | -------------------------------------------------------------------------------- /pysc2/env/mock_sc2_env_test.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # Copyright 2017 Google Inc. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS-IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Tests of the StarCraft2 mock environment.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | from absl.testing import absltest 22 | import mock 23 | import numpy as np 24 | 25 | from pysc2.env import environment 26 | from pysc2.env import mock_sc2_env 27 | from pysc2.lib import features 28 | 29 | 30 | class _TestMixin(object): 31 | 32 | def assert_spec(self, array, shape, dtype): 33 | self.assertSequenceEqual(array.shape, shape) 34 | self.assertEqual(array.dtype, dtype) 35 | 36 | def assert_equal(self, actual, expected): 37 | np.testing.assert_equal(actual, expected) 38 | 39 | def assert_reset(self, env): 40 | expected = env.next_timestep[0]._replace( 41 | step_type=environment.StepType.FIRST, reward=0, discount=0) 42 | timestep = env.reset() 43 | self.assert_equal(timestep, [expected]) 44 | 45 | def assert_first_step(self, env): 46 | expected = env.next_timestep[0]._replace( 47 | step_type=environment.StepType.FIRST, reward=0, discount=0) 48 | timestep = env.step([mock.sentinel.action]) 49 | self.assert_equal(timestep, [expected]) 50 | 51 | def assert_mid_step(self, env): 52 | expected = env.next_timestep[0]._replace( 53 | step_type=environment.StepType.MID) 54 | timestep = env.step([mock.sentinel.action]) 55 | self.assert_equal(timestep, [expected]) 56 | 57 | def assert_last_step(self, env): 58 | expected = env.next_timestep[0]._replace( 59 | step_type=environment.StepType.LAST, 60 | discount=0.) 61 | timestep = env.step([mock.sentinel.action]) 62 | self.assert_equal(timestep, [expected]) 63 | 64 | def _test_episode(self, env): 65 | env.next_timestep = [env.next_timestep[0]._replace( 66 | step_type=environment.StepType.MID)] 67 | self.assert_first_step(env) 68 | 69 | for step in range(1, 10): 70 | env.next_timestep = [env.next_timestep[0]._replace( 71 | reward=step, discount=step / 10)] 72 | self.assert_mid_step(env) 73 | 74 | env.next_timestep = [env.next_timestep[0]._replace( 75 | step_type=environment.StepType.LAST, reward=10, discount=0.0)] 76 | self.assert_last_step(env) 77 | 78 | def _test_episode_length(self, env, length): 79 | self.assert_reset(env) 80 | for _ in range(length - 1): 81 | self.assert_mid_step(env) 82 | self.assert_last_step(env) 83 | 84 | self.assert_first_step(env) 85 | for _ in range(length - 1): 86 | self.assert_mid_step(env) 87 | self.assert_last_step(env) 88 | 89 | 90 | class TestTestEnvironment(_TestMixin, absltest.TestCase): 91 | 92 | def setUp(self): 93 | self._env = mock_sc2_env._TestEnvironment( 94 | num_agents=1, 95 | observation_spec=({'mock': [10, 1]},), 96 | action_spec=(mock.sentinel.action_spec,)) 97 | 98 | def test_observation_spec(self): 99 | self.assertEqual(self._env.observation_spec(), ({'mock': [10, 1]},)) 100 | 101 | def test_action_spec(self): 102 | self.assertEqual(self._env.action_spec(), (mock.sentinel.action_spec,)) 103 | 104 | def test_default_observation(self): 105 | observation = self._env._default_observation( 106 | self._env.observation_spec()[0], 0) 107 | self.assert_equal(observation, {'mock': np.zeros([10, 1], dtype=np.int32)}) 108 | 109 | def test_episode(self): 110 | self._env.episode_length = float('inf') 111 | self._test_episode(self._env) 112 | 113 | def test_two_episodes(self): 114 | self._env.episode_length = float('inf') 115 | self._test_episode(self._env) 116 | self._test_episode(self._env) 117 | 118 | def test_episode_length(self): 119 | self._env.episode_length = 16 120 | self._test_episode_length(self._env, length=16) 121 | 122 | 123 | class TestSC2TestEnv(_TestMixin, absltest.TestCase): 124 | 125 | def test_episode(self): 126 | env = mock_sc2_env.SC2TestEnv( 127 | map_name='nonexistant map', 128 | agent_interface_format=features.AgentInterfaceFormat( 129 | feature_dimensions=features.Dimensions(screen=64, minimap=32))) 130 | env.episode_length = float('inf') 131 | self._test_episode(env) 132 | 133 | def test_episode_length(self): 134 | env = mock_sc2_env.SC2TestEnv( 135 | map_name='nonexistant map', 136 | agent_interface_format=features.AgentInterfaceFormat( 137 | feature_dimensions=features.Dimensions(screen=64, minimap=32))) 138 | self.assertEqual(env.episode_length, 10) 139 | self._test_episode_length(env, length=10) 140 | 141 | def test_screen_minimap_size(self): 142 | env = mock_sc2_env.SC2TestEnv( 143 | map_name='nonexistant map', 144 | agent_interface_format=features.AgentInterfaceFormat( 145 | feature_dimensions=features.Dimensions( 146 | screen=(84, 87), 147 | minimap=(64, 67)))) 148 | timestep = env.reset() 149 | self.assertLen(timestep, 1) 150 | self.assert_spec(timestep[0].observation['feature_screen'], 151 | [len(features.SCREEN_FEATURES), 87, 84], np.int32) 152 | self.assert_spec(timestep[0].observation['feature_minimap'], 153 | [len(features.MINIMAP_FEATURES), 67, 64], np.int32) 154 | 155 | def test_feature_units_are_supported(self): 156 | env = mock_sc2_env.SC2TestEnv( 157 | map_name='nonexistant map', 158 | agent_interface_format=features.AgentInterfaceFormat( 159 | feature_dimensions=features.Dimensions(screen=64, minimap=32), 160 | use_feature_units=True)) 161 | 162 | self.assertIn('feature_units', env.observation_spec()[0]) 163 | 164 | 165 | if __name__ == '__main__': 166 | absltest.main() 167 | -------------------------------------------------------------------------------- /pysc2/lib/protocol.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS-IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Protocol library to make communication easy.""" 15 | 16 | from __future__ import absolute_import 17 | from __future__ import division 18 | from __future__ import print_function 19 | 20 | import contextlib 21 | from absl import logging 22 | import os 23 | import socket 24 | import sys 25 | import time 26 | 27 | from absl import flags 28 | import enum 29 | from pysc2.lib import stopwatch 30 | import websocket 31 | 32 | from s2clientprotocol import sc2api_pb2 as sc_pb 33 | 34 | 35 | flags.DEFINE_integer("sc2_verbose_protocol", 0, 36 | ("Print the communication packets with SC2. 0 disables. " 37 | "-1 means all. >0 will print that many lines per " 38 | "packet. 20 is a good starting value.")) 39 | FLAGS = flags.FLAGS 40 | 41 | 42 | sw = stopwatch.sw 43 | 44 | # Create a python version of the Status enum in the proto. 45 | Status = enum.Enum("Status", sc_pb.Status.items()) # pylint: disable=invalid-name 46 | 47 | 48 | class ConnectionError(Exception): 49 | """Failed to read/write a message, details in the error string.""" 50 | pass 51 | 52 | 53 | class ProtocolError(Exception): 54 | """SC2 responded with an error message likely due to a bad request or bug.""" 55 | pass 56 | 57 | 58 | @contextlib.contextmanager 59 | def catch_websocket_connection_errors(): 60 | """A context manager that translates websocket errors into ConnectionError.""" 61 | try: 62 | yield 63 | except websocket.WebSocketConnectionClosedException: 64 | raise ConnectionError("Connection already closed. SC2 probably crashed. " 65 | "Check the error log.") 66 | except websocket.WebSocketTimeoutException: 67 | raise ConnectionError("Websocket timed out.") 68 | except socket.error as e: 69 | raise ConnectionError("Socket error: %s" % e) 70 | 71 | 72 | class StarcraftProtocol(object): 73 | """Defines the protocol for chatting with starcraft.""" 74 | 75 | def __init__(self, sock): 76 | self._status = Status.launched 77 | self._sock = sock 78 | 79 | @property 80 | def status(self): 81 | return self._status 82 | 83 | def close(self): 84 | if self._sock: 85 | self._sock.close() 86 | self._sock = None 87 | self._status = Status.quit 88 | 89 | @sw.decorate 90 | def read(self): 91 | """Read a Response, do some validation, and return it.""" 92 | if FLAGS.sc2_verbose_protocol: 93 | self._log(" Reading response ".center(60, "-")) 94 | start = time.time() 95 | response = self._read() 96 | if FLAGS.sc2_verbose_protocol: 97 | self._log(" %0.1f msec\n" % (1000 * (time.time() - start))) 98 | self._log_packet(response) 99 | if not response.HasField("status"): 100 | raise ProtocolError("Got an incomplete response without a status.") 101 | prev_status = self._status 102 | self._status = Status(response.status) # pytype: disable=not-callable 103 | if response.error: 104 | err_str = ("Error in RPC response (likely a bug). " 105 | "Prev status: %s, new status: %s, error:\n%s" % ( 106 | prev_status, self._status, "\n".join(response.error))) 107 | logging.error(err_str) 108 | raise ProtocolError(err_str) 109 | return response 110 | 111 | @sw.decorate 112 | def write(self, request): 113 | """Write a Request.""" 114 | if FLAGS.sc2_verbose_protocol: 115 | self._log(" Writing request ".center(60, "-") + "\n") 116 | self._log_packet(request) 117 | self._write(request) 118 | 119 | def send_req(self, request): 120 | """Write a pre-filled Request and return the Response.""" 121 | self.write(request) 122 | return self.read() 123 | 124 | def send(self, **kwargs): 125 | """Create and send a specific request, and return the response. 126 | 127 | For example: send(ping=sc_pb.RequestPing()) => sc_pb.ResponsePing 128 | 129 | Args: 130 | **kwargs: A single kwarg with the name and value to fill in to Request. 131 | 132 | Returns: 133 | The Response corresponding to your request. 134 | """ 135 | assert len(kwargs) == 1, "Must make a single request." 136 | res = self.send_req(sc_pb.Request(**kwargs)) 137 | return getattr(res, list(kwargs.keys())[0]) 138 | 139 | def _log_packet(self, packet): 140 | max_lines = FLAGS.sc2_verbose_protocol 141 | if max_lines > 0: 142 | max_width = int(os.getenv("COLUMNS", 200)) # Get your TTY width. 143 | lines = str(packet).strip().split("\n") 144 | self._log("".join(line[:max_width] + "\n" for line in lines[:max_lines])) 145 | if len(lines) > max_lines: 146 | self._log(("**** %s lines skipped ****\n" % (len(lines) - max_lines))) 147 | else: 148 | self._log("%s\n" % packet) 149 | 150 | def _log(self, s): 151 | r"""Log a string. It flushes but doesn't append \n, so do that yourself.""" 152 | # TODO(tewalds): Should this be using logging.info instead? How to see them 153 | # outside of google infrastructure? 154 | sys.stderr.write(s) 155 | sys.stderr.flush() 156 | 157 | def _read(self): 158 | """Actually read the response and parse it, returning a Response.""" 159 | with sw("read_response"): 160 | with catch_websocket_connection_errors(): 161 | response_str = self._sock.recv() 162 | if not response_str: 163 | raise ProtocolError("Got an empty response from SC2.") 164 | response = sc_pb.Response() 165 | with sw("parse_response"): 166 | response.ParseFromString(response_str) 167 | return response 168 | 169 | def _write(self, request): 170 | """Actually serialize and write the request.""" 171 | with sw("serialize_request"): 172 | request_str = request.SerializeToString() 173 | with sw("write_request"): 174 | with catch_websocket_connection_errors(): 175 | self._sock.send(request_str) 176 | -------------------------------------------------------------------------------- /pysc2/bin/agent.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # Copyright 2017 Google Inc. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS-IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Run an agent.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import importlib 22 | import threading 23 | 24 | from absl import app 25 | from absl import flags 26 | from future.builtins import range # pylint: disable=redefined-builtin 27 | 28 | from pysc2 import maps 29 | from pysc2.env import available_actions_printer 30 | from pysc2.env import run_loop 31 | from pysc2.env import sc2_env 32 | from pysc2.lib import point_flag 33 | from pysc2.lib import stopwatch 34 | 35 | 36 | FLAGS = flags.FLAGS 37 | flags.DEFINE_bool("render", True, "Whether to render with pygame.") 38 | point_flag.DEFINE_point("feature_screen_size", "84", 39 | "Resolution for screen feature layers.") 40 | point_flag.DEFINE_point("feature_minimap_size", "64", 41 | "Resolution for minimap feature layers.") 42 | point_flag.DEFINE_point("rgb_screen_size", None, 43 | "Resolution for rendered screen.") 44 | point_flag.DEFINE_point("rgb_minimap_size", None, 45 | "Resolution for rendered minimap.") 46 | flags.DEFINE_enum("action_space", None, sc2_env.ActionSpace._member_names_, # pylint: disable=protected-access 47 | "Which action space to use. Needed if you take both feature " 48 | "and rgb observations.") 49 | flags.DEFINE_bool("use_feature_units", False, 50 | "Whether to include feature units.") 51 | flags.DEFINE_bool("disable_fog", False, "Whether to disable Fog of War.") 52 | 53 | flags.DEFINE_integer("max_agent_steps", 0, "Total agent steps.") 54 | flags.DEFINE_integer("game_steps_per_episode", None, "Game steps per episode.") 55 | flags.DEFINE_integer("max_episodes", 0, "Total episodes.") 56 | flags.DEFINE_integer("step_mul", 8, "Game steps per agent step.") 57 | 58 | flags.DEFINE_string("agent", "pysc2.agents.random_agent.RandomAgent", 59 | "Which agent to run, as a python path to an Agent class.") 60 | flags.DEFINE_string("agent_name", None, 61 | "Name of the agent in replays. Defaults to the class name.") 62 | flags.DEFINE_enum("agent_race", "random", sc2_env.Race._member_names_, # pylint: disable=protected-access 63 | "Agent 1's race.") 64 | 65 | flags.DEFINE_string("agent2", "Bot", "Second agent, either Bot or agent class.") 66 | flags.DEFINE_string("agent2_name", None, 67 | "Name of the agent in replays. Defaults to the class name.") 68 | flags.DEFINE_enum("agent2_race", "random", sc2_env.Race._member_names_, # pylint: disable=protected-access 69 | "Agent 2's race.") 70 | flags.DEFINE_enum("difficulty", "very_easy", sc2_env.Difficulty._member_names_, # pylint: disable=protected-access 71 | "If agent2 is a built-in Bot, it's strength.") 72 | 73 | flags.DEFINE_bool("profile", False, "Whether to turn on code profiling.") 74 | flags.DEFINE_bool("trace", False, "Whether to trace the code execution.") 75 | flags.DEFINE_integer("parallel", 1, "How many instances to run in parallel.") 76 | 77 | flags.DEFINE_bool("save_replay", True, "Whether to save a replay at the end.") 78 | 79 | flags.DEFINE_string("map", None, "Name of a map to use.") 80 | flags.mark_flag_as_required("map") 81 | 82 | 83 | def run_thread(agent_classes, players, map_name, visualize): 84 | """Run one thread worth of the environment with agents.""" 85 | with sc2_env.SC2Env( 86 | map_name=map_name, 87 | players=players, 88 | agent_interface_format=sc2_env.parse_agent_interface_format( 89 | feature_screen=FLAGS.feature_screen_size, 90 | feature_minimap=FLAGS.feature_minimap_size, 91 | rgb_screen=FLAGS.rgb_screen_size, 92 | rgb_minimap=FLAGS.rgb_minimap_size, 93 | action_space=FLAGS.action_space, 94 | use_feature_units=FLAGS.use_feature_units), 95 | step_mul=FLAGS.step_mul, 96 | game_steps_per_episode=FLAGS.game_steps_per_episode, 97 | disable_fog=FLAGS.disable_fog, 98 | visualize=visualize) as env: 99 | env = available_actions_printer.AvailableActionsPrinter(env) 100 | agents = [agent_cls() for agent_cls in agent_classes] 101 | run_loop.run_loop(agents, env, FLAGS.max_agent_steps, FLAGS.max_episodes) 102 | if FLAGS.save_replay: 103 | env.save_replay(agent_classes[0].__name__) 104 | 105 | 106 | def main(unused_argv): 107 | """Run an agent.""" 108 | stopwatch.sw.enabled = FLAGS.profile or FLAGS.trace 109 | stopwatch.sw.trace = FLAGS.trace 110 | 111 | map_inst = maps.get(FLAGS.map) 112 | 113 | agent_classes = [] 114 | players = [] 115 | 116 | agent_module, agent_name = FLAGS.agent.rsplit(".", 1) 117 | agent_cls = getattr(importlib.import_module(agent_module), agent_name) 118 | agent_classes.append(agent_cls) 119 | players.append(sc2_env.Agent(sc2_env.Race[FLAGS.agent_race], 120 | FLAGS.agent_name or agent_name)) 121 | 122 | if map_inst.players >= 2: 123 | if FLAGS.agent2 == "Bot": 124 | players.append(sc2_env.Bot(sc2_env.Race[FLAGS.agent2_race], 125 | sc2_env.Difficulty[FLAGS.difficulty])) 126 | else: 127 | agent_module, agent_name = FLAGS.agent2.rsplit(".", 1) 128 | agent_cls = getattr(importlib.import_module(agent_module), agent_name) 129 | agent_classes.append(agent_cls) 130 | players.append(sc2_env.Agent(sc2_env.Race[FLAGS.agent2_race], 131 | FLAGS.agent2_name or agent_name)) 132 | 133 | threads = [] 134 | for _ in range(FLAGS.parallel - 1): 135 | t = threading.Thread(target=run_thread, 136 | args=(agent_classes, players, FLAGS.map, False)) 137 | threads.append(t) 138 | t.start() 139 | 140 | run_thread(agent_classes, players, FLAGS.map, FLAGS.render) 141 | 142 | for t in threads: 143 | t.join() 144 | 145 | if FLAGS.profile: 146 | print(stopwatch.sw) 147 | 148 | 149 | def entry_point(): # Needed so setup.py scripts work. 150 | app.run(main) 151 | 152 | 153 | if __name__ == "__main__": 154 | app.run(main) 155 | -------------------------------------------------------------------------------- /pysc2/lib/sc_process.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017-2018 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS-IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Launch the game and set up communication.""" 15 | 16 | from __future__ import absolute_import 17 | from __future__ import division 18 | from __future__ import print_function 19 | 20 | from absl import logging 21 | import os 22 | import shutil 23 | import subprocess 24 | import tempfile 25 | import time 26 | 27 | from absl import flags 28 | from future.builtins import range # pylint: disable=redefined-builtin 29 | 30 | import portpicker 31 | from pysc2.lib import remote_controller 32 | from pysc2.lib import stopwatch 33 | 34 | flags.DEFINE_bool("sc2_verbose", False, "Enable SC2 verbose logging.") 35 | FLAGS = flags.FLAGS 36 | 37 | sw = stopwatch.sw 38 | 39 | 40 | class SC2LaunchError(Exception): 41 | pass 42 | 43 | 44 | class StarcraftProcess(object): 45 | """Launch a starcraft server, initialize a controller, and later, clean up. 46 | 47 | This is best used from run_configs, which decides which version to run, and 48 | where to find it. 49 | 50 | It is important to call `close` or use it as a context manager, otherwise 51 | you'll likely leak temp files and SC2 processes. 52 | """ 53 | 54 | def __init__(self, run_config, exec_path, version, full_screen=False, 55 | extra_args=None, verbose=False, host=None, port=None, 56 | connect=True, timeout_seconds=None, window_size=(640, 480), 57 | window_loc=(50, 50), **kwargs): 58 | """Launch the SC2 process. 59 | 60 | Args: 61 | run_config: `run_configs.lib.RunConfig` object. 62 | exec_path: Path to the binary to run. 63 | version: `run_configs.lib.Version` object. 64 | full_screen: Whether to launch the game window full_screen on win/mac. 65 | extra_args: List of additional args for the SC2 process. 66 | verbose: Whether to have the SC2 process do verbose logging. 67 | host: IP for the game to listen on for its websocket. This is 68 | usually "127.0.0.1", or "::1", but could be others as well. 69 | port: Port SC2 should listen on for the websocket. 70 | connect: Whether to create a RemoteController to connect. 71 | timeout_seconds: Timeout for the remote controller. 72 | window_size: Screen size if not full screen. 73 | window_loc: Screen location if not full screen. 74 | **kwargs: Extra arguments for _launch (useful for subclasses). 75 | """ 76 | self._proc = None 77 | self._controller = None 78 | self._check_exists(exec_path) 79 | self._tmp_dir = tempfile.mkdtemp(prefix="sc-", dir=run_config.tmp_dir) 80 | self._host = host or "127.0.0.1" 81 | self._port = port or portpicker.pick_unused_port() 82 | self._version = version 83 | 84 | args = [ 85 | exec_path, 86 | "-listen", self._host, 87 | "-port", str(self._port), 88 | "-dataDir", os.path.join(run_config.data_dir, ""), 89 | "-tempDir", os.path.join(self._tmp_dir, ""), 90 | ] 91 | if ":" in self._host: 92 | args += ["-ipv6"] 93 | if full_screen: 94 | args += ["-displayMode", "1"] 95 | else: 96 | args += [ 97 | "-displayMode", "0", 98 | "-windowwidth", str(window_size[0]), 99 | "-windowheight", str(window_size[1]), 100 | "-windowx", str(window_loc[0]), 101 | "-windowy", str(window_loc[1]), 102 | ] 103 | 104 | if verbose or FLAGS.sc2_verbose: 105 | args += ["-verbose"] 106 | if self._version and self._version.data_version: 107 | args += ["-dataVersion", self._version.data_version.upper()] 108 | if extra_args: 109 | args += extra_args 110 | logging.info("Launching SC2: %s", " ".join(args)) 111 | try: 112 | with sw("startup"): 113 | self._proc = self._launch(run_config, args, **kwargs) 114 | if connect: 115 | self._controller = remote_controller.RemoteController( 116 | self._host, self._port, self, timeout_seconds=timeout_seconds) 117 | except: 118 | self.close() 119 | raise 120 | 121 | @sw.decorate 122 | def close(self): 123 | """Shut down the game and clean up.""" 124 | if hasattr(self, "_controller") and self._controller: 125 | self._controller.quit() 126 | self._controller.close() 127 | self._controller = None 128 | self._shutdown() 129 | if hasattr(self, "_port") and self._port: 130 | portpicker.return_port(self._port) 131 | self._port = None 132 | if hasattr(self, "_tmp_dir") and os.path.exists(self._tmp_dir): 133 | shutil.rmtree(self._tmp_dir) 134 | 135 | @property 136 | def controller(self): 137 | return self._controller 138 | 139 | @property 140 | def host(self): 141 | return self._host 142 | 143 | @property 144 | def port(self): 145 | return self._port 146 | 147 | @property 148 | def version(self): 149 | return self._version 150 | 151 | def __enter__(self): 152 | return self.controller 153 | 154 | def __exit__(self, unused_exception_type, unused_exc_value, unused_traceback): 155 | self.close() 156 | 157 | def __del__(self): 158 | # Prefer using a context manager, but this cleans most other cases. 159 | self.close() 160 | 161 | def _check_exists(self, exec_path): 162 | if not os.path.isfile(exec_path): 163 | raise RuntimeError("Trying to run '%s', but it doesn't exist" % exec_path) 164 | if not os.access(exec_path, os.X_OK): 165 | raise RuntimeError( 166 | "Trying to run '%s', but it isn't executable." % exec_path) 167 | 168 | def _launch(self, run_config, args, **kwargs): 169 | """Launch the process and return the process object.""" 170 | del kwargs 171 | try: 172 | with sw("popen"): 173 | return subprocess.Popen(args, cwd=run_config.cwd, env=run_config.env) 174 | except OSError: 175 | logging.exception("Failed to launch") 176 | raise SC2LaunchError("Failed to launch: %s" % args) 177 | 178 | def _shutdown(self): 179 | """Terminate the sub-process.""" 180 | if self._proc: 181 | ret = _shutdown_proc(self._proc, 3) 182 | logging.info("Shutdown with return code: %s", ret) 183 | self._proc = None 184 | 185 | @property 186 | def running(self): 187 | # poll returns None if it's running, otherwise the exit code. 188 | return self._proc and (self._proc.poll() is None) 189 | 190 | @property 191 | def pid(self): 192 | return self._proc.pid if self.running else None 193 | 194 | 195 | def _shutdown_proc(p, timeout): 196 | """Wait for a proc to shut down, then terminate or kill it after `timeout`.""" 197 | freq = 10 # how often to check per second 198 | for _ in range(1 + timeout * freq): 199 | ret = p.poll() 200 | if ret is not None: 201 | logging.info("Shutdown gracefully.") 202 | return ret 203 | time.sleep(1 / freq) 204 | logging.warning("Killing the process.") 205 | p.kill() 206 | return p.wait() 207 | -------------------------------------------------------------------------------- /pysc2/env/host_remote_agent.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS-IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Creates SC2 processes and games for remote agents to connect into.""" 15 | 16 | from __future__ import absolute_import 17 | from __future__ import division 18 | from __future__ import print_function 19 | 20 | from pysc2 import maps 21 | from pysc2 import run_configs 22 | from pysc2.lib import portspicker 23 | 24 | from s2clientprotocol import common_pb2 as sc_common 25 | from s2clientprotocol import sc2api_pb2 as sc_pb 26 | 27 | 28 | class VsAgent(object): 29 | """Host a remote agent vs remote agent game. 30 | 31 | Starts two SC2 processes, one for each of two remote agents to connect to. 32 | Call create_game, then have the agents connect to their respective port in 33 | host_ports, specifying lan_ports in the join game request. 34 | 35 | Agents should leave the game once it has finished, then another game can 36 | be created. Note that failure of either agent to leave prior to creating 37 | the next game will lead to SC2 crashing. 38 | 39 | Best used as a context manager for simple and timely resource release. 40 | 41 | **NOTE THAT** currently re-connecting to the same SC2 process is flaky. 42 | If you experience difficulties the workaround is to only create one game 43 | per instantiation of VsAgent. 44 | """ 45 | 46 | def __init__(self): 47 | self._num_agents = 2 48 | self._run_config = run_configs.get() 49 | self._processes = [] 50 | self._controllers = [] 51 | self._saved_maps = set() 52 | 53 | # Reserve LAN ports. 54 | self._lan_ports = portspicker.pick_unused_ports(self._num_agents * 2) 55 | 56 | # Start SC2 processes. 57 | for _ in range(self._num_agents): 58 | process = self._run_config.start(extra_ports=self._lan_ports) 59 | self._processes.append(process) 60 | self._controllers.append(process.controller) 61 | 62 | def __enter__(self): 63 | return self 64 | 65 | def __exit__(self, exception_type, exception_value, traceback): 66 | self.close() 67 | 68 | def __del__(self): 69 | self.close() 70 | 71 | def create_game(self, map_name): 72 | """Create a game for the agents to join. 73 | 74 | Args: 75 | map_name: The map to use. 76 | """ 77 | map_inst = maps.get(map_name) 78 | map_data = map_inst.data(self._run_config) 79 | if map_name not in self._saved_maps: 80 | for controller in self._controllers: 81 | controller.save_map(map_inst.path, map_data) 82 | self._saved_maps.add(map_name) 83 | 84 | # Form the create game message. 85 | create = sc_pb.RequestCreateGame( 86 | local_map=sc_pb.LocalMap(map_path=map_inst.path), 87 | disable_fog=False) 88 | 89 | # Set up for two agents. 90 | for _ in range(self._num_agents): 91 | create.player_setup.add(type=sc_pb.Participant) 92 | 93 | # Create the game. 94 | self._controllers[0].create_game(create) 95 | 96 | @property 97 | def hosts(self): 98 | """The hosts that the remote agents should connect to.""" 99 | return [process.host for process in self._processes] 100 | 101 | @property 102 | def host_ports(self): 103 | """The WebSocket ports that the remote agents should connect to.""" 104 | return [process.port for process in self._processes] 105 | 106 | @property 107 | def lan_ports(self): 108 | """The LAN ports which the remote agents should specify when joining.""" 109 | return self._lan_ports 110 | 111 | def close(self): 112 | """Shutdown and free all resources.""" 113 | for controller in self._controllers: 114 | controller.quit() 115 | self._controllers = [] 116 | 117 | for process in self._processes: 118 | process.close() 119 | self._processes = [] 120 | 121 | portspicker.return_ports(self._lan_ports) 122 | self._lan_ports = [] 123 | 124 | 125 | class VsBot(object): 126 | """Host a remote agent vs bot game. 127 | 128 | Starts a single SC2 process. Call create_game, then have the agent connect 129 | to host_port. 130 | 131 | The agent should leave the game once it has finished, then another game can 132 | be created. Note that failure of the agent to leave prior to creating 133 | the next game will lead to SC2 crashing. 134 | 135 | Best used as a context manager for simple and timely resource release. 136 | 137 | **NOTE THAT** currently re-connecting to the same SC2 process is flaky. 138 | If you experience difficulties the workaround is to only create one game 139 | per instantiation of VsBot. 140 | """ 141 | 142 | def __init__(self): 143 | # Start the SC2 process. 144 | self._run_config = run_configs.get() 145 | self._process = self._run_config.start() 146 | self._controller = self._process.controller 147 | self._saved_maps = set() 148 | 149 | def __enter__(self): 150 | return self 151 | 152 | def __exit__(self, exception_type, exception_value, traceback): 153 | self.close() 154 | 155 | def __del__(self): 156 | self.close() 157 | 158 | def create_game( 159 | self, 160 | map_name, 161 | bot_difficulty=sc_pb.VeryEasy, 162 | bot_race=sc_common.Random, 163 | bot_first=False): 164 | """Create a game, one remote agent vs the specified bot. 165 | 166 | Args: 167 | map_name: The map to use. 168 | bot_difficulty: The difficulty of the bot to play against. 169 | bot_race: The race for the bot. 170 | bot_first: Whether the bot should be player 1 (else is player 2). 171 | """ 172 | self._controller.ping() 173 | 174 | # Form the create game message. 175 | map_inst = maps.get(map_name) 176 | map_data = map_inst.data(self._run_config) 177 | if map_name not in self._saved_maps: 178 | self._controller.save_map(map_inst.path, map_data) 179 | self._saved_maps.add(map_name) 180 | 181 | create = sc_pb.RequestCreateGame( 182 | local_map=sc_pb.LocalMap(map_path=map_inst.path, map_data=map_data), 183 | disable_fog=False) 184 | 185 | # Set up for one bot, one agent. 186 | if not bot_first: 187 | create.player_setup.add(type=sc_pb.Participant) 188 | 189 | create.player_setup.add( 190 | type=sc_pb.Computer, race=bot_race, difficulty=bot_difficulty) 191 | 192 | if bot_first: 193 | create.player_setup.add(type=sc_pb.Participant) 194 | 195 | # Create the game. 196 | self._controller.create_game(create) 197 | 198 | @property 199 | def host(self): 200 | """The host that the remote agent should connect to.""" 201 | return self._process.host 202 | 203 | @property 204 | def host_port(self): 205 | """The WebSocket port that the remote agent should connect to.""" 206 | return self._process.port 207 | 208 | def close(self): 209 | """Shutdown and free all resources.""" 210 | if self._controller is not None: 211 | self._controller.quit() 212 | self._controller = None 213 | if self._process is not None: 214 | self._process.close() 215 | self._process = None 216 | -------------------------------------------------------------------------------- /pysc2/lib/units.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS-IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Define the static list of units for SC2. Generated by bin/gen_units.py.""" 15 | 16 | from __future__ import absolute_import 17 | from __future__ import division 18 | from __future__ import print_function 19 | 20 | import enum 21 | 22 | 23 | # pylint: disable=invalid-name 24 | class Neutral(enum.IntEnum): 25 | """Neutral units.""" 26 | BattleStationMineralField = 886 27 | BattleStationMineralField750 = 887 28 | CarrionBird = 322 29 | CollapsibleRockTowerDebris = 490 30 | CollapsibleRockTowerDebrisRampLeft = 518 31 | CollapsibleRockTowerDebrisRampRight = 517 32 | CollapsibleRockTowerDiagonal = 588 33 | CollapsibleRockTowerPushUnit = 561 34 | CollapsibleRockTowerPushUnitRampLeft = 564 35 | CollapsibleRockTowerPushUnitRampRight = 563 36 | CollapsibleRockTowerRampLeft = 664 37 | CollapsibleRockTowerRampRight = 663 38 | CollapsibleTerranTowerDebris = 485 39 | CollapsibleTerranTowerDiagonal = 589 40 | CollapsibleTerranTowerPushUnit = 562 41 | CollapsibleTerranTowerPushUnitRampLeft = 559 42 | CollapsibleTerranTowerPushUnitRampRight = 560 43 | CollapsibleTerranTowerRampLeft = 590 44 | CollapsibleTerranTowerRampRight = 591 45 | DebrisRampLeft = 486 46 | DebrisRampRight = 487 47 | DestructibleCityDebrisHugeDiagonalBLUR = 630 48 | DestructibleDebris6x6 = 365 49 | DestructibleDebrisRampDiagonalHugeBLUR = 377 50 | DestructibleDebrisRampDiagonalHugeULBR = 376 51 | DestructibleRampDiagonalHugeBLUR = 373 52 | DestructibleRock6x6 = 371 53 | DestructibleRockEx14x4 = 638 54 | DestructibleRockEx16x6 = 639 55 | DestructibleRockEx1DiagonalHugeBLUR = 641 56 | DestructibleRockEx1DiagonalHugeULBR = 640 57 | DestructibleRockEx1HorizontalHuge = 643 58 | KarakFemale = 324 59 | LabBot = 661 60 | LabMineralField = 665 61 | LabMineralField750 = 666 62 | MineralField = 341 63 | MineralField750 = 483 64 | ProtossVespeneGeyser = 608 65 | PurifierMineralField = 884 66 | PurifierMineralField750 = 885 67 | PurifierRichMineralField = 796 68 | PurifierRichMineralField750 = 797 69 | PurifierVespeneGeyser = 880 70 | RichMineralField = 146 71 | RichMineralField750 = 147 72 | RichVespeneGeyser = 344 73 | Scantipede = 335 74 | ShakurasVespeneGeyser = 881 75 | SpacePlatformGeyser = 343 76 | UnbuildableBricksDestructible = 473 77 | UnbuildablePlatesDestructible = 474 78 | UnbuildableRocksDestructible = 472 79 | UtilityBot = 330 80 | VespeneGeyser = 342 81 | XelNagaDestructibleBlocker8NE = 1904 82 | XelNagaDestructibleBlocker8SW = 1908 83 | XelNagaTower = 149 84 | 85 | 86 | class Protoss(enum.IntEnum): 87 | """Protoss units.""" 88 | Adept = 311 89 | AdeptPhaseShift = 801 90 | Archon = 141 91 | Assimilator = 61 92 | Carrier = 79 93 | Colossus = 4 94 | CyberneticsCore = 72 95 | DarkShrine = 69 96 | DarkTemplar = 76 97 | Disruptor = 694 98 | DisruptorPhased = 733 99 | FleetBeacon = 64 100 | ForceField = 135 101 | Forge = 63 102 | Gateway = 62 103 | HighTemplar = 75 104 | Immortal = 83 105 | Interceptor = 85 106 | Mothership = 10 107 | MothershipCore = 488 108 | Nexus = 59 109 | Observer = 82 110 | ObserverSurveillanceMode = 1911 111 | Oracle = 495 112 | Phoenix = 78 113 | PhotonCannon = 66 114 | Probe = 84 115 | Pylon = 60 116 | PylonOvercharged = 894 117 | RoboticsBay = 70 118 | RoboticsFacility = 71 119 | Sentry = 77 120 | ShieldBattery = 1910 121 | Stalker = 74 122 | Stargate = 67 123 | StasisTrap = 732 124 | Tempest = 496 125 | TemplarArchive = 68 126 | TwilightCouncil = 65 127 | VoidRay = 80 128 | WarpGate = 133 129 | WarpPrism = 81 130 | WarpPrismPhasing = 136 131 | Zealot = 73 132 | 133 | 134 | class Terran(enum.IntEnum): 135 | """Terran units.""" 136 | Armory = 29 137 | AutoTurret = 31 138 | Banshee = 55 139 | Barracks = 21 140 | BarracksFlying = 46 141 | BarracksReactor = 38 142 | BarracksTechLab = 37 143 | Battlecruiser = 57 144 | Bunker = 24 145 | CommandCenter = 18 146 | CommandCenterFlying = 36 147 | Cyclone = 692 148 | EngineeringBay = 22 149 | Factory = 27 150 | FactoryFlying = 43 151 | FactoryReactor = 40 152 | FactoryTechLab = 39 153 | FusionCore = 30 154 | Ghost = 50 155 | GhostAcademy = 26 156 | GhostAlternate = 144 157 | GhostNova = 145 158 | Hellion = 53 159 | Hellbat = 484 160 | KD8Charge = 830 161 | Liberator = 689 162 | LiberatorAG = 734 163 | MULE = 268 164 | Marauder = 51 165 | Marine = 48 166 | Medivac = 54 167 | MissileTurret = 23 168 | Nuke = 58 169 | OrbitalCommand = 132 170 | OrbitalCommandFlying = 134 171 | PlanetaryFortress = 130 172 | PointDefenseDrone = 11 173 | Raven = 56 174 | Reactor = 6 175 | Reaper = 49 176 | Refinery = 20 177 | RepairDrone = 1913 178 | SCV = 45 179 | SensorTower = 25 180 | SiegeTank = 33 181 | SiegeTankSieged = 32 182 | Starport = 28 183 | StarportFlying = 44 184 | StarportReactor = 42 185 | StarportTechLab = 41 186 | SupplyDepot = 19 187 | SupplyDepotLowered = 47 188 | TechLab = 5 189 | Thor = 52 190 | ThorHighImpactMode = 691 191 | VikingAssault = 34 192 | VikingFighter = 35 193 | WidowMine = 498 194 | WidowMineBurrowed = 500 195 | 196 | 197 | class Zerg(enum.IntEnum): 198 | """Zerg units.""" 199 | Baneling = 9 200 | BanelingBurrowed = 115 201 | BanelingCocoon = 8 202 | BanelingNest = 96 203 | BroodLord = 114 204 | BroodLordCocoon = 113 205 | Broodling = 289 206 | BroodlingEscort = 143 207 | Changeling = 12 208 | ChangelingMarine = 15 209 | ChangelingMarineShield = 14 210 | ChangelingZealot = 13 211 | ChangelingZergling = 17 212 | ChangelingZerglingWings = 16 213 | Corruptor = 112 214 | CreepTumor = 87 215 | CreepTumorBurrowed = 137 216 | CreepTumorQueen = 138 217 | Drone = 104 218 | DroneBurrowed = 116 219 | Cocoon = 103 220 | EvolutionChamber = 90 221 | Extractor = 88 222 | GreaterSpire = 102 223 | Hatchery = 86 224 | Hive = 101 225 | Hydralisk = 107 226 | HydraliskBurrowed = 117 227 | HydraliskDen = 91 228 | InfestationPit = 94 229 | InfestedTerran = 7 230 | InfestedTerranBurrowed = 120 231 | InfestedTerranCocoon = 150 232 | Infestor = 111 233 | InfestorBurrowed = 127 234 | Lair = 100 235 | Larva = 151 236 | Locust = 489 237 | LocustFlying = 693 238 | Lurker = 502 239 | LurkerBurrowed = 503 240 | LurkerDen = 504 241 | LurkerCocoon = 501 242 | Mutalisk = 108 243 | NydusCanal = 142 244 | NydusNetwork = 95 245 | Overlord = 106 246 | OverlordTransport = 893 247 | OverlordTransportCocoon = 892 248 | Overseer = 129 249 | OverseerCocoon = 128 250 | OverseerOversightMode = 1912 251 | ParasiticBombDummy = 824 252 | Queen = 126 253 | QueenBurrowed = 125 254 | Ravager = 688 255 | RavagerBurrowed = 690 256 | RavagerCocoon = 687 257 | Roach = 110 258 | RoachBurrowed = 118 259 | RoachWarren = 97 260 | SpawningPool = 89 261 | SpineCrawler = 98 262 | SpineCrawlerUprooted = 139 263 | Spire = 92 264 | SporeCrawler = 99 265 | SporeCrawlerUprooted = 140 266 | SwarmHost = 494 267 | SwarmHostBurrowed = 493 268 | Ultralisk = 109 269 | UltraliskBurrowed = 131 270 | UltraliskCavern = 93 271 | Viper = 499 272 | Zergling = 105 273 | ZerglingBurrowed = 119 274 | -------------------------------------------------------------------------------- /pysc2/lib/point_test.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # Copyright 2017 Google Inc. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS-IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Tests for the point library.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | from absl.testing import absltest 22 | from future.builtins import int # pylint: disable=redefined-builtin 23 | 24 | from pysc2.lib import point 25 | 26 | 27 | class FakePoint(object): 28 | 29 | def __init__(self): 30 | self.x = 5 31 | self.y = 8 32 | 33 | 34 | class PointTest(absltest.TestCase): 35 | 36 | def testBuild(self): 37 | self.assertEqual(point.Point(5, 8), point.Point.build(FakePoint())) 38 | 39 | def testAssignTo(self): 40 | f = FakePoint() 41 | self.assertEqual(5, f.x) 42 | self.assertEqual(8, f.y) 43 | point.Point(1, 2).assign_to(f) 44 | self.assertEqual(1, f.x) 45 | self.assertEqual(2, f.y) 46 | 47 | def testDist(self): 48 | a = point.Point(1, 1) 49 | b = point.Point(4, 5) 50 | self.assertEqual(5, a.dist(b)) 51 | 52 | def testDistSq(self): 53 | a = point.Point(1, 1) 54 | b = point.Point(4, 5) 55 | self.assertEqual(25, a.dist_sq(b)) 56 | 57 | def testLen(self): 58 | p = point.Point(3, 4) 59 | self.assertEqual(5, p.len()) 60 | 61 | def testScale(self): 62 | p = point.Point(3, 4) 63 | self.assertAlmostEqual(2, p.scale(2).len()) 64 | 65 | def testScaleMaxSize(self): 66 | p = point.Point(3, 4) 67 | self.assertEqual(p, p.scale_max_size(p)) 68 | self.assertEqual(point.Point(6, 8), p.scale_max_size(point.Point(8, 8))) 69 | self.assertEqual(point.Point(6, 8), p.scale_max_size(point.Point(100, 8))) 70 | self.assertEqual(point.Point(6, 8), p.scale_max_size(point.Point(6, 100))) 71 | 72 | def testScaleMinSize(self): 73 | p = point.Point(3, 4) 74 | self.assertEqual(p, p.scale_min_size(p)) 75 | self.assertEqual(point.Point(6, 8), p.scale_min_size(point.Point(6, 6))) 76 | self.assertEqual(point.Point(6, 8), p.scale_min_size(point.Point(2, 8))) 77 | self.assertEqual(point.Point(6, 8), p.scale_min_size(point.Point(6, 2))) 78 | 79 | def testMinDim(self): 80 | self.assertEqual(5, point.Point(5, 10).min_dim()) 81 | 82 | def testMaxDim(self): 83 | self.assertEqual(10, point.Point(5, 10).max_dim()) 84 | 85 | def testTranspose(self): 86 | self.assertEqual(point.Point(4, 3), point.Point(3, 4).transpose()) 87 | 88 | def testRound(self): 89 | p = point.Point(1.3, 2.6).round() 90 | self.assertEqual(point.Point(1, 3), p) 91 | self.assertIsInstance(p.x, int) 92 | self.assertIsInstance(p.y, int) 93 | 94 | def testCeil(self): 95 | p = point.Point(1.3, 2.6).ceil() 96 | self.assertEqual(point.Point(2, 3), p) 97 | self.assertIsInstance(p.x, int) 98 | self.assertIsInstance(p.y, int) 99 | 100 | def testFloor(self): 101 | p = point.Point(1.3, 2.6).floor() 102 | self.assertEqual(point.Point(1, 2), p) 103 | self.assertIsInstance(p.x, int) 104 | self.assertIsInstance(p.y, int) 105 | 106 | def testRotate(self): 107 | p = point.Point(0, 100) 108 | self.assertEqual(point.Point(-100, 0), p.rotate_deg(90).round()) 109 | self.assertEqual(point.Point(100, 0), p.rotate_deg(-90).round()) 110 | self.assertEqual(point.Point(0, -100), p.rotate_deg(180).round()) 111 | 112 | def testContainedCircle(self): 113 | self.assertTrue(point.Point(2, 2).contained_circle(point.Point(1, 1), 2)) 114 | self.assertFalse(point.Point(2, 2).contained_circle(point.Point(1, 1), 0.5)) 115 | 116 | def testBound(self): 117 | tl = point.Point(1, 2) 118 | br = point.Point(3, 4) 119 | self.assertEqual(tl, point.Point(0, 0).bound(tl, br)) 120 | self.assertEqual(br, point.Point(10, 10).bound(tl, br)) 121 | self.assertEqual(point.Point(1.5, 2), point.Point(1.5, 0).bound(tl, br)) 122 | 123 | 124 | class RectTest(absltest.TestCase): 125 | 126 | def testInit(self): 127 | r = point.Rect(1, 2, 3, 4) 128 | self.assertEqual(r.t, 1) 129 | self.assertEqual(r.l, 2) 130 | self.assertEqual(r.b, 3) 131 | self.assertEqual(r.r, 4) 132 | self.assertEqual(r.tl, point.Point(2, 1)) 133 | self.assertEqual(r.tr, point.Point(4, 1)) 134 | self.assertEqual(r.bl, point.Point(2, 3)) 135 | self.assertEqual(r.br, point.Point(4, 3)) 136 | 137 | def testInitBad(self): 138 | with self.assertRaises(TypeError): 139 | point.Rect(4, 3, 2, 1) # require t <= b, l <= r 140 | with self.assertRaises(TypeError): 141 | point.Rect(1) 142 | with self.assertRaises(TypeError): 143 | point.Rect(1, 2, 3) 144 | with self.assertRaises(TypeError): 145 | point.Rect() 146 | 147 | def testInitOnePoint(self): 148 | r = point.Rect(point.Point(1, 2)) 149 | self.assertEqual(r.t, 0) 150 | self.assertEqual(r.l, 0) 151 | self.assertEqual(r.b, 2) 152 | self.assertEqual(r.r, 1) 153 | self.assertEqual(r.tl, point.Point(0, 0)) 154 | self.assertEqual(r.tr, point.Point(1, 0)) 155 | self.assertEqual(r.bl, point.Point(0, 2)) 156 | self.assertEqual(r.br, point.Point(1, 2)) 157 | self.assertEqual(r.size, point.Point(1, 2)) 158 | self.assertEqual(r.center, point.Point(1, 2) / 2) 159 | self.assertEqual(r.area, 2) 160 | 161 | def testInitTwoPoints(self): 162 | r = point.Rect(point.Point(1, 2), point.Point(3, 4)) 163 | self.assertEqual(r.t, 2) 164 | self.assertEqual(r.l, 1) 165 | self.assertEqual(r.b, 4) 166 | self.assertEqual(r.r, 3) 167 | self.assertEqual(r.tl, point.Point(1, 2)) 168 | self.assertEqual(r.tr, point.Point(3, 2)) 169 | self.assertEqual(r.bl, point.Point(1, 4)) 170 | self.assertEqual(r.br, point.Point(3, 4)) 171 | self.assertEqual(r.size, point.Point(2, 2)) 172 | self.assertEqual(r.center, point.Point(2, 3)) 173 | self.assertEqual(r.area, 4) 174 | 175 | def testInitTwoPointsReversed(self): 176 | r = point.Rect(point.Point(3, 4), point.Point(1, 2)) 177 | self.assertEqual(r.t, 2) 178 | self.assertEqual(r.l, 1) 179 | self.assertEqual(r.b, 4) 180 | self.assertEqual(r.r, 3) 181 | self.assertEqual(r.tl, point.Point(1, 2)) 182 | self.assertEqual(r.tr, point.Point(3, 2)) 183 | self.assertEqual(r.bl, point.Point(1, 4)) 184 | self.assertEqual(r.br, point.Point(3, 4)) 185 | self.assertEqual(r.size, point.Point(2, 2)) 186 | self.assertEqual(r.center, point.Point(2, 3)) 187 | self.assertEqual(r.area, 4) 188 | 189 | def testArea(self): 190 | r = point.Rect(point.Point(1, 1), point.Point(3, 4)) 191 | self.assertEqual(r.area, 6) 192 | 193 | def testContains(self): 194 | r = point.Rect(point.Point(1, 1), point.Point(3, 3)) 195 | self.assertTrue(r.contains_point(point.Point(2, 2))) 196 | self.assertFalse(r.contains_circle(point.Point(2, 2), 5)) 197 | self.assertFalse(r.contains_point(point.Point(4, 4))) 198 | self.assertFalse(r.contains_circle(point.Point(4, 4), 5)) 199 | 200 | def testIntersectsCircle(self): 201 | r = point.Rect(point.Point(1, 1), point.Point(3, 3)) 202 | self.assertFalse(r.intersects_circle(point.Point(0, 0), 0.5)) 203 | self.assertFalse(r.intersects_circle(point.Point(0, 0), 1)) 204 | self.assertTrue(r.intersects_circle(point.Point(0, 0), 1.5)) 205 | self.assertTrue(r.intersects_circle(point.Point(0, 0), 2)) 206 | 207 | 208 | if __name__ == '__main__': 209 | absltest.main() 210 | -------------------------------------------------------------------------------- /pysc2/run_configs/platforms.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS-IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Configs for how to run SC2 from a normal install on various platforms.""" 15 | 16 | from __future__ import absolute_import 17 | from __future__ import division 18 | from __future__ import print_function 19 | 20 | import copy 21 | from absl import logging 22 | import os 23 | import platform 24 | import subprocess 25 | import sys 26 | 27 | from absl import flags 28 | import six 29 | 30 | from pysc2.lib import sc_process 31 | from pysc2.run_configs import lib 32 | 33 | 34 | flags.DEFINE_enum("sc2_version", None, sorted(lib.VERSIONS.keys()), 35 | "Which version of the game to use.") 36 | flags.DEFINE_bool("sc2_dev_build", False, 37 | "Use a dev build. Mostly useful for testing by Blizzard.") 38 | FLAGS = flags.FLAGS 39 | 40 | 41 | def _read_execute_info(path, parents): 42 | """Read the ExecuteInfo.txt file and return the base directory.""" 43 | path = os.path.join(path, "StarCraft II/ExecuteInfo.txt") 44 | if os.path.exists(path): 45 | with open(path, "rb") as f: # Binary because the game appends a '\0' :(. 46 | for line in f: 47 | parts = [p.strip() for p in line.decode("utf-8").split("=")] 48 | if len(parts) == 2 and parts[0] == "executable": 49 | exec_path = parts[1].replace("\\", "/") # For windows compatibility. 50 | for _ in range(parents): 51 | exec_path = os.path.dirname(exec_path) 52 | return exec_path 53 | 54 | 55 | class LocalBase(lib.RunConfig): 56 | """Base run config for public installs.""" 57 | 58 | def __init__(self, base_dir, exec_name, cwd=None, env=None): 59 | base_dir = os.path.expanduser(base_dir) 60 | cwd = cwd and os.path.join(base_dir, cwd) 61 | super(LocalBase, self).__init__( 62 | replay_dir=os.path.join(base_dir, "Replays"), 63 | data_dir=base_dir, tmp_dir=None, cwd=cwd, env=env) 64 | self._exec_name = exec_name 65 | 66 | def start(self, version=None, want_rgb=True, **kwargs): 67 | """Launch the game.""" 68 | del want_rgb # Unused 69 | if not os.path.isdir(self.data_dir): 70 | raise sc_process.SC2LaunchError( 71 | "Expected to find StarCraft II installed at '%s'. If it's not " 72 | "installed, do that and run it once so auto-detection works. If " 73 | "auto-detection failed repeatedly, then set the SC2PATH environment " 74 | "variable with the correct location." % self.data_dir) 75 | 76 | version = version or FLAGS.sc2_version 77 | if isinstance(version, lib.Version) and not version.data_version: 78 | # This is for old replays that don't have the embedded data_version. 79 | version = self._get_version(version.game_version) 80 | elif isinstance(version, six.string_types): 81 | version = self._get_version(version) 82 | elif not version: 83 | version = self._get_version("latest") 84 | if version.build_version < lib.VERSIONS["3.16.1"].build_version: 85 | raise sc_process.SC2LaunchError( 86 | "SC2 Binaries older than 3.16.1 don't support the api.") 87 | if FLAGS.sc2_dev_build: 88 | version = version._replace(build_version=0) 89 | exec_path = os.path.join( 90 | self.data_dir, "Versions/Base%05d" % version.build_version, 91 | self._exec_name) 92 | 93 | if not os.path.exists(exec_path): 94 | raise sc_process.SC2LaunchError("No SC2 binary found at: %s" % exec_path) 95 | 96 | return sc_process.StarcraftProcess( 97 | self, exec_path=exec_path, version=version, **kwargs) 98 | 99 | def get_versions(self): 100 | versions_dir = os.path.join(self.data_dir, "Versions") 101 | version_prefix = "Base" 102 | versions_found = sorted(int(v[len(version_prefix):]) 103 | for v in os.listdir(versions_dir) 104 | if v.startswith(version_prefix)) 105 | if not versions_found: 106 | raise sc_process.SC2LaunchError( 107 | "No SC2 Versions found in %s" % versions_dir) 108 | known_versions = [v for v in lib.VERSIONS.values() 109 | if v.build_version in versions_found] 110 | # Add one more with the max version. That one doesn't need a data version 111 | # since SC2 will find it in the .build.info file. This allows running 112 | # versions newer than what are known by pysc2, and so is the default. 113 | known_versions.append( 114 | lib.Version("latest", max(versions_found), None, None)) 115 | return lib.version_dict(known_versions) 116 | 117 | 118 | class Windows(LocalBase): 119 | """Run on Windows.""" 120 | 121 | def __init__(self): 122 | exec_path = (os.environ.get("SC2PATH") or 123 | _read_execute_info(os.path.expanduser("~/Documents"), 3) or 124 | "C:/Program Files (x86)/StarCraft II") 125 | super(Windows, self).__init__(exec_path, "SC2_x64.exe", "Support64") 126 | 127 | @classmethod 128 | def priority(cls): 129 | if platform.system() == "Windows": 130 | return 1 131 | 132 | 133 | class Cygwin(LocalBase): 134 | """Run on Cygwin. This runs the windows binary within a cygwin terminal.""" 135 | 136 | def __init__(self): 137 | super(Cygwin, self).__init__( 138 | os.environ.get("SC2PATH", 139 | "/cygdrive/c/Program Files (x86)/StarCraft II"), 140 | "SC2_x64.exe", "Support64") 141 | 142 | @classmethod 143 | def priority(cls): 144 | if sys.platform == "cygwin": 145 | return 1 146 | 147 | 148 | class MacOS(LocalBase): 149 | """Run on MacOS.""" 150 | 151 | def __init__(self): 152 | exec_path = (os.environ.get("SC2PATH") or 153 | _read_execute_info(os.path.expanduser( 154 | "~/Library/Application Support/Blizzard"), 6) or 155 | "/Applications/StarCraft II") 156 | super(MacOS, self).__init__(exec_path, "SC2.app/Contents/MacOS/SC2") 157 | 158 | @classmethod 159 | def priority(cls): 160 | if platform.system() == "Darwin": 161 | return 1 162 | 163 | 164 | class Linux(LocalBase): 165 | """Config to run on Linux.""" 166 | 167 | known_mesa = [ # In priority order 168 | "libOSMesa.so", 169 | "libOSMesa.so.8", # Ubuntu 16.04 170 | "libOSMesa.so.6", # Ubuntu 14.04 171 | ] 172 | 173 | def __init__(self): 174 | base_dir = os.environ.get("SC2PATH", "~/StarCraftII") 175 | base_dir = os.path.expanduser(base_dir) 176 | env = copy.deepcopy(os.environ) 177 | env["LD_LIBRARY_PATH"] = ":".join(filter(None, [ 178 | os.environ.get("LD_LIBRARY_PATH"), 179 | os.path.join(base_dir, "Libs/")])) 180 | super(Linux, self).__init__(base_dir, "SC2_x64", env=env) 181 | 182 | @classmethod 183 | def priority(cls): 184 | if platform.system() == "Linux": 185 | return 1 186 | 187 | def start(self, want_rgb=True, **kwargs): 188 | extra_args = kwargs.pop("extra_args", []) 189 | 190 | if want_rgb: 191 | # Figure out whether the various GL libraries exist since SC2 sometimes 192 | # fails if you ask to use a library that doesn't exist. 193 | libs = subprocess.check_output(["/sbin/ldconfig", "-p"]).decode() 194 | libs = {lib.strip().split()[0] for lib in libs.split("\n") if lib} 195 | if "libEGL.so" in libs: # Prefer hardware rendering. 196 | extra_args += ["-eglpath", "libEGL.so"] 197 | else: 198 | for mesa_lib in self.known_mesa: # Look for a software renderer. 199 | if mesa_lib in libs: 200 | extra_args += ["-osmesapath", mesa_lib] 201 | break 202 | else: 203 | logging.info( 204 | "No GL library found, so RGB rendering will be disabled. " 205 | "For software rendering install libosmesa.") 206 | 207 | return super(Linux, self).start( 208 | want_rgb=want_rgb, extra_args=extra_args, **kwargs) 209 | -------------------------------------------------------------------------------- /pysc2/run_configs/lib.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS-IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Configs for various ways to run starcraft.""" 15 | 16 | from __future__ import absolute_import 17 | from __future__ import division 18 | from __future__ import print_function 19 | 20 | import collections 21 | import datetime 22 | import os 23 | 24 | from pysc2.lib import gfile 25 | 26 | 27 | class Version(collections.namedtuple("Version", [ 28 | "game_version", "build_version", "data_version", "binary"])): 29 | """Represents a single version of the game.""" 30 | __slots__ = () 31 | 32 | 33 | def version_dict(versions): 34 | return {ver.game_version: ver for ver in versions} 35 | 36 | 37 | # https://github.com/Blizzard/s2client-proto/blob/master/buildinfo/versions.json 38 | # Generate with bin/gen_versions.py 39 | VERSIONS = version_dict([ 40 | Version("3.13.0", 52910, "8D9FEF2E1CF7C6C9CBE4FBCA830DDE1C", None), 41 | Version("3.14.0", 53644, "CA275C4D6E213ED30F80BACCDFEDB1F5", None), 42 | Version("3.15.0", 54518, "BBF619CCDCC80905350F34C2AF0AB4F6", None), 43 | Version("3.15.1", 54518, "6EB25E687F8637457538F4B005950A5E", None), 44 | Version("3.16.0", 55505, "60718A7CA50D0DF42987A30CF87BCB80", None), 45 | Version("3.16.1", 55958, "5BD7C31B44525DAB46E64C4602A81DC2", None), 46 | Version("3.17.0", 56787, "DFD1F6607F2CF19CB4E1C996B2563D9B", None), 47 | Version("3.17.1", 56787, "3F2FCED08798D83B873B5543BEFA6C4B", None), 48 | Version("3.17.2", 56787, "C690FC543082D35EA0AAA876B8362BEA", None), 49 | Version("3.18.0", 57507, "1659EF34997DA3470FF84A14431E3A86", None), 50 | Version("3.19.0", 58400, "2B06AEE58017A7DF2A3D452D733F1019", None), 51 | Version("3.19.1", 58400, "D9B568472880CC4719D1B698C0D86984", None), 52 | Version("4.0.0", 59587, "9B4FD995C61664831192B7DA46F8C1A1", None), 53 | Version("4.0.2", 59587, "B43D9EE00A363DAFAD46914E3E4AF362", None), 54 | Version("4.1.0", 60196, "1B8ACAB0C663D5510941A9871B3E9FBE", None), 55 | Version("4.1.1", 60321, "5C021D8A549F4A776EE9E9C1748FFBBC", None), 56 | Version("4.1.2", 60321, "33D9FE28909573253B7FC352CE7AEA40", None), 57 | Version("4.1.3", 60321, "F486693E00B2CD305B39E0AB254623EB", None), 58 | Version("4.1.4", 60321, "2E2A3F6E0BAFE5AC659C4D39F13A938C", None), 59 | Version("4.2.0", 62347, "C0C0E9D37FCDBC437CE386C6BE2D1F93", None), 60 | Version("4.2.1", 62848, "29BBAC5AFF364B6101B661DB468E3A37", None), 61 | Version("4.2.2", 63454, "3CB54C86777E78557C984AB1CF3494A0", None), 62 | Version("4.2.3", 63454, "5E3A8B21E41B987E05EE4917AAD68C69", None), 63 | Version("4.2.4", 63454, "7C51BC7B0841EACD3535E6FA6FF2116B", None), 64 | Version("4.3.0", 64469, "C92B3E9683D5A59E08FC011F4BE167FF", None), 65 | Version("4.3.1", 65094, "E5A21037AA7A25C03AC441515F4E0644", None), 66 | Version("4.3.2", 65384, "B6D73C85DFB70F5D01DEABB2517BF11C", None), 67 | Version("4.4.0", 65895, "BF41339C22AE2EDEBEEADC8C75028F7D", None), 68 | Version("4.4.1", 66668, "C094081D274A39219061182DBFD7840F", None), 69 | Version("4.5.0", 67188, "2ACF84A7ECBB536F51FC3F734EC3019F", None), 70 | Version("4.5.1", 67188, "6D239173B8712461E6A7C644A5539369", None), 71 | Version("4.6.0", 67926, "7DE59231CBF06F1ECE9A25A27964D4AE", None), 72 | Version("4.6.1", 67926, "BEA99B4A8E7B41E62ADC06D194801BAB", None), 73 | Version("4.6.2", 69232, "B3E14058F1083913B80C20993AC965DB", None), 74 | Version("4.7.0", 70154, "8E216E34BC61ABDE16A59A672ACB0F3B", None), 75 | Version("4.7.1", 70154, "94596A85191583AD2EBFAE28C5D532DB", None), 76 | Version("4.8.0", 71061, "760581629FC458A1937A05ED8388725B", None), 77 | Version("4.8.1", 71523, "FCAF3F050B7C0CC7ADCF551B61B9B91E", None), 78 | Version("4.8.2", 71663, "FE90C92716FC6F8F04B74268EC369FA5", None), 79 | ]) 80 | 81 | 82 | class RunConfig(object): 83 | """Base class for different run configs.""" 84 | 85 | def __init__(self, replay_dir, data_dir, tmp_dir, cwd=None, env=None): 86 | """Initialize the runconfig with the various directories needed. 87 | 88 | Args: 89 | replay_dir: Where to find replays. Might not be accessible to SC2. 90 | data_dir: Where SC2 should find the data and battle.net cache. 91 | tmp_dir: The temporary directory. None is system default. 92 | cwd: Where to set the current working directory. 93 | env: What to pass as the environment variables. 94 | """ 95 | self.replay_dir = replay_dir 96 | self.data_dir = data_dir 97 | self.tmp_dir = tmp_dir 98 | self.cwd = cwd 99 | self.env = env 100 | 101 | def map_data(self, map_name): 102 | """Return the map data for a map by name or path.""" 103 | with gfile.Open(os.path.join(self.data_dir, "Maps", map_name), "rb") as f: 104 | return f.read() 105 | 106 | def abs_replay_path(self, replay_path): 107 | """Return the absolute path to the replay, outside the sandbox.""" 108 | return os.path.join(self.replay_dir, replay_path) 109 | 110 | def replay_data(self, replay_path): 111 | """Return the replay data given a path to the replay.""" 112 | with gfile.Open(self.abs_replay_path(replay_path), "rb") as f: 113 | return f.read() 114 | 115 | def replay_paths(self, replay_dir): 116 | """A generator yielding the full path to the replays under `replay_dir`.""" 117 | replay_dir = self.abs_replay_path(replay_dir) 118 | if replay_dir.lower().endswith(".sc2replay"): 119 | yield replay_dir 120 | return 121 | for f in gfile.ListDir(replay_dir): 122 | if f.lower().endswith(".sc2replay"): 123 | yield os.path.join(replay_dir, f) 124 | 125 | def save_replay(self, replay_data, replay_dir, prefix=None): 126 | """Save a replay to a directory, returning the path to the replay. 127 | 128 | Args: 129 | replay_data: The result of controller.save_replay(), ie the binary data. 130 | replay_dir: Where to save the replay. This can be absolute or relative. 131 | prefix: Optional prefix for the replay filename. 132 | 133 | Returns: 134 | The full path where the replay is saved. 135 | 136 | Raises: 137 | ValueError: If the prefix contains the path seperator. 138 | """ 139 | if not prefix: 140 | replay_filename = "" 141 | elif os.path.sep in prefix: 142 | raise ValueError("Prefix '%s' contains '%s', use replay_dir instead." % ( 143 | prefix, os.path.sep)) 144 | else: 145 | replay_filename = prefix + "_" 146 | now = datetime.datetime.utcnow().replace(microsecond=0) 147 | replay_filename += "%s.SC2Replay" % now.isoformat("-").replace(":", "-") 148 | replay_dir = self.abs_replay_path(replay_dir) 149 | if not gfile.Exists(replay_dir): 150 | gfile.MakeDirs(replay_dir) 151 | replay_path = os.path.join(replay_dir, replay_filename) 152 | with gfile.Open(replay_path, "wb") as f: 153 | f.write(replay_data) 154 | return replay_path 155 | 156 | def start(self, version=None, **kwargs): 157 | """Launch the game. Find the version and run sc_process.StarcraftProcess.""" 158 | raise NotImplementedError() 159 | 160 | @classmethod 161 | def all_subclasses(cls): 162 | """An iterator over all subclasses of `cls`.""" 163 | for s in cls.__subclasses__(): 164 | yield s 165 | for c in s.all_subclasses(): 166 | yield c 167 | 168 | @classmethod 169 | def name(cls): 170 | return cls.__name__ 171 | 172 | @classmethod 173 | def priority(cls): 174 | """None means this isn't valid. Run the one with the max priority.""" 175 | return None 176 | 177 | def get_versions(self): 178 | """Return a dict of all versions that can be run.""" 179 | return VERSIONS 180 | 181 | def _get_version(self, game_version): 182 | versions = self.get_versions() 183 | if game_version.count(".") == 1: 184 | game_version += ".0" 185 | if game_version not in versions: 186 | raise ValueError("Unknown game version: %s. Known versions: %s" % ( 187 | game_version, sorted(versions.keys()))) 188 | return versions[game_version] 189 | 190 | -------------------------------------------------------------------------------- /docs/mini_games.md: -------------------------------------------------------------------------------- 1 | # DeepMind Mini Games 2 | 3 | ## MoveToBeacon 4 | 5 | #### Description 6 | 7 | A map with 1 Marine and 1 Beacon. Rewards are earned by moving the marine to the 8 | beacon. Whenever the Marine earns a reward for reaching the Beacon, the Beacon 9 | is teleported to a random location (at least 5 units away from Marine). 10 | 11 | #### Initial State 12 | 13 | * 1 Marine at random location (unselected) 14 | * 1 Beacon at random location (at least 4 units away from Marine) 15 | 16 | #### Rewards 17 | 18 | * Marine reaches Beacon: +1 19 | 20 | #### End Condition 21 | 22 | * Time elapsed 23 | 24 | #### Time Limit 25 | 26 | * 120 seconds 27 | 28 | #### Additional Notes 29 | 30 | * Fog of War disabled 31 | * No camera movement required (single-screen) 32 | 33 | ## CollectMineralShards 34 | 35 | #### Description 36 | 37 | A map with 2 Marines and an endless supply of Mineral Shards. Rewards are earned 38 | by moving the Marines to collect the Mineral Shards, with optimal collection 39 | requiring both Marine units to be split up and moved independently. Whenever all 40 | 20 Mineral Shards have been collected, a new set of 20 Mineral Shards are 41 | spawned at random locations (at least 2 units away from all Marines). 42 | 43 | #### Initial State 44 | 45 | * 2 Marines at random locations (unselected) 46 | * 20 Mineral Shards at random locations (at least 2 units away from all 47 | Marines) 48 | 49 | #### Rewards 50 | 51 | * Marine collects Mineral Shard: +1 52 | 53 | #### End Condition 54 | 55 | * Time elapsed 56 | 57 | #### Time Limit 58 | 59 | * 120 seconds 60 | 61 | #### Additional Notes 62 | 63 | * Fog of War disabled 64 | * No camera movement required (single-screen) 65 | * This is the only map in the set to require the Liberty (Campaign) mod, which 66 | is needed for the Mineral Shard unit. 67 | 68 | ## FindAndDefeatZerglings 69 | 70 | #### Description 71 | 72 | A map with 3 Marines and an endless supply of stationary Zerglings. Rewards are 73 | earned by using the Marines to defeat Zerglings, with the optimal strategy 74 | requiring a combination of efficient exploration and combat. Whenever all 25 75 | Zerglings have been defeated, a new set of 25 Zerglings are spawned at random 76 | locations (at least 9 units away from all Marines and at least 5 units away from 77 | all other Zerglings). 78 | 79 | #### Initial State 80 | 81 | * 3 Marines at map center (preselected) 82 | * 2 Zerglings spawned at random locations inside player's vision range 83 | (between 7.5 and 9.5 units away from map center and at least 5 units away 84 | from all other Zerglings) 85 | * 23 Zerglings spawned at random locations outside player's vision range (at 86 | least 10.5 units away from map center and at least 5 units away from all 87 | other Zerglings) 88 | 89 | #### Rewards 90 | 91 | * Zergling defeated: +1 92 | * Marine defeated: -1 93 | 94 | #### End Conditions 95 | 96 | * Time elapsed 97 | * All Marines defeated 98 | 99 | #### Time Limit 100 | 101 | * 180 seconds 102 | 103 | #### Additional Notes 104 | 105 | * Fog of War enabled 106 | * Camera movement required (map is larger than single-screen) 107 | 108 | ## DefeatRoaches 109 | 110 | #### Description 111 | 112 | A map with 9 Marines and a group of 4 Roaches on opposite sides. Rewards are 113 | earned by using the Marines to defeat Roaches, with optimal combat strategy 114 | requiring the Marines to perform focus fire on the Roaches. Whenever all 4 115 | Roaches have been defeated, a new group of 4 Roaches is spawned and the player 116 | is awarded 5 additional Marines at full health, with all other surviving Marines 117 | retaining their existing health (no restore). Whenever new units are spawned, 118 | all unit positions are reset to opposite sides of the map. 119 | 120 | #### Initial State 121 | 122 | * 9 Marines in a vertical line at a random side of the map (preselected) 123 | * 4 Roaches in a vertical line at the opposite side of the map from the 124 | Marines 125 | 126 | #### Rewards 127 | 128 | * Roach defeated: +10 129 | * Marine defeated: -1 130 | 131 | #### End Conditions 132 | 133 | * Time elapsed 134 | * All Marines defeated 135 | 136 | #### Time Limit 137 | 138 | * 120 seconds 139 | 140 | #### Additional Notes 141 | 142 | * Fog of War disabled 143 | * No camera movement required (single-screen) 144 | * This map and DefeatZerglingsAndBanelings are currently the only maps in the 145 | set that can include an automatic, mid-episode state change for 146 | player-controlled units. The Marine units are automatically moved back to a 147 | neutral position (at a random side of the map opposite the Roaches) when new 148 | units are spawned, which occurs whenever the current set of Roaches is 149 | defeated. This is done in order to guarantee that new units do not spawn 150 | within combat range of one another. 151 | 152 | ## DefeatZerglingsAndBanelings 153 | 154 | #### Description 155 | 156 | A map with 9 Marines on the opposite side from a group of 6 Zerglings and 4 157 | Banelings. Rewards are earned by using the Marines to defeat Zerglings and 158 | Banelings. Whenever all Zerglings and Banelings have been defeated, a new group 159 | of 6 Zerglings and 4 Banelings is spawned and the player is awarded 4 additional 160 | Marines at full health, with all other surviving Marines retaining their 161 | existing health (no restore). Whenever new units are spawned, all unit positions 162 | are reset to opposite sides of the map. 163 | 164 | #### Initial State 165 | 166 | * 9 Marines in a vertical line at a random side of the map (preselected) 167 | * 6 Zerglings and 4 Banelings in a group at the opposite side of the map from 168 | the Marines 169 | 170 | #### Rewards 171 | 172 | * Zergling defeated: +5 173 | * Baneling defeated: +5 174 | * Marine defeated: -1 175 | 176 | #### End Conditions 177 | 178 | * Time elapsed 179 | * All Marines defeated 180 | 181 | #### Time Limit 182 | 183 | * 120 seconds 184 | 185 | #### Additional Notes 186 | 187 | * Fog of War disabled 188 | * No camera movement required (single-screen) 189 | * This map and DefeatRoaches are currently the only maps in the set that can 190 | include an automatic, mid-episode state change for player-controlled units. 191 | The Marine units are automatically moved back to a neutral position (at a 192 | random side of the map opposite the Roaches) when new units are spawned, 193 | which occurs whenever the current set of Zerglings and Banelings is 194 | defeated. This is done in order to guarantee that new units do not spawn 195 | within combat range of one another. 196 | 197 | ## CollectMineralsAndGas 198 | 199 | #### Description 200 | 201 | A map with 12 SCVs, 1 Command Center, 16 Mineral Fields and 4 Vespene Geysers. 202 | Rewards are based on the total amount of Minerals and Vespene Gas collected. 203 | Spending Minerals and Vespene Gas to train new units does not decrease your 204 | reward tally. Optimal collection will require expanding your capacity to gather 205 | Minerals and Vespene Gas by constructing additional SCVs and an additional 206 | Command Center. 207 | 208 | #### Initial State 209 | 210 | * 12 SCVs beside the Command Center (unselected) 211 | * 1 Command Center at a fixed location 212 | * 16 Mineral Fields at fixed locations 213 | * 4 Vespene Geysers at fixed locations 214 | * Player Resources: 50 Minerals, 0 Vespene, 12/15 Supply 215 | 216 | #### Rewards 217 | 218 | Reward total is equal to the total amount of Minerals and Vespene Gas collected 219 | 220 | #### End Condition 221 | 222 | Time elapsed 223 | 224 | #### Time Limit 225 | 226 | 300 seconds 227 | 228 | ## BuildMarines 229 | 230 | #### Description 231 | 232 | A map with 12 SCVs, 1 Command Center, and 8 Mineral Fields. Rewards are earned 233 | by building Marines. This is accomplished by using SCVs to collect minerals, 234 | which are used to build Supply Depots and Barracks, which can then build 235 | Marines. 236 | 237 | #### Initial State 238 | 239 | * 12 SCVs beside the Command Center (unselected) 240 | * 1 Command Center at a fixed location 241 | * 8 Mineral Fields at fixed locations 242 | * Player Resources: 50 Minerals, 0 Vespene, 12/15 Supply 243 | 244 | #### Rewards 245 | 246 | Reward total is equal to the total number of Marines built 247 | 248 | #### End Condition 249 | 250 | Time elapsed 251 | 252 | #### Time Limit 253 | 254 | 900 seconds 255 | 256 | #### Additional Notes 257 | 258 | * Fog of War disabled 259 | * No camera movement required (single-screen) 260 | * This is the only map in the set that explicitly limits the available actions 261 | of the units to disallow actions which are not pertinent to the goal of the 262 | map. Actions that are not required for building Marines have been removed. 263 | --------------------------------------------------------------------------------